Skip to content

Commit

Permalink
[ES|QL] Add CombineBinaryComparisons rule (#110548)
Browse files Browse the repository at this point in the history
* add CombineBinaryComparisons
  • Loading branch information
fang-xing-esql authored Jul 30, 2024
1 parent d10bb67 commit ac94a6c
Show file tree
Hide file tree
Showing 8 changed files with 659 additions and 3 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/110548.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 110548
summary: "[ES|QL] Add `CombineBinaryComparisons` rule"
area: ES|QL
type: enhancement
issues:
- 108525
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,13 @@
import static org.junit.Assert.assertTrue;

public final class EsqlTestUtils {

public static final Literal ONE = new Literal(Source.EMPTY, 1, DataType.INTEGER);
public static final Literal TWO = new Literal(Source.EMPTY, 2, DataType.INTEGER);
public static final Literal THREE = new Literal(Source.EMPTY, 3, DataType.INTEGER);
public static final Literal FOUR = new Literal(Source.EMPTY, 4, DataType.INTEGER);
public static final Literal FIVE = new Literal(Source.EMPTY, 5, DataType.INTEGER);
private static final Literal SIX = new Literal(Source.EMPTY, 6, DataType.INTEGER);
public static final Literal SIX = new Literal(Source.EMPTY, 6, DataType.INTEGER);

public static Equals equalsOf(Expression left, Expression right) {
return new Equals(EMPTY, left, right, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -986,3 +986,36 @@ s:double | emp_no:integer | salary:integer
1.0 | 10002 | 56371
1.0 | 10041 | 56415
;

CombineBinaryComparisonsMv
required_capability: combine_binary_comparisons

row x = [1,2,3]
| where 12 * (-x - 5) >= -120 OR x < 5
;
warning:Line 2:15: evaluation of [-x] failed, treating result as null. Only first 20 failures recorded.
warning:Line 2:15: java.lang.IllegalArgumentException: single-value function encountered multi-value
warning:Line 2:34: evaluation of [x < 5] failed, treating result as null. Only first 20 failures recorded.
warning:Line 2:34: java.lang.IllegalArgumentException: single-value function encountered multi-value

x:integer
;

CombineBinaryComparisonsEmp
required_capability: combine_binary_comparisons

from employees
| where salary_change.int == 2 OR salary_change.int > 1
| keep emp_no, salary_change.int
| sort emp_no
;
warning:Line 2:35: evaluation of [salary_change.int > 1] failed, treating result as null. Only first 20 failures recorded.
warning:Line 2:35: java.lang.IllegalArgumentException: single-value function encountered multi-value

emp_no:integer |salary_change.int:integer
10044 | 8
10046 | 2
10066 | 5
10079 | 7
10086 | 13
;
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,12 @@ public enum Cap {
/**
* Support for match operator
*/
MATCH_OPERATOR(true);
MATCH_OPERATOR(true),

/**
* Add CombineBinaryComparisons rule.
*/
COMBINE_BINARY_COMPARISONS;

private final boolean snapshotOnly;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.xpack.esql.optimizer.rules.AddDefaultTopN;
import org.elasticsearch.xpack.esql.optimizer.rules.BooleanFunctionEqualsElimination;
import org.elasticsearch.xpack.esql.optimizer.rules.BooleanSimplification;
import org.elasticsearch.xpack.esql.optimizer.rules.CombineBinaryComparisons;
import org.elasticsearch.xpack.esql.optimizer.rules.CombineDisjunctionsToIn;
import org.elasticsearch.xpack.esql.optimizer.rules.CombineEvals;
import org.elasticsearch.xpack.esql.optimizer.rules.CombineProjections;
Expand Down Expand Up @@ -207,6 +208,7 @@ protected static Batch<LogicalPlan> operators() {
new PropagateEquals(),
new PropagateNullable(),
new BooleanFunctionEqualsElimination(),
new CombineBinaryComparisons(),
new CombineDisjunctionsToIn(),
new SimplifyComparisonsArithmetics(DataType::areCompatible),
// prune/elimination
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.optimizer.rules;

import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogic;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;

import java.util.ArrayList;
import java.util.List;

public final class CombineBinaryComparisons extends OptimizerRules.OptimizerExpressionRule<BinaryLogic> {

public CombineBinaryComparisons() {
super(OptimizerRules.TransformDirection.DOWN);
}

@Override
public Expression rule(BinaryLogic e) {
if (e instanceof And and) {
return combine(and);
} else if (e instanceof Or or) {
return combine(or);
}
return e;
}

// combine conjunction
private static Expression combine(And and) {
List<BinaryComparison> bcs = new ArrayList<>();
List<Expression> exps = new ArrayList<>();
boolean changed = false;
List<Expression> andExps = Predicates.splitAnd(and);

andExps.sort((o1, o2) -> {
if (o1 instanceof NotEquals && o2 instanceof NotEquals) {
return 0; // keep NotEquals' order
} else if (o1 instanceof NotEquals || o2 instanceof NotEquals) {
return o1 instanceof NotEquals ? 1 : -1; // push NotEquals up
} else {
return 0; // keep non-Ranges' and non-NotEquals' order
}
});
for (Expression ex : andExps) {
if (ex instanceof BinaryComparison bc && (ex instanceof Equals || ex instanceof NotEquals) == false) {
if (bc.right().foldable() && (findExistingComparison(bc, bcs, true))) {
changed = true;
} else {
bcs.add(bc);
}
} else if (ex instanceof NotEquals neq) {
if (neq.right().foldable() && notEqualsIsRemovableFromConjunction(neq, bcs)) {
// the non-equality can simply be dropped: either superfluous or has been merged with an updated range/inequality
changed = true;
} else { // not foldable OR not overlapping
exps.add(ex);
}
} else {
exps.add(ex);
}
}
return changed ? Predicates.combineAnd(CollectionUtils.combine(exps, bcs)) : and;
}

// combine disjunction
private static Expression combine(Or or) {
List<BinaryComparison> bcs = new ArrayList<>();
List<Expression> exps = new ArrayList<>();
boolean changed = false;
for (Expression ex : Predicates.splitOr(or)) {
if (ex instanceof BinaryComparison bc) {
if (bc.right().foldable() && findExistingComparison(bc, bcs, false)) {
changed = true;
} else {
bcs.add(bc);
}
} else {
exps.add(ex);
}
}
return changed ? Predicates.combineOr(CollectionUtils.combine(exps, bcs)) : or;
}

/**
* Find commonalities between the given comparison in the given list.
* The method can be applied both for conjunctive (AND) or disjunctive purposes (OR).
*/
private static boolean findExistingComparison(BinaryComparison main, List<BinaryComparison> bcs, boolean conjunctive) {
Object value = main.right().fold();
// NB: the loop modifies the list (hence why the int is used)
for (int i = 0; i < bcs.size(); i++) {
BinaryComparison other = bcs.get(i);
// skip if cannot evaluate
if (other.right().foldable() == false) {
continue;
}
// if bc is a higher/lower value or gte vs gt, use it instead
if ((other instanceof GreaterThan || other instanceof GreaterThanOrEqual)
&& (main instanceof GreaterThan || main instanceof GreaterThanOrEqual)) {
if (main.left().semanticEquals(other.left())) {
Integer compare = BinaryComparison.compare(value, other.right().fold());
if (compare != null) {
// AND
if ((conjunctive &&
// a > 3 AND a > 2 -> a > 3
(compare > 0 ||
// a > 2 AND a >= 2 -> a > 2
(compare == 0 && main instanceof GreaterThan && other instanceof GreaterThanOrEqual))) ||
// OR
(conjunctive == false &&
// a > 2 OR a > 3 -> a > 2
(compare < 0 ||
// a >= 2 OR a > 2 -> a >= 2
(compare == 0 && main instanceof GreaterThanOrEqual && other instanceof GreaterThan)))) {
bcs.remove(i);
bcs.add(i, main);
}
// found a match
return true;
}
return false;
}
}
// if bc is a lower/higher value or lte vs lt, use it instead
else if ((other instanceof LessThan || other instanceof LessThanOrEqual)
&& (main instanceof LessThan || main instanceof LessThanOrEqual)) {
if (main.left().semanticEquals(other.left())) {
Integer compare = BinaryComparison.compare(value, other.right().fold());
if (compare != null) {
// AND
if ((conjunctive &&
// a < 2 AND a < 3 -> a < 2
(compare < 0 ||
// a < 2 AND a <= 2 -> a < 2
(compare == 0 && main instanceof LessThan && other instanceof LessThanOrEqual))) ||
// OR
(conjunctive == false &&
// a < 2 OR a < 3 -> a < 3
(compare > 0 ||
// a <= 2 OR a < 2 -> a <= 2
(compare == 0 && main instanceof LessThanOrEqual && other instanceof LessThan)))) {
bcs.remove(i);
bcs.add(i, main);
}
// found a match
return true;
}
return false;
}
}
}
return false;
}

private static boolean notEqualsIsRemovableFromConjunction(NotEquals notEquals, List<BinaryComparison> bcs) {
Object neqVal = notEquals.right().fold();
Integer comp;

// check on "condition-overlapping" inequalities:
// a != 2 AND a > 3 -> a > 3 (discard NotEquals)
// a != 2 AND a >= 2 -> a > 2 (discard NotEquals plus update inequality)
// a != 2 AND a > 1 -> nop (do nothing)
//
// a != 2 AND a < 3 -> nop
// a != 2 AND a <= 2 -> a < 2
// a != 2 AND a < 1 -> a < 1
for (int i = 0; i < bcs.size(); i++) {
BinaryComparison bc = bcs.get(i);
if (notEquals.left().semanticEquals(bc.left())) {
if (bc instanceof LessThan || bc instanceof LessThanOrEqual) {
comp = bc.right().foldable() ? BinaryComparison.compare(neqVal, bc.right().fold()) : null;
if (comp != null) {
if (comp >= 0) {
if (comp == 0 && bc instanceof LessThanOrEqual) { // a != 2 AND a <= 2 -> a < 2
bcs.set(i, new LessThan(bc.source(), bc.left(), bc.right(), bc.zoneId()));
} // else : comp > 0 (a != 2 AND a </<= 1 -> a </<= 1), or == 0 && bc i.of "<" (a != 2 AND a < 2 -> a < 2)
return true;
} // else: comp < 0 : a != 2 AND a </<= 3 -> nop
} // else: non-comparable, nop
} else if (bc instanceof GreaterThan || bc instanceof GreaterThanOrEqual) {
comp = bc.right().foldable() ? BinaryComparison.compare(neqVal, bc.right().fold()) : null;
if (comp != null) {
if (comp <= 0) {
if (comp == 0 && bc instanceof GreaterThanOrEqual) { // a != 2 AND a >= 2 -> a > 2
bcs.set(i, new GreaterThan(bc.source(), bc.left(), bc.right(), bc.zoneId()));
} // else: comp < 0 (a != 2 AND a >/>= 3 -> a >/>= 3), or == 0 && bc i.of ">" (a != 2 AND a > 2 -> a > 2)
return true;
} // else: comp > 0 : a != 2 AND a >/>= 1 -> nop
} // else: non-comparable, nop
} // else: other non-relevant type
}
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4824,7 +4824,6 @@ public void testSimplifyComparisonArithmeticWithConjunction() {
doTestSimplifyComparisonArithmetics("12 * (-integer - 5) == -120 AND integer < 6 ", "integer", EQ, 5);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108525")
public void testSimplifyComparisonArithmeticWithDisjunction() {
doTestSimplifyComparisonArithmetics("12 * (-integer - 5) >= -120 OR integer < 5", "integer", LTE, 5);
}
Expand Down
Loading

0 comments on commit ac94a6c

Please sign in to comment.