001/* 002 * Java Genetic Algorithm Library (jenetics-8.1.0). 003 * Copyright (c) 2007-2024 Franz Wilhelmstötter 004 * 005 * Licensed under the Apache License, Version 2.0 (the "License"); 006 * you may not use this file except in compliance with the License. 007 * You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 * 017 * Author: 018 * Franz Wilhelmstötter (franz.wilhelmstoetter@gmail.com) 019 */ 020package io.jenetics.prog.regression; 021 022import static java.lang.Math.abs; 023import static java.lang.Math.sqrt; 024import static java.lang.String.format; 025 026// https://blog.algorithmia.com/introduction-to-loss-functions/ 027// https://towardsdatascience.com/common-loss-functions-in-machine-learning-46af0ffc4d23 028 029/** 030 * This function evaluates how well an evolved program tree fits the given 031 * sample data set. If the predictions are totally off, the loss function will 032 * output a higher value. If they're pretty good, it’ll output a lower number. 033 * It is the essential part of the <em>overall</em> {@link Error} function. 034 * 035 * {@snippet lang="java": 036 * final Error<Double> error = Error.of(LossFunction::mse); 037 * } 038 * 039 * @see <a href="https://en.wikipedia.org/wiki/Loss_function">Loss function</a> 040 * 041 * @author <a href="mailto:franz.wilhelmstoetter@gmail.com">Franz Wilhelmstötter</a> 042 * @version 5.0 043 * @since 5.0 044 */ 045@FunctionalInterface 046public interface LossFunction<T> { 047 048 /** 049 * Calculates the error between the expected function values and the 050 * values calculated by the actual {@link io.jenetics.prog.ProgramGene}. 051 * 052 * @param calculated the currently calculated function value 053 * @param expected the expected function values 054 * @return the error value 055 * @throws IllegalArgumentException if the length of the two arrays is not 056 * equal 057 * @throws NullPointerException if one of the {@code double[]} arrays is 058 * {@code null} 059 */ 060 double apply(final T[] calculated, final T[] expected); 061 062 /** 063 * Mean square error is measured as the average of squared difference 064 * between predictions and actual observations. 065 * 066 * @see #rmse(Double[], Double[]) 067 * 068 * @param calculated the function values calculated with the current program 069 * tree 070 * @param expected the expected function value as given by the sample points 071 * @return the mean square error 072 * @throws IllegalArgumentException if the length of the two arrays is not 073 * equal 074 * @throws NullPointerException if one of the {@code double[]} arrays is 075 * {@code null} 076 */ 077 static double mse(final Double[] calculated, final Double[] expected) { 078 if (expected.length != calculated.length) { 079 throw new IllegalArgumentException(format( 080 "Expected result and calculated results have different " + 081 "length: %d != %d", 082 expected.length, calculated.length 083 )); 084 } 085 086 double result = 0; 087 for (int i = 0; i < expected.length; ++i) { 088 result += (expected[i] - calculated[i])*(expected[i] - calculated[i]); 089 } 090 if (expected.length > 0) { 091 result = result/expected.length; 092 } 093 094 return result; 095 } 096 097 /** 098 * Root-mean-square error is measured as the average of squared difference 099 * between predictions and actual observations. 100 * 101 * @see #mse(Double[], Double[]) 102 * 103 * @param calculated the function values calculated with the current program 104 * tree 105 * @param expected the expected function value as given by the sample points 106 * @return the mean square error 107 * @throws IllegalArgumentException if the length of the two arrays is not 108 * equal 109 * @throws NullPointerException if one of the {@code double[]} arrays is 110 * {@code null} 111 */ 112 static double rmse(final Double[] calculated, final Double[] expected) { 113 return sqrt(mse(calculated, expected)); 114 } 115 116 /** 117 * Mean absolute error is measured as the average of sum of absolute 118 * differences between predictions and actual observations. 119 * 120 * @param calculated the function values calculated with the current program 121 * tree 122 * @param expected the expected function value as given by the sample points 123 * @return the mean absolute error 124 * @throws IllegalArgumentException if the length of the two arrays is not 125 * equal 126 * @throws NullPointerException if one of the {@code double[]} arrays is 127 * {@code null} 128 */ 129 static double mae(final Double[] calculated, final Double[] expected) { 130 if (expected.length != calculated.length) { 131 throw new IllegalArgumentException(format( 132 "Expected result and calculated results have different " + 133 "length: %d != %d", 134 expected.length, calculated.length 135 )); 136 } 137 138 double result = 0; 139 for (int i = 0; i < expected.length; ++i) { 140 result += abs(expected[i] - calculated[i]); 141 } 142 if (expected.length > 0) { 143 result = result/expected.length; 144 } 145 146 return result; 147 } 148 149}