LossFunction.java
001 /*
002  * Java Genetic Algorithm Library (jenetics-6.1.0).
003  * Copyright (c) 2007-2020 Franz Wilhelmstötter
004  *
005  * Licensed under the Apache License, Version 2.0 (the "License");
006  * you may not use this file except in compliance with the License.
007  * You may obtain a copy of the License at
008  *
009  *      http://www.apache.org/licenses/LICENSE-2.0
010  *
011  * Unless required by applicable law or agreed to in writing, software
012  * distributed under the License is distributed on an "AS IS" BASIS,
013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014  * See the License for the specific language governing permissions and
015  * limitations under the License.
016  *
017  * Author:
018  *    Franz Wilhelmstötter (franz.wilhelmstoetter@gmail.com)
019  */
020 package io.jenetics.prog.regression;
021 
022 import static java.lang.Math.abs;
023 import static java.lang.Math.sqrt;
024 import static java.lang.String.format;
025 
026 // https://blog.algorithmia.com/introduction-to-loss-functions/
027 // https://towardsdatascience.com/common-loss-functions-in-machine-learning-46af0ffc4d23
028 
029 /**
030  * This function evaluates how well an evolved program tree fits the given
031  * sample data set. If the predictions are totally off, the loss function will
032  * output a higher value. If they’re pretty good, it’ll output a lower number.
033  * It is the essential part of the <em>overall</em> {@link Error} function.
034  *
035  <pre>{@code
036  * final Error<Double> error = Error.of(LossFunction::mse);
037  * }</pre>
038  *
039  @see <a href="https://en.wikipedia.org/wiki/Loss_function">Loss function</a>
040  *
041  @author <a href="mailto:franz.wilhelmstoetter@gmail.com">Franz Wilhelmstötter</a>
042  @version 5.0
043  @since 5.0
044  */
045 @FunctionalInterface
046 public interface LossFunction<T> {
047 
048     /**
049      * Calculates the error between the expected function values and the
050      * values calculated by the actual {@link io.jenetics.prog.ProgramGene}.
051      *
052      @param calculated the currently calculated function value
053      @param expected the expected function values
054      @return the error value
055      @throws IllegalArgumentException if the length of the two arrays are not
056      *         equal
057      @throws NullPointerException if one of the {@code double[]} arrays is
058      *         {@code null}
059      */
060     double apply(final T[] calculated, final T[] expected);
061 
062     /**
063      * Mean square error is measured as the average of squared difference
064      * between predictions and actual observations.
065      *
066      @see #rmse(Double[], Double[])
067      *
068      @param calculated the function values calculated with the current program
069      *        tree
070      @param expected the expected function value as given by the sample points
071      @return the mean square error
072      @throws IllegalArgumentException if the length of the two arrays are not
073      *         equal
074      @throws NullPointerException if one of the {@code double[]} arrays is
075      *         {@code null}
076      */
077     static double mse(final Double[] calculated, final Double[] expected) {
078         if (expected.length != calculated.length) {
079             throw new IllegalArgumentException(format(
080                 "Expected result and calculated results have different " +
081                     "length: %d != %d",
082                 expected.length, calculated.length
083             ));
084         }
085 
086         double result = 0;
087         for (int i = 0; i < expected.length; ++i) {
088             result += (expected[i- calculated[i])*(expected[i- calculated[i]);
089         }
090         if (expected.length > 0) {
091             result = result/expected.length;
092         }
093 
094         return result;
095     }
096 
097     /**
098      * Root mean square error is measured as the average of squared difference
099      * between predictions and actual observations.
100      *
101      @see #mse(Double[], Double[])
102      *
103      @param calculated the function values calculated with the current program
104      *        tree
105      @param expected the expected function value as given by the sample points
106      @return the mean square error
107      @throws IllegalArgumentException if the length of the two arrays are not
108      *         equal
109      @throws NullPointerException if one of the {@code double[]} arrays is
110      *         {@code null}
111      */
112     static double rmse(final Double[] calculated, final Double[] expected) {
113         return sqrt(mse(calculated, expected));
114     }
115 
116     /**
117      * Mean absolute error is measured as the average of sum of absolute
118      * differences between predictions and actual observations.
119      *
120      @param calculated the function values calculated with the current program
121      *        tree
122      @param expected the expected function value as given by the sample points
123      @return the mean absolute error
124      @throws IllegalArgumentException if the length of the two arrays are not
125      *         equal
126      @throws NullPointerException if one of the {@code double[]} arrays is
127      *         {@code null}
128      */
129     static double mae(final Double[] calculated, final Double[] expected) {
130         if (expected.length != calculated.length) {
131             throw new IllegalArgumentException(format(
132                 "Expected result and calculated results have different " +
133                     "length: %d != %d",
134                 expected.length, calculated.length
135             ));
136         }
137 
138         double result = 0;
139         for (int i = 0; i < expected.length; ++i) {
140             result += abs(expected[i- calculated[i]);
141         }
142         if (expected.length > 0) {
143             result = result/expected.length;
144         }
145 
146         return result;
147     }
148 
149 }