Skip to content

Commit

Permalink
add tests for SimplifyComparisonArithmetics optimization rule (#108744)
Browse files Browse the repository at this point in the history
This adds in the tests from OptimizerRunTests in SQL to apply to ESQL. I've opened issues and applied the AwaitsFix annotation for those of the tests that are currently failing.
  • Loading branch information
not-napoleon authored May 17, 2024
1 parent bb5cac9 commit 9f54d9a
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ public static BinaryComparisonOperation readFromStream(StreamInput in) throws IO
throw new IOException("No BinaryComparisonOperation found for id [" + id + "]");
}

public String symbol() {
return symbol;
}

public EsqlBinaryComparison buildNewInstance(Source source, Expression lhs, Expression rhs) {
return constructor.apply(source, lhs, rhs);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.compute.aggregation.QuantileStates;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.TestBlockFactory;
Expand Down Expand Up @@ -65,6 +66,7 @@
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
Expand Down Expand Up @@ -95,19 +97,23 @@
import org.elasticsearch.xpack.ql.expression.NamedExpression;
import org.elasticsearch.xpack.ql.expression.Nullability;
import org.elasticsearch.xpack.ql.expression.ReferenceAttribute;
import org.elasticsearch.xpack.ql.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.ql.expression.predicate.Predicates;
import org.elasticsearch.xpack.ql.expression.predicate.logical.And;
import org.elasticsearch.xpack.ql.expression.predicate.logical.Or;
import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNotNull;
import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNull;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.ql.expression.predicate.regex.RLikePattern;
import org.elasticsearch.xpack.ql.expression.predicate.regex.WildcardPattern;
import org.elasticsearch.xpack.ql.index.EsIndex;
import org.elasticsearch.xpack.ql.index.IndexResolution;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules;
import org.elasticsearch.xpack.ql.plan.logical.Filter;
import org.elasticsearch.xpack.ql.plan.logical.Limit;
import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.ql.plan.logical.OrderBy;
import org.elasticsearch.xpack.ql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataType;
import org.elasticsearch.xpack.ql.type.DataTypes;
Expand All @@ -117,6 +123,7 @@
import org.junit.BeforeClass;

import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand All @@ -137,6 +144,11 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.localSource;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
import static org.elasticsearch.xpack.esql.analysis.Analyzer.NO_FIELDS;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.EQ;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.GT;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.GTE;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LT;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LTE;
import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.GEO_POINT;
import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.GEO_SHAPE;
import static org.elasticsearch.xpack.ql.TestUtils.getFieldAttribute;
Expand Down Expand Up @@ -173,16 +185,18 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
private static final Literal ONE = L(1);
private static final Literal TWO = L(2);
private static final Literal THREE = L(3);

private static EsqlParser parser;
private static Analyzer analyzer;
private static LogicalPlanOptimizer logicalOptimizer;
private static Map<String, EsField> mapping;
private static Map<String, EsField> mappingAirports;
private static Map<String, EsField> mappingTypes;
private static Analyzer analyzerAirports;
private static Analyzer analyzerTypes;
private static Map<String, EsField> mappingExtra;
private static Analyzer analyzerExtra;
private static EnrichResolution enrichResolution;
private static final OptimizerRules.LiteralsOnTheRight LITERALS_ON_THE_RIGHT = new OptimizerRules.LiteralsOnTheRight();

private static class SubstitutionOnlyOptimizer extends LogicalPlanOptimizer {
static SubstitutionOnlyOptimizer INSTANCE = new SubstitutionOnlyOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG));
Expand Down Expand Up @@ -222,6 +236,15 @@ public static void init() {
TEST_VERIFIER
);

// Some tests need additional types, so we load that index here and use it in the plan_types() function.
mappingTypes = loadMapping("mapping-all-types.json");
EsIndex types = new EsIndex("types", mappingTypes, Set.of("types"));
IndexResolution getIndexResultTypes = IndexResolution.valid(types);
analyzerTypes = new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultTypes, enrichResolution),
TEST_VERIFIER
);

// Some tests use mappings from mapping-extra.json to be able to test more types so we load it here
mappingExtra = loadMapping("mapping-extra.json");
EsIndex extra = new EsIndex("extra", mappingExtra, Set.of("extra"));
Expand Down Expand Up @@ -4438,11 +4461,207 @@ private LogicalPlan planExtra(String query) {
return optimized;
}

private LogicalPlan planTypes(String query) {
return logicalOptimizer.optimize(analyzerTypes.analyze(parser.createStatement(query)));
}

private EsqlBinaryComparison extractPlannedBinaryComparison(String expression) {
LogicalPlan plan = planTypes("FROM types | WHERE " + expression);

return extractPlannedBinaryComparison(plan);
}

