001/* 002 * Java Genetic Algorithm Library (jenetics-8.3.0). 003 * Copyright (c) 2007-2025 Franz Wilhelmstötter 004 * 005 * Licensed under the Apache License, Version 2.0 (the "License"); 006 * you may not use this file except in compliance with the License. 007 * You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 * 017 * Author: 018 * Franz Wilhelmstötter (franz.wilhelmstoetter@gmail.com) 019 */ 020package io.jenetics.prog.regression; 021 022import static java.lang.Math.abs; 023import static java.lang.Math.sqrt; 024import static java.lang.String.format; 025 026// https://blog.algorithmia.com/introduction-to-loss-functions/ 027// https://towardsdatascience.com/common-loss-functions-in-machine-learning-46af0ffc4d23 028 029/** 030 * This function evaluates how well an evolved program tree fits the given 031 * sample data set. If the predictions are totally off, the loss function will 032 * output a higher value. If they're pretty good, it’ll output a lower number. 033 * It is the essential part of the <em>overall</em> {@link Error} function. 034 * {@snippet lang="java": 035 * final Error<Double> error = Error.of(LossFunction::mse); 036 * } 037 * 038 * @see <a href="https://en.wikipedia.org/wiki/Loss_function">Loss function</a> 039 * 040 * @author <a href="mailto:franz.wilhelmstoetter@gmail.com">Franz Wilhelmstötter</a> 041 * @version 5.0 042 * @since 5.0 043 */ 044@FunctionalInterface 045public interface LossFunction<T> { 046 047 /** 048 * Calculates the error between the expected function values and the 049 * values calculated by the actual {@link io.jenetics.prog.ProgramGene}. 050 * 051 * @param calculated the currently calculated function value 052 * @param expected the expected function values 053 * @return the error value 054 * @throws IllegalArgumentException if the length of the two arrays is not 055 * equal 056 * @throws NullPointerException if one of the {@code double[]} arrays is 057 * {@code null} 058 */ 059 double apply(final T[] calculated, final T[] expected); 060 061 /** 062 * Mean square error is measured as the average of squared difference 063 * between predictions and actual observations. 064 * 065 * @see #rmse(Double[], Double[]) 066 * 067 * @param calculated the function values calculated with the current program 068 * tree 069 * @param expected the expected function value as given by the sample points 070 * @return the mean square error 071 * @throws IllegalArgumentException if the length of the two arrays is not 072 * equal 073 * @throws NullPointerException if one of the {@code double[]} arrays is 074 * {@code null} 075 */ 076 static double mse(final Double[] calculated, final Double[] expected) { 077 if (expected.length != calculated.length) { 078 throw new IllegalArgumentException(format( 079 "Expected result and calculated results have different " + 080 "length: %d != %d", 081 expected.length, calculated.length 082 )); 083 } 084 085 double result = 0; 086 for (int i = 0; i < expected.length; ++i) { 087 result += (expected[i] - calculated[i])*(expected[i] - calculated[i]); 088 } 089 if (expected.length > 0) { 090 result = result/expected.length; 091 } 092 093 return result; 094 } 095 096 /** 097 * Root-mean-square error is measured as the average of squared difference 098 * between predictions and actual observations. 099 * 100 * @see #mse(Double[], Double[]) 101 * 102 * @param calculated the function values calculated with the current program 103 * tree 104 * @param expected the expected function value as given by the sample points 105 * @return the mean square error 106 * @throws IllegalArgumentException if the length of the two arrays is not 107 * equal 108 * @throws NullPointerException if one of the {@code double[]} arrays is 109 * {@code null} 110 */ 111 static double rmse(final Double[] calculated, final Double[] expected) { 112 return sqrt(mse(calculated, expected)); 113 } 114 115 /** 116 * Mean absolute error is measured as the average of sum of absolute 117 * differences between predictions and actual observations. 118 * 119 * @param calculated the function values calculated with the current program 120 * tree 121 * @param expected the expected function value as given by the sample points 122 * @return the mean absolute error 123 * @throws IllegalArgumentException if the length of the two arrays is not 124 * equal 125 * @throws NullPointerException if one of the {@code double[]} arrays is 126 * {@code null} 127 */ 128 static double mae(final Double[] calculated, final Double[] expected) { 129 if (expected.length != calculated.length) { 130 throw new IllegalArgumentException(format( 131 "Expected result and calculated results have different " + 132 "length: %d != %d", 133 expected.length, calculated.length 134 )); 135 } 136 137 double result = 0; 138 for (int i = 0; i < expected.length; ++i) { 139 result += abs(expected[i] - calculated[i]); 140 } 141 if (expected.length > 0) { 142 result = result/expected.length; 143 } 144 145 return result; 146 } 147 148}