diff --git a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/optimizer/Optimizer.java b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/optimizer/Optimizer.java index da14cc2404cce..929acb8420b95 100644 --- a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/optimizer/Optimizer.java +++ b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/optimizer/Optimizer.java @@ -42,6 +42,7 @@ import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.OptimizerRule; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PropagateEquals; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PruneLiteralsInOrderBy; +import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.SimplifyComparisonsArithmetics; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.ReplaceRegexMatch; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.ReplaceSurrogateFunction; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.SetAsOptimized; @@ -87,6 +88,7 @@ protected Iterable.Batch> batches() { new CombineBinaryComparisons(), new CombineDisjunctionsToIn(), new PushDownAndCombineFilters(), + new SimplifyComparisonsArithmetics(DataTypes::areCompatible), // prune/elimination new PruneFilters(), new PruneLiteralsInOrderBy(), diff --git a/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/optimizer/OptimizerTests.java b/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/optimizer/OptimizerTests.java index a849f1ad875b7..5531d8fe238d3 100644 --- a/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/optimizer/OptimizerTests.java +++ b/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/optimizer/OptimizerTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNull; import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThan; +import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThanOrEqual; import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.LessThan; import org.elasticsearch.xpack.ql.expression.predicate.regex.Like; import org.elasticsearch.xpack.ql.index.EsIndex; @@ -665,6 +666,28 @@ public void testDifferentKeyFromDisjunction() { assertEquals(filter, filterCondition(child2.children().get(0))); } + // ((a + 1) - 3) * 4 >= 16 -> a >= 6. + public void testReduceBinaryComparisons() { + LogicalPlan plan = accept("foo where ((pid + 1) - 3) * 4 >= 16"); + assertNotNull(plan); + List filters = plan.collectFirstChildren(x -> x instanceof Filter); + assertNotNull(filters); + assertEquals(1, filters.size()); + assertTrue(filters.get(0) instanceof Filter); + Filter filter = (Filter) filters.get(0); + + assertTrue(filter.condition() instanceof And); + And and = (And) filter.condition(); + assertTrue(and.right() instanceof GreaterThanOrEqual); + GreaterThanOrEqual gte = (GreaterThanOrEqual) and.right(); + + assertTrue(gte.left() instanceof FieldAttribute); + assertEquals("pid", ((FieldAttribute) gte.left()).name()); + + assertTrue(gte.right() instanceof Literal); + assertEquals(6, ((Literal) gte.right()).value()); + } + private static Attribute timestamp() { return new FieldAttribute(EMPTY, "test", new EsField("field", INTEGER, emptyMap(), true)); } diff --git a/x-pack/plugin/eql/src/test/resources/queryfolder_tests.txt b/x-pack/plugin/eql/src/test/resources/queryfolder_tests.txt index acd611c3849f6..64f15402fae95 100644 --- a/x-pack/plugin/eql/src/test/resources/queryfolder_tests.txt +++ b/x-pack/plugin/eql/src/test/resources/queryfolder_tests.txt @@ -528,35 +528,35 @@ process where wildcard~(process_path, "*\\red_ttp\\wininit.*", "*\\abc\\*", "*de addOperator -process where serial_event_id + 2 == 41 +process where serial_event_id + 2 == -2147483647 ; "script":{"source":"InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.eq( InternalQlScriptUtils.add(InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2))", -"params":{"v0":"serial_event_id","v1":2,"v2":41} +"params":{"v0":"serial_event_id","v1":2,"v2":-2147483647} ; addOperatorReversed -process where 2 + serial_event_id == 41 +process where 2 + serial_event_id == -2147483647 ; "script":{"source":"InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.eq( InternalQlScriptUtils.add(InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2))", -"params":{"v0":"serial_event_id","v1":2,"v2":41} +"params":{"v0":"serial_event_id","v1":2,"v2":-2147483647} ; addFunction -process where add(serial_event_id, 2) == 41 +process where add(serial_event_id, 2) == -2147483647 ; "script":{"source":"InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.eq( InternalQlScriptUtils.add(InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2))", -"params":{"v0":"serial_event_id","v1":2,"v2":41} +"params":{"v0":"serial_event_id","v1":2,"v2":-2147483647} ; addFunctionReversed -process where add(2, serial_event_id) == 41 +process where add(2, serial_event_id) == -2147483647 ; "script":{"source":"InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.eq( InternalQlScriptUtils.add(InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2))", -"params":{"v0":"serial_event_id","v1":2,"v2":41} +"params":{"v0":"serial_event_id","v1":2,"v2":-2147483647} ; divideOperator @@ -656,35 +656,35 @@ InternalQlScriptUtils.mul(InternalQlScriptUtils.docValue(doc,params.v0),params.v ; subtractOperator -process where serial_event_id - 2 == 41 +process where serial_event_id - 2 == 2147483647 ; "script":{"source":"InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.eq( InternalQlScriptUtils.sub(InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2))", -"params":{"v0":"serial_event_id","v1":2,"v2":41} +"params":{"v0":"serial_event_id","v1":2,"v2":2147483647} ; subtractOperatorReversed -process where 43 - serial_event_id == 41 +process where 43 - serial_event_id == -2147483647 ; "script":{"source":"InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.eq( InternalQlScriptUtils.sub(params.v0,InternalQlScriptUtils.docValue(doc,params.v1)),params.v2))", -"params":{"v0":43,"v1":"serial_event_id","v2":41} +"params":{"v0":43,"v1":"serial_event_id","v2":-2147483647} ; subtractFunction -process where subtract(serial_event_id, 2) == 41 +process where subtract(serial_event_id, 2) == 2147483647 ; "script":{"source":"InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.eq( InternalQlScriptUtils.sub(InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2))", -"params":{"v0":"serial_event_id","v1":2,"v2":41} +"params":{"v0":"serial_event_id","v1":2,"v2":2147483647} ; subtractFunctionReversed -process where subtract(43, serial_event_id) == 41 +process where subtract(43, serial_event_id) == -2147483647 ; "script":{"source":"InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.eq( InternalQlScriptUtils.sub(params.v0,InternalQlScriptUtils.docValue(doc,params.v1)),params.v2))", -"params":{"v0":43,"v1":"serial_event_id","v2":41} +"params":{"v0":43,"v1":"serial_event_id","v2":-2147483647} ; eventQueryDefaultLimit diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/Add.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/Add.java index cc57c6cd1bb31..6e1c64907c0f3 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/Add.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/Add.java @@ -12,7 +12,7 @@ /** * Addition function ({@code a + b}). */ -public class Add extends DateTimeArithmeticOperation { +public class Add extends DateTimeArithmeticOperation implements BinaryComparisonInversible { public Add(Source source, Expression left, Expression right) { super(source, left, right, DefaultBinaryArithmeticOperation.ADD); } @@ -30,4 +30,9 @@ protected Add replaceChildren(Expression left, Expression right) { public Add swapLeftAndRight() { return new Add(source(), right(), left()); } + + @Override + public ArithmeticOperationFactory binaryComparisonInverse() { + return Sub::new; + } } diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/BinaryComparisonInversible.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/BinaryComparisonInversible.java new file mode 100644 index 0000000000000..8771ea3e77b77 --- /dev/null +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/BinaryComparisonInversible.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic; + +import org.elasticsearch.xpack.ql.expression.Expression; +import org.elasticsearch.xpack.ql.tree.Source; + +/* + * Factory interface for arithmetic operations that have an inverse in reference to a binary comparison. + * For instance the division is multiplication's inverse, substitution addition's, log exponentiation's a.s.o. + * Not all operations - like modulo - are invertible. + */ +public interface BinaryComparisonInversible { + + interface ArithmeticOperationFactory { + ArithmeticOperation create(Source source, Expression left, Expression right); + } + + ArithmeticOperationFactory binaryComparisonInverse(); +} diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/Div.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/Div.java index bf2538eee6441..ba1e74dc8f9e0 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/Div.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/Div.java @@ -14,7 +14,7 @@ /** * Division function ({@code a / b}). */ -public class Div extends ArithmeticOperation { +public class Div extends ArithmeticOperation implements BinaryComparisonInversible { public Div(Source source, Expression left, Expression right) { super(source, left, right, DefaultBinaryArithmeticOperation.DIV); @@ -34,4 +34,9 @@ protected Div replaceChildren(Expression newLeft, Expression newRight) { public DataType dataType() { return DataTypeConverter.commonType(left().dataType(), right().dataType()); } + + @Override + public ArithmeticOperationFactory binaryComparisonInverse() { + return Mul::new; + } } diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/Mul.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/Mul.java index d78f1984cb8e0..a9156471faa4d 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/Mul.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/Mul.java @@ -16,7 +16,7 @@ /** * Multiplication function ({@code a * b}). */ -public class Mul extends ArithmeticOperation { +public class Mul extends ArithmeticOperation implements BinaryComparisonInversible { public Mul(Source source, Expression left, Expression right) { super(source, left, right, DefaultBinaryArithmeticOperation.MUL); @@ -52,4 +52,9 @@ protected Mul replaceChildren(Expression newLeft, Expression newRight) { public Mul swapLeftAndRight() { return new Mul(source(), right(), left()); } + + @Override + public ArithmeticOperationFactory binaryComparisonInverse() { + return Div::new; + } } diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/Sub.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/Sub.java index c5d7fc920e843..f1338e47d8d33 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/Sub.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/arithmetic/Sub.java @@ -12,7 +12,7 @@ /** * Subtraction function ({@code a - b}). */ -public class Sub extends DateTimeArithmeticOperation { +public class Sub extends DateTimeArithmeticOperation implements BinaryComparisonInversible { public Sub(Source source, Expression left, Expression right) { super(source, left, right, DefaultBinaryArithmeticOperation.SUB); @@ -27,4 +27,9 @@ protected NodeInfo info() { protected Sub replaceChildren(Expression newLeft, Expression newRight) { return new Sub(source(), newLeft, newRight); } + + @Override + public ArithmeticOperationFactory binaryComparisonInverse() { + return Add::new; + } } diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/BinaryComparison.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/BinaryComparison.java index 3774a14952077..225bd3854ee66 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/BinaryComparison.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/BinaryComparison.java @@ -49,4 +49,11 @@ protected Pipe makePipe() { public static Integer compare(Object left, Object right) { return Comparisons.compare(left, right); } + + /** + * Reverses the direction of this comparison on the comparison axis. + * Some operations like Greater/LessThan/OrEqual will behave as if the operands of a numerical comparison get multiplied with a + * negative number. Others like Not/Equal can be immutable to this operation. + */ + public abstract BinaryComparison reverse(); } diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/Equals.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/Equals.java index 3f8fc3bf9283f..49e276bc60e1a 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/Equals.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/Equals.java @@ -42,4 +42,9 @@ public Equals swapLeftAndRight() { public BinaryComparison negate() { return new NotEquals(source(), left(), right(), zoneId()); } + + @Override + public BinaryComparison reverse() { + return this; + } } diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/GreaterThan.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/GreaterThan.java index c9c7b6729f5d2..d20065302dbb8 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/GreaterThan.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/GreaterThan.java @@ -38,4 +38,9 @@ public LessThan swapLeftAndRight() { public LessThanOrEqual negate() { return new LessThanOrEqual(source(), left(), right(), zoneId()); } + + @Override + public BinaryComparison reverse() { + return new LessThan(source(), left(), right(), zoneId()); + } } diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/GreaterThanOrEqual.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/GreaterThanOrEqual.java index f5da6e7349b69..c196ead8b0531 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/GreaterThanOrEqual.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/GreaterThanOrEqual.java @@ -38,4 +38,9 @@ public LessThanOrEqual swapLeftAndRight() { public LessThan negate() { return new LessThan(source(), left(), right(), zoneId()); } + + @Override + public BinaryComparison reverse() { + return new LessThanOrEqual(source(), left(), right(), zoneId()); + } } diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/LessThan.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/LessThan.java index 4dfa00186bdcb..487910ac5c746 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/LessThan.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/LessThan.java @@ -38,4 +38,9 @@ public GreaterThan swapLeftAndRight() { public GreaterThanOrEqual negate() { return new GreaterThanOrEqual(source(), left(), right(), zoneId()); } + + @Override + public BinaryComparison reverse() { + return new GreaterThan(source(), left(), right(), zoneId()); + } } diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/LessThanOrEqual.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/LessThanOrEqual.java index 71a963a2da5c3..c3a194a1353d0 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/LessThanOrEqual.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/LessThanOrEqual.java @@ -38,4 +38,9 @@ public GreaterThanOrEqual swapLeftAndRight() { public GreaterThan negate() { return new GreaterThan(source(), left(), right(), zoneId()); } + + @Override + public BinaryComparison reverse() { + return new GreaterThanOrEqual(source(), left(), right(), zoneId()); + } } diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/NotEquals.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/NotEquals.java index ff0bb72255611..1e069ee3d8d91 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/NotEquals.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/NotEquals.java @@ -38,4 +38,9 @@ public NotEquals swapLeftAndRight() { public BinaryComparison negate() { return new Equals(source(), left(), right(), zoneId()); } + + @Override + public BinaryComparison reverse() { + return this; + } } diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/NullEquals.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/NullEquals.java index 8402e89fd31a9..5a1ba0bd7cd4e 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/NullEquals.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/expression/predicate/operator/comparison/NullEquals.java @@ -41,4 +41,9 @@ public NullEquals swapLeftAndRight() { public Nullability nullable() { return Nullability.FALSE; } + + @Override + public BinaryComparison reverse() { + return this; + } } diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRules.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRules.java index a125d621f8052..b2d55fe151a7f 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRules.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRules.java @@ -20,6 +20,10 @@ import org.elasticsearch.xpack.ql.expression.predicate.logical.Not; import org.elasticsearch.xpack.ql.expression.predicate.logical.Or; import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNotNull; +import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.ArithmeticOperation; +import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.BinaryComparisonInversible; +import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.Neg; +import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.Sub; import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThan; @@ -36,9 +40,11 @@ import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.ql.plan.logical.OrderBy; import org.elasticsearch.xpack.ql.rule.Rule; +import org.elasticsearch.xpack.ql.type.DataType; import org.elasticsearch.xpack.ql.type.DataTypes; import org.elasticsearch.xpack.ql.util.CollectionUtils; +import java.time.DateTimeException; import java.time.ZoneId; import java.util.ArrayList; import java.util.Iterator; @@ -48,7 +54,10 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.BiFunction; +import static java.lang.Math.signum; +import static java.util.Arrays.asList; import static org.elasticsearch.xpack.ql.expression.Literal.FALSE; import static org.elasticsearch.xpack.ql.expression.Literal.TRUE; import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.combineAnd; @@ -57,6 +66,12 @@ import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.splitAnd; import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.splitOr; import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.subtract; +import static org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation.ADD; +import static org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation.DIV; +import static org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation.MOD; +import static org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation.MUL; +import static org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation.SUB; +import static org.elasticsearch.xpack.ql.tree.Source.EMPTY; import static org.elasticsearch.xpack.ql.util.CollectionUtils.combine; @@ -226,6 +241,7 @@ private Expression simplifyNot(Not n) { } } + // TODO: should this be renamed to just `LiteralsOnTheRight`? It swaps all literals, not just booleans. Or `MaybeLiteralsOnTheRight`? public static final class BooleanLiteralsOnTheRight extends OptimizerExpressionRule { public BooleanLiteralsOnTheRight() { @@ -1137,6 +1153,221 @@ protected Expression rule(Expression e) { } } + // Simplifies arithmetic expressions with BinaryComparisons and fixed point fields, such as: (int + 2) / 3 > 4 => int > 10 + public static final class SimplifyComparisonsArithmetics extends OptimizerExpressionRule { + BiFunction typesCompatible; + + public SimplifyComparisonsArithmetics(BiFunction typesCompatible) { + super(TransformDirection.UP); + this.typesCompatible = typesCompatible; + } + + @Override + protected Expression rule(Expression e) { + return (e instanceof BinaryComparison) ? simplify((BinaryComparison) e) : e; + } + + private Expression simplify(BinaryComparison bc) { + // optimize only once the expression has a literal on the right side of the binary comparison + if (bc.right() instanceof Literal) { + if (bc.left() instanceof ArithmeticOperation) { + return simplifyBinaryComparison(bc); + } + if (bc.left() instanceof Neg) { + return foldNegation(bc); + } + } + return bc; + } + + private Expression simplifyBinaryComparison(BinaryComparison comparison) { + ArithmeticOperation operation = (ArithmeticOperation) comparison.left(); + // Use symbol comp: SQL operations aren't available in this package (as dependencies) + String opSymbol = operation.symbol(); + // Modulo can't be simplified. + if (opSymbol == MOD.symbol()) { + return comparison; + } + OperationSimplifier simplification = null; + if (isMulOrDiv(opSymbol)) { + simplification = new MulDivSimplifier(comparison); + } else if (opSymbol == ADD.symbol() || opSymbol == SUB.symbol()) { + simplification = new AddSubSimplifier(comparison); + } + + return (simplification == null || simplification.isUnsafe(typesCompatible)) ? comparison : simplification.apply(); + } + + private static boolean isMulOrDiv(String opSymbol) { + return opSymbol == MUL.symbol() || opSymbol == DIV.symbol(); + } + + private static Expression foldNegation(BinaryComparison bc) { + Literal bcLiteral = (Literal) bc.right(); + Expression literalNeg = tryFolding(new Neg(bcLiteral.source(), bcLiteral)); + return literalNeg == null ? bc : bc.reverse().replaceChildren(asList(((Neg) bc.left()).field(), literalNeg)); + } + + private static Expression tryFolding(Expression expression) { + if (expression.foldable()) { + try { + expression = new Literal(expression.source(), expression.fold(), expression.dataType()); + } catch (ArithmeticException | DateTimeException e) { + // null signals that folding would result in an over-/underflow (such as Long.MAX_VALUE+1); the optimisation is skipped. + expression = null; + } + } + return expression; + } + + private abstract static class OperationSimplifier { + final BinaryComparison comparison; + final Literal bcLiteral; + final ArithmeticOperation operation; + final Expression opLeft; + final Expression opRight; + final Literal opLiteral; + + OperationSimplifier(BinaryComparison comparison) { + this.comparison = comparison; + operation = (ArithmeticOperation) comparison.left(); + bcLiteral = (Literal) comparison.right(); + + opLeft = operation.left(); + opRight = operation.right(); + + if (opLeft instanceof Literal) { + opLiteral = (Literal) opLeft; + } else if (opRight instanceof Literal) { + opLiteral = (Literal) opRight; + } else { + opLiteral = null; + } + } + + // can it be quickly fast-tracked that the operation can't be reduced? + final boolean isUnsafe(BiFunction typesCompatible) { + if (opLiteral == null) { + // one of the arithm. operands must be a literal, otherwise the operation wouldn't simplify anything + return true; + } + + // Only operations on fixed point literals are supported, since optimizing float point operations can also change the + // outcome of the filtering: + // x + 1e18 > 1e18::long will yield different results with a field value in [-2^6, 2^6], optimised vs original; + // x * (1 + 1e-15d) > 1 : same with a field value of (1 - 1e-15d) + // so consequently, int fields optimisation requiring FP arithmetic isn't possible either: (x - 1e-15) * (1 + 1e-15) > 1. + if (opLiteral.dataType().isRational() || bcLiteral.dataType().isRational()) { + return true; + } + + // the Literal will be moved to the right of the comparison, but only if data-compatible with what's there + if (typesCompatible.apply(bcLiteral.dataType(), opLiteral.dataType()) == false) { + return true; + } + + return isOpUnsafe(); + } + + final Expression apply() { + // force float point folding for FlP field + Literal bcl = operation.dataType().isRational() + ? Literal.of(bcLiteral, ((Number) bcLiteral.value()).doubleValue()) + : bcLiteral; + + Expression bcRightExpression = ((BinaryComparisonInversible) operation).binaryComparisonInverse() + .create(bcl.source(), bcl, opRight); + bcRightExpression = tryFolding(bcRightExpression); + return bcRightExpression != null + ? postProcess((BinaryComparison) comparison.replaceChildren(List.of(opLeft, bcRightExpression))) + : comparison; + } + + // operation-specific operations: + // - fast-tracking of simplification unsafety + abstract boolean isOpUnsafe(); + + // - post optimisation adjustments + Expression postProcess(BinaryComparison binaryComparison) { + return binaryComparison; + } + } + + private static class AddSubSimplifier extends OperationSimplifier { + + AddSubSimplifier(BinaryComparison comparison) { + super(comparison); + } + + @Override + boolean isOpUnsafe() { + // no ADD/SUB with floating fields + if (operation.dataType().isRational()) { + return true; + } + + if (operation.symbol() == SUB.symbol() && opRight instanceof Literal == false) { // such as: 1 - x > -MAX + // if next simplification step would fail on overflow anyways, skip the optimisation already + return tryFolding(new Sub(EMPTY, opLeft, bcLiteral)) == null; + } + + return false; + } + } + + private static class MulDivSimplifier extends OperationSimplifier { + + private final boolean isDiv; // and not MUL. + private final int opRightSign; // sign of the right operand in: (left) (op) (right) (comp) (literal) + + MulDivSimplifier(BinaryComparison comparison) { + super(comparison); + isDiv = operation.symbol() == DIV.symbol(); + opRightSign = sign(opRight); + } + + @Override + boolean isOpUnsafe() { + // Integer divisions are not safe to optimise: x / 5 > 1 <=/=> x > 5 for x in [6, 9]; same for the `==` comp + if (operation.dataType().isInteger() && isDiv) { + return true; + } + + // If current operation is a multiplication, it's inverse will be a division: safe only if outcome is still integral. + if (isDiv == false && opLeft.dataType().isInteger()) { + long opLiteralValue = ((Number) opLiteral.value()).longValue(); + return opLiteralValue == 0 || ((Number) bcLiteral.value()).longValue() % opLiteralValue != 0; + } + + // can't move a 0 in Mul/Div comparisons + return opRightSign == 0; + } + + @Override + Expression postProcess(BinaryComparison binaryComparison) { + // negative multiplication/division changes the direction of the comparison + return opRightSign < 0 ? binaryComparison.reverse() : binaryComparison; + } + + private static int sign(Object obj) { + int sign = 1; + if (obj instanceof Number) { + sign = (int) signum(((Number) obj).doubleValue()); + } else if (obj instanceof Literal) { + sign = sign(((Literal) obj).value()); + } else if (obj instanceof Neg) { + sign = -sign(((Neg) obj).field()); + } else if (obj instanceof ArithmeticOperation) { + ArithmeticOperation operation = (ArithmeticOperation) obj; + if (isMulOrDiv(operation.symbol())) { + sign = sign(operation.left()) * sign(operation.right()); + } + } + return sign; + } + } + } + public abstract static class PruneFilters extends OptimizerRule { @Override diff --git a/x-pack/plugin/ql/src/test/resources/mapping-multi-field-variation.json b/x-pack/plugin/ql/src/test/resources/mapping-multi-field-variation.json index 95bb0c34877ae..4d7c021bb69d0 100644 --- a/x-pack/plugin/ql/src/test/resources/mapping-multi-field-variation.json +++ b/x-pack/plugin/ql/src/test/resources/mapping-multi-field-variation.json @@ -2,6 +2,7 @@ "properties" : { "bool" : { "type" : "boolean" }, "int" : { "type" : "integer" }, + "float" : { "type" : "float" }, "text" : { "type" : "text" }, "keyword" : { "type" : "keyword" }, "date" : { "type" : "date" }, diff --git a/x-pack/plugin/sql/qa/server/src/main/resources/arithmetic.csv-spec b/x-pack/plugin/sql/qa/server/src/main/resources/arithmetic.csv-spec index 31b2417932aa4..9550091f3df37 100644 --- a/x-pack/plugin/sql/qa/server/src/main/resources/arithmetic.csv-spec +++ b/x-pack/plugin/sql/qa/server/src/main/resources/arithmetic.csv-spec @@ -23,7 +23,51 @@ nullArithmetics schema::a:i|b:d|c:s|d:s|e:l|f:i|g:i|h:i|i:i|j:i|k:d SELECT null + 2 AS a, null * 1.5 AS b, null + null AS c, null - null AS d, null - 1234567890123 AS e, 123 - null AS f, null / 5 AS g, 5 / null AS h, null % 5 AS i, 5 % null AS j, null + 5.5 - (null * (null * 3)) AS k; - a | b | c | d | e | f | g | h | i | j | k + a | b | c | d | e | f | g | h | i | j | k ---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+--------------- -null |null |null |null |null |null |null |null |null |null |null +null |null |null |null |null |null |null |null |null |null |null +; + +optimizedIntervalFilterPlus +SELECT emp_no x, hire_date h FROM test_emp WHERE hire_date + INTERVAL 20 YEAR > CAST('2010-01-01T00:00:00' AS TIMESTAMP) LIMIT 10; + + x | h +---------------+------------------------ +10008 |1994-09-15T00:00:00.000Z +10011 |1990-01-22T00:00:00.000Z +10012 |1992-12-18T00:00:00.000Z +10016 |1995-01-27T00:00:00.000Z +10017 |1993-08-03T00:00:00.000Z +10019 |1999-04-30T00:00:00.000Z +10020 |1991-01-26T00:00:00.000Z +10022 |1995-08-22T00:00:00.000Z +10024 |1997-05-19T00:00:00.000Z +10026 |1995-03-20T00:00:00.000Z +; + +optimizedIntervalFilterMinus +SELECT emp_no x, hire_date h FROM test_emp WHERE hire_date - INTERVAL 10 YEAR > CAST('1980-01-01T00:00:00' AS TIMESTAMP) LIMIT 10; + + x | h +---------------+------------------------ +10008 |1994-09-15T00:00:00.000Z +10011 |1990-01-22T00:00:00.000Z +10012 |1992-12-18T00:00:00.000Z +10016 |1995-01-27T00:00:00.000Z +10017 |1993-08-03T00:00:00.000Z +10019 |1999-04-30T00:00:00.000Z +10020 |1991-01-26T00:00:00.000Z +10022 |1995-08-22T00:00:00.000Z +10024 |1997-05-19T00:00:00.000Z +10026 |1995-03-20T00:00:00.000Z +; + +optimizedBinaryCompArithmeticWithNegationOfIntMinVal +SELECT IIF(languages < 4, -2147483645 - languages, 0) AS I FROM test_emp WHERE -I * 3 > 6 GROUP BY I; + + I:i +--------------- +-2147483648 +-2147483647 +-2147483646 ; diff --git a/x-pack/plugin/sql/qa/server/src/main/resources/arithmetic.sql-spec b/x-pack/plugin/sql/qa/server/src/main/resources/arithmetic.sql-spec index d94123ea0642f..5eb07570b74ea 100644 --- a/x-pack/plugin/sql/qa/server/src/main/resources/arithmetic.sql-spec +++ b/x-pack/plugin/sql/qa/server/src/main/resources/arithmetic.sql-spec @@ -87,3 +87,166 @@ orderByModulo SELECT emp_no FROM test_emp ORDER BY emp_no % 10000 LIMIT 10; orderByMul SELECT emp_no FROM test_emp ORDER BY emp_no * 2 LIMIT 10; + +// arithmetic optimiser +plusMulIntFieldPlus +SELECT emp_no FROM test_emp WHERE 1000 + 2 * (salary + 50000) > 201000; +mulIntFieldPlus +SELECT emp_no FROM test_emp WHERE -2 * (salary + 50000) - 1000 <= -201000; +plusIntFieldPlusDiv +SELECT emp_no FROM test_emp WHERE 1000 + (salary + 50000) / 2 > 101000; +intFieldPlusDiv +SELECT emp_no FROM test_emp WHERE (salary + 50000) / 2 + 1000 > 101000; +plusMulPlusIntField +SELECT emp_no FROM test_emp WHERE 1000 + 2 * (50000 + salary) > 201000; +mulPlusIntFieldPlus +SELECT emp_no FROM test_emp WHERE -2 * (50000 + salary) - 1000 <= -201000; +plusPlusIntFieldDiv +SELECT emp_no FROM test_emp WHERE 1000 + (50000 + salary) / 2 > 51000; +plusIntFieldDivPlus +SELECT emp_no FROM test_emp WHERE (50000 + salary) / 2 + 1000 > 51000; +plusMulIntFieldMinus +SELECT emp_no FROM test_emp WHERE 1000 + 2 * (salary - 10000) > 60000; +mulIntFieldMinusPlus +SELECT emp_no FROM test_emp WHERE -2 * (salary - 10000) - 1000 <= -61000; +plusIntFieldMinusDiv +SELECT emp_no FROM test_emp WHERE 1000 + (salary - 10000) / 2 > 16000; +intFieldMinusDivPlus +SELECT emp_no FROM test_emp WHERE (salary - 10000) / 2 + 1000 > 16000; +plusMulMinusIntField +SELECT emp_no FROM test_emp WHERE 1000 + 2 * (100000 - salary) > 101000; +mulMinusIntFieldPlus +SELECT emp_no FROM test_emp WHERE -2 * (100000 - salary) - 1000 <= -101000; +plusMinusIntFieldDiv +SELECT emp_no FROM test_emp WHERE 1000 + (100000 - salary) / 2 > 26000; +minusIntFieldDivPlus +SELECT emp_no FROM test_emp WHERE (100000 - salary) / 2 + 1000 > 25000; +plusMulIntFieldMul +SELECT emp_no FROM test_emp WHERE 1000 + 2 * (salary * 2) > 201000; +mulIntFieldMulPlus +SELECT emp_no FROM test_emp WHERE 2 * (-salary * 2) - 1000 <= -200000; +plusIntFieldMulDiv +SELECT emp_no FROM test_emp WHERE 1000 + (salary * 2) / 2 > 50100; +intFieldMulDivPlus +SELECT emp_no FROM test_emp WHERE (salary * 2) / 2 + 1000 > 50100; +plusMulMulIntField +SELECT emp_no FROM test_emp WHERE 1000 + 2 * (2 * salary) > 201000; +mulMulIntFieldPlus +SELECT emp_no FROM test_emp WHERE 2 * (2 * -salary) - 1000 <= -201000; +plusMulIntFieldDiv +SELECT emp_no FROM test_emp WHERE 1000 + (2 * salary) / 2 > 51000; +mulIntFieldDivPlus +SELECT emp_no FROM test_emp WHERE (2 * salary) / 2 + 1000 > 51000; +plusMulIntFieldDiv +SELECT emp_no FROM test_emp WHERE 1000 + 2 * (salary / 2) > 61000; +mulIntFieldDivPlus +SELECT emp_no FROM test_emp WHERE 2 * (-salary / 2) - 1000 <= -61000; +plusIntFieldDivDiv +SELECT emp_no FROM test_emp WHERE 1000 + (salary / 2) / 2 > 16000; +intFieldDivDivPlus +SELECT emp_no FROM test_emp WHERE (salary / 2) / 2 + 1000 > 16000; +plusMulDivIntField +SELECT emp_no FROM test_emp WHERE 1000 + 2 * (1000000000 / salary) > 61000; +mulDivIntFieldPlus +SELECT emp_no FROM test_emp WHERE 2 * (1000000000 / -salary) - 1000 <= -61000; +plusDivIntFieldDiv +SELECT emp_no FROM test_emp WHERE 1000 + (1000000000 / salary) / 2 > 16000; +divIntFieldDivPlus +SELECT emp_no FROM test_emp WHERE (1000000000 / salary) / 2 + 1000 > 16000; + +plusMulFloatFieldPlus +SELECT emp_no FROM test_emp WHERE 1000 + 2 * (salary::FLOAT + 50000) > 201000; +mulFloatFieldPlus +SELECT emp_no FROM test_emp WHERE -2 * (salary::FLOAT + 50000) - 1000 <= -201000; +plusFloatFieldPlusDiv +SELECT emp_no FROM test_emp WHERE 1000 + (salary::FLOAT + 50000) / 2 > 101000; +floatFieldPlusDiv +SELECT emp_no FROM test_emp WHERE (salary::FLOAT + 50000) / 2 + 1000 > 101000; +plusMulPlusFloatField +SELECT emp_no FROM test_emp WHERE 1000 + 2 * (50000 + salary::FLOAT) > 201000; +mulPlusFloatFieldPlus +SELECT emp_no FROM test_emp WHERE -2 * (50000 + salary::FLOAT) - 1000 <= -201000; +plusPlusFloatFieldDiv +SELECT emp_no FROM test_emp WHERE 1000 + (50000 + salary::FLOAT) / 2 > 51000; +plusFloatFieldDivPlus +SELECT emp_no FROM test_emp WHERE (50000 + salary::FLOAT) / 2 + 1000 > 51000; +plusMulFloatFieldMinus +SELECT emp_no FROM test_emp WHERE 1000 + 2 * (salary::FLOAT - 10000) > 60000; +mulFloatFieldMinusPlus +SELECT emp_no FROM test_emp WHERE -2 * (salary::FLOAT - 10000) - 1000 <= -61000; +plusFloatFieldMinusDiv +SELECT emp_no FROM test_emp WHERE 1000 + (salary::FLOAT - 10000) / 2 > 16000; +floatFieldMinusDivPlus +SELECT emp_no FROM test_emp WHERE (salary::FLOAT - 10000) / 2 + 1000 > 16000; +plusMulMinusFloatField +SELECT emp_no FROM test_emp WHERE 1000 + 2 * (100000 - salary::FLOAT) > 101000; +mulMinusFloatFieldPlus +SELECT emp_no FROM test_emp WHERE -2 * (100000 - salary::FLOAT) - 1000 <= -101000; +plusMinusFloatFieldDiv +SELECT emp_no FROM test_emp WHERE 1000 + (100000 - salary::FLOAT) / 2 > 26000; +minusFloatFieldDivPlus +SELECT emp_no FROM test_emp WHERE (100000 - salary::FLOAT) / 2 + 1000 > 25000; +plusMulFloatFieldMul +SELECT emp_no FROM test_emp WHERE 1000 + 2 * (salary::FLOAT * 2) > 201000; +mulFloatFieldMulPlus +SELECT emp_no FROM test_emp WHERE 2 * (-salary::FLOAT * 2) - 1000 <= -200000; +plusFloatFieldMulDiv +SELECT emp_no FROM test_emp WHERE 1000 + (salary::FLOAT * 2) / 2 > 50100; +floatFieldMulDivPlus +SELECT emp_no FROM test_emp WHERE (salary::FLOAT * 2) / 2 + 1000 > 50100; +plusMulMulFloatField +SELECT emp_no FROM test_emp WHERE 1000 + 2 * (2 * salary::FLOAT) > 201000; +mulMulFloatFieldPlus +SELECT emp_no FROM test_emp WHERE 2 * (2 * -salary::FLOAT) - 1000 <= -201000; +plusMulFloatFieldDiv +SELECT emp_no FROM test_emp WHERE 1000 + (2 * salary::FLOAT) / 2 > 51000; +mulFloatFieldDivPlus +SELECT emp_no FROM test_emp WHERE (2 * salary::FLOAT) / 2 + 1000 > 51000; +plusMulFloatFieldDiv +SELECT emp_no FROM test_emp WHERE 1000 + 2 * (salary::FLOAT / 2) > 61000; +mulFloatFieldDivPlus +SELECT emp_no FROM test_emp WHERE 2 * (-salary::FLOAT / 2) - 1000 <= -61000; +plusFloatFieldDivDiv +SELECT emp_no FROM test_emp WHERE 1000 + (salary::FLOAT / 2) / 2 > 16000; +floatFieldDivDivPlus +SELECT emp_no FROM test_emp WHERE (salary::FLOAT / 2) / 2 + 1000 > 16000; +plusMulDivFloatField +SELECT emp_no FROM test_emp WHERE 1000 + 2 * (1000000000 / salary::FLOAT) > 61000; +mulDivFloatFieldPlus +SELECT emp_no FROM test_emp WHERE 2 * (1000000000 / -salary::FLOAT) - 1000 <= -61000; +plusDivFloatFieldDiv +SELECT emp_no FROM test_emp WHERE 1000 + (1000000000 / salary::FLOAT) / 2 > 16000; +divFloatFieldDivPlus +SELECT emp_no FROM test_emp WHERE (1000000000 / salary::FLOAT) / 2 + 1000 > 16000; + +noOptimisationOnLongOverflowAdd +SELECT emp_no FROM test_emp WHERE salary - 2 < 9223372036854775807; +noOptimisationOnLongUnderflowSub +SELECT emp_no FROM test_emp WHERE -salary + 2 < -9223372036854775807; +noOptimisationOnIntOverflowAdd +SELECT emp_no FROM test_emp WHERE salary::INT - 2 < 2147483647; +noOptimisationOnIntUnderflowSub +SELECT emp_no FROM test_emp WHERE -salary::INT + 2 < -2147483648; +noOptimisationOnOverflowMul +SELECT emp_no FROM test_emp WHERE salary / 10 < 1.7976931348623157E308; +noOptimisationOnPrecisionLossOnFloatFieldAdd +SELECT emp_no FROM test_emp WHERE 1 - salary::FLOAT/1E21 < 1; +noOptimisationOnPrecisionLossOnFloatFieldDiv +SELECT emp_no FROM test_emp WHERE (1 - salary::FLOAT / 1E21) * (1 + 1E-15) > 1; +noOptimisationOnIntegralDivByZero +SELECT emp_no FROM test_emp WHERE (5/4 - 1) * salary > 1; +noOptimisationOnFloatDivByZero +SELECT emp_no FROM test_emp WHERE (5/4 - 1) * salary::FLOAT > 1; + +// negations +negationDenominator +SELECT emp_no FROM test_emp WHERE 1./-salary > 1/-6E4; +chainedNegationDenominator +SELECT emp_no FROM test_emp WHERE 1./(-(-(-salary)) * -1) < 1/6E4; +negationProductDenominator +SELECT emp_no FROM test_emp WHERE 1/(-salary::FLOAT * -emp_no) > 1/5E8; +negationNumeratorAndDenominator +SELECT emp_no FROM test_emp WHERE 3 * (-languages)/(-salary::FLOAT * -emp_no) > -3/6E7; +negationsDoubleDenominator +SELECT emp_no FROM test_emp WHERE ((1000000000 / -salary) / 2) * (-1 / (2 / -emp_no::FLOAT)) / -1000 >= 50000; + diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/Add.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/Add.java index ef286d54a84a6..433647043724c 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/Add.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/Add.java @@ -6,13 +6,14 @@ package org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic; import org.elasticsearch.xpack.ql.expression.Expression; +import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.BinaryComparisonInversible; import org.elasticsearch.xpack.ql.tree.NodeInfo; import org.elasticsearch.xpack.ql.tree.Source; /** * Addition function ({@code a + b}). */ -public class Add extends DateTimeArithmeticOperation { +public class Add extends DateTimeArithmeticOperation implements BinaryComparisonInversible { public Add(Source source, Expression left, Expression right) { super(source, left, right, SqlBinaryArithmeticOperation.ADD); } @@ -31,4 +32,9 @@ protected Add replaceChildren(Expression left, Expression right) { public Add swapLeftAndRight() { return new Add(source(), right(), left()); } + + @Override + public ArithmeticOperationFactory binaryComparisonInverse() { + return Sub::new; + } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/Div.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/Div.java index 67da39f46eaf3..1de9fc741973a 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/Div.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/Div.java @@ -6,6 +6,8 @@ package org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic; import org.elasticsearch.xpack.ql.expression.Expression; +import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.BinaryComparisonInversible; +import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.Mul; import org.elasticsearch.xpack.ql.tree.NodeInfo; import org.elasticsearch.xpack.ql.tree.Source; import org.elasticsearch.xpack.ql.type.DataType; @@ -14,7 +16,7 @@ /** * Division function ({@code a / b}). */ -public class Div extends SqlArithmeticOperation { +public class Div extends SqlArithmeticOperation implements BinaryComparisonInversible { public Div(Source source, Expression left, Expression right) { super(source, left, right, SqlBinaryArithmeticOperation.DIV); @@ -34,4 +36,9 @@ protected Div replaceChildren(Expression newLeft, Expression newRight) { public DataType dataType() { return SqlDataTypeConverter.commonType(left().dataType(), right().dataType()); } + + @Override + public ArithmeticOperationFactory binaryComparisonInverse() { + return Mul::new; + } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/Mul.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/Mul.java index 8827a42355cf9..f26eb168c2810 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/Mul.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/Mul.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic; import org.elasticsearch.xpack.ql.expression.Expression; +import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.BinaryComparisonInversible; import org.elasticsearch.xpack.ql.tree.NodeInfo; import org.elasticsearch.xpack.ql.tree.Source; import org.elasticsearch.xpack.ql.type.DataType; @@ -17,7 +18,7 @@ /** * Multiplication function ({@code a * b}). */ -public class Mul extends SqlArithmeticOperation { +public class Mul extends SqlArithmeticOperation implements BinaryComparisonInversible { private DataType dataType; @@ -71,4 +72,9 @@ protected Mul replaceChildren(Expression newLeft, Expression newRight) { public Mul swapLeftAndRight() { return new Mul(source(), right(), left()); } + + @Override + public ArithmeticOperationFactory binaryComparisonInverse() { + return Div::new; + } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/SqlBinaryArithmeticOperation.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/SqlBinaryArithmeticOperation.java index 04be1399a1115..661a7c3f6500f 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/SqlBinaryArithmeticOperation.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/SqlBinaryArithmeticOperation.java @@ -52,7 +52,7 @@ public enum SqlBinaryArithmeticOperation implements BinaryArithmeticOperation { return IntervalArithmetics.add((Temporal) r, ((IntervalDayTime) l).interval()); } - throw new QlIllegalArgumentException("Cannot compute [+] between [{}] [{}]", l.getClass().getSimpleName(), + throw new QlIllegalArgumentException("Cannot compute [+] between [{}] and [{}]", l.getClass().getSimpleName(), r.getClass().getSimpleName()); }, "+"), SUB((Object l, Object r) -> { @@ -77,7 +77,7 @@ public enum SqlBinaryArithmeticOperation implements BinaryArithmeticOperation { throw new QlIllegalArgumentException("Cannot subtract a date from an interval; do you mean the reverse?"); } - throw new QlIllegalArgumentException("Cannot compute [-] between [{}] [{}]", l.getClass().getSimpleName(), + throw new QlIllegalArgumentException("Cannot compute [-] between [{}] and [{}]", l.getClass().getSimpleName(), r.getClass().getSimpleName()); }, "-"), MUL((Object l, Object r) -> { @@ -99,7 +99,7 @@ public enum SqlBinaryArithmeticOperation implements BinaryArithmeticOperation { return ((IntervalDayTime) l).mul(((Number) r).longValue()); } - throw new QlIllegalArgumentException("Cannot compute [*] between [{}] [{}]", l.getClass().getSimpleName(), + throw new QlIllegalArgumentException("Cannot compute [*] between [{}] and [{}]", l.getClass().getSimpleName(), r.getClass().getSimpleName()); }, "*"), DIV(Arithmetics::div, "/"), @@ -151,4 +151,4 @@ public static SqlBinaryArithmeticOperation read(StreamInput in) throws IOExcepti private static Object unwrapJodaTime(Object o) { return o instanceof JodaCompatibleZonedDateTime ? ((JodaCompatibleZonedDateTime) o).getZonedDateTime() : o; } -} \ No newline at end of file +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/Sub.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/Sub.java index 77f8eed05afa3..7ced4a81eb46d 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/Sub.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/Sub.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic; import org.elasticsearch.xpack.ql.expression.Expression; +import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.BinaryComparisonInversible; import org.elasticsearch.xpack.ql.tree.NodeInfo; import org.elasticsearch.xpack.ql.tree.Source; import org.elasticsearch.xpack.sql.type.SqlDataTypes; @@ -15,7 +16,7 @@ /** * Subtraction function ({@code a - b}). */ -public class Sub extends DateTimeArithmeticOperation { +public class Sub extends DateTimeArithmeticOperation implements BinaryComparisonInversible { public Sub(Source source, Expression left, Expression right) { super(source, left, right, SqlBinaryArithmeticOperation.SUB); @@ -43,4 +44,9 @@ protected TypeResolution resolveWithIntervals() { } return TypeResolution.TYPE_RESOLVED; } + + @Override + public ArithmeticOperationFactory binaryComparisonInverse() { + return Add::new; + } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java index 1f1afd3d823b3..038cc1c481418 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java @@ -44,6 +44,7 @@ import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PruneLiteralsInOrderBy; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.ReplaceRegexMatch; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.SetAsOptimized; +import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.SimplifyComparisonsArithmetics; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.TransformDirection; import org.elasticsearch.xpack.ql.plan.logical.Aggregate; import org.elasticsearch.xpack.ql.plan.logical.EsRelation; @@ -89,6 +90,7 @@ import org.elasticsearch.xpack.sql.plan.logical.SubQueryAlias; import org.elasticsearch.xpack.sql.session.EmptyExecutable; import org.elasticsearch.xpack.sql.session.SingletonExecutable; +import org.elasticsearch.xpack.sql.type.SqlDataTypes; import java.time.ZoneId; import java.util.ArrayList; @@ -148,6 +150,7 @@ protected Iterable.Batch> batches() { new PropagateEquals(), new CombineBinaryComparisons(), new CombineDisjunctionsToIn(), + new SimplifyComparisonsArithmetics(SqlDataTypes::areCompatible), // prune/elimination new PruneLiteralsInGroupBy(), new PruneDuplicatesInGroupBy(), @@ -158,7 +161,7 @@ protected Iterable.Batch> batches() { new PruneCast(), // order by alignment of the aggs new SortAggregateOnOrderBy() - ); + ); Batch aggregate = new Batch("Aggregation Rewrite", new ReplaceMinMaxWithTopHits(), diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/type/SqlDataTypes.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/type/SqlDataTypes.java index 52a24ef583193..f1f03a3717d46 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/type/SqlDataTypes.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/type/SqlDataTypes.java @@ -296,6 +296,7 @@ public static boolean areCompatible(DataType left, DataType right) { || (DataTypes.isString(left) && DataTypes.isString(right)) || (left.isNumeric() && right.isNumeric()) || (isDateBased(left) && isDateBased(right)) + || (isInterval(left) && isDateBased(right)) || (isDateBased(left) && isInterval(right)) || (isInterval(left) && isInterval(right) && Intervals.compatibleInterval(left, right) != null); } } diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/FieldAttributeTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/FieldAttributeTests.java index f93417c6ff95b..db5472424dc55 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/FieldAttributeTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/FieldAttributeTests.java @@ -165,7 +165,7 @@ public void testDottedFieldPathTypo() { public void testStarExpansionExcludesObjectAndUnsupportedTypes() { LogicalPlan plan = plan("SELECT * FROM test"); List list = ((Project) plan).projections(); - assertThat(list, hasSize(11)); + assertThat(list, hasSize(12)); List names = Expressions.names(list); assertThat(names, not(hasItem("some"))); assertThat(names, not(hasItem("some.dotted"))); diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/SqlBinaryArithmeticTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/SqlBinaryArithmeticTests.java index 3b3f6c003732f..238330918173c 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/SqlBinaryArithmeticTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/SqlBinaryArithmeticTests.java @@ -136,7 +136,7 @@ public void testAddDayTimeIntervalToTimeReverse() { public void testAddNumberToIntervalIllegal() { Literal r = interval(Duration.ofHours(2), INTERVAL_HOUR); QlIllegalArgumentException expect = expectThrows(QlIllegalArgumentException.class, () -> add(r, L(1))); - assertEquals("Cannot compute [+] between [IntervalDayTime] [Integer]", expect.getMessage()); + assertEquals("Cannot compute [+] between [IntervalDayTime] and [Integer]", expect.getMessage()); } public void testSubYearMonthIntervals() { @@ -210,7 +210,7 @@ public void testSubDayTimeIntervalToTime() { public void testSubNumberFromIntervalIllegal() { Literal r = interval(Duration.ofHours(2), INTERVAL_HOUR); QlIllegalArgumentException expect = expectThrows(QlIllegalArgumentException.class, () -> sub(r, L(1))); - assertEquals("Cannot compute [-] between [IntervalDayTime] [Integer]", expect.getMessage()); + assertEquals("Cannot compute [-] between [IntervalDayTime] and [Integer]", expect.getMessage()); } public void testMulIntervalNumber() { diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerRunTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerRunTests.java index 4a54a1e48b78b..a8d1832f92b5a 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerRunTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerRunTests.java @@ -5,11 +5,27 @@ */ package org.elasticsearch.xpack.sql.optimizer; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.ql.expression.Expression; +import org.elasticsearch.xpack.ql.expression.FieldAttribute; +import org.elasticsearch.xpack.ql.expression.Literal; +import org.elasticsearch.xpack.ql.expression.UnresolvedAttribute; import org.elasticsearch.xpack.ql.expression.function.FunctionRegistry; +import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; +import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThan; +import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThanOrEqual; +import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.LessThan; +import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.LessThanOrEqual; +import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.NotEquals; +import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.NullEquals; import org.elasticsearch.xpack.ql.index.EsIndex; import org.elasticsearch.xpack.ql.index.IndexResolution; +import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.BooleanLiteralsOnTheRight; +import org.elasticsearch.xpack.ql.plan.logical.Filter; import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.ql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.ql.type.EsField; import org.elasticsearch.xpack.sql.SqlTestUtils; import org.elasticsearch.xpack.sql.analysis.analyzer.Analyzer; @@ -18,7 +34,19 @@ import org.elasticsearch.xpack.sql.stats.Metrics; import org.elasticsearch.xpack.sql.types.SqlTypesTests; +import java.time.ZonedDateTime; +import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparisonProcessor.BinaryComparisonOperation.EQ; +import static org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparisonProcessor.BinaryComparisonOperation.GT; +import static org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparisonProcessor.BinaryComparisonOperation.GTE; +import static org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparisonProcessor.BinaryComparisonOperation.LT; +import static org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparisonProcessor.BinaryComparisonOperation.LTE; +import static org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparisonProcessor.BinaryComparisonOperation.NEQ; +import static org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparisonProcessor.BinaryComparisonOperation.NULLEQ; public class OptimizerRunTests extends ESTestCase { @@ -27,6 +55,18 @@ public class OptimizerRunTests extends ESTestCase { private final FunctionRegistry functionRegistry; private final Analyzer analyzer; private final Optimizer optimizer; + private static final Map> COMPARISONS = new HashMap<>() { + { + put(EQ.symbol(), Equals.class); + put(NULLEQ.symbol(), NullEquals.class); + put(NEQ.symbol(), NotEquals.class); + put(GT.symbol(), GreaterThan.class); + put(GTE.symbol(), GreaterThanOrEqual.class); + put(LT.symbol(), LessThan.class); + put(LTE.symbol(), LessThanOrEqual.class); + } + }; + private static final BooleanLiteralsOnTheRight LITERALS_ON_THE_RIGHT = new BooleanLiteralsOnTheRight(); public OptimizerRunTests() { parser = new SqlParser(); @@ -48,4 +88,182 @@ public void testWhereClause() { LogicalPlan p = plan("SELECT some.string l FROM test WHERE int IS NOT NULL AND int < 10005 ORDER BY int"); assertNotNull(p); } -} \ No newline at end of file + + public void testSimplifyComparisonArithmeticCommutativeVsNonCommutativeOps() { + doTestSimplifyComparisonArithmetics("int + 2 > 3", "int", ">", 1); + doTestSimplifyComparisonArithmetics("2 + int > 3", "int", ">", 1); + doTestSimplifyComparisonArithmetics("int - 2 > 3", "int", ">", 5); + doTestSimplifyComparisonArithmetics("2 - int > 3", "int", "<", -1); + doTestSimplifyComparisonArithmetics("int * 2 > 4", "int", ">", 2); + doTestSimplifyComparisonArithmetics("2 * int > 4", "int", ">", 2); + doTestSimplifyComparisonArithmetics("float / 2 > 4", "float", ">", 8d); + doTestSimplifyComparisonArithmetics("2 / float < 4", "float", ">", .5); + } + + public void testSimplifyComparisonArithmeticWithMultipleOps() { + // i >= 3 + doTestSimplifyComparisonArithmetics("((int + 1) * 2 - 4) * 4 >= 16", "int", ">=", 3); + } + + public void testSimplifyComparisonArithmeticWithFieldNegation() { + doTestSimplifyComparisonArithmetics("12 * (-int - 5) >= -120", "int", "<=", 5); + } + + public void testSimplifyComparisonArithmeticWithFieldDoubleNegation() { + doTestSimplifyComparisonArithmetics("12 * -(-int - 5) <= 120", "int", "<=", 5); + } + + public void testSimplifyComparisonArithmeticWithConjunction() { + doTestSimplifyComparisonArithmetics("12 * (-int - 5) = -120 AND int < 6 ", "int", "==", 5); + } + + public void testSimplifyComparisonArithmeticWithDisjunction() { + doTestSimplifyComparisonArithmetics("12 * (-int - 5) >= -120 OR int < 5", "int", "<=", 5); + } + + public void testSimplifyComparisonArithmeticWithFloatsAndDirectionChange() { + doTestSimplifyComparisonArithmetics("float / -2 < 4", "float", ">", -8d); + doTestSimplifyComparisonArithmetics("float * -2 < 4", "float", ">", -2d); + } + + public void testSimplyComparisonArithmeticWithUnfoldedProd() { + assertSemanticMatching("int * int >= 3", "((int * int + 1) * 2 - 4) * 4 >= 16"); + } + + public void testSimplifyComparisonArithmeticSkippedOnIntegerArithmeticalOverflow() { + assertNotSimplified("int - 1 " + randomBinaryComparison() + " " + Long.MAX_VALUE); + assertNotSimplified("1 - int " + randomBinaryComparison() + " " + Long.MIN_VALUE); + assertNotSimplified("int - 1 " + randomBinaryComparison() + " " + Integer.MAX_VALUE); + assertNotSimplified("1 - int " + randomBinaryComparison() + " " + Integer.MIN_VALUE); + } + + public void testSimplifyComparisonArithmeticSkippedOnIntegerArithmeticalOverflowOnNegation() { + assertNotSimplified("-int " + randomBinaryComparison() + " " + Long.MIN_VALUE); + assertNotSimplified("-int " + randomBinaryComparison() + " " + Integer.MIN_VALUE); + } + + public void testSimplifyComparisonArithmeticSkippedOnFloatingPointArithmeticalOverflow() { + assertNotSimplified("float / 10 " + randomBinaryComparison() + " " + Float.MAX_VALUE); + assertNotSimplified("float / " + Float.MAX_VALUE +" " + randomBinaryComparison() + " 10"); + assertNotSimplified("float / 10 " + randomBinaryComparison() + " " + Double.MAX_VALUE); + assertNotSimplified("float / " + Double.MAX_VALUE + " " + randomBinaryComparison() + " 10"); + // note: the "reversed" test (i.e.: MAX_VALUE / float < literal) would require a floating literal, which is skipped for other + // reason (see testSimplifyComparisonArithmeticSkippedOnFloats()) + } + + public void testSimplifyComparisonArithmeticSkippedOnNegatingOverflow() { + assertNotSimplified("-int " + randomBinaryComparison() + " " + Long.MIN_VALUE); + assertNotSimplified("-int " + randomBinaryComparison() + " " + Integer.MIN_VALUE); + } + + public void testSimplifyComparisonArithmeticSkippedOnDateOverflow() { + assertNotSimplified("date - INTERVAL 999999999 YEAR > '2010-01-01T01:01:01'::DATETIME"); + assertNotSimplified("date + INTERVAL -999999999 YEAR > '2010-01-01T01:01:01'::DATETIME"); + } + + public void testSimplifyComparisonArithmeticSkippedOnMulDivByZero() { + assertNotSimplified("float / 0 " + randomBinaryComparison() + " 1"); + assertNotSimplified("float * 0 " + randomBinaryComparison() + " 1"); + assertNotSimplified("int / 0 " + randomBinaryComparison() + " 1"); + assertNotSimplified("int * 0 " + randomBinaryComparison() + " 1"); + } + + public void testSimplifyComparisonArithmeticSkippedOnDiv() { + assertNotSimplified("int / 4 " + randomBinaryComparison() + " 1"); + assertNotSimplified("4 / int " + randomBinaryComparison() + " 1"); + } + + public void testSimplifyComparisonArithmeticSkippedOnResultingFloatLiteral() { + assertNotSimplified("int * 2 " + randomBinaryComparison() + " 3"); + } + + public void testSimplifyComparisonArithmeticSkippedOnFloatFieldWithPlusMinus() { + assertNotSimplified("float + 4 " + randomBinaryComparison() + " 1"); + assertNotSimplified("4 + float " + randomBinaryComparison() + " 1"); + assertNotSimplified("float - 4 " + randomBinaryComparison() + " 1"); + assertNotSimplified("4 - float " + randomBinaryComparison() + " 1"); + } + + public void testSimplifyComparisonArithmeticSkippedOnFloats() { + for (String field : List.of("int", "float")) { + for (Tuple nr : List.of(new Tuple<>(.4, 1), new Tuple<>(1, .4))) { + assertNotSimplified(field + " + " + nr.v1() + " " + randomBinaryComparison() + " " + nr.v2()); + assertNotSimplified(field + " - " + nr.v1() + " " + randomBinaryComparison() + " " + nr.v2()); + assertNotSimplified(nr.v1()+ " + " + field + " " + randomBinaryComparison() + " " + nr.v2()); + assertNotSimplified(nr.v1()+ " - " + field + " " + randomBinaryComparison() + " " + nr.v2()); + } + } + } + + public void testSimplifyComparisonArithmeticWithDateTime() { + doTestSimplifyComparisonArithmetics("date - INTERVAL 1 MONTH > '2010-01-01T01:01:01'::DATETIME", "date", ">", + ZonedDateTime.parse("2010-02-01T01:01:01Z")); + } + + public void testSimplifyComparisonArithmeticWithDate() { + doTestSimplifyComparisonArithmetics("date + INTERVAL 1 YEAR <= '2011-01-01T00:00:00'::DATE", "date", "<=", + ZonedDateTime.parse("2010-01-01T00:00:00Z")); + } + + public void testSimplifyComparisonArithmeticWithDateAndMultiplication() { + // the multiplication should be folded, but check + doTestSimplifyComparisonArithmetics("date + 2 * INTERVAL 1 YEAR <= '2012-01-01T00:00:00'::DATE", "date", "<=", + ZonedDateTime.parse("2010-01-01T00:00:00Z")); + } + + private void doTestSimplifyComparisonArithmetics(String expression, String fieldName, String compSymbol, Object bound) { + BinaryComparison bc = extractPlannedBinaryComparison(expression); + assertTrue(COMPARISONS.get(compSymbol).isInstance(bc)); + + assertTrue(bc.left() instanceof FieldAttribute); + FieldAttribute attribute = (FieldAttribute) bc.left(); + assertEquals(fieldName, attribute.name()); + + assertTrue(bc.right() instanceof Literal); + Literal literal = (Literal) bc.right(); + assertEquals(bound, literal.value()); + } + + private void assertNotSimplified(String condition) { + assertSemanticMatching(extractPlannedBinaryComparison(condition), parser.createExpression(condition)); + } + + private void assertSemanticMatching(String expected, String provided) { + BinaryComparison bc = extractPlannedBinaryComparison(provided); + Expression exp = parser.createExpression(expected); + assertSemanticMatching(bc, exp); + } + + private BinaryComparison extractPlannedBinaryComparison(String expression) { + LogicalPlan plan = planWithArithmeticCondition(expression); + + assertTrue(plan instanceof UnaryPlan); + UnaryPlan unaryPlan = (UnaryPlan) plan; + assertTrue(unaryPlan.child() instanceof Filter); + Filter filter = (Filter) unaryPlan.child(); + assertTrue(filter.condition() instanceof BinaryComparison); + return (BinaryComparison) filter.condition(); + } + + private LogicalPlan planWithArithmeticCondition(String condition) { + return plan("SELECT some.string FROM test WHERE " + condition); + } + + private static void assertSemanticMatching(Expression fieldAttributeExp, Expression unresolvedAttributeExp) { + Expression unresolvedUpdated = unresolvedAttributeExp + .transformUp(LITERALS_ON_THE_RIGHT::rule) + .transformUp(x -> x.foldable() ? new Literal(x.source(), x.fold(), x.dataType()) : x); + + List resolvedFields = fieldAttributeExp.collectFirstChildren(x -> x instanceof FieldAttribute); + for (Expression field : resolvedFields) { + FieldAttribute fa = (FieldAttribute) field; + unresolvedUpdated = unresolvedUpdated.transformDown(UnresolvedAttribute.class, x -> x.name().equals(fa.name()) ? fa : x); + } + + assertTrue(unresolvedUpdated.semanticEquals(fieldAttributeExp)); + } + + private static String randomBinaryComparison() { + return randomFrom(COMPARISONS.keySet().stream().map(x -> EQ.symbol().equals(x) ? "=" : x).collect(Collectors.toSet())); + } +} diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/plan/logical/command/sys/SysColumnsTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/plan/logical/command/sys/SysColumnsTests.java index 5911f68b2a9fc..8ccf92039b0b6 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/plan/logical/command/sys/SysColumnsTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/plan/logical/command/sys/SysColumnsTests.java @@ -54,7 +54,8 @@ public class SysColumnsTests extends ESTestCase { private static final String CLUSTER_NAME = "cluster"; private static final Map MAPPING1 = loadMapping("mapping-multi-field-with-nested.json", true); private static final Map MAPPING2 = loadMapping("mapping-multi-field-variation.json", true); - private static final int FIELD_COUNT = 16; + private static final int FIELD_COUNT1 = 16; + private static final int FIELD_COUNT2 = 17; private final SqlParser parser = new SqlParser(); @@ -62,7 +63,7 @@ private void sysColumnsInMode(Mode mode) { Class typeClass = mode == Mode.ODBC ? Short.class : Integer.class; List> rows = new ArrayList<>(); SysColumns.fillInRows("test", "index", MAPPING2, null, rows, null, mode); - assertEquals(FIELD_COUNT, rows.size()); + assertEquals(FIELD_COUNT2, rows.size()); assertEquals(24, rows.get(0).size()); List row = rows.get(0); @@ -72,39 +73,42 @@ private void sysColumnsInMode(Mode mode) { assertDriverType("int", Types.INTEGER, true, 11, 4, typeClass, row); row = rows.get(2); - assertDriverType("text", Types.VARCHAR, false, Integer.MAX_VALUE, Integer.MAX_VALUE, typeClass, row); + assertDriverType("float", Types.REAL, true, 15, 4, typeClass, row); row = rows.get(3); - assertDriverType("keyword", Types.VARCHAR, false, Short.MAX_VALUE - 1, Integer.MAX_VALUE, typeClass, row); + assertDriverType("text", Types.VARCHAR, false, Integer.MAX_VALUE, Integer.MAX_VALUE, typeClass, row); row = rows.get(4); - assertDriverType("date", Types.TIMESTAMP, false, 34, 8, typeClass, row); + assertDriverType("keyword", Types.VARCHAR, false, Short.MAX_VALUE - 1, Integer.MAX_VALUE, typeClass, row); row = rows.get(5); - assertDriverType("date_nanos", Types.TIMESTAMP, false, 34, 8, typeClass, row); + assertDriverType("date", Types.TIMESTAMP, false, 34, 8, typeClass, row); row = rows.get(6); - assertDriverType("some.dotted.field", Types.VARCHAR, false, Short.MAX_VALUE - 1, Integer.MAX_VALUE, typeClass, row); + assertDriverType("date_nanos", Types.TIMESTAMP, false, 34, 8, typeClass, row); row = rows.get(7); - assertDriverType("some.string", Types.VARCHAR, false, Integer.MAX_VALUE, Integer.MAX_VALUE, typeClass, row); + assertDriverType("some.dotted.field", Types.VARCHAR, false, Short.MAX_VALUE - 1, Integer.MAX_VALUE, typeClass, row); row = rows.get(8); - assertDriverType("some.string.normalized", Types.VARCHAR, false, Short.MAX_VALUE - 1, Integer.MAX_VALUE, typeClass, row); + assertDriverType("some.string", Types.VARCHAR, false, Integer.MAX_VALUE, Integer.MAX_VALUE, typeClass, row); row = rows.get(9); - assertDriverType("some.string.typical", Types.VARCHAR, false, Short.MAX_VALUE - 1, Integer.MAX_VALUE, typeClass, row); + assertDriverType("some.string.normalized", Types.VARCHAR, false, Short.MAX_VALUE - 1, Integer.MAX_VALUE, typeClass, row); row = rows.get(10); - assertDriverType("some.ambiguous", Types.VARCHAR, false, Integer.MAX_VALUE, Integer.MAX_VALUE, typeClass, row); + assertDriverType("some.string.typical", Types.VARCHAR, false, Short.MAX_VALUE - 1, Integer.MAX_VALUE, typeClass, row); row = rows.get(11); - assertDriverType("some.ambiguous.one", Types.VARCHAR, false, Short.MAX_VALUE - 1, Integer.MAX_VALUE, typeClass, row); + assertDriverType("some.ambiguous", Types.VARCHAR, false, Integer.MAX_VALUE, Integer.MAX_VALUE, typeClass, row); row = rows.get(12); - assertDriverType("some.ambiguous.two", Types.VARCHAR, false, Short.MAX_VALUE - 1, Integer.MAX_VALUE, typeClass, row); + assertDriverType("some.ambiguous.one", Types.VARCHAR, false, Short.MAX_VALUE - 1, Integer.MAX_VALUE, typeClass, row); row = rows.get(13); + assertDriverType("some.ambiguous.two", Types.VARCHAR, false, Short.MAX_VALUE - 1, Integer.MAX_VALUE, typeClass, row); + + row = rows.get(14); assertDriverType("some.ambiguous.normalized", Types.VARCHAR, false, Short.MAX_VALUE - 1, Integer.MAX_VALUE, typeClass, row); } @@ -162,7 +166,7 @@ private static Object sqlDataTypeSub(List list) { public void testSysColumnsNoArg() { executeCommand("SYS COLUMNS", emptyList(), r -> { - assertEquals(FIELD_COUNT, r.size()); + assertEquals(FIELD_COUNT1, r.size()); assertEquals(CLUSTER_NAME, r.column(0)); // no index specified assertEquals("test", r.column(2)); @@ -177,7 +181,7 @@ public void testSysColumnsNoArg() { public void testSysColumnsWithCatalogWildcard() { executeCommand("SYS COLUMNS CATALOG 'cluster' TABLE LIKE 'test' LIKE '%'", emptyList(), r -> { - assertEquals(FIELD_COUNT, r.size()); + assertEquals(FIELD_COUNT1, r.size()); assertEquals(CLUSTER_NAME, r.column(0)); assertEquals("test", r.column(2)); assertEquals("bool", r.column(3)); @@ -190,7 +194,7 @@ public void testSysColumnsWithCatalogWildcard() { public void testSysColumnsWithMissingCatalog() { executeCommand("SYS COLUMNS TABLE LIKE 'test' LIKE '%'", emptyList(), r -> { - assertEquals(FIELD_COUNT, r.size()); + assertEquals(FIELD_COUNT1, r.size()); assertEquals(CLUSTER_NAME, r.column(0)); assertEquals("test", r.column(2)); assertEquals("bool", r.column(3)); @@ -203,7 +207,7 @@ public void testSysColumnsWithMissingCatalog() { public void testSysColumnsWithNullCatalog() { executeCommand("SYS COLUMNS CATALOG ? TABLE LIKE 'test' LIKE '%'", singletonList(new SqlTypedParamValue("keyword", null)), r -> { - assertEquals(FIELD_COUNT, r.size()); + assertEquals(FIELD_COUNT1, r.size()); assertEquals(CLUSTER_NAME, r.column(0)); assertEquals("test", r.column(2)); assertEquals("bool", r.column(3)); @@ -220,8 +224,8 @@ public void testSysColumnsTypesInOdbcMode() { } public void testSysColumnsPaginationInOdbcMode() { - assertEquals(FIELD_COUNT, executeCommandInOdbcModeAndCountRows("SYS COLUMNS")); - assertEquals(FIELD_COUNT, executeCommandInOdbcModeAndCountRows("SYS COLUMNS TABLE LIKE 'test'")); + assertEquals(FIELD_COUNT1, executeCommandInOdbcModeAndCountRows("SYS COLUMNS")); + assertEquals(FIELD_COUNT1, executeCommandInOdbcModeAndCountRows("SYS COLUMNS TABLE LIKE 'test'")); } private int executeCommandInOdbcModeAndCountRows(String sql) { @@ -285,7 +289,7 @@ private Tuple sql(String sql, List para } private static void checkOdbcShortTypes(SchemaRowSet r) { - assertEquals(FIELD_COUNT, r.size()); + assertEquals(FIELD_COUNT1, r.size()); // https://github.com/elastic/elasticsearch/issues/35376 // cols that need to be of short type: DATA_TYPE, DECIMAL_DIGITS, NUM_PREC_RADIX, NULLABLE, SQL_DATA_TYPE, SQL_DATETIME_SUB List cols = Arrays.asList(4, 8, 9, 10, 13, 14); diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java index cac47897efb3d..6cd6e0535d137 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java @@ -2128,15 +2128,15 @@ public void testNoCountDoesNotTrackHits() { public void testZonedDateTimeInScripts() { PhysicalPlan p = optimizeAndPlan( - "SELECT date FROM test WHERE date + INTERVAL 1 YEAR > CAST('2019-03-11T12:34:56.000Z' AS DATETIME)"); + "SELECT date FROM test WHERE date - INTERVAL 999999999 YEAR > CAST('2019-03-11T12:34:56.000Z' AS DATETIME)"); assertEquals(EsQueryExec.class, p.getClass()); EsQueryExec eqe = (EsQueryExec) p; assertThat(eqe.queryContainer().toString().replaceAll("\\s+", ""), containsString( "\"script\":{\"script\":{\"source\":\"InternalQlScriptUtils.nullSafeFilter(" - + "InternalQlScriptUtils.gt(InternalSqlScriptUtils.add(InternalQlScriptUtils.docValue(doc,params.v0)," + + "InternalQlScriptUtils.gt(InternalSqlScriptUtils.sub(InternalQlScriptUtils.docValue(doc,params.v0)," + "InternalSqlScriptUtils.intervalYearMonth(params.v1,params.v2)),InternalSqlScriptUtils.asDateTime(params.v3)))\"," + "\"lang\":\"painless\"," - + "\"params\":{\"v0\":\"date\",\"v1\":\"P1Y\",\"v2\":\"INTERVAL_YEAR\",\"v3\":\"2019-03-11T12:34:56.000Z\"}},")); + + "\"params\":{\"v0\":\"date\",\"v1\":\"P999999999Y\",\"v2\":\"INTERVAL_YEAR\",\"v3\":\"2019-03-11T12:34:56.000Z\"}},")); } public void testChronoFieldBasedDateTimeFunctionsWithMathIntervalAndGroupBy() { diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/type/SqlDataTypesTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/type/SqlDataTypesTests.java index 5b2f5fabc8f1c..80829d6a3984c 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/type/SqlDataTypesTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/type/SqlDataTypesTests.java @@ -35,6 +35,7 @@ import static org.elasticsearch.xpack.sql.type.SqlDataTypes.INTERVAL_YEAR; import static org.elasticsearch.xpack.sql.type.SqlDataTypes.INTERVAL_YEAR_TO_MONTH; import static org.elasticsearch.xpack.sql.type.SqlDataTypes.TIME; +import static org.elasticsearch.xpack.sql.type.SqlDataTypes.areCompatible; import static org.elasticsearch.xpack.sql.type.SqlDataTypes.defaultPrecision; import static org.elasticsearch.xpack.sql.type.SqlDataTypes.isInterval; import static org.elasticsearch.xpack.sql.type.SqlDataTypes.metaSqlDataType; @@ -138,6 +139,27 @@ public void testIncompatibleInterval() { assertNull(compatibleInterval(INTERVAL_MINUTE_TO_SECOND, INTERVAL_MONTH)); } + public void testIntervalCompabitilityWithDateTimes() { + for (DataType intervalType : asList(INTERVAL_YEAR, + INTERVAL_MONTH, + INTERVAL_DAY, + INTERVAL_HOUR, + INTERVAL_MINUTE, + INTERVAL_SECOND, + INTERVAL_YEAR_TO_MONTH, + INTERVAL_DAY_TO_HOUR, + INTERVAL_DAY_TO_MINUTE, + INTERVAL_DAY_TO_SECOND, + INTERVAL_HOUR_TO_MINUTE, + INTERVAL_HOUR_TO_SECOND, + INTERVAL_MINUTE_TO_SECOND)) { + for (DataType dateTimeType: asList(DATE, DATETIME)) { + assertTrue(areCompatible(intervalType, dateTimeType)); + assertTrue(areCompatible(dateTimeType, intervalType)); + } + } + } + public void testEsToDataType() { List types = new ArrayList<>(Arrays.asList("null", "boolean", "bool", "byte", "tinyint",