diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTest.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTest.java new file mode 100644 index 0000000000000..6e275598fb07f --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTest.java @@ -0,0 +1,146 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.LessThan; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; +import org.elasticsearch.xpack.ql.expression.Expression; +import org.elasticsearch.xpack.ql.expression.FieldAttribute; +import org.elasticsearch.xpack.ql.expression.Literal; +import org.elasticsearch.xpack.ql.expression.predicate.logical.And; +import org.elasticsearch.xpack.ql.expression.predicate.logical.Or; +import org.elasticsearch.xpack.ql.plan.logical.Filter; +import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.ql.tree.Source; +import org.elasticsearch.xpack.ql.type.DataTypes; + +import java.util.List; + +import static java.util.Arrays.asList; +import static org.elasticsearch.xpack.ql.TestUtils.getFieldAttribute; +import static org.elasticsearch.xpack.ql.TestUtils.relation; +import static org.elasticsearch.xpack.ql.tree.Source.EMPTY; +import static org.hamcrest.Matchers.contains; + +public class OptimizerRulesTest extends ESTestCase { + private static final Literal ONE = new Literal(Source.EMPTY, 1, DataTypes.INTEGER); + private static final Literal TWO = new Literal(Source.EMPTY, 2, DataTypes.INTEGER); + private static final Literal THREE = new Literal(Source.EMPTY, 3, DataTypes.INTEGER); + + private static Equals equalsOf(Expression left, Expression right) { + return new Equals(EMPTY, left, right, null); + } + + private static LessThan lessThanOf(Expression left, Expression right) { + return new LessThan(EMPTY, left, right, null); + } + + // + // CombineDisjunction in Equals + // + public void testTwoEqualsWithOr() { + FieldAttribute fa = getFieldAttribute("a"); + + Or or = new Or(EMPTY, equalsOf(fa, ONE), equalsOf(fa, TWO)); + Expression e = new OptimizerRules.CombineDisjunctionsToIn().rule(or); + assertEquals(In.class, e.getClass()); + In in = (In) e; + assertEquals(fa, in.value()); + assertThat(in.list(), contains(ONE, TWO)); + } + + public void testTwoEqualsWithSameValue() { + FieldAttribute fa = getFieldAttribute("a"); + + Or or = new Or(EMPTY, equalsOf(fa, ONE), equalsOf(fa, ONE)); + Expression e = new OptimizerRules.CombineDisjunctionsToIn().rule(or); + assertEquals(Equals.class, e.getClass()); + Equals eq = (Equals) e; + assertEquals(fa, eq.left()); + assertEquals(ONE, eq.right()); + } + + public void testOneEqualsOneIn() { + FieldAttribute fa = getFieldAttribute("a"); + + Or or = new Or(EMPTY, equalsOf(fa, ONE), new In(EMPTY, fa, List.of(TWO))); + Expression e = new OptimizerRules.CombineDisjunctionsToIn().rule(or); + assertEquals(In.class, e.getClass()); + In in = (In) e; + assertEquals(fa, in.value()); + assertThat(in.list(), contains(ONE, TWO)); + } + + public void testOneEqualsOneInWithSameValue() { + FieldAttribute fa = getFieldAttribute("a"); + + Or or = new Or(EMPTY, equalsOf(fa, ONE), new In(EMPTY, fa, asList(ONE, TWO))); + Expression e = new OptimizerRules.CombineDisjunctionsToIn().rule(or); + assertEquals(In.class, e.getClass()); + In in = (In) e; + assertEquals(fa, in.value()); + assertThat(in.list(), contains(ONE, TWO)); + } + + public void testSingleValueInToEquals() { + FieldAttribute fa = getFieldAttribute("a"); + + Equals equals = equalsOf(fa, ONE); + Or or = new Or(EMPTY, equals, new In(EMPTY, fa, List.of(ONE))); + Expression e = new OptimizerRules.CombineDisjunctionsToIn().rule(or); + assertEquals(equals, e); + } + + public void testEqualsBehindAnd() { + FieldAttribute fa = getFieldAttribute("a"); + + And and = new And(EMPTY, equalsOf(fa, ONE), equalsOf(fa, TWO)); + Filter dummy = new Filter(EMPTY, relation(), and); + LogicalPlan transformed = new OptimizerRules.CombineDisjunctionsToIn().apply(dummy); + assertSame(dummy, transformed); + assertEquals(and, ((Filter) transformed).condition()); + } + + public void testTwoEqualsDifferentFields() { + FieldAttribute fieldOne = getFieldAttribute("ONE"); + FieldAttribute fieldTwo = getFieldAttribute("TWO"); + + Or or = new Or(EMPTY, equalsOf(fieldOne, ONE), equalsOf(fieldTwo, TWO)); + Expression e = new OptimizerRules.CombineDisjunctionsToIn().rule(or); + assertEquals(or, e); + } + + public void testMultipleIn() { + FieldAttribute fa = getFieldAttribute("a"); + + Or firstOr = new Or(EMPTY, new In(EMPTY, fa, List.of(ONE)), new In(EMPTY, fa, List.of(TWO))); + Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, List.of(THREE))); + Expression e = new OptimizerRules.CombineDisjunctionsToIn().rule(secondOr); + assertEquals(In.class, e.getClass()); + In in = (In) e; + assertEquals(fa, in.value()); + assertThat(in.list(), contains(ONE, TWO, THREE)); + } + + public void testOrWithNonCombinableExpressions() { + FieldAttribute fa = getFieldAttribute("a"); + + Or firstOr = new Or(EMPTY, new In(EMPTY, fa, List.of(ONE)), lessThanOf(fa, TWO)); + Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, List.of(THREE))); + Expression e = new OptimizerRules.CombineDisjunctionsToIn().rule(secondOr); + assertEquals(Or.class, e.getClass()); + Or or = (Or) e; + assertEquals(or.left(), firstOr.right()); + assertEquals(In.class, or.right().getClass()); + In in = (In) or.right(); + assertEquals(fa, in.value()); + assertThat(in.list(), contains(ONE, THREE)); + } +}