diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index af8ad7a1fc435..93505fa4f20fc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -46,7 +46,6 @@ import org.elasticsearch.xpack.ql.expression.predicate.logical.Or; import org.elasticsearch.xpack.ql.expression.predicate.regex.RegexMatch; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules; -import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.BooleanFunctionEqualsElimination; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.ConstantFolding; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.LiteralsOnTheRight; import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PruneLiteralsInOrderBy; @@ -127,7 +126,7 @@ protected static Batch operators() { // needs to occur before BinaryComparison combinations (see class) new org.elasticsearch.xpack.esql.optimizer.OptimizerRules.PropagateEquals(), new PropagateNullable(), - new BooleanFunctionEqualsElimination(), + new org.elasticsearch.xpack.esql.optimizer.OptimizerRules.BooleanFunctionEqualsElimination(), new org.elasticsearch.xpack.esql.optimizer.OptimizerRules.CombineDisjunctionsToIn(), new SimplifyComparisonsArithmetics(EsqlDataTypes::areCompatible), // prune/elimination diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java index 3ae662580a200..38ac596135abb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java @@ -41,10 +41,12 @@ import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.expression.Expressions; import org.elasticsearch.xpack.ql.expression.Literal; +import org.elasticsearch.xpack.ql.expression.function.Function; import org.elasticsearch.xpack.ql.expression.predicate.Predicates; import org.elasticsearch.xpack.ql.expression.predicate.Range; import org.elasticsearch.xpack.ql.expression.predicate.logical.And; import org.elasticsearch.xpack.ql.expression.predicate.logical.BinaryLogic; +import org.elasticsearch.xpack.ql.expression.predicate.logical.Not; import org.elasticsearch.xpack.ql.expression.predicate.logical.Or; import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.ql.plan.QueryPlan; @@ -65,6 +67,7 @@ import java.util.Set; import static org.elasticsearch.xpack.ql.common.Failure.fail; +import static org.elasticsearch.xpack.ql.expression.Literal.FALSE; import static org.elasticsearch.xpack.ql.expression.Literal.TRUE; import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.combineOr; import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.splitOr; @@ -257,6 +260,35 @@ protected Expression rule(Or or) { } } + /** + * This rule must always be placed after {@link org.elasticsearch.xpack.ql.optimizer.OptimizerRules.LiteralsOnTheRight}, since it looks + * at TRUE/FALSE literals' existence on the right hand-side of the {@link Equals}/{@link NotEquals} expressions. + */ + public static final class BooleanFunctionEqualsElimination extends + org.elasticsearch.xpack.ql.optimizer.OptimizerRules.OptimizerExpressionRule { + + BooleanFunctionEqualsElimination() { + super(org.elasticsearch.xpack.ql.optimizer.OptimizerRules.TransformDirection.UP); + } + + @Override + protected Expression rule(BinaryComparison bc) { + if ((bc instanceof Equals || bc instanceof NotEquals) && bc.left() instanceof Function) { + // for expression "==" or "!=" TRUE/FALSE, return the expression itself or its negated variant + + // TODO: Replace use of QL Not with ESQL Not + if (TRUE.equals(bc.right())) { + return bc instanceof Equals ? bc.left() : new Not(bc.left().source(), bc.left()); + } + if (FALSE.equals(bc.right())) { + return bc instanceof Equals ? new Not(bc.left().source(), bc.left()) : bc.left(); + } + } + + return bc; + } + } + /** * Propagate Equals to eliminate conjuncted Ranges or BinaryComparisons. * When encountering a different Equals, non-containing {@link Range} or {@link BinaryComparison}, the conjunction becomes false. diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java index 1aac8efbe6f65..01fcd222a5141 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java @@ -23,7 +23,9 @@ import org.elasticsearch.xpack.ql.expression.predicate.Predicates; import org.elasticsearch.xpack.ql.expression.predicate.Range; import org.elasticsearch.xpack.ql.expression.predicate.logical.And; +import org.elasticsearch.xpack.ql.expression.predicate.logical.Not; import org.elasticsearch.xpack.ql.expression.predicate.logical.Or; +import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.ql.plan.logical.Filter; import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.ql.tree.Source; @@ -32,11 +34,10 @@ import java.util.List; import static java.util.Arrays.asList; -import static org.elasticsearch.xpack.ql.TestUtils.equalsOf; -import static org.elasticsearch.xpack.ql.TestUtils.nullEqualsOf; import static org.elasticsearch.xpack.ql.TestUtils.rangeOf; import static org.elasticsearch.xpack.ql.TestUtils.relation; import static org.elasticsearch.xpack.ql.expression.Literal.FALSE; +import static org.elasticsearch.xpack.ql.expression.Literal.NULL; import static org.elasticsearch.xpack.ql.expression.Literal.TRUE; import static org.elasticsearch.xpack.ql.tree.Source.EMPTY; import static org.hamcrest.Matchers.contains; @@ -182,6 +183,39 @@ public void testOrWithNonCombinableExpressions() { assertThat(in.list(), contains(ONE, THREE)); } + // Test BooleanFunctionEqualsElimination + public void testBoolEqualsSimplificationOnExpressions() { + OptimizerRules.BooleanFunctionEqualsElimination s = new OptimizerRules.BooleanFunctionEqualsElimination(); + Expression exp = new GreaterThan(EMPTY, getFieldAttribute(), new Literal(EMPTY, 0, DataTypes.INTEGER), null); + + assertEquals(exp, s.rule(new Equals(EMPTY, exp, TRUE))); + // TODO: Replace use of QL Not with ESQL Not + assertEquals(new Not(EMPTY, exp), s.rule(new Equals(EMPTY, exp, FALSE))); + } + + public void testBoolEqualsSimplificationOnFields() { + OptimizerRules.BooleanFunctionEqualsElimination s = new OptimizerRules.BooleanFunctionEqualsElimination(); + + FieldAttribute field = getFieldAttribute(); + + List comparisons = asList( + new Equals(EMPTY, field, TRUE), + new Equals(EMPTY, field, FALSE), + notEqualsOf(field, TRUE), + notEqualsOf(field, FALSE), + new Equals(EMPTY, NULL, TRUE), + new Equals(EMPTY, NULL, FALSE), + notEqualsOf(NULL, TRUE), + notEqualsOf(NULL, FALSE) + ); + + for (BinaryComparison comparison : comparisons) { + assertEquals(comparison, s.rule(comparison)); + } + } + + // Test Propagate Equals + // a == 1 AND a == 2 -> FALSE public void testDualEqualsConjunction() { FieldAttribute fa = getFieldAttribute();