From 9f54d9a804f17eb20f090e011f0ef016081d3a20 Mon Sep 17 00:00:00 2001 From: Mark Tozzi Date: Fri, 17 May 2024 09:43:48 -0400 Subject: [PATCH] add tests for SimplifyComparisonArithmetics optimization rule (#108744) This adds in the tests from OptimizerRunTests in SQL to apply to ESQL. I've opened issues and applied the AwaitsFix annotation for those of the tests that are currently failing. --- .../comparison/EsqlBinaryComparison.java | 4 + .../optimizer/LogicalPlanOptimizerTests.java | 221 +++++++++++++++++- 2 files changed, 224 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java index a3b50ba9bc6d6..a8a8166f7c06f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java @@ -86,6 +86,10 @@ public static BinaryComparisonOperation readFromStream(StreamInput in) throws IO throw new IOException("No BinaryComparisonOperation found for id [" + id + "]"); } + public String symbol() { + return symbol; + } + public EsqlBinaryComparison buildNewInstance(Source source, Expression lhs, Expression rhs) { return constructor.apply(source, lhs, rhs); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 533cb45902594..0520ea76e3a6a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.logging.LoggerMessageFormat; import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.compute.aggregation.QuantileStates; +import org.elasticsearch.core.Tuple; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.TestBlockFactory; @@ -65,6 +66,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; @@ -95,19 +97,23 @@ import org.elasticsearch.xpack.ql.expression.NamedExpression; import org.elasticsearch.xpack.ql.expression.Nullability; import org.elasticsearch.xpack.ql.expression.ReferenceAttribute; +import org.elasticsearch.xpack.ql.expression.UnresolvedAttribute; import org.elasticsearch.xpack.ql.expression.predicate.Predicates; import org.elasticsearch.xpack.ql.expression.predicate.logical.And; import org.elasticsearch.xpack.ql.expression.predicate.logical.Or; import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNotNull; import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNull; +import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.ql.expression.predicate.regex.RLikePattern; import org.elasticsearch.xpack.ql.expression.predicate.regex.WildcardPattern; import org.elasticsearch.xpack.ql.index.EsIndex; import org.elasticsearch.xpack.ql.index.IndexResolution; +import org.elasticsearch.xpack.ql.optimizer.OptimizerRules; import org.elasticsearch.xpack.ql.plan.logical.Filter; import org.elasticsearch.xpack.ql.plan.logical.Limit; import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.ql.plan.logical.OrderBy; +import org.elasticsearch.xpack.ql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.ql.tree.Source; import org.elasticsearch.xpack.ql.type.DataType; import org.elasticsearch.xpack.ql.type.DataTypes; @@ -117,6 +123,7 @@ import org.junit.BeforeClass; import java.lang.reflect.Constructor; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -137,6 +144,11 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.localSource; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; import static org.elasticsearch.xpack.esql.analysis.Analyzer.NO_FIELDS; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.EQ; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.GT; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.GTE; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LT; +import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LTE; import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.GEO_POINT; import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.GEO_SHAPE; import static org.elasticsearch.xpack.ql.TestUtils.getFieldAttribute; @@ -173,16 +185,18 @@ public class LogicalPlanOptimizerTests extends ESTestCase { private static final Literal ONE = L(1); private static final Literal TWO = L(2); private static final Literal THREE = L(3); - private static EsqlParser parser; private static Analyzer analyzer; private static LogicalPlanOptimizer logicalOptimizer; private static Map mapping; private static Map mappingAirports; + private static Map mappingTypes; private static Analyzer analyzerAirports; + private static Analyzer analyzerTypes; private static Map mappingExtra; private static Analyzer analyzerExtra; private static EnrichResolution enrichResolution; + private static final OptimizerRules.LiteralsOnTheRight LITERALS_ON_THE_RIGHT = new OptimizerRules.LiteralsOnTheRight(); private static class SubstitutionOnlyOptimizer extends LogicalPlanOptimizer { static SubstitutionOnlyOptimizer INSTANCE = new SubstitutionOnlyOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG)); @@ -222,6 +236,15 @@ public static void init() { TEST_VERIFIER ); + // Some tests need additional types, so we load that index here and use it in the plan_types() function. + mappingTypes = loadMapping("mapping-all-types.json"); + EsIndex types = new EsIndex("types", mappingTypes, Set.of("types")); + IndexResolution getIndexResultTypes = IndexResolution.valid(types); + analyzerTypes = new Analyzer( + new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultTypes, enrichResolution), + TEST_VERIFIER + ); + // Some tests use mappings from mapping-extra.json to be able to test more types so we load it here mappingExtra = loadMapping("mapping-extra.json"); EsIndex extra = new EsIndex("extra", mappingExtra, Set.of("extra")); @@ -4438,11 +4461,207 @@ private LogicalPlan planExtra(String query) { return optimized; } + private LogicalPlan planTypes(String query) { + return logicalOptimizer.optimize(analyzerTypes.analyze(parser.createStatement(query))); + } + + private EsqlBinaryComparison extractPlannedBinaryComparison(String expression) { + LogicalPlan plan = planTypes("FROM types | WHERE " + expression); + + return extractPlannedBinaryComparison(plan); + } + + private static EsqlBinaryComparison extractPlannedBinaryComparison(LogicalPlan plan) { + assertTrue("Expected unary plan, found [" + plan + "]", plan instanceof UnaryPlan); + UnaryPlan unaryPlan = (UnaryPlan) plan; + assertTrue("Epxected top level Filter, foung [" + unaryPlan.child().toString() + "]", unaryPlan.child() instanceof Filter); + Filter filter = (Filter) unaryPlan.child(); + assertTrue( + "Expected filter condition to be a binary comparison but found [" + filter.condition() + "]", + filter.condition() instanceof EsqlBinaryComparison + ); + return (EsqlBinaryComparison) filter.condition(); + } + + private void doTestSimplifyComparisonArithmetics( + String expression, + String fieldName, + EsqlBinaryComparison.BinaryComparisonOperation opType, + Object bound + ) { + EsqlBinaryComparison bc = extractPlannedBinaryComparison(expression); + assertEquals(opType, bc.getFunctionType()); + + assertTrue( + "Expected left side of comparison to be a field attribute but found [" + bc.left() + "]", + bc.left() instanceof FieldAttribute + ); + FieldAttribute attribute = (FieldAttribute) bc.left(); + assertEquals(fieldName, attribute.name()); + + assertTrue("Expected right side of comparison to be a literal but found [" + bc.right() + "]", bc.right() instanceof Literal); + Literal literal = (Literal) bc.right(); + assertEquals(bound, literal.value()); + } + + private void assertSemanticMatching(String expected, String provided) { + BinaryComparison bc = extractPlannedBinaryComparison(provided); + LogicalPlan exp = analyzerTypes.analyze(parser.createStatement("FROM types | WHERE " + expected)); + assertSemanticMatching(bc, extractPlannedBinaryComparison(exp)); + } + + private static void assertSemanticMatching(Expression fieldAttributeExp, Expression unresolvedAttributeExp) { + Expression unresolvedUpdated = unresolvedAttributeExp.transformUp( + LITERALS_ON_THE_RIGHT.expressionToken(), + LITERALS_ON_THE_RIGHT::rule + ).transformUp(x -> x.foldable() ? new Literal(x.source(), x.fold(), x.dataType()) : x); + + List resolvedFields = fieldAttributeExp.collectFirstChildren(x -> x instanceof FieldAttribute); + for (Expression field : resolvedFields) { + FieldAttribute fa = (FieldAttribute) field; + unresolvedUpdated = unresolvedUpdated.transformDown(UnresolvedAttribute.class, x -> x.name().equals(fa.name()) ? fa : x); + } + + assertTrue(unresolvedUpdated.semanticEquals(fieldAttributeExp)); + } + + private Expression getComparisonFromLogicalPlan(LogicalPlan plan) { + List expressions = new ArrayList<>(); + plan.forEachExpression(Expression.class, expressions::add); + return expressions.get(0); + } + + private void assertNotSimplified(String comparison) { + String query = "FROM types | WHERE " + comparison; + Expression optimized = getComparisonFromLogicalPlan(planTypes(query)); + Expression raw = getComparisonFromLogicalPlan(analyzerTypes.analyze(parser.createStatement(query))); + + assertTrue(raw.semanticEquals(optimized)); + } + + private static String randomBinaryComparison() { + return randomFrom(EsqlBinaryComparison.BinaryComparisonOperation.values()).symbol(); + } + + public void testSimplifyComparisonArithmeticCommutativeVsNonCommutativeOps() { + doTestSimplifyComparisonArithmetics("integer + 2 > 3", "integer", GT, 1); + doTestSimplifyComparisonArithmetics("2 + integer > 3", "integer", GT, 1); + doTestSimplifyComparisonArithmetics("integer - 2 > 3", "integer", GT, 5); + doTestSimplifyComparisonArithmetics("2 - integer > 3", "integer", LT, -1); + doTestSimplifyComparisonArithmetics("integer * 2 > 4", "integer", GT, 2); + doTestSimplifyComparisonArithmetics("2 * integer > 4", "integer", GT, 2); + + } + + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108388") + public void testSimplifyComparisonArithmeticsWithFloatingPoints() { + doTestSimplifyComparisonArithmetics("float / 2 > 4", "float", GT, 8d); + } + + public void testAssertSemanticMatching() { + // This test is just to verify that the complicated assert logic is working on a known-good case + assertSemanticMatching("integer > 1", "integer + 2 > 3"); + } + + public void testSimplyComparisonArithmeticWithUnfoldedProd() { + assertSemanticMatching("integer * integer >= 3", "((integer * integer + 1) * 2 - 4) * 4 >= 16"); + } + + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108524") + public void testSimplifyComparisionArithmetics_floatDivision() { + doTestSimplifyComparisonArithmetics("2 / float < 4", "float", GT, .5); + } + + public void testSimplifyComparisonArithmeticWithMultipleOps() { + // i >= 3 + doTestSimplifyComparisonArithmetics("((integer + 1) * 2 - 4) * 4 >= 16", "integer", GTE, 3); + } + + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108743") + public void testSimplifyComparisonArithmeticWithFieldNegation() { + doTestSimplifyComparisonArithmetics("12 * (-integer - 5) >= -120", "integer", LTE, 5); + } + + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108743") + public void testSimplifyComparisonArithmeticWithFieldDoubleNegation() { + doTestSimplifyComparisonArithmetics("12 * -(-integer - 5) <= 120", "integer", LTE, 5); + } + + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108743") + public void testSimplifyComparisonArithmeticWithConjunction() { + doTestSimplifyComparisonArithmetics("12 * (-integer - 5) == -120 AND integer < 6 ", "integer", EQ, 5); + } + + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108525") + public void testSimplifyComparisonArithmeticWithDisjunction() { + doTestSimplifyComparisonArithmetics("12 * (-integer - 5) >= -120 OR integer < 5", "integer", LTE, 5); + } + + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108388") + public void testSimplifyComparisonArithmeticWithFloatsAndDirectionChange() { + doTestSimplifyComparisonArithmetics("float / -2 < 4", "float", GT, -8d); + doTestSimplifyComparisonArithmetics("float * -2 < 4", "float", GT, -2d); + } + private void assertNullLiteral(Expression expression) { assertEquals(Literal.class, expression.getClass()); assertNull(expression.fold()); } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108519") + public void testSimplifyComparisonArithmeticSkippedOnIntegerArithmeticalOverflow() { + assertNotSimplified("integer - 1 " + randomBinaryComparison() + " " + Long.MAX_VALUE); + assertNotSimplified("1 - integer " + randomBinaryComparison() + " " + Long.MIN_VALUE); + assertNotSimplified("integer - 1 " + randomBinaryComparison() + " " + Integer.MAX_VALUE); + assertNotSimplified("1 - integer " + randomBinaryComparison() + " " + Integer.MIN_VALUE); + } + + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108519") + public void testSimplifyComparisonArithmeticSkippedOnNegatingOverflow() { + assertNotSimplified("-integer " + randomBinaryComparison() + " " + Long.MIN_VALUE); + assertNotSimplified("-integer " + randomBinaryComparison() + " " + Integer.MIN_VALUE); + } + + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108519") + public void testSimplifyComparisonArithmeticSkippedOnDateOverflow() { + assertNotSimplified("date - 999999999 years > to_datetime(\"2010-01-01T01:01:01\")"); + assertNotSimplified("date + -999999999 years > to_datetime(\"2010-01-01T01:01:01\")"); + } + + public void testSimplifyComparisonArithmeticSkippedOnMulDivByZero() { + assertNotSimplified("float / 0 " + randomBinaryComparison() + " 1"); + assertNotSimplified("float * 0 " + randomBinaryComparison() + " 1"); + assertNotSimplified("integer / 0 " + randomBinaryComparison() + " 1"); + assertNotSimplified("integer * 0 " + randomBinaryComparison() + " 1"); + } + + public void testSimplifyComparisonArithmeticSkippedOnDiv() { + assertNotSimplified("integer / 4 " + randomBinaryComparison() + " 1"); + assertNotSimplified("4 / integer " + randomBinaryComparison() + " 1"); + } + + public void testSimplifyComparisonArithmeticSkippedOnResultingFloatLiteral() { + assertNotSimplified("integer * 2 " + randomBinaryComparison() + " 3"); + } + + public void testSimplifyComparisonArithmeticSkippedOnFloatFieldWithPlusMinus() { + assertNotSimplified("float + 4 " + randomBinaryComparison() + " 1"); + assertNotSimplified("4 + float " + randomBinaryComparison() + " 1"); + assertNotSimplified("float - 4 " + randomBinaryComparison() + " 1"); + assertNotSimplified("4 - float " + randomBinaryComparison() + " 1"); + } + + public void testSimplifyComparisonArithmeticSkippedOnFloats() { + for (String field : List.of("integer", "float")) { + for (Tuple nr : List.of(new Tuple<>(.4, 1), new Tuple<>(1, .4))) { + assertNotSimplified(field + " + " + nr.v1() + " " + randomBinaryComparison() + " " + nr.v2()); + assertNotSimplified(field + " - " + nr.v1() + " " + randomBinaryComparison() + " " + nr.v2()); + assertNotSimplified(nr.v1() + " + " + field + " " + randomBinaryComparison() + " " + nr.v2()); + assertNotSimplified(nr.v1() + " - " + field + " " + randomBinaryComparison() + " " + nr.v2()); + } + } + } + public static WildcardLike wildcardLike(Expression left, String exp) { return new WildcardLike(EMPTY, left, new WildcardPattern(exp)); }