001/*
002 * Java Genetic Algorithm Library (jenetics-8.3.0).
003 * Copyright (c) 2007-2025 Franz Wilhelmstötter
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 *
017 * Author:
018 *    Franz Wilhelmstötter (franz.wilhelmstoetter@gmail.com)
019 */
020package io.jenetics.prog.regression;
021
022import static java.util.Objects.requireNonNull;
023
024import java.util.function.DoubleBinaryOperator;
025
026import io.jenetics.ext.util.Tree;
027
028import io.jenetics.prog.op.Op;
029
030/**
031 * This function calculates the <em>overall</em> error of a given program tree.
032 * The error is calculated from the {@link LossFunction} and, if desired, the
033 * program {@link Complexity}.
034 * {@snippet lang="java":
035 * final Error<Double> error = Error.of(LossFunction::mse, Complexity.ofNodeCount(50));
036 * }
037 *
038 * @see LossFunction
039 * @see Complexity
040 *
041 * @param <T> the sample type
042 *
043 * @author <a href="mailto:franz.wilhelmstoetter@gmail.com">Franz Wilhelmstötter</a>
044 * @version 5.0
045 * @since 5.0
046 */
047@FunctionalInterface
048public interface Error<T> {
049
050        /**
051         * Calculates the <em>overall</em> error of a given program tree. The error
052         * is calculated from the {@link LossFunction} and, if desired, the program
053         * {@link Complexity}.
054         *
055         * @param program the program tree which calculated the {@code calculated}
056         *        values
057         * @param calculated the calculated function values
058         * @param expected the expected function values
059         * @return the overall program error
060         * @throws NullPointerException if one of the arguments is {@code null}
061         */
062        double apply(
063                final Tree<? extends Op<T>, ?> program,
064                final T[] calculated,
065                final T[] expected
066        );
067
068        /**
069         * Creates an error function which only uses the given {@code loss} function
070         * for calculating the program error
071         *
072         * @param <T> the sample type
073         * @param loss the loss function to use for calculating the program error
074         * @return an error function which uses the loss function for error
075         *         calculation
076         * @throws NullPointerException if the given {@code loss} function is
077         *         {@code null}
078         */
079        static <T> Error<T> of(final LossFunction<T> loss) {
080                requireNonNull(loss);
081                return (p, c, e) -> loss.apply(c, e);
082        }
083
084        /**
085         * Creates an error function by combining the given {@code loss} function
086         * and program {@code complexity}. The loss function and program complexity
087         * is combined in the following way: {@code error = loss + loss*complexity}.
088         * The complexity function penalizes programs which grow to big.
089         *
090         * @param <T> the sample type
091         * @param loss the loss function
092         * @param complexity the program complexity measure
093         * @return a new error function by combining the given loss and complexity
094         *         function
095         * @throws NullPointerException if one of the functions is {@code null}
096         */
097        static <T> Error<T>
098        of(final LossFunction<T> loss, final Complexity<T> complexity) {
099                return of(loss, complexity, (lss, cpx) -> lss + lss*cpx);
100        }
101
102        /**
103         * Creates an error function by combining the given {@code loss} function
104         * and program {@code complexity}. The loss function and program complexity
105         * is combined in the following way: {@code error = loss + loss*complexity}.
106         * The complexity function penalizes programs which grow to big.
107         *
108         * @param <T> the sample type
109         * @param loss the loss function
110         * @param complexity the program complexity measure
111         * @param compose the function which composes the {@code loss} and
112         *        {@code complexity} function
113         * @return a new error function by combining the given loss and complexity
114         *         function
115         * @throws NullPointerException if one of the functions is {@code null}
116         */
117        static <T> Error<T> of(
118                final LossFunction<T> loss,
119                final Complexity<T> complexity,
120                final DoubleBinaryOperator compose
121        ) {
122                requireNonNull(loss);
123                requireNonNull(complexity);
124                requireNonNull(compose);
125
126                return (p, c, e) ->
127                        compose.applyAsDouble(loss.apply(c, e), complexity.apply(p));
128        }
129
130}