001 /*
002 * Java Genetic Algorithm Library (jenetics-5.1.0).
003 * Copyright (c) 2007-2019 Franz Wilhelmstötter
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 *
017 * Author:
018 * Franz Wilhelmstötter (franz.wilhelmstoetter@gmail.com)
019 */
020 package io.jenetics.prog.regression;
021
022 import static java.util.Objects.requireNonNull;
023
024 import java.util.function.DoubleBinaryOperator;
025
026 import io.jenetics.ext.util.Tree;
027
028 import io.jenetics.prog.op.Op;
029
030 /**
031 * This function calculates the <em>overall</em> error of a given program tree.
032 * The error is calculated from the {@link LossFunction} and, if desired, the
033 * program {@link Complexity}.
034 *
035 * <pre>{@code
036 * final Error<Double> error = Error.of(LossFunction::mse, Complexity.ofNodeCount(50));
037 * }</pre>
038 *
039 * @see LossFunction
040 * @see Complexity
041 *
042 * @param <T> the sample type
043 *
044 * @author <a href="mailto:franz.wilhelmstoetter@gmail.com">Franz Wilhelmstötter</a>
045 * @version 5.0
046 * @since 5.0
047 */
048 @FunctionalInterface
049 public interface Error<T> {
050
051 /**
052 * Calculates the <em>overall</em> error of a given program tree. The error
053 * is calculated from the {@link LossFunction} and, if desired, the program
054 * {@link Complexity}.
055 *
056 * @param program the program tree which calculated the {@code calculated}
057 * values
058 * @param calculated the calculated function values
059 * @param expected the expected function values
060 * @return the overall program error
061 * @throws NullPointerException if one of the arguments is {@code null}
062 */
063 public double apply(
064 final Tree<? extends Op<T>, ?> program,
065 final T[] calculated,
066 final T[] expected
067 );
068
069
070 /**
071 * Creates an error function which only uses the given {@code loss} function
072 * for calculating the program error
073 *
074 * @param <T> the sample type
075 * @param loss the loss function to use for calculating the program error
076 * @return an error function which uses the loss function for error
077 * calculation
078 * @throws NullPointerException if the given {@code loss} function is
079 * {@code null}
080 */
081 public static <T> Error<T> of(final LossFunction<T> loss) {
082 requireNonNull(loss);
083
084 return (p, c, e) -> loss.apply(c, e);
085 }
086
087 /**
088 * Creates an error function by combining the given {@code loss} function
089 * and program {@code complexity}. The loss function and program complexity
090 * is combined in the following way: {@code error = loss + loss*complexity}.
091 * The complexity function penalizes programs which grows to big.
092 *
093 * @param <T> the sample type
094 * @param loss the loss function
095 * @param complexity the program complexity measure
096 * @return a new error function by combining the given loss and complexity
097 * function
098 * @throws NullPointerException if one of the functions is {@code null}
099 */
100 public static <T> Error<T>
101 of(final LossFunction<T> loss, final Complexity<T> complexity) {
102 return of(loss, complexity, (lss, cpx) -> lss + lss*cpx);
103 }
104
105 /**
106 * Creates an error function by combining the given {@code loss} function
107 * and program {@code complexity}. The loss function and program complexity
108 * is combined in the following way: {@code error = loss + loss*complexity}.
109 * The complexity function penalizes programs which grows to big.
110 *
111 * @param <T> the sample type
112 * @param loss the loss function
113 * @param complexity the program complexity measure
114 * @param compose the function which composes the {@code loss} and
115 * {@code complexity} function
116 * @return a new error function by combining the given loss and complexity
117 * function
118 * @throws NullPointerException if one of the functions is {@code null}
119 */
120 public static <T> Error<T> of(
121 final LossFunction<T> loss,
122 final Complexity<T> complexity,
123 final DoubleBinaryOperator compose
124 ) {
125 requireNonNull(loss);
126 requireNonNull(complexity);
127 requireNonNull(compose);
128
129 return (p, c, e) ->
130 compose.applyAsDouble(loss.apply(c, e), complexity.apply(p));
131 }
132
133 }
|