001/*
002 * Java Genetic Algorithm Library (jenetics-8.3.0).
003 * Copyright (c) 2007-2025 Franz Wilhelmstötter
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 *
017 * Author:
018 *    Franz Wilhelmstötter (franz.wilhelmstoetter@gmail.com)
019 */
020package io.jenetics.prog.regression;
021
022import static java.lang.Math.abs;
023import static java.lang.Math.sqrt;
024import static java.lang.String.format;
025
026// https://blog.algorithmia.com/introduction-to-loss-functions/
027// https://towardsdatascience.com/common-loss-functions-in-machine-learning-46af0ffc4d23
028
029/**
030 * This function evaluates how well an evolved program tree fits the given
031 * sample data set. If the predictions are totally off, the loss function will
032 * output a higher value. If they're pretty good, it’ll output a lower number.
033 * It is the essential part of the <em>overall</em> {@link Error} function.
034 * {@snippet lang="java":
035 * final Error<Double> error = Error.of(LossFunction::mse);
036 * }
037 *
038 * @see <a href="https://en.wikipedia.org/wiki/Loss_function">Loss function</a>
039 *
040 * @author <a href="mailto:franz.wilhelmstoetter@gmail.com">Franz Wilhelmstötter</a>
041 * @version 5.0
042 * @since 5.0
043 */
044@FunctionalInterface
045public interface LossFunction<T> {
046
047        /**
048         * Calculates the error between the expected function values and the
049         * values calculated by the actual {@link io.jenetics.prog.ProgramGene}.
050         *
051         * @param calculated the currently calculated function value
052         * @param expected the expected function values
053         * @return the error value
054         * @throws IllegalArgumentException if the length of the two arrays is not
055         *         equal
056         * @throws NullPointerException if one of the {@code double[]} arrays is
057         *         {@code null}
058         */
059        double apply(final T[] calculated, final T[] expected);
060
061        /**
062         * Mean square error is measured as the average of squared difference
063         * between predictions and actual observations.
064         *
065         * @see #rmse(Double[], Double[])
066         *
067         * @param calculated the function values calculated with the current program
068         *        tree
069         * @param expected the expected function value as given by the sample points
070         * @return the mean square error
071         * @throws IllegalArgumentException if the length of the two arrays is not
072         *         equal
073         * @throws NullPointerException if one of the {@code double[]} arrays is
074         *         {@code null}
075         */
076        static double mse(final Double[] calculated, final Double[] expected) {
077                if (expected.length != calculated.length) {
078                        throw new IllegalArgumentException(format(
079                                "Expected result and calculated results have different " +
080                                        "length: %d != %d",
081                                expected.length, calculated.length
082                        ));
083                }
084
085                double result = 0;
086                for (int i = 0; i < expected.length; ++i) {
087                        result += (expected[i] - calculated[i])*(expected[i] - calculated[i]);
088                }
089                if (expected.length > 0) {
090                        result = result/expected.length;
091                }
092
093                return result;
094        }
095
096        /**
097         * Root-mean-square error is measured as the average of squared difference
098         * between predictions and actual observations.
099         *
100         * @see #mse(Double[], Double[])
101         *
102         * @param calculated the function values calculated with the current program
103         *        tree
104         * @param expected the expected function value as given by the sample points
105         * @return the mean square error
106         * @throws IllegalArgumentException if the length of the two arrays is not
107         *         equal
108         * @throws NullPointerException if one of the {@code double[]} arrays is
109         *         {@code null}
110         */
111        static double rmse(final Double[] calculated, final Double[] expected) {
112                return sqrt(mse(calculated, expected));
113        }
114
115        /**
116         * Mean absolute error is measured as the average of sum of absolute
117         * differences between predictions and actual observations.
118         *
119         * @param calculated the function values calculated with the current program
120         *        tree
121         * @param expected the expected function value as given by the sample points
122         * @return the mean absolute error
123         * @throws IllegalArgumentException if the length of the two arrays is not
124         *         equal
125         * @throws NullPointerException if one of the {@code double[]} arrays is
126         *         {@code null}
127         */
128        static double mae(final Double[] calculated, final Double[] expected) {
129                if (expected.length != calculated.length) {
130                        throw new IllegalArgumentException(format(
131                                "Expected result and calculated results have different " +
132                                        "length: %d != %d",
133                                expected.length, calculated.length
134                        ));
135                }
136
137                double result = 0;
138                for (int i = 0; i < expected.length; ++i) {
139                        result += abs(expected[i] - calculated[i]);
140                }
141                if (expected.length > 0) {
142                        result = result/expected.length;
143                }
144
145                return result;
146        }
147
148}