001/*
002 * Java Genetic Algorithm Library (jenetics-8.1.0).
003 * Copyright (c) 2007-2024 Franz Wilhelmstötter
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 *
017 * Author:
018 *    Franz Wilhelmstötter (franz.wilhelmstoetter@gmail.com)
019 */
020package  io.jenetics.prog.op;
021
022import static java.nio.charset.StandardCharsets.UTF_8;
023import static java.util.Comparator.comparing;
024import static java.util.Objects.requireNonNull;
025import static java.util.stream.Collectors.toCollection;
026import static io.jenetics.ext.internal.util.FormulaParser.TokenType.FUNCTION;
027import static io.jenetics.ext.internal.util.FormulaParser.TokenType.UNARY_OPERATOR;
028import static io.jenetics.internal.util.SerialIO.readInt;
029import static io.jenetics.internal.util.SerialIO.writeInt;
030import static io.jenetics.prog.op.MathTokenType.COMMA;
031import static io.jenetics.prog.op.MathTokenType.DIV;
032import static io.jenetics.prog.op.MathTokenType.IDENTIFIER;
033import static io.jenetics.prog.op.MathTokenType.LPAREN;
034import static io.jenetics.prog.op.MathTokenType.MINUS;
035import static io.jenetics.prog.op.MathTokenType.MOD;
036import static io.jenetics.prog.op.MathTokenType.NUMBER;
037import static io.jenetics.prog.op.MathTokenType.PLUS;
038import static io.jenetics.prog.op.MathTokenType.POW;
039import static io.jenetics.prog.op.MathTokenType.RPAREN;
040import static io.jenetics.prog.op.MathTokenType.TIMES;
041
042import java.io.DataInput;
043import java.io.DataOutput;
044import java.io.IOException;
045import java.io.InvalidObjectException;
046import java.io.ObjectInputStream;
047import java.io.Serial;
048import java.io.Serializable;
049import java.util.TreeSet;
050import java.util.function.Function;
051import java.util.function.Supplier;
052
053import io.jenetics.internal.util.Lazy;
054import io.jenetics.util.ISeq;
055
056import io.jenetics.ext.internal.parser.ParsingException;
057import io.jenetics.ext.internal.parser.Token;
058import io.jenetics.ext.internal.util.FormulaParser;
059import io.jenetics.ext.internal.util.FormulaParser.TokenType;
060import io.jenetics.ext.rewriting.TreeRewriteRule;
061import io.jenetics.ext.rewriting.TreeRewriter;
062import io.jenetics.ext.util.FlatTreeNode;
063import io.jenetics.ext.util.Tree;
064import io.jenetics.ext.util.TreeNode;
065
066/**
067 * Contains methods for parsing mathematical expression.
068 *
069 * @author <a href="mailto:franz.wilhelmstoetter@gmail.com">Franz Wilhelmstötter</a>
070 * @since 4.1
071 * @version 7.1
072 */
073public final class MathExpr
074        implements Function<Double[], Double>, Serializable
075{
076
077        @Serial
078        private static final long serialVersionUID = 1L;
079
080        private static final FormulaParser<Token<String>> FORMULA_PARSER =
081                FormulaParser.<Token<String>>builder()
082                        .lparen(t -> t.type() == LPAREN)
083                        .rparen(t -> t.type() == RPAREN)
084                        .separator(t -> t.type() == COMMA)
085                        .unaryOperators(t -> t.type() == PLUS || t.type() == MINUS)
086                        .binaryOperators(ops -> ops
087                                .add(11, t -> t.type() == PLUS || t.type() == MINUS)
088                                .add(12, t -> t.type() == TIMES || t.type() == DIV || t.type() == MOD)
089                                .add(13, t -> t.type() == POW)
090                        )
091                        .identifiers(t -> t.type() == IDENTIFIER || t.type() == NUMBER)
092                        .functions(t -> MathOp.NAMES.contains(t.value()))
093                        .build();
094
095        /**
096         * This tree-rewriter rewrites constant expressions to its single value.
097         *
098         * {@snippet lang="java":
099         * final TreeNode<Op<Double>> tree = MathExpr.parseTree("1 + 2*(6 + 7)");
100         * MathExpr.CONST_REWRITER.rewrite(tree);
101         * assertEquals(tree.getValue(), Const.of(27.0));
102         * }
103         *
104         * @since 5.0
105         */
106        public static final TreeRewriter<Op<Double>> CONST_REWRITER =
107                ConstRewriter.DOUBLE;
108
109        /**
110         * This rewriter implements some common arithmetic identities, in exactly
111         * this order.
112         * <pre> {@code
113         *     sub($x,$x) ->  0
114         *     sub($x,0)  ->  $x
115         *     add($x,0)  ->  $x
116         *     add(0,$x)  ->  $x
117         *     add($x,$x) ->  mul(2,$x)
118         *     div($x,$x) ->  1
119         *     div(0,$x)  ->  0
120         *     mul($x,0)  ->  0
121         *     mul(0,$x)  ->  0
122         *     mul($x,1)  ->  $x
123         *     mul(1,$x)  ->  $x
124         *     mul($x,$x) ->  pow($x,2)
125         *     pow($x,0)  ->  1
126         *     pow(0,$x)  ->  0
127         *     pow($x,1)  ->  $x
128         *     pow(1,$x)  ->  1
129         * } </pre>
130         *
131         * @since 5.0
132         */
133        public static final TreeRewriter<Op<Double>> ARITHMETIC_REWRITER =
134                TreeRewriter.concat(
135                        compile("sub($x,$x) -> 0"),
136                        compile("sub($x,0) -> $x"),
137                        compile("add($x,0) -> $x"),
138                        compile("add(0,$x) -> $x"),
139                        compile("add($x,$x) -> mul(2,$x)"),
140                        compile("div($x,$x) -> 1"),
141                        compile("div(0,$x) -> 0"),
142                        compile("mul($x,0) -> 0"),
143                        compile("mul(0,$x) -> 0"),
144                        compile("mul($x,1) -> $x"),
145                        compile("mul(1,$x) -> $x"),
146                        compile("mul($x,$x) -> pow($x,2)"),
147                        compile("pow($x,0) -> 1"),
148                        compile("pow(0,$x) -> 0"),
149                        compile("pow($x,1) -> $x"),
150                        compile("pow(1,$x) -> 1")
151                );
152
153        private static TreeRewriter<Op<Double>> compile(final String rule) {
154                return TreeRewriteRule.parse(rule, MathOp::toMathOp);
155        }
156
157        /**
158         * Combination of the {@link #ARITHMETIC_REWRITER} and the
159         * {@link #CONST_REWRITER}, in this specific order.
160         *
161         * @since 5.0
162         */
163        public static final TreeRewriter<Op<Double>> REWRITER = TreeRewriter.concat(
164                ARITHMETIC_REWRITER,
165                CONST_REWRITER
166        );
167
168        private final FlatTreeNode<Op<Double>> _tree;
169
170        private final Lazy<ISeq<Var<Double>>> _vars;
171
172        // Primary constructor.
173        private MathExpr(final FlatTreeNode<Op<Double>> tree) {
174                _tree = requireNonNull(tree);
175                _vars = Lazy.of(() -> ISeq.of(
176                        _tree.stream()
177                                .filter(node -> node.value() instanceof Var)
178                                .map(node -> (Var<Double>)node.value())
179                                .collect(toCollection(() -> new TreeSet<>(comparing(Var::name))))
180                ));
181        }
182
183        /**
184         * Create a new {@code MathExpr} object from the given operation tree.
185         *
186         * @param tree the underlying operation tree
187         * @throws NullPointerException if the given {@code program} is {@code null}
188         * @throws IllegalArgumentException if the given operation tree is invalid,
189         *         which means there is at least one node where the operation arity
190         *         and the node child count differ.
191         */
192        public MathExpr(final Tree<? extends Op<Double>, ?> tree) {
193                this(FlatTreeNode.ofTree(tree));
194                Program.check(tree);
195        }
196
197        /**
198         * Return the variable list of this <em>math</em> expression.
199         *
200         * @return the variable list of this <em>math</em> expression
201         */
202        public ISeq<Var<Double>> vars() {
203                return _vars.get();
204        }
205
206        /**
207         * Return the operation tree underlying {@code this} math expression.
208         *
209         * @since 7.1
210         *
211         * @return the operation tree s
212         */
213        public Tree<Op<Double>, ?> tree() {
214                return _tree;
215        }
216
217        /**
218         * Return the math expression as an operation tree.
219         *
220         * @return a new expression tree
221         * @deprecated Will be removed, use {@link #tree()} instead
222         */
223        @Deprecated(forRemoval = true)
224        public TreeNode<Op<Double>> toTree() {
225                return TreeNode.ofTree(_tree);
226        }
227
228        /**
229         * @see #eval(double...)
230         * @see #eval(String, double...)
231         */
232        @Override
233        public Double apply(final Double[] args) {
234                return Program.eval(_tree, args);
235        }
236
237        /**
238         * Convenient method, which lets you apply the program function without
239         * explicitly create a wrapper array.
240         *
241         * {@snippet lang="java":
242         *  final double result = MathExpr.parse("2*z + 3*x - y").eval(3, 2, 1);
243         *  assert result == 9.0;
244         * }
245         *
246         * @see #apply(Double[])
247         * @see #eval(String, double...)
248         *
249         * @param args the function arguments
250         * @return the evaluated value
251         * @throws NullPointerException if the given variable array is {@code null}
252         * @throws IllegalArgumentException if the length of the argument array
253         *         is smaller than the program arity
254         */
255        public double eval(final double... args) {
256                final double val = apply(box(args));
257                return val == -0.0 ? 0.0 : val;
258        }
259
260        @Override
261        public int hashCode() {
262                return Tree.hashCode(_tree);
263        }
264
265        @Override
266        public boolean equals(final Object obj) {
267                return obj == this ||
268                        obj instanceof MathExpr expr &&
269                        _tree.equals(expr._tree);
270        }
271
272        /**
273         * Return the string representation of this {@code MathExpr} object. The
274         * string returned by this method can be parsed again and will result in the
275         * same expression object.
276         * {@snippet lang="java":
277         *  final String expr = "5.0 + 6.0*x + sin(x)^34.0 + (1.0 + sin(x*5.0)/4.0) + 6.5";
278         *  final MathExpr tree = MathExpr.parse(expr);
279         *  assert tree.toString().equals(expr);
280         * }
281         *
282         * @return the expression string
283         */
284        @Override
285        public String toString() {
286                return format(_tree);
287        }
288
289        /**
290         * Simplifying {@code this} expression by applying the given {@code rewriter}
291         * and the given rewrite {@code limit}.
292         *
293         * @param rewriter the rewriter used for simplifying {@code this} expression
294         * @param limit the rewrite limit
295         * @return a newly created math expression object
296         * @throws NullPointerException if the {@code rewriter} is {@code null}
297         * @throws IllegalArgumentException if the {@code limit} is smaller than
298         *         zero
299         */
300        public MathExpr simplify(
301                final TreeRewriter<Op<Double>> rewriter,
302                final int limit
303        ) {
304                final TreeNode<Op<Double>> tree = TreeNode.ofTree(tree());
305                rewriter.rewrite(tree, limit);
306                return new MathExpr(FlatTreeNode.ofTree(tree));
307        }
308
309        /**
310         * Simplifying {@code this} expression by applying the given {@code rewriter}.
311         *
312         * @param rewriter the rewriter used for simplifying {@code this} expression
313         * @return a newly created math expression object
314         * @throws NullPointerException if the {@code rewriter} is {@code null}
315         */
316        public MathExpr simplify(final TreeRewriter<Op<Double>> rewriter) {
317                return simplify(rewriter, Integer.MAX_VALUE);
318        }
319
320        /**
321         * Simplifies {@code this} expression by applying the default
322         * {@link #REWRITER}.
323         *
324         * @return a newly created math expression object
325         */
326        public MathExpr simplify() {
327                return simplify(REWRITER);
328        }
329
330        private static Double[] box(final double... values) {
331                final Double[] result = new Double[values.length];
332                for (int i = values.length; --i >= 0;) {
333                        result[i] = values[i];
334                }
335                return result;
336        }
337
338        private static Op<Double> toOp(
339                final Token<String> token,
340                final TokenType type
341        ) {
342                return switch ((MathTokenType)token.type()) {
343                        case PLUS -> type == UNARY_OPERATOR ? MathOp.ID : MathOp.ADD;
344                        case MINUS -> type == UNARY_OPERATOR ? MathOp.NEG : MathOp.SUB;
345                        case TIMES -> MathOp.MUL;
346                        case DIV -> MathOp.DIV;
347                        case MOD -> MathOp.MOD;
348                        case POW -> MathOp.POW;
349                        case NUMBER -> Const.of(Double.parseDouble(token.value()));
350                        case IDENTIFIER -> {
351                                if (type == FUNCTION) {
352                                        yield MathOp.toMathOp(token.value());
353                                } else {
354                                        yield switch (token.value()) {
355                                                case "π", "PI" -> MathOp.PI;
356                                                default -> Var.of(token.value());
357                                        };
358                                }
359                        }
360                        default -> throw new ParsingException("Unknown token: " + token);
361                };
362        }
363
364
365        /* *************************************************************************
366         *  Java object serialization
367         * ************************************************************************/
368
369        @Serial
370        private Object writeReplace() {
371                return new SerialProxy(SerialProxy.MATH_EXPR, this);
372        }
373
374        @Serial
375        private void readObject(final ObjectInputStream stream)
376                throws InvalidObjectException
377        {
378                throw new InvalidObjectException("Serialization proxy required.");
379        }
380
381        void write(final DataOutput out) throws IOException {
382                final byte[] data = toString().getBytes(UTF_8);
383                writeInt(data.length, out);
384                out.write(data);
385        }
386
387        static MathExpr read(final DataInput in) throws IOException {
388                final byte[] data = new byte[readInt(in)];
389                in.readFully(data);
390                return parse(new String(data, UTF_8));
391        }
392
393        /* *************************************************************************
394         * Static helper methods.
395         * ************************************************************************/
396
397        /**
398         * Return the string representation of the given {@code tree} object. The
399         * string returned by this method can be parsed again and will result in the
400         * same expression object.
401         * {@snippet lang="java":
402         *  final String expr = "5.0 + 6.0*x + sin(x)^34.0 + (1.0 + sin(x*5.0)/4.0) + 6.5";
403         *  final MathExpr tree = MathExpr.parse(expr);
404         *  assert MathExpr.format(tree.tree()).equals(expr);
405         * }
406         *
407         * @since 4.3
408         *
409         * @param tree the tree object to convert to a string
410         * @return a new expression string
411         * @throws NullPointerException if the given {@code tree} is {@code null}
412         */
413        public static String format(final Tree<? extends Op<Double>, ?> tree) {
414                return MathExprFormatter.format(tree);
415        }
416
417        /**
418         * Parses the given {@code expression} into an AST tree.
419         *
420         * @param expression the expression string
421         * @return the tree representation of the given {@code expression}
422         * @throws NullPointerException if the given {@code expression} is {@code null}
423         * @throws IllegalArgumentException if the given expression is invalid or
424         *         can't be parsed.
425         */
426        public static MathExpr parse(final String expression) {
427                return new MathExpr(FlatTreeNode.ofTree(parseTree(expression)));
428        }
429
430        /**
431         * Parses the given mathematical expression string and returns the
432         * mathematical expression tree. The expression may contain all functions
433         * defined in {@link MathOp}.
434         * {@snippet lang="java":
435         * final Tree<? extends Op<Double>, ?> tree = MathExpr
436         *     .parseTree("5 + 6*x + sin(x)^34 + (1 + sin(x*5)/4)/6");
437         * }
438         * The example above will lead to the following tree:
439         * <pre> {@code
440         *  add
441         *  ├── add
442         *  │   ├── add
443         *  │   │   ├── 5.0
444         *  │   │   └── mul
445         *  │   │       ├── 6.0
446         *  │   │       └── x
447         *  │   └── pow
448         *  │       ├── sin
449         *  │       │   └── x
450         *  │       └── 34.0
451         *  └── div
452         *      ├── add
453         *      │   ├── 1.0
454         *      │   └── div
455         *      │       ├── sin
456         *      │       │   └── mul
457         *      │       │       ├── x
458         *      │       │       └── 5.0
459         *      │       └── 4.0
460         *      └── 6.0
461         * } </pre>
462         *
463         * @param expression the expression string
464         * @return the parsed expression tree
465         * @throws NullPointerException if the given {@code expression} is {@code null}
466         * @throws IllegalArgumentException if the given expression is invalid or
467         *         can't be parsed.
468         */
469        public static Tree<Op<Double>, ?> parseTree(final String expression) {
470                final var tokenizer = new MathStringTokenizer(expression);
471                return parseTree(tokenizer::next);
472        }
473
474        private static <V> Tree<Op<Double>, ?>
475        parseTree(final Supplier<Token<String>> tokens) {
476                final TreeNode<Op<Double>> tree = FORMULA_PARSER.parse(tokens, MathExpr::toOp);
477                Var.reindex(tree);
478                return FlatTreeNode.ofTree(tree);
479        }
480
481        /**
482         * Evaluates the given {@code expression} with the given arguments.
483         *
484         * {@snippet lang="java":
485         *  final double result = MathExpr.eval("2*z + 3*x - y", 3, 2, 1);
486         *  assert result == 9.0;
487         * }
488         *
489         * @see #apply(Double[])
490         * @see #eval(double...)
491         *
492         * @param expression the expression to evaluate
493         * @param args the expression arguments, in alphabetical order
494         * @return the evaluation result
495         * @throws NullPointerException if the given {@code expression} is
496         *         {@code null}
497         * @throws IllegalArgumentException if the given operation tree is invalid,
498         *         which means there is at least one node where the operation arity
499         *         and the node child count differ.
500         */
501        public static double eval(final String expression, final double... args) {
502                return parse(expression).eval(args);
503        }
504
505        /**
506         * Evaluates the given {@code expression} with the given arguments.
507         *
508         * @see #apply(Double[])
509         * @see #eval(double...)
510         * @see #eval(String, double...)
511         *
512         * @since 4.4
513         *
514         * @param expression the expression to evaluate
515         * @param args the expression arguments, in alphabetical order
516         * @return the evaluation result
517         * @throws NullPointerException if the given {@code expression} is
518         *         {@code null}
519         */
520        public static double eval(
521                final Tree<? extends Op<Double>, ?> expression,
522                final double... args
523        ) {
524                return Program.eval(expression, box(args));
525        }
526
527        /**
528         * Applies the {@link #REWRITER} to the given (mutable) {@code tree}. The
529         * tree rewrite is done in place.
530         *
531         * @see TreeRewriter#rewrite(TreeNode, int)
532         *
533         * @since 5.0
534         *
535         * @param tree the tree to be rewritten
536         * @param limit the maximal number this rewrite rule is applied to the given
537         *        tree. This guarantees the termination of the rewrite method.
538         * @return the number of rewrites applied to the input {@code tree}
539         * @throws NullPointerException if the given {@code tree} is {@code null}
540         * @throws IllegalArgumentException if the {@code limit} is smaller than
541         *         one
542         */
543        public static int rewrite(final TreeNode<Op<Double>> tree, final int limit) {
544                return REWRITER.rewrite(tree, limit);
545        }
546
547        /**
548         * Applies the {@link #REWRITER} to the given (mutable) {@code tree}. The
549         * tree rewrite is done in place. The limit of the applied rewrites is set
550         * unlimited ({@link Integer#MAX_VALUE}).
551         *
552         * @see #rewrite(TreeNode, int)
553         * @see TreeRewriter#rewrite(TreeNode)
554         *
555         * @since 5.0
556         *
557         * @param tree the tree to be rewritten
558         * @return {@code true} if the tree has been changed (rewritten) by this
559         *         method, {@code false} if the tree hasn't been changed
560         * @throws NullPointerException if the given {@code tree} is {@code null}
561         */
562        public static int rewrite(final TreeNode<Op<Double>> tree) {
563                return rewrite(tree, Integer.MAX_VALUE);
564        }
565
566}