From 50b2e953f1cb47e143f663b12d73203a60b850e1 Mon Sep 17 00:00:00 2001 From: Mark Tozzi Date: Tue, 19 Mar 2024 12:48:40 -0400 Subject: [PATCH] copy ql disjunction rule into esql --- .../esql/optimizer/LogicalPlanOptimizer.java | 60 ++++++++++++++++++- .../optimizer/LogicalPlanOptimizerTests.java | 31 ++++++++++ 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index 59f0d46bf618a..883cb1a739a60 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -73,7 +73,9 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.LinkedHashSet; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; @@ -83,6 +85,8 @@ import static java.util.Collections.singleton; import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputExpressions; import static org.elasticsearch.xpack.ql.expression.Expressions.asAttributes; +import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.combineOr; +import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.splitOr; import static org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PropagateEquals; import static org.elasticsearch.xpack.ql.optimizer.OptimizerRules.TransformDirection; import static org.elasticsearch.xpack.ql.optimizer.OptimizerRules.TransformDirection.DOWN; @@ -1129,7 +1133,10 @@ private static Project pushDownPastProject(UnaryPlan parent) { * This rule does NOT check for type compatibility as that phase has been * already be verified in the analyzer. */ - public static class CombineDisjunctionsToIn extends OptimizerRules.CombineDisjunctionsToIn { + public static class CombineDisjunctionsToIn extends OptimizerRules.OptimizerExpressionRule { + public CombineDisjunctionsToIn() { + super(TransformDirection.UP); + } protected In createIn(Expression key, List values, ZoneId zoneId) { return new In(key.source(), key, values); @@ -1138,6 +1145,57 @@ protected In createIn(Expression key, List values, ZoneId zoneId) { protected Equals createEquals(Expression k, Set v, ZoneId finalZoneId) { return new Equals(k.source(), k, v.iterator().next(), finalZoneId); } + + @Override + protected Expression rule(Or or) { + Expression e = or; + // look only at equals and In + List exps = splitOr(e); + + Map> found = new LinkedHashMap<>(); + ZoneId zoneId = null; + List ors = new LinkedList<>(); + + for (Expression exp : exps) { + if (exp instanceof Equals eq) { + // consider only equals against foldables + if (eq.right().foldable()) { + found.computeIfAbsent(eq.left(), k -> new LinkedHashSet<>()).add(eq.right()); + } else { + ors.add(exp); + } + if (zoneId == null) { + zoneId = eq.zoneId(); + } + } else if (exp instanceof In in) { + found.computeIfAbsent(in.value(), k -> new LinkedHashSet<>()).addAll(in.list()); + if (zoneId == null) { + zoneId = in.zoneId(); + } + } else { + ors.add(exp); + } + } + + if (found.isEmpty() == false) { + // combine equals alongside the existing ors + final ZoneId finalZoneId = zoneId; + found.forEach( + (k, v) -> { ors.add(v.size() == 1 ? createEquals(k, v, finalZoneId) : createIn(k, new ArrayList<>(v), finalZoneId)); } + ); + + // TODO: this makes a QL `or`, not an ESQL `or` + Expression combineOr = combineOr(ors); + // check the result semantically since the result might different in order + // but be actually the same which can trigger a loop + // e.g. a == 1 OR a == 2 OR null --> null OR a in (1,2) --> literalsOnTheRight --> cycle + if (e.semanticEquals(combineOr) == false) { + e = combineOr; + } + } + + return e; + } } static class ReplaceLimitAndSortAsTopN extends OptimizerRules.OptimizerRule { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 1ce383f2327ad..905154cb50466 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -315,6 +315,37 @@ public void testQlComparisonOptimizationsApply() { assertThat(con.value(), equalTo(5)); } + public void testCombineDisjunctionToInEquals() { + LogicalPlan plan = plan(""" + from test + | where emp_no == 1 or emp_no == 2 + """); + var limit = as(plan, Limit.class); + var filter = as(limit.child(), Filter.class); + var condition = as(filter.condition(), In.class); + assertThat(condition.list(), equalTo(List.of(new Literal(EMPTY, 1, INTEGER), new Literal(EMPTY, 2, INTEGER)))); + } + + public void testCombineDisjunctionToInMixed() { + LogicalPlan plan = plan(""" + from test + | where emp_no == 1 or emp_no in (2) + """); + var limit = as(plan, Limit.class); + var filter = as(limit.child(), Filter.class); + var condition = as(filter.condition(), In.class); + assertThat(condition.list(), equalTo(List.of(new Literal(EMPTY, 1, INTEGER), new Literal(EMPTY, 2, INTEGER)))); + } + public void testCombineDisjunctionToInFromIn() { + LogicalPlan plan = plan(""" + from test + | where emp_no in (1) or emp_no in (2) + """); + var limit = as(plan, Limit.class); + var filter = as(limit.child(), Filter.class); + var condition = as(filter.condition(), In.class); + assertThat(condition.list(), equalTo(List.of(new Literal(EMPTY, 1, INTEGER), new Literal(EMPTY, 2, INTEGER)))); + } public void testCombineProjectionWithPruning() { var plan = plan(""" from test