private static EsqlBinaryComparison extractPlannedBinaryComparison(LogicalPlan plan) {
assertTrue("Expected unary plan, found [" + plan + "]", plan instanceof UnaryPlan);
UnaryPlan unaryPlan = (UnaryPlan) plan;
assertTrue("Epxected top level Filter, foung [" + unaryPlan.child().toString() + "]", unaryPlan.child() instanceof Filter);
Filter filter = (Filter) unaryPlan.child();
assertTrue(
"Expected filter condition to be a binary comparison but found [" + filter.condition() + "]",
filter.condition() instanceof EsqlBinaryComparison
);
return (EsqlBinaryComparison) filter.condition();
}

private void doTestSimplifyComparisonArithmetics(
String expression,
String fieldName,
EsqlBinaryComparison.BinaryComparisonOperation opType,
Object bound
) {
EsqlBinaryComparison bc = extractPlannedBinaryComparison(expression);
assertEquals(opType, bc.getFunctionType());

assertTrue(
"Expected left side of comparison to be a field attribute but found [" + bc.left() + "]",
bc.left() instanceof FieldAttribute
);
FieldAttribute attribute = (FieldAttribute) bc.left();
assertEquals(fieldName, attribute.name());

assertTrue("Expected right side of comparison to be a literal but found [" + bc.right() + "]", bc.right() instanceof Literal);
Literal literal = (Literal) bc.right();
assertEquals(bound, literal.value());
}

private void assertSemanticMatching(String expected, String provided) {
BinaryComparison bc = extractPlannedBinaryComparison(provided);
LogicalPlan exp = analyzerTypes.analyze(parser.createStatement("FROM types | WHERE " + expected));
assertSemanticMatching(bc, extractPlannedBinaryComparison(exp));
}

private static void assertSemanticMatching(Expression fieldAttributeExp, Expression unresolvedAttributeExp) {
Expression unresolvedUpdated = unresolvedAttributeExp.transformUp(
LITERALS_ON_THE_RIGHT.expressionToken(),
LITERALS_ON_THE_RIGHT::rule
).transformUp(x -> x.foldable() ? new Literal(x.source(), x.fold(), x.dataType()) : x);

List<Expression> resolvedFields = fieldAttributeExp.collectFirstChildren(x -> x instanceof FieldAttribute);
for (Expression field : resolvedFields) {
FieldAttribute fa = (FieldAttribute) field;
unresolvedUpdated = unresolvedUpdated.transformDown(UnresolvedAttribute.class, x -> x.name().equals(fa.name()) ? fa : x);
}

assertTrue(unresolvedUpdated.semanticEquals(fieldAttributeExp));
}

private Expression getComparisonFromLogicalPlan(LogicalPlan plan) {
List<Expression> expressions = new ArrayList<>();
plan.forEachExpression(Expression.class, expressions::add);
return expressions.get(0);
}

private void assertNotSimplified(String comparison) {
String query = "FROM types | WHERE " + comparison;
Expression optimized = getComparisonFromLogicalPlan(planTypes(query));
Expression raw = getComparisonFromLogicalPlan(analyzerTypes.analyze(parser.createStatement(query)));

assertTrue(raw.semanticEquals(optimized));
}

private static String randomBinaryComparison() {
return randomFrom(EsqlBinaryComparison.BinaryComparisonOperation.values()).symbol();
}

public void testSimplifyComparisonArithmeticCommutativeVsNonCommutativeOps() {
doTestSimplifyComparisonArithmetics("integer + 2 > 3", "integer", GT, 1);
doTestSimplifyComparisonArithmetics("2 + integer > 3", "integer", GT, 1);
doTestSimplifyComparisonArithmetics("integer - 2 > 3", "integer", GT, 5);
doTestSimplifyComparisonArithmetics("2 - integer > 3", "integer", LT, -1);
doTestSimplifyComparisonArithmetics("integer * 2 > 4", "integer", GT, 2);
doTestSimplifyComparisonArithmetics("2 * integer > 4", "integer", GT, 2);

}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108388")
public void testSimplifyComparisonArithmeticsWithFloatingPoints() {
doTestSimplifyComparisonArithmetics("float / 2 > 4", "float", GT, 8d);
}

public void testAssertSemanticMatching() {
// This test is just to verify that the complicated assert logic is working on a known-good case
assertSemanticMatching("integer > 1", "integer + 2 > 3");
}

