Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ESQL] Migrate SimplifiyComparisonArithmetics optimization rule #108200

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/108200.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 108200
summary: "[ESQL] Migrate `SimplifiyComparisonArithmetics` optimization rule"
area: ES|QL
type: bug
issues:
- 108388
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public abstract class EsqlArithmeticOperation extends ArithmeticOperation implem
* used just for its symbol.
* The rest of the methods should not be triggered hence the UOE.
*/
enum OperationSymbol implements BinaryArithmeticOperation {
public enum OperationSymbol implements BinaryArithmeticOperation {
ADD("+"),
SUB("-"),
MUL("*"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ public static BinaryComparisonOperation readFromStream(StreamInput in) throws IO
public EsqlBinaryComparison buildNewInstance(Source source, Expression lhs, Expression rhs) {
return constructor.apply(source, lhs, rhs);
}

public String symbol() {
return symbol;
}
}

protected EsqlBinaryComparison(
Expand Down Expand Up @@ -222,4 +226,8 @@ public String formatIncompatibleTypesMessage() {
);
}

@Override
public String toString() {
return left() + symbol() + right();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.LiteralsOnTheRight;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PruneLiteralsInOrderBy;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.SetAsOptimized;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.SimplifyComparisonsArithmetics;
import org.elasticsearch.xpack.ql.plan.logical.Filter;
import org.elasticsearch.xpack.ql.plan.logical.Limit;
import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan;
Expand Down Expand Up @@ -154,7 +153,7 @@ protected static Batch<LogicalPlan> operators() {
new PropagateNullable(),
new org.elasticsearch.xpack.esql.optimizer.OptimizerRules.BooleanFunctionEqualsElimination(),
new org.elasticsearch.xpack.esql.optimizer.OptimizerRules.CombineDisjunctionsToIn(),
new SimplifyComparisonsArithmetics(EsqlDataTypes::areCompatible),
new org.elasticsearch.xpack.esql.optimizer.OptimizerRules.SimplifyComparisonsArithmetics(EsqlDataTypes::areCompatible),
// prune/elimination
new PruneFilters(),
new PruneColumns(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

package org.elasticsearch.xpack.esql.optimizer;

import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
Expand Down Expand Up @@ -49,12 +53,17 @@
import org.elasticsearch.xpack.ql.expression.predicate.logical.BinaryLogic;
import org.elasticsearch.xpack.ql.expression.predicate.logical.Not;
import org.elasticsearch.xpack.ql.expression.predicate.logical.Or;
import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.ArithmeticOperation;
import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.BinaryArithmeticOperation;
import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.BinaryComparisonInversible;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.ql.plan.QueryPlan;
import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.ql.type.DataType;
import org.elasticsearch.xpack.ql.type.DataTypes;
import org.elasticsearch.xpack.ql.util.CollectionUtils;

import java.time.DateTimeException;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Iterator;
Expand All @@ -64,12 +73,21 @@
import java.util.List;
import java.util.Map;
import java.util.Set;

import java.util.function.BiFunction;

import static java.lang.Math.signum;
import static java.util.Arrays.asList;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.ADD;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.DIV;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.MOD;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.MUL;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.SUB;
import static org.elasticsearch.xpack.ql.common.Failure.fail;
import static org.elasticsearch.xpack.ql.expression.Literal.FALSE;
import static org.elasticsearch.xpack.ql.expression.Literal.TRUE;
import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.combineOr;
import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.splitOr;
import static org.elasticsearch.xpack.ql.tree.Source.EMPTY;

class OptimizerRules {

Expand Down Expand Up @@ -612,4 +630,216 @@ private static Expression propagate(Or or) {
return updated ? Predicates.combineOr(CollectionUtils.combine(exps, equals, notEquals, inequalities, ranges)) : or;
}
}

/**
* Simplifies arithmetic expressions with BinaryComparisons and fixed point fields, such as: (int + 2) / 3 > 4 => int > 10
*/
public static final class SimplifyComparisonsArithmetics extends
org.elasticsearch.xpack.ql.optimizer.OptimizerRules.OptimizerExpressionRule<EsqlBinaryComparison> {
BiFunction<DataType, DataType, Boolean> typesCompatible;

SimplifyComparisonsArithmetics(BiFunction<DataType, DataType, Boolean> typesCompatible) {
super(org.elasticsearch.xpack.ql.optimizer.OptimizerRules.TransformDirection.UP);
this.typesCompatible = typesCompatible;
}

@Override
protected Expression rule(EsqlBinaryComparison bc) {
// optimize only once the expression has a literal on the right side of the binary comparison
if (bc.right() instanceof Literal) {
if (bc.left() instanceof ArithmeticOperation) {
return simplifyBinaryComparison(bc);
}
if (bc.left() instanceof Neg) {
return foldNegation(bc);
}
}
return bc;
}

private Expression simplifyBinaryComparison(EsqlBinaryComparison comparison) {
EsqlArithmeticOperation operation = (EsqlArithmeticOperation) comparison.left();
BinaryArithmeticOperation function = operation.function();
// Modulo can't be simplified.
if (function.equals(MOD)) {
return comparison;
}
OperationSimplifier simplification = null;
if (isMulOrDiv(function)) {
simplification = new MulDivSimplifier(comparison);
} else if (function.equals(ADD) || function.equals(SUB)) {
simplification = new AddSubSimplifier(comparison);
}

return (simplification == null || simplification.isUnsafe(typesCompatible)) ? comparison : simplification.apply();
}

private static boolean isMulOrDiv(BinaryArithmeticOperation op) {
return op.equals(MUL) || op.equals(DIV);
}

private static Expression foldNegation(EsqlBinaryComparison bc) {
Literal bcLiteral = (Literal) bc.right();
Expression literalNeg = tryFolding(new Neg(bcLiteral.source(), bcLiteral));
return literalNeg == null ? bc : bc.reverse().replaceChildren(asList(((Neg) bc.left()).field(), literalNeg));
}

private static Expression tryFolding(Expression expression) {
if (expression.foldable()) {
try {
expression = new Literal(expression.source(), expression.fold(), expression.dataType());
} catch (ArithmeticException | DateTimeException e) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is the source of the issue in #108519, as ES|QL will never throw from folding. It's not as simple as just checking if we got a non-null value in the literal here though, as generating the null literal will have put the warning on the stack already.

// null signals that folding would result in an over-/underflow (such as Long.MAX_VALUE+1); the optimisation is skipped.
expression = null;
}
}
return expression;
}

private abstract static class OperationSimplifier {
final EsqlBinaryComparison comparison;
final Literal bcLiteral;
final EsqlArithmeticOperation operation;
final Expression opLeft;
final Expression opRight;
final Literal opLiteral;

OperationSimplifier(EsqlBinaryComparison comparison) {
this.comparison = comparison;
operation = (EsqlArithmeticOperation) comparison.left();
bcLiteral = (Literal) comparison.right();

opLeft = operation.left();
opRight = operation.right();

if (opLeft instanceof Literal) {
opLiteral = (Literal) opLeft;
} else if (opRight instanceof Literal) {
opLiteral = (Literal) opRight;
} else {
opLiteral = null;
}
}

// can it be quickly fast-tracked that the operation can't be reduced?
final boolean isUnsafe(BiFunction<DataType, DataType, Boolean> typesCompatible) {
if (opLiteral == null) {
// one of the arithm. operands must be a literal, otherwise the operation wouldn't simplify anything
return true;
}

// Only operations on fixed point literals are supported, since optimizing float point operations can also change the
// outcome of the filtering:
// x + 1e18 > 1e18::long will yield different results with a field value in [-2^6, 2^6], optimised vs original;
// x * (1 + 1e-15d) > 1 : same with a field value of (1 - 1e-15d)
// so consequently, int fields optimisation requiring FP arithmetic isn't possible either: (x - 1e-15) * (1 + 1e-15) > 1.
if (opLiteral.dataType().isRational() || bcLiteral.dataType().isRational()) {
return true;
}

// the Literal will be moved to the right of the comparison, but only if data-compatible with what's there
if (typesCompatible.apply(bcLiteral.dataType(), opLiteral.dataType()) == false) {
return true;
}

return isOpUnsafe();
}

final Expression apply() {
// force float point folding for FlP field
Literal bcl = operation.dataType().isRational()
? new Literal(bcLiteral.source(), ((Number) bcLiteral.value()).doubleValue(), DataTypes.DOUBLE)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is changed from the original rule. The original version read ? Literal.of(bcLiteral, ((Number) bcLiteral.value()).doubleValue()), however Literal.of() takes the data type from the first parameter, not the second. This was causing us to create (e.g.) Literals with a value of 4.0 and a type of INTEGER, which in turn caused a class cast exception. See #108388

: bcLiteral;

Expression bcRightExpression = ((BinaryComparisonInversible) operation).binaryComparisonInverse()
.create(bcl.source(), bcl, opRight);
bcRightExpression = tryFolding(bcRightExpression);
return bcRightExpression != null
? postProcess((EsqlBinaryComparison) comparison.replaceChildren(List.of(opLeft, bcRightExpression)))
: comparison;
}

// operation-specific operations:
// - fast-tracking of simplification unsafety
abstract boolean isOpUnsafe();

// - post optimisation adjustments
Expression postProcess(EsqlBinaryComparison binaryComparison) {
return binaryComparison;
}
}

private static class AddSubSimplifier extends SimplifyComparisonsArithmetics.OperationSimplifier {

AddSubSimplifier(EsqlBinaryComparison comparison) {
super(comparison);
}

@Override
boolean isOpUnsafe() {
// no ADD/SUB with floating fields
if (operation.dataType().isRational()) {
return true;
}

if (operation.function().equals(SUB) && opRight instanceof Literal == false) { // such as: 1 - x > -MAX
// if next simplification step would fail on overflow anyway, skip the optimisation already
return tryFolding(new Sub(EMPTY, opLeft, bcLiteral)) == null;
}

return false;
}
}

private static class MulDivSimplifier extends SimplifyComparisonsArithmetics.OperationSimplifier {

private final boolean isDiv; // and not MUL.
private final int opRightSign; // sign of the right operand in: (left) (op) (right) (comp) (literal)

MulDivSimplifier(EsqlBinaryComparison comparison) {
super(comparison);
isDiv = operation.function().equals(DIV);
opRightSign = sign(opRight);
}

@Override
boolean isOpUnsafe() {
// Integer divisions are not safe to optimise: x / 5 > 1 <=/=> x > 5 for x in [6, 9]; same for the `==` comp
if (operation.dataType().isInteger() && isDiv) {
return true;
}

// If current operation is a multiplication, it's inverse will be a division: safe only if outcome is still integral.
if (isDiv == false && opLeft.dataType().isInteger()) {
long opLiteralValue = ((Number) opLiteral.value()).longValue();
return opLiteralValue == 0 || ((Number) bcLiteral.value()).longValue() % opLiteralValue != 0;
}

// can't move a 0 in Mul/Div comparisons
return opRightSign == 0;
}

@Override
Expression postProcess(EsqlBinaryComparison binaryComparison) {
// negative multiplication/division changes the direction of the comparison
return opRightSign < 0 ? binaryComparison.reverse() : binaryComparison;
}

private static int sign(Object obj) {
int sign = 1;
if (obj instanceof Number) {
sign = (int) signum(((Number) obj).doubleValue());
} else if (obj instanceof Literal) {
sign = sign(((Literal) obj).value());
} else if (obj instanceof Neg) {
sign = -sign(((Neg) obj).field());
} else if (obj instanceof EsqlArithmeticOperation operation) {
if (isMulOrDiv(operation.function())) {
sign = sign(operation.left()) * sign(operation.right());
}
}
return sign;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ public static boolean isRepresentable(DataType t) {
&& isCounterType(t) == false;
}

@Deprecated
public static boolean areCompatible(DataType left, DataType right) {
if (left == right) {
return true;
Expand Down
Loading