Error.java
001 /*
002  * Java Genetic Algorithm Library (jenetics-6.1.0).
003  * Copyright (c) 2007-2020 Franz Wilhelmstötter
004  *
005  * Licensed under the Apache License, Version 2.0 (the "License");
006  * you may not use this file except in compliance with the License.
007  * You may obtain a copy of the License at
008  *
009  *      http://www.apache.org/licenses/LICENSE-2.0
010  *
011  * Unless required by applicable law or agreed to in writing, software
012  * distributed under the License is distributed on an "AS IS" BASIS,
013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014  * See the License for the specific language governing permissions and
015  * limitations under the License.
016  *
017  * Author:
018  *    Franz Wilhelmstötter (franz.wilhelmstoetter@gmail.com)
019  */
020 package io.jenetics.prog.regression;
021 
022 import static java.util.Objects.requireNonNull;
023 
024 import java.util.function.DoubleBinaryOperator;
025 
026 import io.jenetics.ext.util.Tree;
027 
028 import io.jenetics.prog.op.Op;
029 
030 /**
031  * This function calculates the <em>overall</em> error of a given program tree.
032  * The error is calculated from the {@link LossFunction} and, if desired, the
033  * program {@link Complexity}.
034  *
035  <pre>{@code
036  * final Error<Double> error = Error.of(LossFunction::mse, Complexity.ofNodeCount(50));
037  * }</pre>
038  *
039  @see LossFunction
040  @see Complexity
041  *
042  @param <T> the sample type
043  *
044  @author <a href="mailto:franz.wilhelmstoetter@gmail.com">Franz Wilhelmstötter</a>
045  @version 5.0
046  @since 5.0
047  */
048 @FunctionalInterface
049 public interface Error<T> {
050 
051     /**
052      * Calculates the <em>overall</em> error of a given program tree. The error
053      * is calculated from the {@link LossFunction} and, if desired, the program
054      {@link Complexity}.
055      *
056      @param program the program tree which calculated the {@code calculated}
057      *        values
058      @param calculated the calculated function values
059      @param expected the expected function values
060      @return the overall program error
061      @throws NullPointerException if one of the arguments is {@code null}
062      */
063     double apply(
064         final Tree<? extends Op<T>, ?> program,
065         final T[] calculated,
066         final T[] expected
067     );
068 
069 
070     /**
071      * Creates an error function which only uses the given {@code loss} function
072      * for calculating the program error
073      *
074      @param <T> the sample type
075      @param loss the loss function to use for calculating the program error
076      @return an error function which uses the loss function for error
077      *         calculation
078      @throws NullPointerException if the given {@code loss} function is
079      *         {@code null}
080      */
081     static <T> Error<T> of(final LossFunction<T> loss) {
082         requireNonNull(loss);
083         return (p, c, e-> loss.apply(c, e);
084     }
085 
086     /**
087      * Creates an error function by combining the given {@code loss} function
088      * and program {@code complexity}. The loss function and program complexity
089      * is combined in the following way: {@code error = loss + loss*complexity}.
090      * The complexity function penalizes programs which grows to big.
091      *
092      @param <T> the sample type
093      @param loss the loss function
094       @param complexity the program complexity measure
095      @return a new error function by combining the given loss and complexity
096      *         function
097      @throws NullPointerException if one of the functions is {@code null}
098      */
099     static <T> Error<T>
100     of(final LossFunction<T> loss, final Complexity<T> complexity) {
101         return of(loss, complexity, (lss, cpx-> lss + lss*cpx);
102     }
103 
104     /**
105      * Creates an error function by combining the given {@code loss} function
106      * and program {@code complexity}. The loss function and program complexity
107      * is combined in the following way: {@code error = loss + loss*complexity}.
108      * The complexity function penalizes programs which grows to big.
109      *
110      @param <T> the sample type
111      @param loss the loss function
112      @param complexity the program complexity measure
113      @param compose the function which composes the {@code loss} and
114      *        {@code complexity} function
115      @return a new error function by combining the given loss and complexity
116      *         function
117      @throws NullPointerException if one of the functions is {@code null}
118      */
119     static <T> Error<T> of(
120         final LossFunction<T> loss,
121         final Complexity<T> complexity,
122         final DoubleBinaryOperator compose
123     ) {
124         requireNonNull(loss);
125         requireNonNull(complexity);
126         requireNonNull(compose);
127 
128         return (p, c, e->
129             compose.applyAsDouble(loss.apply(c, e), complexity.apply(p));
130     }
131 
132 }