public void testSimplyComparisonArithmeticWithUnfoldedProd() {
assertSemanticMatching("integer * integer >= 3", "((integer * integer + 1) * 2 - 4) * 4 >= 16");
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108524")
public void testSimplifyComparisionArithmetics_floatDivision() {
doTestSimplifyComparisonArithmetics("2 / float < 4", "float", GT, .5);
}

public void testSimplifyComparisonArithmeticWithMultipleOps() {
// i >= 3
doTestSimplifyComparisonArithmetics("((integer + 1) * 2 - 4) * 4 >= 16", "integer", GTE, 3);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108743")
public void testSimplifyComparisonArithmeticWithFieldNegation() {
doTestSimplifyComparisonArithmetics("12 * (-integer - 5) >= -120", "integer", LTE, 5);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108743")
public void testSimplifyComparisonArithmeticWithFieldDoubleNegation() {
doTestSimplifyComparisonArithmetics("12 * -(-integer - 5) <= 120", "integer", LTE, 5);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108743")
public void testSimplifyComparisonArithmeticWithConjunction() {
doTestSimplifyComparisonArithmetics("12 * (-integer - 5) == -120 AND integer < 6 ", "integer", EQ, 5);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108525")
public void testSimplifyComparisonArithmeticWithDisjunction() {
doTestSimplifyComparisonArithmetics("12 * (-integer - 5) >= -120 OR integer < 5", "integer", LTE, 5);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108388")
public void testSimplifyComparisonArithmeticWithFloatsAndDirectionChange() {
doTestSimplifyComparisonArithmetics("float / -2 < 4", "float", GT, -8d);
doTestSimplifyComparisonArithmetics("float * -2 < 4", "float", GT, -2d);
}

private void assertNullLiteral(Expression expression) {
assertEquals(Literal.class, expression.getClass());
assertNull(expression.fold());
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108519")
public void testSimplifyComparisonArithmeticSkippedOnIntegerArithmeticalOverflow() {
assertNotSimplified("integer - 1 " + randomBinaryComparison() + " " + Long.MAX_VALUE);
assertNotSimplified("1 - integer " + randomBinaryComparison() + " " + Long.MIN_VALUE);
assertNotSimplified("integer - 1 " + randomBinaryComparison() + " " + Integer.MAX_VALUE);
assertNotSimplified("1 - integer " + randomBinaryComparison() + " " + Integer.MIN_VALUE);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108519")
public void testSimplifyComparisonArithmeticSkippedOnNegatingOverflow() {
assertNotSimplified("-integer " + randomBinaryComparison() + " " + Long.MIN_VALUE);
assertNotSimplified("-integer " + randomBinaryComparison() + " " + Integer.MIN_VALUE);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108519")
public void testSimplifyComparisonArithmeticSkippedOnDateOverflow() {
assertNotSimplified("date - 999999999 years > to_datetime(\"2010-01-01T01:01:01\")");
assertNotSimplified("date + -999999999 years > to_datetime(\"2010-01-01T01:01:01\")");
}

public void testSimplifyComparisonArithmeticSkippedOnMulDivByZero() {
assertNotSimplified("float / 0 " + randomBinaryComparison() + " 1");
assertNotSimplified("float * 0 " + randomBinaryComparison() + " 1");
assertNotSimplified("integer / 0 " + randomBinaryComparison() + " 1");
assertNotSimplified("integer * 0 " + randomBinaryComparison() + " 1");
}

public void testSimplifyComparisonArithmeticSkippedOnDiv() {
assertNotSimplified("integer / 4 " + randomBinaryComparison() + " 1");
assertNotSimplified("4 / integer " + randomBinaryComparison() + " 1");
}

public void testSimplifyComparisonArithmeticSkippedOnResultingFloatLiteral() {
assertNotSimplified("integer * 2 " + randomBinaryComparison() + " 3");
}

public void testSimplifyComparisonArithmeticSkippedOnFloatFieldWithPlusMinus() {
assertNotSimplified("float + 4 " + randomBinaryComparison() + " 1");
assertNotSimplified("4 + float " + randomBinaryComparison() + " 1");
assertNotSimplified("float - 4 " + randomBinaryComparison() + " 1");
assertNotSimplified("4 - float " + randomBinaryComparison() + " 1");
}

public void testSimplifyComparisonArithmeticSkippedOnFloats() {
for (String field : List.of("integer", "float")) {
for (Tuple<? extends Number, ? extends Number> nr : List.of(new Tuple<>(.4, 1), new Tuple<>(1, .4))) {
assertNotSimplified(field + " + " + nr.v1() + " " + randomBinaryComparison() + " " + nr.v2());
assertNotSimplified(field + " - " + nr.v1() + " " + randomBinaryComparison() + " " + nr.v2());
assertNotSimplified(nr.v1() + " + " + field + " " + randomBinaryComparison() + " " + nr.v2());
assertNotSimplified(nr.v1() + " - " + field + " " + randomBinaryComparison() + " " + nr.v2());
}
}
}

public static WildcardLike wildcardLike(Expression left, String exp) {
return new WildcardLike(EMPTY, left, new WildcardPattern(exp));
}
Expand Down

0 comments on commit 9f54d9a

Please sign in to comment.