diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/EsqlArithmeticOperation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/EsqlArithmeticOperation.java index ba283bc4d877b..d481ee955b63b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/EsqlArithmeticOperation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/EsqlArithmeticOperation.java @@ -36,7 +36,7 @@ public abstract class EsqlArithmeticOperation extends ArithmeticOperation implem * used just for its symbol. * The rest of the methods should not be triggered hence the UOE. */ - enum OperationSymbol implements BinaryArithmeticOperation { + public enum OperationSymbol implements BinaryArithmeticOperation { ADD("+"), SUB("-"), MUL("*"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java index a3b50ba9bc6d6..b7f72f5bfe842 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java @@ -222,4 +222,8 @@ public String formatIncompatibleTypesMessage() { ); } + @Override + public String toString() { + return left() + symbol() + right(); + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index f10ad235cf255..9b2171cf21c3b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -58,7 +58,6 @@ import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.LiteralsOnTheRight; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PruneLiteralsInOrderBy; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.SetAsOptimized; -import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.SimplifyComparisonsArithmetics; import org.elasticsearch.xpack.ql.plan.logical.Filter; import org.elasticsearch.xpack.ql.plan.logical.Limit; import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan; @@ -152,9 +151,10 @@ protected static Batch operators() { // needs to occur before BinaryComparison combinations (see class) new org.elasticsearch.xpack.esql.optimizer.OptimizerRules.PropagateEquals(), new PropagateNullable(), + new OptimizerRules.CombineBinaryComparisons(), new org.elasticsearch.xpack.esql.optimizer.OptimizerRules.BooleanFunctionEqualsElimination(), new org.elasticsearch.xpack.esql.optimizer.OptimizerRules.CombineDisjunctionsToIn(), - new SimplifyComparisonsArithmetics(EsqlDataTypes::areCompatible), + new org.elasticsearch.xpack.esql.optimizer.OptimizerRules.SimplifyComparisonsArithmetics(EsqlDataTypes::areCompatible), // prune/elimination new PruneFilters(), new PruneColumns(), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java index 3b7245b38ae4c..f8edf29b8c08a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java @@ -7,7 +7,11 @@ package org.elasticsearch.xpack.esql.optimizer; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; @@ -49,12 +53,17 @@ import org.elasticsearch.xpack.ql.expression.predicate.logical.BinaryLogic; import org.elasticsearch.xpack.ql.expression.predicate.logical.Not; import org.elasticsearch.xpack.ql.expression.predicate.logical.Or; +import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.ArithmeticOperation; +import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.BinaryArithmeticOperation; +import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.BinaryComparisonInversible; import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.ql.plan.QueryPlan; import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.ql.type.DataType; import org.elasticsearch.xpack.ql.type.DataTypes; import org.elasticsearch.xpack.ql.util.CollectionUtils; +import java.time.DateTimeException; import java.time.ZoneId; import java.util.ArrayList; import java.util.Iterator; @@ -64,12 +73,21 @@ import java.util.List; import java.util.Map; import java.util.Set; - +import java.util.function.BiFunction; + +import static java.lang.Math.signum; +import static java.util.Arrays.asList; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.ADD; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.DIV; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.MOD; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.MUL; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.SUB; import static org.elasticsearch.xpack.ql.common.Failure.fail; import static org.elasticsearch.xpack.ql.expression.Literal.FALSE; import static org.elasticsearch.xpack.ql.expression.Literal.TRUE; import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.combineOr; import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.splitOr; +import static org.elasticsearch.xpack.ql.tree.Source.EMPTY; class OptimizerRules { @@ -612,4 +630,217 @@ private static Expression propagate(Or or) { return updated ? Predicates.combineOr(CollectionUtils.combine(exps, equals, notEquals, inequalities, ranges)) : or; } } + + /** + * Simplifies arithmetic expressions with BinaryComparisons and fixed point fields, such as: (int + 2) / 3 > 4 => int > 10 + */ + public static final class SimplifyComparisonsArithmetics extends + org.elasticsearch.xpack.ql.optimizer.OptimizerRules.OptimizerExpressionRule { + BiFunction typesCompatible; + + SimplifyComparisonsArithmetics(BiFunction typesCompatible) { + super(org.elasticsearch.xpack.ql.optimizer.OptimizerRules.TransformDirection.UP); + this.typesCompatible = typesCompatible; + } + + @Override + protected Expression rule(EsqlBinaryComparison bc) { + // optimize only once the expression has a literal on the right side of the binary comparison + if (bc.right() instanceof Literal) { + if (bc.left() instanceof ArithmeticOperation) { + return simplifyBinaryComparison(bc); + } + if (bc.left() instanceof Neg) { + return foldNegation(bc); + } + } + return bc; + } + + private Expression simplifyBinaryComparison(EsqlBinaryComparison comparison) { + EsqlArithmeticOperation operation = (EsqlArithmeticOperation) comparison.left(); + BinaryArithmeticOperation function = operation.function(); + // Modulo can't be simplified. + if (function.equals(MOD)) { + return comparison; + } + OperationSimplifier simplification = null; + if (isMulOrDiv(function)) { + simplification = new MulDivSimplifier(comparison); + } else if (function.equals(ADD) || function.equals(SUB)) { + simplification = new AddSubSimplifier(comparison); + } + + return (simplification == null || simplification.isUnsafe(typesCompatible)) ? comparison : simplification.apply(); + } + + private static boolean isMulOrDiv(BinaryArithmeticOperation op) { + return op.equals(MUL) || op.equals(DIV); + } + + private static Expression foldNegation(EsqlBinaryComparison bc) { + Literal bcLiteral = (Literal) bc.right(); + Expression literalNeg = tryFolding(new Neg(bcLiteral.source(), bcLiteral)); + return literalNeg == null ? bc : bc.reverse().replaceChildren(asList(((Neg) bc.left()).field(), literalNeg)); + } + + private static Expression tryFolding(Expression expression) { + if (expression.foldable()) { + try { + expression = new Literal(expression.source(), expression.fold(), expression.dataType()); + } catch (ArithmeticException | DateTimeException e) { + // null signals that folding would result in an over-/underflow (such as Long.MAX_VALUE+1); the optimisation is skipped. + expression = null; + } + } + return expression; + } + + private abstract static class OperationSimplifier { + final EsqlBinaryComparison comparison; + final Literal bcLiteral; + final EsqlArithmeticOperation operation; + final Expression opLeft; + final Expression opRight; + final Literal opLiteral; + + OperationSimplifier(EsqlBinaryComparison comparison) { + this.comparison = comparison; + operation = (EsqlArithmeticOperation) comparison.left(); + bcLiteral = (Literal) comparison.right(); + + opLeft = operation.left(); + opRight = operation.right(); + + if (opLeft instanceof Literal) { + opLiteral = (Literal) opLeft; + } else if (opRight instanceof Literal) { + opLiteral = (Literal) opRight; + } else { + opLiteral = null; + } + } + + // can it be quickly fast-tracked that the operation can't be reduced? + final boolean isUnsafe(BiFunction typesCompatible) { + if (opLiteral == null) { + // one of the arithm. operands must be a literal, otherwise the operation wouldn't simplify anything + return true; + } + + // Only operations on fixed point literals are supported, since optimizing float point operations can also change the + // outcome of the filtering: + // x + 1e18 > 1e18::long will yield different results with a field value in [-2^6, 2^6], optimised vs original; + // x * (1 + 1e-15d) > 1 : same with a field value of (1 - 1e-15d) + // so consequently, int fields optimisation requiring FP arithmetic isn't possible either: (x - 1e-15) * (1 + 1e-15) > 1. + if (opLiteral.dataType().isRational() || bcLiteral.dataType().isRational()) { + return true; + } + + // the Literal will be moved to the right of the comparison, but only if data-compatible with what's there + if (typesCompatible.apply(bcLiteral.dataType(), opLiteral.dataType()) == false) { + return true; + } + + return isOpUnsafe(); + } + + final Expression apply() { + // force float point folding for FlP field + Literal bcl = operation.dataType().isRational() + // ? Literal.of(bcLiteral, ((Number) bcLiteral.value()).doubleValue()) + ? new Literal(bcLiteral.source(), ((Number) bcLiteral.value()).doubleValue(), DataTypes.DOUBLE) + : bcLiteral; + + Expression bcRightExpression = ((BinaryComparisonInversible) operation).binaryComparisonInverse() + .create(bcl.source(), bcl, opRight); + bcRightExpression = tryFolding(bcRightExpression); + return bcRightExpression != null + ? postProcess((EsqlBinaryComparison) comparison.replaceChildren(List.of(opLeft, bcRightExpression))) + : comparison; + } + + // operation-specific operations: + // - fast-tracking of simplification unsafety + abstract boolean isOpUnsafe(); + + // - post optimisation adjustments + Expression postProcess(EsqlBinaryComparison binaryComparison) { + return binaryComparison; + } + } + + private static class AddSubSimplifier extends SimplifyComparisonsArithmetics.OperationSimplifier { + + AddSubSimplifier(EsqlBinaryComparison comparison) { + super(comparison); + } + + @Override + boolean isOpUnsafe() { + // no ADD/SUB with floating fields + if (operation.dataType().isRational()) { + return true; + } + + if (operation.function().equals(SUB) && opRight instanceof Literal == false) { // such as: 1 - x > -MAX + // if next simplification step would fail on overflow anyway, skip the optimisation already + return tryFolding(new Sub(EMPTY, opLeft, bcLiteral)) == null; + } + + return false; + } + } + + private static class MulDivSimplifier extends SimplifyComparisonsArithmetics.OperationSimplifier { + + private final boolean isDiv; // and not MUL. + private final int opRightSign; // sign of the right operand in: (left) (op) (right) (comp) (literal) + + MulDivSimplifier(EsqlBinaryComparison comparison) { + super(comparison); + isDiv = operation.function().equals(DIV); + opRightSign = sign(opRight); + } + + @Override + boolean isOpUnsafe() { + // Integer divisions are not safe to optimise: x / 5 > 1 <=/=> x > 5 for x in [6, 9]; same for the `==` comp + if (operation.dataType().isInteger() && isDiv) { + return true; + } + + // If current operation is a multiplication, it's inverse will be a division: safe only if outcome is still integral. + if (isDiv == false && opLeft.dataType().isInteger()) { + long opLiteralValue = ((Number) opLiteral.value()).longValue(); + return opLiteralValue == 0 || ((Number) bcLiteral.value()).longValue() % opLiteralValue != 0; + } + + // can't move a 0 in Mul/Div comparisons + return opRightSign == 0; + } + + @Override + Expression postProcess(EsqlBinaryComparison binaryComparison) { + // negative multiplication/division changes the direction of the comparison + return opRightSign < 0 ? binaryComparison.reverse() : binaryComparison; + } + + private static int sign(Object obj) { + int sign = 1; + if (obj instanceof Number) { + sign = (int) signum(((Number) obj).doubleValue()); + } else if (obj instanceof Literal) { + sign = sign(((Literal) obj).value()); + } else if (obj instanceof Neg) { + sign = -sign(((Neg) obj).field()); + } else if (obj instanceof EsqlArithmeticOperation operation) { + if (isMulOrDiv(operation.function())) { + sign = sign(operation.left()) * sign(operation.right()); + } + } + return sign; + } + } + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypes.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypes.java index 44f6844544698..a3f6b42f18962 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypes.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypes.java @@ -230,6 +230,7 @@ public static boolean isRepresentable(DataType t) { && isCounterType(t) == false; } + @Deprecated public static boolean areCompatible(DataType left, DataType right) { if (left == right) { return true; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index aa31e21d850ab..0ad8851aa0368 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -65,6 +65,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; @@ -108,6 +109,7 @@ import org.elasticsearch.xpack.ql.plan.logical.Limit; import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.ql.plan.logical.OrderBy; +import org.elasticsearch.xpack.ql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.ql.tree.Source; import org.elasticsearch.xpack.ql.type.DataType; import org.elasticsearch.xpack.ql.type.DataTypes; @@ -137,6 +139,11 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.localSource; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; import static org.elasticsearch.xpack.esql.analysis.Analyzer.NO_FIELDS; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.EQ; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.GT; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.GTE; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LT; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LTE; import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.GEO_POINT; import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.GEO_SHAPE; import static org.elasticsearch.xpack.ql.TestUtils.getFieldAttribute; @@ -173,13 +180,14 @@ public class LogicalPlanOptimizerTests extends ESTestCase { private static final Literal ONE = L(1); private static final Literal TWO = L(2); private static final Literal THREE = L(3); - private static EsqlParser parser; private static Analyzer analyzer; private static LogicalPlanOptimizer logicalOptimizer; private static Map mapping; private static Map mappingAirports; + private static Map mappingTypes; private static Analyzer analyzerAirports; + private static Analyzer analyzerTypes; private static EnrichResolution enrichResolution; private static class SubstitutionOnlyOptimizer extends LogicalPlanOptimizer { @@ -219,6 +227,15 @@ public static void init() { new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultAirports, enrichResolution), TEST_VERIFIER ); + + // Some tests need additional types, so we load that index here and use it in the plan_types() function. + mappingTypes = loadMapping("mapping-all-types.json"); + EsIndex types = new EsIndex("types", mappingTypes, Set.of("types")); + IndexResolution getIndexResultTypes = IndexResolution.valid(types); + analyzerTypes = new Analyzer( + new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultTypes, enrichResolution), + TEST_VERIFIER + ); } public void testEmptyProjections() { @@ -4366,6 +4383,87 @@ private LogicalPlan planAirports(String query) { return optimized; } + private LogicalPlan planTypes(String query) { + return logicalOptimizer.optimize(analyzerTypes.analyze(parser.createStatement(query))); + } + + private EsqlBinaryComparison extractPlannedBinaryComparison(String expression) { + LogicalPlan plan = planTypes("FROM types | WHERE " + expression); + + assertTrue(plan instanceof UnaryPlan); + UnaryPlan unaryPlan = (UnaryPlan) plan; + assertTrue("Epxected top level Filter, foung [" + unaryPlan.child().toString() + "]", unaryPlan.child() instanceof Filter); + Filter filter = (Filter) unaryPlan.child(); + assertTrue( + "Expected filter condition to be a binary comparison but found [" + filter.condition() + "]", + filter.condition() instanceof EsqlBinaryComparison + ); + return (EsqlBinaryComparison) filter.condition(); + } + + private void doTestSimplifyComparisonArithmetics( + String expression, + String fieldName, + EsqlBinaryComparison.BinaryComparisonOperation opType, + Object bound + ) { + EsqlBinaryComparison bc = extractPlannedBinaryComparison(expression); + assertEquals(opType, bc.getFunctionType()); + + assertTrue( + "Expected left side of comparison to be a field attribute but found [" + bc.left() + "]", + bc.left() instanceof FieldAttribute + ); + FieldAttribute attribute = (FieldAttribute) bc.left(); + assertEquals(fieldName, attribute.name()); + + assertTrue("Expected right side of comparison to be a literal but found [" + bc.right() + "]", bc.right() instanceof Literal); + Literal literal = (Literal) bc.right(); + assertEquals(bound, literal.value()); + } + + public void testSimplifyComparisonArithmeticCommutativeVsNonCommutativeOps() { + doTestSimplifyComparisonArithmetics("integer + 2 > 3", "integer", GT, 1); + doTestSimplifyComparisonArithmetics("2 + integer > 3", "integer", GT, 1); + doTestSimplifyComparisonArithmetics("integer - 2 > 3", "integer", GT, 5); + doTestSimplifyComparisonArithmetics("2 - integer > 3", "integer", LT, -1); + doTestSimplifyComparisonArithmetics("integer * 2 > 4", "integer", GT, 2); + doTestSimplifyComparisonArithmetics("2 * integer > 4", "integer", GT, 2); + } + + @AwaitsFix(bugUrl = "") + public void testSimplifyComparisonArithmeticCommutativeVsNonCommutativeOps_WithFloats() { + doTestSimplifyComparisonArithmetics("float / 2 > 4", "float", GT, 8d); + doTestSimplifyComparisonArithmetics("2 / float < 4", "float", GT, .5); + } + + public void testSimplifyComparisonArithmeticWithMultipleOps() { + // i >= 3 + doTestSimplifyComparisonArithmetics("((integer + 1) * 2 - 4) * 4 >= 16", "integer", GTE, 3); + } + + public void testSimplifyComparisonArithmeticWithFieldNegation() { + doTestSimplifyComparisonArithmetics("12 * (-integer - 5) >= -120", "integer", LTE, 5); + } + + public void testSimplifyComparisonArithmeticWithFieldDoubleNegation() { + doTestSimplifyComparisonArithmetics("12 * -(-integer - 5) <= 120", "integer", LTE, 5); + } + + public void testSimplifyComparisonArithmeticWithConjunction() { + doTestSimplifyComparisonArithmetics("12 * (-integer - 5) == -120 AND integer < 6 ", "integer", EQ, 5); + } + + @AwaitsFix(bugUrl = "") + public void testSimplifyComparisonArithmeticWithDisjunction() { + doTestSimplifyComparisonArithmetics("12 * (-integer - 5) >= -120 OR integer < 5", "integer", LTE, 5); + } + + public void testSimplifyComparisonArithmeticWithFloatsAndDirectionChange() { + doTestSimplifyComparisonArithmetics("float / -2 < 4", "float", GT, -8d); + doTestSimplifyComparisonArithmetics("float * -2 < 4", "float", GT, -2d); + } + private void assertNullLiteral(Expression expression) { assertEquals(Literal.class, expression.getClass()); assertNull(expression.fold()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java index ee10c661029c0..425f65b087e27 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java @@ -8,6 +8,10 @@ package org.elasticsearch.xpack.esql.optimizer; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; @@ -15,6 +19,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals; +import org.elasticsearch.xpack.esql.type.EsqlDataTypes; import org.elasticsearch.xpack.ql.TestUtils; import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.expression.FieldAttribute; @@ -39,14 +44,17 @@ import static org.elasticsearch.xpack.ql.expression.Literal.NULL; import static org.elasticsearch.xpack.ql.expression.Literal.TRUE; import static org.elasticsearch.xpack.ql.tree.Source.EMPTY; +import static org.elasticsearch.xpack.ql.type.DataTypes.DOUBLE; import static org.hamcrest.Matchers.contains; public class OptimizerRulesTests extends ESTestCase { + private static final Literal ZERO = new Literal(Source.EMPTY, 0, DataTypes.INTEGER); private static final Literal ONE = new Literal(Source.EMPTY, 1, DataTypes.INTEGER); private static final Literal TWO = new Literal(Source.EMPTY, 2, DataTypes.INTEGER); private static final Literal THREE = new Literal(Source.EMPTY, 3, DataTypes.INTEGER); private static final Literal FOUR = new Literal(Source.EMPTY, 4, DataTypes.INTEGER); private static final Literal FIVE = new Literal(Source.EMPTY, 5, DataTypes.INTEGER); + public static final Literal TWO_DOUBLE = new Literal(EMPTY, 2d, DataTypes.DOUBLE); private static Equals equalsOf(Expression left, Expression right) { return new Equals(EMPTY, left, right, null); @@ -76,6 +84,10 @@ private static FieldAttribute getFieldAttribute() { return TestUtils.getFieldAttribute("a"); } + private static OptimizerRules.SimplifyComparisonsArithmetics simplifyComparisonsArithmetics() { + return new OptimizerRules.SimplifyComparisonsArithmetics(EsqlDataTypes::areCompatible); + } + // // CombineDisjunction in Equals // @@ -500,4 +512,111 @@ public void testEliminateRangeByEqualsInInterval() { Expression exp = rule.rule(new And(EMPTY, eq1, r)); assertEquals(eq1, exp); } + + // Tests for SimplifyComparisonArithmetics + // a + 1 == 0 -> a == -1 + public void testSimplifyEqualsAddition() { + FieldAttribute fa = getFieldAttribute(); + Equals before = equalsOf(new Add(EMPTY, fa, ONE), ZERO); + Equals expected = equalsOf(fa, new Literal(EMPTY, -1, DataTypes.INTEGER)); + OptimizerRules.SimplifyComparisonsArithmetics rule = simplifyComparisonsArithmetics(); + + Expression actual = rule.rule(before); + assertEquals(expected, actual); + } + + // a - 1 == 0 -> a == 1 + public void testSimplifyEqualsSubtraction() { + FieldAttribute fa = getFieldAttribute(); + Equals before = equalsOf(new Sub(EMPTY, fa, ONE), ZERO); + Equals expected = equalsOf(fa, ONE); + OptimizerRules.SimplifyComparisonsArithmetics rule = simplifyComparisonsArithmetics(); + + Expression actual = rule.rule(before); + assertEquals(expected, actual); + } + + // a * 2 == 4 -> a == 2 + public void testSimplifyEqualsMultiplicationExact() { + FieldAttribute fa = getFieldAttribute(); + Equals before = equalsOf(new Mul(EMPTY, fa, TWO), FOUR); + Equals expected = equalsOf(fa, TWO); + OptimizerRules.SimplifyComparisonsArithmetics rule = simplifyComparisonsArithmetics(); + + Expression actual = rule.rule(before); + assertEquals(expected, actual); + } + + // a * 3 == 4 -> no change (integers) + public void testSimplifyEqualsMultiplicationUneven() { + FieldAttribute fa = getFieldAttribute(); + Equals before = equalsOf(new Mul(EMPTY, fa, THREE), FOUR); + OptimizerRules.SimplifyComparisonsArithmetics rule = simplifyComparisonsArithmetics(); + + Expression actual = rule.rule(before); + // Since these are all integers, we aren't able to rearrange this. Although it seems like we could rewrite the whole thing to + // false, since no integer a satisfies this equality. + assertEquals(before, actual); + } + + // a * 0 == 2 -> Multiplications and divisions involving a zero are not optimized + public void testSimplifyEqualsMultiplicationZero() { + FieldAttribute fa = getFieldAttribute(); + Equals before = equalsOf(new Mul(EMPTY, fa, ZERO), TWO); + OptimizerRules.SimplifyComparisonsArithmetics rule = simplifyComparisonsArithmetics(); + + Expression actual = rule.rule(before); + assertEquals(before, actual); + } + + // a / 2 == 2 -> not optimized for integers + public void testSimplifyEqualsDivideInteger() { + FieldAttribute fa = getFieldAttribute(); + Equals before = equalsOf(new Div(EMPTY, fa, TWO), TWO); + OptimizerRules.SimplifyComparisonsArithmetics rule = simplifyComparisonsArithmetics(); + + Expression actual = rule.rule(before); + assertEquals(before, actual); + } + + // a / 2 == 2 -> Also not optimized, due to floating point rounding + public void testSimplifyEqualsDivideDouble() { + FieldAttribute fa = TestUtils.getFieldAttribute("a", DOUBLE); + Equals before = equalsOf(new Div(EMPTY, fa, TWO_DOUBLE), TWO_DOUBLE); + OptimizerRules.SimplifyComparisonsArithmetics rule = simplifyComparisonsArithmetics(); + + Expression actual = rule.rule(before); + assertEquals(before, actual); + } + + // a / 2 == 0 -> Multiplications and divisions involving a zero are not optimized + public void testSimplifyEqualsDivideZero() { + FieldAttribute fa = getFieldAttribute(); + Equals before = equalsOf(new Div(EMPTY, fa, TWO), ZERO); + OptimizerRules.SimplifyComparisonsArithmetics rule = simplifyComparisonsArithmetics(); + + Expression actual = rule.rule(before); + assertEquals(before, actual); + } + + // a % 3 == 2 -> Mod is not an invertable operation, so it can't be optimized in this way + public void testSimplifyEqualsMod() { + FieldAttribute fa = getFieldAttribute(); + Equals before = equalsOf(new Div(EMPTY, fa, THREE), TWO); + OptimizerRules.SimplifyComparisonsArithmetics rule = simplifyComparisonsArithmetics(); + + Expression actual = rule.rule(before); + assertEquals(before, actual); + } + + // a * -2 > 4 -> a < -2 + public void testSimplifyGreaterThanWithNegativeMultiplication() { + FieldAttribute fa = getFieldAttribute(); + GreaterThan before = greaterThanOf(new Mul(EMPTY, fa, new Literal(EMPTY, -2, DataTypes.INTEGER)), FOUR); + OptimizerRules.SimplifyComparisonsArithmetics rule = simplifyComparisonsArithmetics(); + + LessThan expected = lessThanOf(fa, new Literal(EMPTY, -2, DataTypes.INTEGER)); + Expression actual = rule.rule(before); + assertEquals(expected, actual); + } }