001 /*
002 * Java Genetic Algorithm Library (jenetics-6.1.0).
003 * Copyright (c) 2007-2020 Franz Wilhelmstötter
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 *
017 * Author:
018 * Franz Wilhelmstötter (franz.wilhelmstoetter@gmail.com)
019 */
020 package io.jenetics.prog.regression;
021
022 import static java.lang.Math.abs;
023 import static java.lang.Math.sqrt;
024 import static java.lang.String.format;
025
026 // https://blog.algorithmia.com/introduction-to-loss-functions/
027 // https://towardsdatascience.com/common-loss-functions-in-machine-learning-46af0ffc4d23
028
029 /**
030 * This function evaluates how well an evolved program tree fits the given
031 * sample data set. If the predictions are totally off, the loss function will
032 * output a higher value. If they’re pretty good, it’ll output a lower number.
033 * It is the essential part of the <em>overall</em> {@link Error} function.
034 *
035 * <pre>{@code
036 * final Error<Double> error = Error.of(LossFunction::mse);
037 * }</pre>
038 *
039 * @see <a href="https://en.wikipedia.org/wiki/Loss_function">Loss function</a>
040 *
041 * @author <a href="mailto:franz.wilhelmstoetter@gmail.com">Franz Wilhelmstötter</a>
042 * @version 5.0
043 * @since 5.0
044 */
045 @FunctionalInterface
046 public interface LossFunction<T> {
047
048 /**
049 * Calculates the error between the expected function values and the
050 * values calculated by the actual {@link io.jenetics.prog.ProgramGene}.
051 *
052 * @param calculated the currently calculated function value
053 * @param expected the expected function values
054 * @return the error value
055 * @throws IllegalArgumentException if the length of the two arrays are not
056 * equal
057 * @throws NullPointerException if one of the {@code double[]} arrays is
058 * {@code null}
059 */
060 double apply(final T[] calculated, final T[] expected);
061
062 /**
063 * Mean square error is measured as the average of squared difference
064 * between predictions and actual observations.
065 *
066 * @see #rmse(Double[], Double[])
067 *
068 * @param calculated the function values calculated with the current program
069 * tree
070 * @param expected the expected function value as given by the sample points
071 * @return the mean square error
072 * @throws IllegalArgumentException if the length of the two arrays are not
073 * equal
074 * @throws NullPointerException if one of the {@code double[]} arrays is
075 * {@code null}
076 */
077 static double mse(final Double[] calculated, final Double[] expected) {
078 if (expected.length != calculated.length) {
079 throw new IllegalArgumentException(format(
080 "Expected result and calculated results have different " +
081 "length: %d != %d",
082 expected.length, calculated.length
083 ));
084 }
085
086 double result = 0;
087 for (int i = 0; i < expected.length; ++i) {
088 result += (expected[i] - calculated[i])*(expected[i] - calculated[i]);
089 }
090 if (expected.length > 0) {
091 result = result/expected.length;
092 }
093
094 return result;
095 }
096
097 /**
098 * Root mean square error is measured as the average of squared difference
099 * between predictions and actual observations.
100 *
101 * @see #mse(Double[], Double[])
102 *
103 * @param calculated the function values calculated with the current program
104 * tree
105 * @param expected the expected function value as given by the sample points
106 * @return the mean square error
107 * @throws IllegalArgumentException if the length of the two arrays are not
108 * equal
109 * @throws NullPointerException if one of the {@code double[]} arrays is
110 * {@code null}
111 */
112 static double rmse(final Double[] calculated, final Double[] expected) {
113 return sqrt(mse(calculated, expected));
114 }
115
116 /**
117 * Mean absolute error is measured as the average of sum of absolute
118 * differences between predictions and actual observations.
119 *
120 * @param calculated the function values calculated with the current program
121 * tree
122 * @param expected the expected function value as given by the sample points
123 * @return the mean absolute error
124 * @throws IllegalArgumentException if the length of the two arrays are not
125 * equal
126 * @throws NullPointerException if one of the {@code double[]} arrays is
127 * {@code null}
128 */
129 static double mae(final Double[] calculated, final Double[] expected) {
130 if (expected.length != calculated.length) {
131 throw new IllegalArgumentException(format(
132 "Expected result and calculated results have different " +
133 "length: %d != %d",
134 expected.length, calculated.length
135 ));
136 }
137
138 double result = 0;
139 for (int i = 0; i < expected.length; ++i) {
140 result += abs(expected[i] - calculated[i]);
141 }
142 if (expected.length > 0) {
143 result = result/expected.length;
144 }
145
146 return result;
147 }
148
149 }
|