Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ESQL] Add tests for SimplifyComparisonArithmetics optimization rule #108744

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ public static BinaryComparisonOperation readFromStream(StreamInput in) throws IO
throw new IOException("No BinaryComparisonOperation found for id [" + id + "]");
}

public String symbol() {
return symbol;
}

public EsqlBinaryComparison buildNewInstance(Source source, Expression lhs, Expression rhs) {
return constructor.apply(source, lhs, rhs);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.compute.aggregation.QuantileStates;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.TestBlockFactory;
Expand Down Expand Up @@ -65,6 +66,7 @@
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
Expand Down Expand Up @@ -95,19 +97,23 @@
import org.elasticsearch.xpack.ql.expression.NamedExpression;
import org.elasticsearch.xpack.ql.expression.Nullability;
import org.elasticsearch.xpack.ql.expression.ReferenceAttribute;
import org.elasticsearch.xpack.ql.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.ql.expression.predicate.Predicates;
import org.elasticsearch.xpack.ql.expression.predicate.logical.And;
import org.elasticsearch.xpack.ql.expression.predicate.logical.Or;
import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNotNull;
import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNull;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.ql.expression.predicate.regex.RLikePattern;
import org.elasticsearch.xpack.ql.expression.predicate.regex.WildcardPattern;
import org.elasticsearch.xpack.ql.index.EsIndex;
import org.elasticsearch.xpack.ql.index.IndexResolution;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules;
import org.elasticsearch.xpack.ql.plan.logical.Filter;
import org.elasticsearch.xpack.ql.plan.logical.Limit;
import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.ql.plan.logical.OrderBy;
import org.elasticsearch.xpack.ql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataType;
import org.elasticsearch.xpack.ql.type.DataTypes;
Expand All @@ -117,6 +123,7 @@
import org.junit.BeforeClass;

import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand All @@ -137,6 +144,11 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.localSource;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
import static org.elasticsearch.xpack.esql.analysis.Analyzer.NO_FIELDS;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.EQ;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.GT;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.GTE;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LT;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LTE;
import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.GEO_POINT;
import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.GEO_SHAPE;
import static org.elasticsearch.xpack.ql.TestUtils.getFieldAttribute;
Expand Down Expand Up @@ -173,14 +185,16 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
private static final Literal ONE = L(1);
private static final Literal TWO = L(2);
private static final Literal THREE = L(3);

private static EsqlParser parser;
private static Analyzer analyzer;
private static LogicalPlanOptimizer logicalOptimizer;
private static Map<String, EsField> mapping;
private static Map<String, EsField> mappingAirports;
private static Map<String, EsField> mappingTypes;
private static Analyzer analyzerAirports;
private static Analyzer analyzerTypes;
private static EnrichResolution enrichResolution;
private static final OptimizerRules.LiteralsOnTheRight LITERALS_ON_THE_RIGHT = new OptimizerRules.LiteralsOnTheRight();

private static class SubstitutionOnlyOptimizer extends LogicalPlanOptimizer {
static SubstitutionOnlyOptimizer INSTANCE = new SubstitutionOnlyOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG));
Expand Down Expand Up @@ -219,6 +233,15 @@ public static void init() {
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultAirports, enrichResolution),
TEST_VERIFIER
);

// Some tests need additional types, so we load that index here and use it in the plan_types() function.
mappingTypes = loadMapping("mapping-all-types.json");
EsIndex types = new EsIndex("types", mappingTypes, Set.of("types"));
IndexResolution getIndexResultTypes = IndexResolution.valid(types);
analyzerTypes = new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultTypes, enrichResolution),
TEST_VERIFIER
);
}

