-
Notifications
You must be signed in to change notification settings - Fork 25k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ESQL] Migrate SimplifiyComparisonArithmetics optimization rule #108200
Changes from all commits
7510b81
9c43f65
e983ba9
50f4bb6
ae42bc5
105ac4d
60b8a43
6f89e11
d10b26b
2be0db0
fdc1ba8
c3b78ec
f207177
a96de12
26adf8d
5b28504
1f7591b
dbcd864
2824681
03a8d32
8650437
7daf532
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
pr: 108200 | ||
summary: "[ESQL] Migrate `SimplifiyComparisonArithmetics` optimization rule" | ||
area: ES|QL | ||
type: bug | ||
issues: | ||
- 108388 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,11 @@ | |
|
||
package org.elasticsearch.xpack.esql.optimizer; | ||
|
||
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation; | ||
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg; | ||
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub; | ||
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; | ||
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison; | ||
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; | ||
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; | ||
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; | ||
|
@@ -49,12 +53,17 @@ | |
import org.elasticsearch.xpack.ql.expression.predicate.logical.BinaryLogic; | ||
import org.elasticsearch.xpack.ql.expression.predicate.logical.Not; | ||
import org.elasticsearch.xpack.ql.expression.predicate.logical.Or; | ||
import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.ArithmeticOperation; | ||
import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.BinaryArithmeticOperation; | ||
import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.BinaryComparisonInversible; | ||
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; | ||
import org.elasticsearch.xpack.ql.plan.QueryPlan; | ||
import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan; | ||
import org.elasticsearch.xpack.ql.type.DataType; | ||
import org.elasticsearch.xpack.ql.type.DataTypes; | ||
import org.elasticsearch.xpack.ql.util.CollectionUtils; | ||
|
||
import java.time.DateTimeException; | ||
import java.time.ZoneId; | ||
import java.util.ArrayList; | ||
import java.util.Iterator; | ||
|
@@ -64,12 +73,21 @@ | |
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Set; | ||
|
||
import java.util.function.BiFunction; | ||
|
||
import static java.lang.Math.signum; | ||
import static java.util.Arrays.asList; | ||
import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.ADD; | ||
import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.DIV; | ||
import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.MOD; | ||
import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.MUL; | ||
import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.SUB; | ||
import static org.elasticsearch.xpack.ql.common.Failure.fail; | ||
import static org.elasticsearch.xpack.ql.expression.Literal.FALSE; | ||
import static org.elasticsearch.xpack.ql.expression.Literal.TRUE; | ||
import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.combineOr; | ||
import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.splitOr; | ||
import static org.elasticsearch.xpack.ql.tree.Source.EMPTY; | ||
|
||
class OptimizerRules { | ||
|
||
|
@@ -612,4 +630,216 @@ private static Expression propagate(Or or) { | |
return updated ? Predicates.combineOr(CollectionUtils.combine(exps, equals, notEquals, inequalities, ranges)) : or; | ||
} | ||
} | ||
|
||
/** | ||
* Simplifies arithmetic expressions with BinaryComparisons and fixed point fields, such as: (int + 2) / 3 > 4 => int > 10 | ||
*/ | ||
public static final class SimplifyComparisonsArithmetics extends | ||
org.elasticsearch.xpack.ql.optimizer.OptimizerRules.OptimizerExpressionRule<EsqlBinaryComparison> { | ||
BiFunction<DataType, DataType, Boolean> typesCompatible; | ||
|
||
SimplifyComparisonsArithmetics(BiFunction<DataType, DataType, Boolean> typesCompatible) { | ||
super(org.elasticsearch.xpack.ql.optimizer.OptimizerRules.TransformDirection.UP); | ||
this.typesCompatible = typesCompatible; | ||
} | ||
|
||
@Override | ||
protected Expression rule(EsqlBinaryComparison bc) { | ||
// optimize only once the expression has a literal on the right side of the binary comparison | ||
if (bc.right() instanceof Literal) { | ||
if (bc.left() instanceof ArithmeticOperation) { | ||
return simplifyBinaryComparison(bc); | ||
} | ||
if (bc.left() instanceof Neg) { | ||
return foldNegation(bc); | ||
} | ||
} | ||
return bc; | ||
} | ||
|
||
private Expression simplifyBinaryComparison(EsqlBinaryComparison comparison) { | ||
EsqlArithmeticOperation operation = (EsqlArithmeticOperation) comparison.left(); | ||
BinaryArithmeticOperation function = operation.function(); | ||
// Modulo can't be simplified. | ||
if (function.equals(MOD)) { | ||
return comparison; | ||
} | ||
OperationSimplifier simplification = null; | ||
if (isMulOrDiv(function)) { | ||
simplification = new MulDivSimplifier(comparison); | ||
} else if (function.equals(ADD) || function.equals(SUB)) { | ||
simplification = new AddSubSimplifier(comparison); | ||
} | ||
|
||
return (simplification == null || simplification.isUnsafe(typesCompatible)) ? comparison : simplification.apply(); | ||
} | ||
|
||
private static boolean isMulOrDiv(BinaryArithmeticOperation op) { | ||
return op.equals(MUL) || op.equals(DIV); | ||
} | ||
|
||
private static Expression foldNegation(EsqlBinaryComparison bc) { | ||
Literal bcLiteral = (Literal) bc.right(); | ||
Expression literalNeg = tryFolding(new Neg(bcLiteral.source(), bcLiteral)); | ||
return literalNeg == null ? bc : bc.reverse().replaceChildren(asList(((Neg) bc.left()).field(), literalNeg)); | ||
} | ||
|
||
private static Expression tryFolding(Expression expression) { | ||
if (expression.foldable()) { | ||
try { | ||
expression = new Literal(expression.source(), expression.fold(), expression.dataType()); | ||
} catch (ArithmeticException | DateTimeException e) { | ||
// null signals that folding would result in an over-/underflow (such as Long.MAX_VALUE+1); the optimisation is skipped. | ||
expression = null; | ||
} | ||
} | ||
return expression; | ||
} | ||
|
||
private abstract static class OperationSimplifier { | ||
final EsqlBinaryComparison comparison; | ||
final Literal bcLiteral; | ||
final EsqlArithmeticOperation operation; | ||
final Expression opLeft; | ||
final Expression opRight; | ||
final Literal opLiteral; | ||
|
||
OperationSimplifier(EsqlBinaryComparison comparison) { | ||
this.comparison = comparison; | ||
operation = (EsqlArithmeticOperation) comparison.left(); | ||
bcLiteral = (Literal) comparison.right(); | ||
|
||
opLeft = operation.left(); | ||
opRight = operation.right(); | ||
|
||
if (opLeft instanceof Literal) { | ||
opLiteral = (Literal) opLeft; | ||
} else if (opRight instanceof Literal) { | ||
opLiteral = (Literal) opRight; | ||
} else { | ||
opLiteral = null; | ||
} | ||
} | ||
|
||
// can it be quickly fast-tracked that the operation can't be reduced? | ||
final boolean isUnsafe(BiFunction<DataType, DataType, Boolean> typesCompatible) { | ||
if (opLiteral == null) { | ||
// one of the arithm. operands must be a literal, otherwise the operation wouldn't simplify anything | ||
return true; | ||
} | ||
|
||
// Only operations on fixed point literals are supported, since optimizing float point operations can also change the | ||
// outcome of the filtering: | ||
// x + 1e18 > 1e18::long will yield different results with a field value in [-2^6, 2^6], optimised vs original; | ||
// x * (1 + 1e-15d) > 1 : same with a field value of (1 - 1e-15d) | ||
// so consequently, int fields optimisation requiring FP arithmetic isn't possible either: (x - 1e-15) * (1 + 1e-15) > 1. | ||
if (opLiteral.dataType().isRational() || bcLiteral.dataType().isRational()) { | ||
return true; | ||
} | ||
|
||
// the Literal will be moved to the right of the comparison, but only if data-compatible with what's there | ||
if (typesCompatible.apply(bcLiteral.dataType(), opLiteral.dataType()) == false) { | ||
return true; | ||
} | ||
|
||
return isOpUnsafe(); | ||
} | ||
|
||
final Expression apply() { | ||
// force float point folding for FlP field | ||
Literal bcl = operation.dataType().isRational() | ||
? new Literal(bcLiteral.source(), ((Number) bcLiteral.value()).doubleValue(), DataTypes.DOUBLE) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is changed from the original rule. The original version read |
||
: bcLiteral; | ||
|
||
Expression bcRightExpression = ((BinaryComparisonInversible) operation).binaryComparisonInverse() | ||
.create(bcl.source(), bcl, opRight); | ||
bcRightExpression = tryFolding(bcRightExpression); | ||
return bcRightExpression != null | ||
? postProcess((EsqlBinaryComparison) comparison.replaceChildren(List.of(opLeft, bcRightExpression))) | ||
: comparison; | ||
} | ||
|
||
// operation-specific operations: | ||
// - fast-tracking of simplification unsafety | ||
abstract boolean isOpUnsafe(); | ||
|
||
// - post optimisation adjustments | ||
Expression postProcess(EsqlBinaryComparison binaryComparison) { | ||
return binaryComparison; | ||
} | ||
} | ||
|
||
private static class AddSubSimplifier extends SimplifyComparisonsArithmetics.OperationSimplifier { | ||
|
||
AddSubSimplifier(EsqlBinaryComparison comparison) { | ||
super(comparison); | ||
} | ||
|
||
@Override | ||
boolean isOpUnsafe() { | ||
// no ADD/SUB with floating fields | ||
if (operation.dataType().isRational()) { | ||
return true; | ||
} | ||
|
||
if (operation.function().equals(SUB) && opRight instanceof Literal == false) { // such as: 1 - x > -MAX | ||
// if next simplification step would fail on overflow anyway, skip the optimisation already | ||
return tryFolding(new Sub(EMPTY, opLeft, bcLiteral)) == null; | ||
} | ||
|
||
return false; | ||
} | ||
} | ||
|
||
private static class MulDivSimplifier extends SimplifyComparisonsArithmetics.OperationSimplifier { | ||
|
||
private final boolean isDiv; // and not MUL. | ||
private final int opRightSign; // sign of the right operand in: (left) (op) (right) (comp) (literal) | ||
|
||
MulDivSimplifier(EsqlBinaryComparison comparison) { | ||
super(comparison); | ||
isDiv = operation.function().equals(DIV); | ||
opRightSign = sign(opRight); | ||
} | ||
|
||
@Override | ||
boolean isOpUnsafe() { | ||
// Integer divisions are not safe to optimise: x / 5 > 1 <=/=> x > 5 for x in [6, 9]; same for the `==` comp | ||
if (operation.dataType().isInteger() && isDiv) { | ||
return true; | ||
} | ||
|
||
// If current operation is a multiplication, it's inverse will be a division: safe only if outcome is still integral. | ||
if (isDiv == false && opLeft.dataType().isInteger()) { | ||
long opLiteralValue = ((Number) opLiteral.value()).longValue(); | ||
return opLiteralValue == 0 || ((Number) bcLiteral.value()).longValue() % opLiteralValue != 0; | ||
} | ||
|
||
// can't move a 0 in Mul/Div comparisons | ||
return opRightSign == 0; | ||
} | ||
|
||
@Override | ||
Expression postProcess(EsqlBinaryComparison binaryComparison) { | ||
// negative multiplication/division changes the direction of the comparison | ||
return opRightSign < 0 ? binaryComparison.reverse() : binaryComparison; | ||
} | ||
|
||
private static int sign(Object obj) { | ||
int sign = 1; | ||
if (obj instanceof Number) { | ||
sign = (int) signum(((Number) obj).doubleValue()); | ||
} else if (obj instanceof Literal) { | ||
sign = sign(((Literal) obj).value()); | ||
} else if (obj instanceof Neg) { | ||
sign = -sign(((Neg) obj).field()); | ||
} else if (obj instanceof EsqlArithmeticOperation operation) { | ||
if (isMulOrDiv(operation.function())) { | ||
sign = sign(operation.left()) * sign(operation.right()); | ||
} | ||
} | ||
return sign; | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is the source of the issue in #108519, as ES|QL will never throw from folding. It's not as simple as just checking if we got a non-null value in the literal here though, as generating the null literal will have put the warning on the stack already.