diff --git a/docs/changelog/103099.yaml b/docs/changelog/103099.yaml new file mode 100644 index 0000000000000..c3fd3f9d7b8e4 --- /dev/null +++ b/docs/changelog/103099.yaml @@ -0,0 +1,6 @@ +pr: 103099 +summary: "ESQL: Simpify IS NULL/IS NOT NULL evaluation" +area: ES|QL +type: enhancement +issues: + - 103097 diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java index 3451a3981d3e3..e05dd9a00c567 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java @@ -7,10 +7,12 @@ package org.elasticsearch.xpack.esql.optimizer; +import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.TopN; import org.elasticsearch.xpack.esql.stats.SearchStats; import org.elasticsearch.xpack.ql.expression.Alias; +import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.expression.FieldAttribute; import org.elasticsearch.xpack.ql.expression.Literal; import org.elasticsearch.xpack.ql.expression.NamedExpression; @@ -37,7 +39,13 @@ public LocalLogicalPlanOptimizer(LocalLogicalOptimizerContext localLogicalOptimi @Override protected List> batches() { - var local = new Batch<>("Local rewrite", new ReplaceTopNWithLimitAndSort(), new ReplaceMissingFieldWithNull()); + var local = new Batch<>( + "Local rewrite", + Limiter.ONCE, + new ReplaceTopNWithLimitAndSort(), + new ReplaceMissingFieldWithNull(), + new InferIsNotNull() + ); var rules = new ArrayList>(); rules.add(local); @@ -116,6 +124,14 @@ else if (plan instanceof Project project) { } } + static class InferIsNotNull extends OptimizerRules.InferIsNotNull { + + @Override + protected boolean skipExpression(Expression e) { + return e instanceof Coalesce; + } + } + abstract static class ParameterizedOptimizerRule extends ParameterizedRule { public final LogicalPlan apply(LogicalPlan plan, P context) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java index bc46189e13827..ac2426f485fcc 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.xpack.esql.analysis.Analyzer; import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; +import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; import org.elasticsearch.xpack.esql.parser.EsqlParser; import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; @@ -21,9 +22,11 @@ import org.elasticsearch.xpack.ql.expression.Expressions; import org.elasticsearch.xpack.ql.expression.Literal; import org.elasticsearch.xpack.ql.expression.ReferenceAttribute; +import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNotNull; import org.elasticsearch.xpack.ql.index.EsIndex; import org.elasticsearch.xpack.ql.index.IndexResolution; import org.elasticsearch.xpack.ql.plan.logical.EsRelation; +import org.elasticsearch.xpack.ql.plan.logical.Filter; import org.elasticsearch.xpack.ql.plan.logical.Limit; import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.ql.plan.logical.Project; @@ -35,6 +38,7 @@ import java.util.Map; import static org.elasticsearch.xpack.esql.EsqlTestUtils.L; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_SEARCH_STATS; import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER; import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping; @@ -263,6 +267,38 @@ public void testMissingFieldInFilterNoProjection() { ); } + public void testIsNotNullOnCoalesce() { + var plan = localPlan(""" + from test + | where coalesce(emp_no, salary) is not null + """); + + var limit = as(plan, Limit.class); + var filter = as(limit.child(), Filter.class); + var inn = as(filter.condition(), IsNotNull.class); + var coalesce = as(inn.children().get(0), Coalesce.class); + assertThat(Expressions.names(coalesce.children()), contains("emp_no", "salary")); + var source = as(filter.child(), EsRelation.class); + } + + public void testIsNotNullOnExpression() { + var plan = localPlan(""" + from test + | eval x = emp_no + 1 + | where x is not null + """); + + var limit = as(plan, Limit.class); + var filter = as(limit.child(), Filter.class); + var inn = as(filter.condition(), IsNotNull.class); + assertThat(Expressions.names(inn.children()), contains("x")); + var eval = as(filter.child(), Eval.class); + filter = as(eval.child(), Filter.class); + inn = as(filter.condition(), IsNotNull.class); + assertThat(Expressions.names(inn.children()), contains("emp_no")); + var source = as(filter.child(), EsRelation.class); + } + private LocalRelation asEmptyRelation(Object o) { var empty = as(o, LocalRelation.class); assertThat(empty.supplier(), is(LocalSupplier.EMPTY)); @@ -285,6 +321,10 @@ private LogicalPlan localPlan(LogicalPlan plan, SearchStats searchStats) { return localPlan; } + private LogicalPlan localPlan(String query) { + return localPlan(plan(query), TEST_SEARCH_STATS); + } + @Override protected List filteredWarnings() { return withDefaultLimitWarning(super.filteredWarnings()); diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRules.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRules.java index 5d9736726b46f..f084b5cda4abe 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRules.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRules.java @@ -8,6 +8,8 @@ import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.xpack.ql.expression.Alias; +import org.elasticsearch.xpack.ql.expression.Attribute; +import org.elasticsearch.xpack.ql.expression.AttributeMap; import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.expression.Expressions; import org.elasticsearch.xpack.ql.expression.Literal; @@ -69,6 +71,7 @@ import static java.lang.Math.signum; import static java.util.Arrays.asList; +import static java.util.Collections.emptySet; import static org.elasticsearch.xpack.ql.expression.Literal.FALSE; import static org.elasticsearch.xpack.ql.expression.Literal.TRUE; import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.combineAnd; @@ -1785,6 +1788,93 @@ protected Expression nonNullify(Expression exp, Expression nonNullExp) { } } + /** + * Simplify IsNotNull targets by resolving the underlying expression to its root fields with unknown + * nullability. + * e.g. + * (x + 1) / 2 IS NOT NULL --> x IS NOT NULL AND (x+1) / 2 IS NOT NULL + * SUBSTRING(x, 3) > 4 IS NOT NULL --> x IS NOT NULL AND SUBSTRING(x, 3) > 4 IS NOT NULL + * When dealing with multiple fields, a conjunction/disjunction based on the predicate: + * (x + y) / 4 IS NOT NULL --> x IS NOT NULL AND y IS NOT NULL AND (x + y) / 4 IS NOT NULL + * This handles the case of fields nested inside functions or expressions in order to avoid: + * - having to evaluate the whole expression + * - not pushing down the filter due to expression evaluation + * IS NULL cannot be simplified since it leads to a disjunction which prevents the filter to be + * pushed down: + * (x + 1) IS NULL --> x IS NULL OR x + 1 IS NULL + * and x IS NULL cannot be pushed down + *
+ * Implementation-wise this rule goes bottom-up, keeping an alias up to date to the current plan + * and then looks for replacing the target. + */ + public static class InferIsNotNull extends Rule { + + @Override + public LogicalPlan apply(LogicalPlan plan) { + // the alias map is shared across the whole plan + AttributeMap aliases = new AttributeMap<>(); + // traverse bottom-up to pick up the aliases as we go + plan = plan.transformUp(p -> inspectPlan(p, aliases)); + return plan; + } + + private LogicalPlan inspectPlan(LogicalPlan plan, AttributeMap aliases) { + // inspect just this plan properties + plan.forEachExpression(Alias.class, a -> aliases.put(a.toAttribute(), a.child())); + // now go about finding isNull/isNotNull + LogicalPlan newPlan = plan.transformExpressionsOnlyUp(IsNotNull.class, inn -> inferNotNullable(inn, aliases)); + return newPlan; + } + + private Expression inferNotNullable(IsNotNull inn, AttributeMap aliases) { + Expression result = inn; + Set refs = resolveExpressionAsRootAttributes(inn.field(), aliases); + // no refs found or could not detect - return the original function + if (refs.size() > 0) { + // add IsNull for the filters along with the initial inn + var innList = CollectionUtils.combine(refs.stream().map(r -> (Expression) new IsNotNull(inn.source(), r)).toList(), inn); + result = Predicates.combineAnd(innList); + } + return result; + } + + /** + * Unroll the expression to its references to get to the root fields + * that really matter for filtering. + */ + protected Set resolveExpressionAsRootAttributes(Expression exp, AttributeMap aliases) { + Set resolvedExpressions = new LinkedHashSet<>(); + boolean changed = doResolve(exp, aliases, resolvedExpressions); + return changed ? resolvedExpressions : emptySet(); + } + + private boolean doResolve(Expression exp, AttributeMap aliases, Set resolvedExpressions) { + boolean changed = false; + // check if the expression can be skipped or is not nullabe + if (skipExpression(exp) || exp.nullable() == Nullability.FALSE) { + resolvedExpressions.add(exp); + } else { + for (Expression e : exp.references()) { + Expression resolved = aliases.resolve(e, e); + // found a root attribute, bail out + if (resolved instanceof Attribute a && resolved == e) { + resolvedExpressions.add(a); + // don't mark things as change if the original expression hasn't been broken down + changed |= resolved != exp; + } else { + // go further + changed |= doResolve(resolved, aliases, resolvedExpressions); + } + } + } + return changed; + } + + protected boolean skipExpression(Expression e) { + return false; + } + } + public static final class SetAsOptimized extends Rule { @Override diff --git a/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRulesTests.java b/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRulesTests.java index b14e46a96a9e6..1cab7dd87195b 100644 --- a/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRulesTests.java +++ b/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/optimizer/OptimizerRulesTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.ql.expression.Literal; import org.elasticsearch.xpack.ql.expression.Nullability; import org.elasticsearch.xpack.ql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.ql.expression.function.scalar.string.StartsWith; import org.elasticsearch.xpack.ql.expression.predicate.BinaryOperator; import org.elasticsearch.xpack.ql.expression.predicate.Predicates; import org.elasticsearch.xpack.ql.expression.predicate.Range; @@ -1768,7 +1769,90 @@ public void testPushDownFilterThroughAgg() throws Exception { // expected Filter expected = new Filter(EMPTY, new Aggregate(EMPTY, combinedFilter, emptyList(), emptyList()), aggregateCondition); assertEquals(expected, new PushDownAndCombineFilters().apply(fb)); + } + + public void testIsNotNullOnIsNullField() { + EsRelation relation = relation(); + var fieldA = getFieldAttribute("a"); + Expression inn = isNotNull(fieldA); + Filter f = new Filter(EMPTY, relation, inn); + + assertEquals(f, new OptimizerRules.InferIsNotNull().apply(f)); + } + + public void testIsNotNullOnOperatorWithOneField() { + EsRelation relation = relation(); + var fieldA = getFieldAttribute("a"); + Expression inn = isNotNull(new Add(EMPTY, fieldA, ONE)); + Filter f = new Filter(EMPTY, relation, inn); + Filter expected = new Filter(EMPTY, relation, new And(EMPTY, isNotNull(fieldA), inn)); + + assertEquals(expected, new OptimizerRules.InferIsNotNull().apply(f)); + } + + public void testIsNotNullOnOperatorWithTwoFields() { + EsRelation relation = relation(); + var fieldA = getFieldAttribute("a"); + var fieldB = getFieldAttribute("b"); + Expression inn = isNotNull(new Add(EMPTY, fieldA, fieldB)); + Filter f = new Filter(EMPTY, relation, inn); + Filter expected = new Filter(EMPTY, relation, new And(EMPTY, new And(EMPTY, isNotNull(fieldA), isNotNull(fieldB)), inn)); + + assertEquals(expected, new OptimizerRules.InferIsNotNull().apply(f)); + } + + public void testIsNotNullOnFunctionWithOneField() { + EsRelation relation = relation(); + var fieldA = getFieldAttribute("a"); + var pattern = L("abc"); + Expression inn = isNotNull( + new And(EMPTY, new TestStartsWith(EMPTY, fieldA, pattern, false), greaterThanOf(new Add(EMPTY, ONE, TWO), THREE)) + ); + + Filter f = new Filter(EMPTY, relation, inn); + Filter expected = new Filter(EMPTY, relation, new And(EMPTY, isNotNull(fieldA), inn)); + + assertEquals(expected, new OptimizerRules.InferIsNotNull().apply(f)); + } + + public void testIsNotNullOnFunctionWithTwoFields() { + EsRelation relation = relation(); + var fieldA = getFieldAttribute("a"); + var fieldB = getFieldAttribute("b"); + var pattern = L("abc"); + Expression inn = isNotNull(new TestStartsWith(EMPTY, fieldA, fieldB, false)); + + Filter f = new Filter(EMPTY, relation, inn); + Filter expected = new Filter(EMPTY, relation, new And(EMPTY, new And(EMPTY, isNotNull(fieldA), isNotNull(fieldB)), inn)); + + assertEquals(expected, new OptimizerRules.InferIsNotNull().apply(f)); + } + + public static class TestStartsWith extends StartsWith { + + public TestStartsWith(Source source, Expression input, Expression pattern, boolean caseInsensitive) { + super(source, input, pattern, caseInsensitive); + } + + @Override + public Expression replaceChildren(List newChildren) { + return new TestStartsWith(source(), newChildren.get(0), newChildren.get(1), isCaseInsensitive()); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, TestStartsWith::new, input(), pattern(), isCaseInsensitive()); + } + } + + public void testIsNotNullOnFunctionWithTwoField() {} + + private IsNotNull isNotNull(Expression field) { + return new IsNotNull(EMPTY, field); + } + private IsNull isNull(Expression field) { + return new IsNull(EMPTY, field); } private Literal nullOf(DataType dataType) {