diff --git a/docs/changelog/110548.yaml b/docs/changelog/110548.yaml new file mode 100644 index 0000000000000..d1c0952920889 --- /dev/null +++ b/docs/changelog/110548.yaml @@ -0,0 +1,6 @@ +pr: 110548 +summary: "[ES|QL] Add `CombineBinaryComparisons` rule" +area: ES|QL +type: enhancement +issues: + - 108525 diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java index 4e69dffa13bc8..a310f8028800c 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java @@ -101,12 +101,13 @@ import static org.junit.Assert.assertTrue; public final class EsqlTestUtils { + public static final Literal ONE = new Literal(Source.EMPTY, 1, DataType.INTEGER); public static final Literal TWO = new Literal(Source.EMPTY, 2, DataType.INTEGER); public static final Literal THREE = new Literal(Source.EMPTY, 3, DataType.INTEGER); public static final Literal FOUR = new Literal(Source.EMPTY, 4, DataType.INTEGER); public static final Literal FIVE = new Literal(Source.EMPTY, 5, DataType.INTEGER); - private static final Literal SIX = new Literal(Source.EMPTY, 6, DataType.INTEGER); + public static final Literal SIX = new Literal(Source.EMPTY, 6, DataType.INTEGER); public static Equals equalsOf(Expression left, Expression right) { return new Equals(EMPTY, left, right, null); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ints.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ints.csv-spec index c8cb6cf88a4f0..b399734151412 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ints.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ints.csv-spec @@ -986,3 +986,36 @@ s:double | emp_no:integer | salary:integer 1.0 | 10002 | 56371 1.0 | 10041 | 56415 ; + +CombineBinaryComparisonsMv +required_capability: combine_binary_comparisons + +row x = [1,2,3] +| where 12 * (-x - 5) >= -120 OR x < 5 +; +warning:Line 2:15: evaluation of [-x] failed, treating result as null. Only first 20 failures recorded. +warning:Line 2:15: java.lang.IllegalArgumentException: single-value function encountered multi-value +warning:Line 2:34: evaluation of [x < 5] failed, treating result as null. Only first 20 failures recorded. +warning:Line 2:34: java.lang.IllegalArgumentException: single-value function encountered multi-value + +x:integer +; + +CombineBinaryComparisonsEmp +required_capability: combine_binary_comparisons + +from employees +| where salary_change.int == 2 OR salary_change.int > 1 +| keep emp_no, salary_change.int +| sort emp_no +; +warning:Line 2:35: evaluation of [salary_change.int > 1] failed, treating result as null. Only first 20 failures recorded. +warning:Line 2:35: java.lang.IllegalArgumentException: single-value function encountered multi-value + +emp_no:integer |salary_change.int:integer +10044 | 8 +10046 | 2 +10066 | 5 +10079 | 7 +10086 | 13 +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 76cf95494f7ca..76562bbe6ebf0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -186,7 +186,12 @@ public enum Cap { /** * Support for match operator */ - MATCH_OPERATOR(true); + MATCH_OPERATOR(true), + + /** + * Add CombineBinaryComparisons rule. + */ + COMBINE_BINARY_COMPARISONS; private final boolean snapshotOnly; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index 96be1249a76ee..285dac0b4641c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.AddDefaultTopN; import org.elasticsearch.xpack.esql.optimizer.rules.BooleanFunctionEqualsElimination; import org.elasticsearch.xpack.esql.optimizer.rules.BooleanSimplification; +import org.elasticsearch.xpack.esql.optimizer.rules.CombineBinaryComparisons; import org.elasticsearch.xpack.esql.optimizer.rules.CombineDisjunctionsToIn; import org.elasticsearch.xpack.esql.optimizer.rules.CombineEvals; import org.elasticsearch.xpack.esql.optimizer.rules.CombineProjections; @@ -207,6 +208,7 @@ protected static Batch operators() { new PropagateEquals(), new PropagateNullable(), new BooleanFunctionEqualsElimination(), + new CombineBinaryComparisons(), new CombineDisjunctionsToIn(), new SimplifyComparisonsArithmetics(DataType::areCompatible), // prune/elimination diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineBinaryComparisons.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineBinaryComparisons.java new file mode 100644 index 0000000000000..0d1d5baf920d7 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineBinaryComparisons.java @@ -0,0 +1,210 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogic; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; +import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison; +import org.elasticsearch.xpack.esql.core.util.CollectionUtils; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals; + +import java.util.ArrayList; +import java.util.List; + +public final class CombineBinaryComparisons extends OptimizerRules.OptimizerExpressionRule { + + public CombineBinaryComparisons() { + super(OptimizerRules.TransformDirection.DOWN); + } + + @Override + public Expression rule(BinaryLogic e) { + if (e instanceof And and) { + return combine(and); + } else if (e instanceof Or or) { + return combine(or); + } + return e; + } + + // combine conjunction + private static Expression combine(And and) { + List bcs = new ArrayList<>(); + List exps = new ArrayList<>(); + boolean changed = false; + List andExps = Predicates.splitAnd(and); + + andExps.sort((o1, o2) -> { + if (o1 instanceof NotEquals && o2 instanceof NotEquals) { + return 0; // keep NotEquals' order + } else if (o1 instanceof NotEquals || o2 instanceof NotEquals) { + return o1 instanceof NotEquals ? 1 : -1; // push NotEquals up + } else { + return 0; // keep non-Ranges' and non-NotEquals' order + } + }); + for (Expression ex : andExps) { + if (ex instanceof BinaryComparison bc && (ex instanceof Equals || ex instanceof NotEquals) == false) { + if (bc.right().foldable() && (findExistingComparison(bc, bcs, true))) { + changed = true; + } else { + bcs.add(bc); + } + } else if (ex instanceof NotEquals neq) { + if (neq.right().foldable() && notEqualsIsRemovableFromConjunction(neq, bcs)) { + // the non-equality can simply be dropped: either superfluous or has been merged with an updated range/inequality + changed = true; + } else { // not foldable OR not overlapping + exps.add(ex); + } + } else { + exps.add(ex); + } + } + return changed ? Predicates.combineAnd(CollectionUtils.combine(exps, bcs)) : and; + } + + // combine disjunction + private static Expression combine(Or or) { + List bcs = new ArrayList<>(); + List exps = new ArrayList<>(); + boolean changed = false; + for (Expression ex : Predicates.splitOr(or)) { + if (ex instanceof BinaryComparison bc) { + if (bc.right().foldable() && findExistingComparison(bc, bcs, false)) { + changed = true; + } else { + bcs.add(bc); + } + } else { + exps.add(ex); + } + } + return changed ? Predicates.combineOr(CollectionUtils.combine(exps, bcs)) : or; + } + + /** + * Find commonalities between the given comparison in the given list. + * The method can be applied both for conjunctive (AND) or disjunctive purposes (OR). + */ + private static boolean findExistingComparison(BinaryComparison main, List bcs, boolean conjunctive) { + Object value = main.right().fold(); + // NB: the loop modifies the list (hence why the int is used) + for (int i = 0; i < bcs.size(); i++) { + BinaryComparison other = bcs.get(i); + // skip if cannot evaluate + if (other.right().foldable() == false) { + continue; + } + // if bc is a higher/lower value or gte vs gt, use it instead + if ((other instanceof GreaterThan || other instanceof GreaterThanOrEqual) + && (main instanceof GreaterThan || main instanceof GreaterThanOrEqual)) { + if (main.left().semanticEquals(other.left())) { + Integer compare = BinaryComparison.compare(value, other.right().fold()); + if (compare != null) { + // AND + if ((conjunctive && + // a > 3 AND a > 2 -> a > 3 + (compare > 0 || + // a > 2 AND a >= 2 -> a > 2 + (compare == 0 && main instanceof GreaterThan && other instanceof GreaterThanOrEqual))) || + // OR + (conjunctive == false && + // a > 2 OR a > 3 -> a > 2 + (compare < 0 || + // a >= 2 OR a > 2 -> a >= 2 + (compare == 0 && main instanceof GreaterThanOrEqual && other instanceof GreaterThan)))) { + bcs.remove(i); + bcs.add(i, main); + } + // found a match + return true; + } + return false; + } + } + // if bc is a lower/higher value or lte vs lt, use it instead + else if ((other instanceof LessThan || other instanceof LessThanOrEqual) + && (main instanceof LessThan || main instanceof LessThanOrEqual)) { + if (main.left().semanticEquals(other.left())) { + Integer compare = BinaryComparison.compare(value, other.right().fold()); + if (compare != null) { + // AND + if ((conjunctive && + // a < 2 AND a < 3 -> a < 2 + (compare < 0 || + // a < 2 AND a <= 2 -> a < 2 + (compare == 0 && main instanceof LessThan && other instanceof LessThanOrEqual))) || + // OR + (conjunctive == false && + // a < 2 OR a < 3 -> a < 3 + (compare > 0 || + // a <= 2 OR a < 2 -> a <= 2 + (compare == 0 && main instanceof LessThanOrEqual && other instanceof LessThan)))) { + bcs.remove(i); + bcs.add(i, main); + } + // found a match + return true; + } + return false; + } + } + } + return false; + } + + private static boolean notEqualsIsRemovableFromConjunction(NotEquals notEquals, List bcs) { + Object neqVal = notEquals.right().fold(); + Integer comp; + + // check on "condition-overlapping" inequalities: + // a != 2 AND a > 3 -> a > 3 (discard NotEquals) + // a != 2 AND a >= 2 -> a > 2 (discard NotEquals plus update inequality) + // a != 2 AND a > 1 -> nop (do nothing) + // + // a != 2 AND a < 3 -> nop + // a != 2 AND a <= 2 -> a < 2 + // a != 2 AND a < 1 -> a < 1 + for (int i = 0; i < bcs.size(); i++) { + BinaryComparison bc = bcs.get(i); + if (notEquals.left().semanticEquals(bc.left())) { + if (bc instanceof LessThan || bc instanceof LessThanOrEqual) { + comp = bc.right().foldable() ? BinaryComparison.compare(neqVal, bc.right().fold()) : null; + if (comp != null) { + if (comp >= 0) { + if (comp == 0 && bc instanceof LessThanOrEqual) { // a != 2 AND a <= 2 -> a < 2 + bcs.set(i, new LessThan(bc.source(), bc.left(), bc.right(), bc.zoneId())); + } // else : comp > 0 (a != 2 AND a a a < 2) + return true; + } // else: comp < 0 : a != 2 AND a nop + } // else: non-comparable, nop + } else if (bc instanceof GreaterThan || bc instanceof GreaterThanOrEqual) { + comp = bc.right().foldable() ? BinaryComparison.compare(neqVal, bc.right().fold()) : null; + if (comp != null) { + if (comp <= 0) { + if (comp == 0 && bc instanceof GreaterThanOrEqual) { // a != 2 AND a >= 2 -> a > 2 + bcs.set(i, new GreaterThan(bc.source(), bc.left(), bc.right(), bc.zoneId())); + } // else: comp < 0 (a != 2 AND a >/>= 3 -> a >/>= 3), or == 0 && bc i.of ">" (a != 2 AND a > 2 -> a > 2) + return true; + } // else: comp > 0 : a != 2 AND a >/>= 1 -> nop + } // else: non-comparable, nop + } // else: other non-relevant type + } + } + return false; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 43c3cb92dff66..b98a26e5946b0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -4829,7 +4829,6 @@ public void testSimplifyComparisonArithmeticWithConjunction() { doTestSimplifyComparisonArithmetics("12 * (-integer - 5) == -120 AND integer < 6 ", "integer", EQ, 5); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108525") public void testSimplifyComparisonArithmeticWithDisjunction() { doTestSimplifyComparisonArithmetics("12 * (-integer - 5) >= -120 OR integer < 5", "integer", LTE, 5); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineBinaryComparisonsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineBinaryComparisonsTests.java new file mode 100644 index 0000000000000..3c98b7fa23e8b --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineBinaryComparisonsTests.java @@ -0,0 +1,400 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals; + +import java.util.List; + +import static java.util.Arrays.asList; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.FIVE; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.FOUR; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.L; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.ONE; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.SIX; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.THREE; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.TWO; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.equalsOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.fieldAttribute; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOrEqualOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOrEqualOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.notEqualsOf; +import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; +import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; +import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; +import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; +import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; +import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; + +public class CombineBinaryComparisonsTests extends ESTestCase { + + private static final Expression DUMMY_EXPRESSION = + new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 0); + + public void testCombineBinaryComparisonsNotComparable() { + FieldAttribute fa = getFieldAttribute(); + LessThanOrEqual lte = lessThanOrEqualOf(fa, SIX); + LessThan lt = lessThanOf(fa, FALSE); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + + And and = new And(EMPTY, lte, lt); + Expression exp = rule.rule(and); + assertEquals(exp, and); + } + + // a <= 6 AND a < 5 -> a < 5 + public void testCombineBinaryComparisonsUpper() { + FieldAttribute fa = getFieldAttribute(); + LessThanOrEqual lte = lessThanOrEqualOf(fa, SIX); + LessThan lt = lessThanOf(fa, FIVE); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + + Expression exp = rule.rule(new And(EMPTY, lte, lt)); + assertEquals(LessThan.class, exp.getClass()); + LessThan r = (LessThan) exp; + assertEquals(FIVE, r.right()); + } + + // 6 <= a AND 5 < a -> 6 <= a + public void testCombineBinaryComparisonsLower() { + FieldAttribute fa = getFieldAttribute(); + GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, SIX); + GreaterThan gt = greaterThanOf(fa, FIVE); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + + Expression exp = rule.rule(new And(EMPTY, gte, gt)); + assertEquals(GreaterThanOrEqual.class, exp.getClass()); + GreaterThanOrEqual r = (GreaterThanOrEqual) exp; + assertEquals(SIX, r.right()); + } + + // 5 <= a AND 5 < a -> 5 < a + public void testCombineBinaryComparisonsInclude() { + FieldAttribute fa = getFieldAttribute(); + GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, FIVE); + GreaterThan gt = greaterThanOf(fa, FIVE); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + + Expression exp = rule.rule(new And(EMPTY, gte, gt)); + assertEquals(GreaterThan.class, exp.getClass()); + GreaterThan r = (GreaterThan) exp; + assertEquals(FIVE, r.right()); + } + + // 3 <= a AND 4 < a AND a <= 7 AND a < 6 -> 4 < a AND a < 6 + public void testCombineMultipleBinaryComparisons() { + FieldAttribute fa = getFieldAttribute(); + GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, THREE); + GreaterThan gt = greaterThanOf(fa, FOUR); + LessThanOrEqual lte = lessThanOrEqualOf(fa, L(7)); + LessThan lt = lessThanOf(fa, SIX); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + + Expression exp = rule.rule(new And(EMPTY, gte, new And(EMPTY, gt, new And(EMPTY, lt, lte)))); + assertEquals(And.class, exp.getClass()); + And and = (And) exp; + assertEquals(gt, and.left()); + assertEquals(lt, and.right()); + } + + // 3 <= a AND TRUE AND 4 < a AND a != 5 AND a <= 7 -> 4 < a AND a <= 7 AND a != 5 AND TRUE + public void testCombineMixedMultipleBinaryComparisons() { + FieldAttribute fa = getFieldAttribute(); + GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, THREE); + GreaterThan gt = greaterThanOf(fa, FOUR); + LessThanOrEqual lte = lessThanOrEqualOf(fa, L(7)); + Expression ne = notEqualsOf(fa, FIVE); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + + // TRUE AND a != 5 AND 4 < a <= 7 + Expression exp = rule.rule(new And(EMPTY, gte, new And(EMPTY, TRUE, new And(EMPTY, gt, new And(EMPTY, ne, lte))))); + assertEquals(And.class, exp.getClass()); + And and = ((And) exp); + assertEquals(And.class, and.right().getClass()); + And right = (And) and.right(); + assertEquals(gt, right.left()); + assertEquals(lte, right.right()); + assertEquals(And.class, and.left().getClass()); + And left = (And) and.left(); + assertEquals(TRUE, left.left()); + assertEquals(ne, left.right()); + } + + // 1 <= a AND a < 5 -> 1 <= a AND a < 5 + public void testCombineComparisonsIntoRange() { + FieldAttribute fa = getFieldAttribute(); + GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, ONE); + LessThan lt = lessThanOf(fa, FIVE); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + Expression exp = rule.rule(new And(EMPTY, gte, lt)); + assertEquals(And.class, exp.getClass()); + + And and = (And) exp; + assertEquals(gte, and.left()); + assertEquals(lt, and.right()); + } + + // a != 2 AND a > 3 -> a > 3 + public void testCombineBinaryComparisonsConjunction_Neq2AndGt3() { + FieldAttribute fa = getFieldAttribute(); + + NotEquals neq = notEqualsOf(fa, TWO); + GreaterThan gt = greaterThanOf(fa, THREE); + And and = new And(EMPTY, neq, gt); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + Expression exp = rule.rule(and); + assertEquals(gt, exp); + } + + // a != 2 AND a >= 2 -> a > 2 + public void testCombineBinaryComparisonsConjunction_Neq2AndGte2() { + FieldAttribute fa = getFieldAttribute(); + + NotEquals neq = notEqualsOf(fa, TWO); + GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, TWO); + And and = new And(EMPTY, neq, gte); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + Expression exp = rule.rule(and); + assertEquals(GreaterThan.class, exp.getClass()); + GreaterThan gt = (GreaterThan) exp; + assertEquals(TWO, gt.right()); + } + + // a != 2 AND a >= 1 -> nop + public void testCombineBinaryComparisonsConjunction_Neq2AndGte1() { + FieldAttribute fa = getFieldAttribute(); + + NotEquals neq = notEqualsOf(fa, TWO); + GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, ONE); + And and = new And(EMPTY, neq, gte); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + Expression exp = rule.rule(and); + assertEquals(And.class, exp.getClass()); // can't optimize + } + + // a != 2 AND a <= 3 -> nop + public void testCombineBinaryComparisonsConjunction_Neq2AndLte3() { + FieldAttribute fa = getFieldAttribute(); + + NotEquals neq = notEqualsOf(fa, TWO); + LessThanOrEqual lte = lessThanOrEqualOf(fa, THREE); + And and = new And(EMPTY, neq, lte); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + Expression exp = rule.rule(and); + assertEquals(and, exp); // can't optimize + } + + // a != 2 AND a <= 2 -> a < 2 + public void testCombineBinaryComparisonsConjunction_Neq2AndLte2() { + FieldAttribute fa = getFieldAttribute(); + + NotEquals neq = notEqualsOf(fa, TWO); + LessThanOrEqual lte = lessThanOrEqualOf(fa, TWO); + And and = new And(EMPTY, neq, lte); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + Expression exp = rule.rule(and); + assertEquals(LessThan.class, exp.getClass()); + LessThan lt = (LessThan) exp; + assertEquals(TWO, lt.right()); + } + + // a != 2 AND a <= 1 -> a <= 1 + public void testCombineBinaryComparisonsConjunction_Neq2AndLte1() { + FieldAttribute fa = getFieldAttribute(); + + NotEquals neq = notEqualsOf(fa, TWO); + LessThanOrEqual lte = lessThanOrEqualOf(fa, ONE); + And and = new And(EMPTY, neq, lte); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + Expression exp = rule.rule(and); + assertEquals(lte, exp); + } + + // Disjunction + + public void testCombineBinaryComparisonsDisjunctionNotComparable() { + FieldAttribute fa = getFieldAttribute(); + + GreaterThan gt1 = greaterThanOf(fa, ONE); + GreaterThan gt2 = greaterThanOf(fa, FALSE); + + Or or = new Or(EMPTY, gt1, gt2); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + Expression exp = rule.rule(or); + assertEquals(exp, or); + } + + // 2 < a OR 1 < a OR 3 < a -> 1 < a + public void testCombineBinaryComparisonsDisjunctionLowerBound() { + FieldAttribute fa = getFieldAttribute(); + + GreaterThan gt1 = greaterThanOf(fa, ONE); + GreaterThan gt2 = greaterThanOf(fa, TWO); + GreaterThan gt3 = greaterThanOf(fa, THREE); + + Or or = new Or(EMPTY, gt1, new Or(EMPTY, gt2, gt3)); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + Expression exp = rule.rule(or); + assertEquals(GreaterThan.class, exp.getClass()); + + GreaterThan gt = (GreaterThan) exp; + assertEquals(ONE, gt.right()); + } + + // 2 < a OR 1 < a OR 3 <= a -> 1 < a + public void testCombineBinaryComparisonsDisjunctionIncludeLowerBounds() { + FieldAttribute fa = getFieldAttribute(); + + GreaterThan gt1 = greaterThanOf(fa, ONE); + GreaterThan gt2 = greaterThanOf(fa, TWO); + GreaterThanOrEqual gte3 = greaterThanOrEqualOf(fa, THREE); + + Or or = new Or(EMPTY, new Or(EMPTY, gt1, gt2), gte3); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + Expression exp = rule.rule(or); + assertEquals(GreaterThan.class, exp.getClass()); + + GreaterThan gt = (GreaterThan) exp; + assertEquals(ONE, gt.right()); + } + + // a < 1 OR a < 2 OR a < 3 -> a < 3 + public void testCombineBinaryComparisonsDisjunctionUpperBound() { + FieldAttribute fa = getFieldAttribute(); + + LessThan lt1 = lessThanOf(fa, ONE); + LessThan lt2 = lessThanOf(fa, TWO); + LessThan lt3 = lessThanOf(fa, THREE); + + Or or = new Or(EMPTY, new Or(EMPTY, lt1, lt2), lt3); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + Expression exp = rule.rule(or); + assertEquals(LessThan.class, exp.getClass()); + + LessThan lt = (LessThan) exp; + assertEquals(THREE, lt.right()); + } + + // a < 2 OR a <= 2 OR a < 1 -> a <= 2 + public void testCombineBinaryComparisonsDisjunctionIncludeUpperBounds() { + FieldAttribute fa = getFieldAttribute(); + + LessThan lt1 = lessThanOf(fa, ONE); + LessThan lt2 = lessThanOf(fa, TWO); + LessThanOrEqual lte2 = lessThanOrEqualOf(fa, TWO); + + Or or = new Or(EMPTY, lt2, new Or(EMPTY, lte2, lt1)); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + Expression exp = rule.rule(or); + assertEquals(LessThanOrEqual.class, exp.getClass()); + + LessThanOrEqual lte = (LessThanOrEqual) exp; + assertEquals(TWO, lte.right()); + } + + // a < 2 OR 3 < a OR a < 1 OR 4 < a -> a < 2 OR 3 < a + public void testCombineBinaryComparisonsDisjunctionOfLowerAndUpperBounds() { + FieldAttribute fa = getFieldAttribute(); + + LessThan lt1 = lessThanOf(fa, ONE); + LessThan lt2 = lessThanOf(fa, TWO); + + GreaterThan gt3 = greaterThanOf(fa, THREE); + GreaterThan gt4 = greaterThanOf(fa, FOUR); + + Or or = new Or(EMPTY, new Or(EMPTY, lt2, gt3), new Or(EMPTY, lt1, gt4)); + + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + Expression exp = rule.rule(or); + assertEquals(Or.class, exp.getClass()); + + Or ro = (Or) exp; + + assertEquals(LessThan.class, ro.left().getClass()); + LessThan lt = (LessThan) ro.left(); + assertEquals(TWO, lt.right()); + assertEquals(GreaterThan.class, ro.right().getClass()); + GreaterThan gt = (GreaterThan) ro.right(); + assertEquals(THREE, gt.right()); + } + + // (a = 1 AND b = 3 AND c = 4) OR (a = 2 AND b = 3 AND c = 4) -> (b = 3 AND c = 4) AND (a = 1 OR a = 2) + public void testBooleanSimplificationCommonExpressionSubstraction() { + FieldAttribute fa = EsqlTestUtils.getFieldAttribute("a"); + FieldAttribute fb = EsqlTestUtils.getFieldAttribute("b"); + FieldAttribute fc = EsqlTestUtils.getFieldAttribute("c"); + + Expression a1 = equalsOf(fa, ONE); + Expression a2 = equalsOf(fa, TWO); + And common = new And(EMPTY, equalsOf(fb, THREE), equalsOf(fc, FOUR)); + And left = new And(EMPTY, a1, common); + And right = new And(EMPTY, a2, common); + Or or = new Or(EMPTY, left, right); + + Expression exp = new BooleanSimplification().rule(or); + assertEquals(new And(EMPTY, common, new Or(EMPTY, a1, a2)), exp); + } + + public void testBinaryComparisonAndOutOfRangeNotEqualsDifferentFields() { + FieldAttribute doubleOne = fieldAttribute("double", DOUBLE); + FieldAttribute doubleTwo = fieldAttribute("double2", DOUBLE); + FieldAttribute intOne = fieldAttribute("int", INTEGER); + FieldAttribute datetimeOne = fieldAttribute("datetime", INTEGER); + FieldAttribute keywordOne = fieldAttribute("keyword", KEYWORD); + FieldAttribute keywordTwo = fieldAttribute("keyword2", KEYWORD); + + List testCases = asList( + // double > 10 AND integer != -10 + new And(EMPTY, greaterThanOf(doubleOne, L(10)), notEqualsOf(intOne, L(-10))), + // keyword > '5' AND keyword2 != '48' + new And(EMPTY, greaterThanOf(keywordOne, L("5")), notEqualsOf(keywordTwo, L("48"))), + // keyword != '2021' AND datetime <= '2020-12-04T17:48:22.954240Z' + new And(EMPTY, notEqualsOf(keywordOne, L("2021")), lessThanOrEqualOf(datetimeOne, L("2020-12-04T17:48:22.954240Z"))), + // double > 10.1 AND double2 != -10.1 + new And(EMPTY, greaterThanOf(doubleOne, L(10.1d)), notEqualsOf(doubleTwo, L(-10.1d))) + ); + + for (And and : testCases) { + CombineBinaryComparisons rule = new CombineBinaryComparisons(); + Expression exp = rule.rule(and); + assertEquals("Rule should not have transformed [" + and.nodeString() + "]", and, exp); + } + } + +}