001/*
002 * Java Genetic Algorithm Library (jenetics-7.1.0).
003 * Copyright (c) 2007-2022 Franz Wilhelmstötter
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 *
017 * Author:
018 *    Franz Wilhelmstötter (franz.wilhelmstoetter@gmail.com)
019 */
020package io.jenetics.prog.regression;
021
022import static java.lang.Math.abs;
023import static java.lang.Math.sqrt;
024import static java.lang.String.format;
025
026// https://blog.algorithmia.com/introduction-to-loss-functions/
027// https://towardsdatascience.com/common-loss-functions-in-machine-learning-46af0ffc4d23
028
029/**
030 * This function evaluates how well an evolved program tree fits the given
031 * sample data set. If the predictions are totally off, the loss function will
032 * output a higher value. If they’re pretty good, it’ll output a lower number.
033 * It is the essential part of the <em>overall</em> {@link Error} function.
034 *
035 * <pre>{@code
036 * final Error<Double> error = Error.of(LossFunction::mse);
037 * }</pre>
038 *
039 * @see <a href="https://en.wikipedia.org/wiki/Loss_function">Loss function</a>
040 *
041 * @author <a href="mailto:franz.wilhelmstoetter@gmail.com">Franz Wilhelmstötter</a>
042 * @version 5.0
043 * @since 5.0
044 */
045@FunctionalInterface
046public interface LossFunction<T> {
047
048        /**
049         * Calculates the error between the expected function values and the
050         * values calculated by the actual {@link io.jenetics.prog.ProgramGene}.
051         *
052         * @param calculated the currently calculated function value
053         * @param expected the expected function values
054         * @return the error value
055         * @throws IllegalArgumentException if the length of the two arrays are not
056         *         equal
057         * @throws NullPointerException if one of the {@code double[]} arrays is
058         *         {@code null}
059         */
060        double apply(final T[] calculated, final T[] expected);
061
062        /**
063         * Mean square error is measured as the average of squared difference
064         * between predictions and actual observations.
065         *
066         * @see #rmse(Double[], Double[])
067         *
068         * @param calculated the function values calculated with the current program
069         *        tree
070         * @param expected the expected function value as given by the sample points
071         * @return the mean square error
072         * @throws IllegalArgumentException if the length of the two arrays are not
073         *         equal
074         * @throws NullPointerException if one of the {@code double[]} arrays is
075         *         {@code null}
076         */
077        static double mse(final Double[] calculated, final Double[] expected) {
078                if (expected.length != calculated.length) {
079                        throw new IllegalArgumentException(format(
080                                "Expected result and calculated results have different " +
081                                        "length: %d != %d",
082                                expected.length, calculated.length
083                        ));
084                }
085
086                double result = 0;
087                for (int i = 0; i < expected.length; ++i) {
088                        result += (expected[i] - calculated[i])*(expected[i] - calculated[i]);
089                }
090                if (expected.length > 0) {
091                        result = result/expected.length;
092                }
093
094                return result;
095        }
096
097        /**
098         * Root-mean-square error is measured as the average of squared difference
099         * between predictions and actual observations.
100         *
101         * @see #mse(Double[], Double[])
102         *
103         * @param calculated the function values calculated with the current program
104         *        tree
105         * @param expected the expected function value as given by the sample points
106         * @return the mean square error
107         * @throws IllegalArgumentException if the length of the two arrays are not
108         *         equal
109         * @throws NullPointerException if one of the {@code double[]} arrays is
110         *         {@code null}
111         */
112        static double rmse(final Double[] calculated, final Double[] expected) {
113                return sqrt(mse(calculated, expected));
114        }
115
116        /**
117         * Mean absolute error is measured as the average of sum of absolute
118         * differences between predictions and actual observations.
119         *
120         * @param calculated the function values calculated with the current program
121         *        tree
122         * @param expected the expected function value as given by the sample points
123         * @return the mean absolute error
124         * @throws IllegalArgumentException if the length of the two arrays are not
125         *         equal
126         * @throws NullPointerException if one of the {@code double[]} arrays is
127         *         {@code null}
128         */
129        static double mae(final Double[] calculated, final Double[] expected) {
130                if (expected.length != calculated.length) {
131                        throw new IllegalArgumentException(format(
132                                "Expected result and calculated results have different " +
133                                        "length: %d != %d",
134                                expected.length, calculated.length
135                        ));
136                }
137
138                double result = 0;
139                for (int i = 0; i < expected.length; ++i) {
140                        result += abs(expected[i] - calculated[i]);
141                }
142                if (expected.length > 0) {
143                        result = result/expected.length;
144                }
145
146                return result;
147        }
148
149}