From 395efe8ca6b04e6bf6d59099fe8fa6bf8a001fa4 Mon Sep 17 00:00:00 2001 From: Mark Tozzi Date: Fri, 22 Mar 2024 11:38:34 -0400 Subject: [PATCH] copy over BooleanFuncitonEqualsEliminiation and tests --- .../xpack/esql/optimizer/OptimizerRules.java | 35 +++++++++++++++++ .../esql/optimizer/OptimizerRulesTests.java | 38 +++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java index 645924907b6f5..ff64e72f8ef44 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.optimizer; import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.NotEquals; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.Eval; @@ -34,7 +35,10 @@ import org.elasticsearch.xpack.ql.expression.AttributeSet; import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.expression.Expressions; +import org.elasticsearch.xpack.ql.expression.function.Function; +import org.elasticsearch.xpack.ql.expression.predicate.logical.Not; import org.elasticsearch.xpack.ql.expression.predicate.logical.Or; +import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.ql.plan.QueryPlan; import org.elasticsearch.xpack.ql.plan.logical.Aggregate; import org.elasticsearch.xpack.ql.plan.logical.EsRelation; @@ -50,6 +54,8 @@ import java.util.Set; import static org.elasticsearch.xpack.ql.common.Failure.fail; +import static org.elasticsearch.xpack.ql.expression.Literal.FALSE; +import static org.elasticsearch.xpack.ql.expression.Literal.TRUE; import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.combineOr; import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.splitOr; @@ -230,4 +236,33 @@ protected Expression rule(Or or) { return e; } } + + /** + * This rule must always be placed after {@link org.elasticsearch.xpack.ql.optimizer.OptimizerRules.LiteralsOnTheRight}, since it looks + * at TRUE/FALSE literals' existence on the right hand-side of the {@link Equals}/{@link NotEquals} expressions. + */ + public static final class BooleanFunctionEqualsElimination extends org.elasticsearch.xpack.ql.optimizer.OptimizerRules.OptimizerExpressionRule { + + BooleanFunctionEqualsElimination() { + super(org.elasticsearch.xpack.ql.optimizer.OptimizerRules.TransformDirection.UP); + } + + @Override + protected Expression rule(BinaryComparison bc) { + if ((bc instanceof Equals || bc instanceof NotEquals) && bc.left() instanceof Function) { + // for expression "==" or "!=" TRUE/FALSE, return the expression itself or its negated variant + + // TODO: Replace use of QL Not with ESQL Not + if (TRUE.equals(bc.right())) { + return bc instanceof Equals ? bc.left() : new Not(bc.left().source(), bc.left()); + } + if (FALSE.equals(bc.right())) { + return bc instanceof Equals ? new Not(bc.left().source(), bc.left()) : bc.left(); + } + } + + return bc; + } + } + } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java index dd9704d57b12a..301d514aa8700 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java @@ -15,7 +15,10 @@ import org.elasticsearch.xpack.ql.expression.FieldAttribute; import org.elasticsearch.xpack.ql.expression.Literal; import org.elasticsearch.xpack.ql.expression.predicate.logical.And; +import org.elasticsearch.xpack.ql.expression.predicate.logical.Not; import org.elasticsearch.xpack.ql.expression.predicate.logical.Or; +import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; +import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.ql.plan.logical.Filter; import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.ql.tree.Source; @@ -25,7 +28,11 @@ import static java.util.Arrays.asList; import static org.elasticsearch.xpack.ql.TestUtils.getFieldAttribute; +import static org.elasticsearch.xpack.ql.TestUtils.notEqualsOf; import static org.elasticsearch.xpack.ql.TestUtils.relation; +import static org.elasticsearch.xpack.ql.expression.Literal.FALSE; +import static org.elasticsearch.xpack.ql.expression.Literal.NULL; +import static org.elasticsearch.xpack.ql.expression.Literal.TRUE; import static org.elasticsearch.xpack.ql.tree.Source.EMPTY; import static org.hamcrest.Matchers.contains; @@ -143,4 +150,35 @@ public void testOrWithNonCombinableExpressions() { assertEquals(fa, in.value()); assertThat(in.list(), contains(ONE, THREE)); } + + // Test BooleanFunctionEqualsElimination + public void testBoolEqualsSimplificationOnExpressions() { + OptimizerRules.BooleanFunctionEqualsElimination s = new OptimizerRules.BooleanFunctionEqualsElimination(); + Expression exp = new GreaterThan(EMPTY, getFieldAttribute("a"), new Literal(EMPTY, 0, DataTypes.INTEGER), null); + + assertEquals(exp, s.rule(new Equals(EMPTY, exp, TRUE))); + // TODO: Replace use of QL Not with ESQL Not + assertEquals(new Not(EMPTY, exp), s.rule(new Equals(EMPTY, exp, FALSE))); + } + + public void testBoolEqualsSimplificationOnFields() { + OptimizerRules.BooleanFunctionEqualsElimination s = new OptimizerRules.BooleanFunctionEqualsElimination(); + + FieldAttribute field = getFieldAttribute("a"); + + List comparisons = asList( + new Equals(EMPTY, field, TRUE), + new Equals(EMPTY, field, FALSE), + notEqualsOf(field, TRUE), + notEqualsOf(field, FALSE), + new Equals(EMPTY, NULL, TRUE), + new Equals(EMPTY, NULL, FALSE), + notEqualsOf(NULL, TRUE), + notEqualsOf(NULL, FALSE) + ); + + for (BinaryComparison comparison : comparisons) { + assertEquals(comparison, s.rule(comparison)); + } + } }