public void testEmptyProjections() {
Expand Down Expand Up @@ -4366,11 +4389,207 @@ private LogicalPlan planAirports(String query) {
return optimized;
}

private LogicalPlan planTypes(String query) {
return logicalOptimizer.optimize(analyzerTypes.analyze(parser.createStatement(query)));
}

private EsqlBinaryComparison extractPlannedBinaryComparison(String expression) {
LogicalPlan plan = planTypes("FROM types | WHERE " + expression);

return extractPlannedBinaryComparison(plan);
}

private static EsqlBinaryComparison extractPlannedBinaryComparison(LogicalPlan plan) {
assertTrue("Expected unary plan, found [" + plan + "]", plan instanceof UnaryPlan);
UnaryPlan unaryPlan = (UnaryPlan) plan;
assertTrue("Epxected top level Filter, foung [" + unaryPlan.child().toString() + "]", unaryPlan.child() instanceof Filter);
Filter filter = (Filter) unaryPlan.child();
assertTrue(
"Expected filter condition to be a binary comparison but found [" + filter.condition() + "]",
filter.condition() instanceof EsqlBinaryComparison
);
return (EsqlBinaryComparison) filter.condition();
}

private void doTestSimplifyComparisonArithmetics(
String expression,
String fieldName,
EsqlBinaryComparison.BinaryComparisonOperation opType,
Object bound
) {
EsqlBinaryComparison bc = extractPlannedBinaryComparison(expression);
assertEquals(opType, bc.getFunctionType());

assertTrue(
"Expected left side of comparison to be a field attribute but found [" + bc.left() + "]",
bc.left() instanceof FieldAttribute
);
FieldAttribute attribute = (FieldAttribute) bc.left();
assertEquals(fieldName, attribute.name());

assertTrue("Expected right side of comparison to be a literal but found [" + bc.right() + "]", bc.right() instanceof Literal);
Literal literal = (Literal) bc.right();
assertEquals(bound, literal.value());
}

private void assertSemanticMatching(String expected, String provided) {
BinaryComparison bc = extractPlannedBinaryComparison(provided);
LogicalPlan exp = analyzerTypes.analyze(parser.createStatement("FROM types | WHERE " + expected));
assertSemanticMatching(bc, extractPlannedBinaryComparison(exp));
}

private static void assertSemanticMatching(Expression fieldAttributeExp, Expression unresolvedAttributeExp) {
Expression unresolvedUpdated = unresolvedAttributeExp.transformUp(
LITERALS_ON_THE_RIGHT.expressionToken(),
LITERALS_ON_THE_RIGHT::rule
).transformUp(x -> x.foldable() ? new Literal(x.source(), x.fold(), x.dataType()) : x);

List<Expression> resolvedFields = fieldAttributeExp.collectFirstChildren(x -> x instanceof FieldAttribute);
for (Expression field : resolvedFields) {
FieldAttribute fa = (FieldAttribute) field;
unresolvedUpdated = unresolvedUpdated.transformDown(UnresolvedAttribute.class, x -> x.name().equals(fa.name()) ? fa : x);
}

assertTrue(unresolvedUpdated.semanticEquals(fieldAttributeExp));
}

private Expression getComparisonFromLogicalPlan(LogicalPlan plan) {
List<Expression> expressions = new ArrayList<>();
plan.forEachExpression(Expression.class, expressions::add);
return expressions.get(0);
}

private void assertNotSimplified(String comparison) {
String query = "FROM types | WHERE " + comparison;
Expression optimized = getComparisonFromLogicalPlan(planTypes(query));
Expression raw = getComparisonFromLogicalPlan(analyzerTypes.analyze(parser.createStatement(query)));

assertTrue(raw.semanticEquals(optimized));
}

private static String randomBinaryComparison() {
return randomFrom(EsqlBinaryComparison.BinaryComparisonOperation.values()).symbol();
}

public void testSimplifyComparisonArithmeticCommutativeVsNonCommutativeOps() {
doTestSimplifyComparisonArithmetics("integer + 2 > 3", "integer", GT, 1);
doTestSimplifyComparisonArithmetics("2 + integer > 3", "integer", GT, 1);
doTestSimplifyComparisonArithmetics("integer - 2 > 3", "integer", GT, 5);
doTestSimplifyComparisonArithmetics("2 - integer > 3", "integer", LT, -1);
doTestSimplifyComparisonArithmetics("integer * 2 > 4", "integer", GT, 2);
doTestSimplifyComparisonArithmetics("2 * integer > 4", "integer", GT, 2);

}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108388")
public void testSimplifyComparisonArithmeticsWithFloatingPoints() {
doTestSimplifyComparisonArithmetics("float / 2 > 4", "float", GT, 8d);
}

public void testAssertSemanticMatching() {
// This test is just to verify that the complicated assert logic is working on a known-good case
assertSemanticMatching("integer > 1", "integer + 2 > 3");
}

public void testSimplyComparisonArithmeticWithUnfoldedProd() {
assertSemanticMatching("integer * integer >= 3", "((integer * integer + 1) * 2 - 4) * 4 >= 16");
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108524")
public void testSimplifyComparisionArithmetics_floatDivision() {
doTestSimplifyComparisonArithmetics("2 / float < 4", "float", GT, .5);
}

public void testSimplifyComparisonArithmeticWithMultipleOps() {
// i >= 3
doTestSimplifyComparisonArithmetics("((integer + 1) * 2 - 4) * 4 >= 16", "integer", GTE, 3);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108743")
public void testSimplifyComparisonArithmeticWithFieldNegation() {
doTestSimplifyComparisonArithmetics("12 * (-integer - 5) >= -120", "integer", LTE, 5);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108743")
public void testSimplifyComparisonArithmeticWithFieldDoubleNegation() {
doTestSimplifyComparisonArithmetics("12 * -(-integer - 5) <= 120", "integer", LTE, 5);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108743")
public void testSimplifyComparisonArithmeticWithConjunction() {
doTestSimplifyComparisonArithmetics("12 * (-integer - 5) == -120 AND integer < 6 ", "integer", EQ, 5);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108525")
public void testSimplifyComparisonArithmeticWithDisjunction() {
doTestSimplifyComparisonArithmetics("12 * (-integer - 5) >= -120 OR integer < 5", "integer", LTE, 5);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108388")
public void testSimplifyComparisonArithmeticWithFloatsAndDirectionChange() {
doTestSimplifyComparisonArithmetics("float / -2 < 4", "float", GT, -8d);
doTestSimplifyComparisonArithmetics("float * -2 < 4", "float", GT, -2d);
}

private void assertNullLiteral(Expression expression) {
assertEquals(Literal.class, expression.getClass());
assertNull(expression.fold());
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108519")
public void testSimplifyComparisonArithmeticSkippedOnIntegerArithmeticalOverflow() {
assertNotSimplified("integer - 1 " + randomBinaryComparison() + " " + Long.MAX_VALUE);
assertNotSimplified("1 - integer " + randomBinaryComparison() + " " + Long.MIN_VALUE);
assertNotSimplified("integer - 1 " + randomBinaryComparison() + " " + Integer.MAX_VALUE);
assertNotSimplified("1 - integer " + randomBinaryComparison() + " " + Integer.MIN_VALUE);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108519")
public void testSimplifyComparisonArithmeticSkippedOnNegatingOverflow() {
assertNotSimplified("-integer " + randomBinaryComparison() + " " + Long.MIN_VALUE);
assertNotSimplified("-integer " + randomBinaryComparison() + " " + Integer.MIN_VALUE);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108519")
public void testSimplifyComparisonArithmeticSkippedOnDateOverflow() {
assertNotSimplified("date - 999999999 years > to_datetime(\"2010-01-01T01:01:01\")");
assertNotSimplified("date + -999999999 years > to_datetime(\"2010-01-01T01:01:01\")");
}

public void testSimplifyComparisonArithmeticSkippedOnMulDivByZero() {
assertNotSimplified("float / 0 " + randomBinaryComparison() + " 1");
assertNotSimplified("float * 0 " + randomBinaryComparison() + " 1");
assertNotSimplified("integer / 0 " + randomBinaryComparison() + " 1");
assertNotSimplified("integer * 0 " + randomBinaryComparison() + " 1");
}

public void testSimplifyComparisonArithmeticSkippedOnDiv() {
assertNotSimplified("integer / 4 " + randomBinaryComparison() + " 1");
assertNotSimplified("4 / integer " + randomBinaryComparison() + " 1");
}

public void testSimplifyComparisonArithmeticSkippedOnResultingFloatLiteral() {
assertNotSimplified("integer * 2 " + randomBinaryComparison() + " 3");
}

public void testSimplifyComparisonArithmeticSkippedOnFloatFieldWithPlusMinus() {
assertNotSimplified("float + 4 " + randomBinaryComparison() + " 1");
assertNotSimplified("4 + float " + randomBinaryComparison() + " 1");
assertNotSimplified("float - 4 " + randomBinaryComparison() + " 1");
assertNotSimplified("4 - float " + randomBinaryComparison() + " 1");
}

public void testSimplifyComparisonArithmeticSkippedOnFloats() {
for (String field : List.of("integer", "float")) {
for (Tuple<? extends Number, ? extends Number> nr : List.of(new Tuple<>(.4, 1), new Tuple<>(1, .4))) {
assertNotSimplified(field + " + " + nr.v1() + " " + randomBinaryComparison() + " " + nr.v2());
assertNotSimplified(field + " - " + nr.v1() + " " + randomBinaryComparison() + " " + nr.v2());
assertNotSimplified(nr.v1() + " + " + field + " " + randomBinaryComparison() + " " + nr.v2());
assertNotSimplified(nr.v1() + " - " + field + " " + randomBinaryComparison() + " " + nr.v2());
}
}
}

public static WildcardLike wildcardLike(Expression left, String exp) {
return new WildcardLike(EMPTY, left, new WildcardPattern(exp));
}
Expand Down