Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Esql migrate sql tests for simplify comparison arithmetics #108454

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public abstract class EsqlArithmeticOperation extends ArithmeticOperation implem
* used just for its symbol.
* The rest of the methods should not be triggered hence the UOE.
*/
enum OperationSymbol implements BinaryArithmeticOperation {
public enum OperationSymbol implements BinaryArithmeticOperation {
ADD("+"),
SUB("-"),
MUL("*"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,4 +222,8 @@ public String formatIncompatibleTypesMessage() {
);
}

@Override
public String toString() {
return left() + symbol() + right();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.LiteralsOnTheRight;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PruneLiteralsInOrderBy;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.SetAsOptimized;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.SimplifyComparisonsArithmetics;
import org.elasticsearch.xpack.ql.plan.logical.Filter;
import org.elasticsearch.xpack.ql.plan.logical.Limit;
import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan;
Expand Down Expand Up @@ -152,9 +151,10 @@ protected static Batch<LogicalPlan> operators() {
// needs to occur before BinaryComparison combinations (see class)
new org.elasticsearch.xpack.esql.optimizer.OptimizerRules.PropagateEquals(),
new PropagateNullable(),
new OptimizerRules.CombineBinaryComparisons(),
new org.elasticsearch.xpack.esql.optimizer.OptimizerRules.BooleanFunctionEqualsElimination(),
new org.elasticsearch.xpack.esql.optimizer.OptimizerRules.CombineDisjunctionsToIn(),
new SimplifyComparisonsArithmetics(EsqlDataTypes::areCompatible),
new org.elasticsearch.xpack.esql.optimizer.OptimizerRules.SimplifyComparisonsArithmetics(EsqlDataTypes::areCompatible),
// prune/elimination
new PruneFilters(),
new PruneColumns(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

package org.elasticsearch.xpack.esql.optimizer;

import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
Expand Down Expand Up @@ -49,12 +53,17 @@
import org.elasticsearch.xpack.ql.expression.predicate.logical.BinaryLogic;
import org.elasticsearch.xpack.ql.expression.predicate.logical.Not;
import org.elasticsearch.xpack.ql.expression.predicate.logical.Or;
import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.ArithmeticOperation;
import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.BinaryArithmeticOperation;
import org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.BinaryComparisonInversible;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.ql.plan.QueryPlan;
import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.ql.type.DataType;
import org.elasticsearch.xpack.ql.type.DataTypes;
import org.elasticsearch.xpack.ql.util.CollectionUtils;

import java.time.DateTimeException;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Iterator;
Expand All @@ -64,12 +73,21 @@
import java.util.List;
import java.util.Map;
import java.util.Set;

import java.util.function.BiFunction;

import static java.lang.Math.signum;
import static java.util.Arrays.asList;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.ADD;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.DIV;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.MOD;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.MUL;
import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.SUB;
import static org.elasticsearch.xpack.ql.common.Failure.fail;
import static org.elasticsearch.xpack.ql.expression.Literal.FALSE;
import static org.elasticsearch.xpack.ql.expression.Literal.TRUE;
import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.combineOr;
import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.splitOr;
import static org.elasticsearch.xpack.ql.tree.Source.EMPTY;

class OptimizerRules {

Expand Down Expand Up @@ -612,4 +630,217 @@ private static Expression propagate(Or or) {
return updated ? Predicates.combineOr(CollectionUtils.combine(exps, equals, notEquals, inequalities, ranges)) : or;
}
}

/**
* Simplifies arithmetic expressions with BinaryComparisons and fixed point fields, such as: (int + 2) / 3 > 4 => int > 10
*/
public static final class SimplifyComparisonsArithmetics extends
org.elasticsearch.xpack.ql.optimizer.OptimizerRules.OptimizerExpressionRule<EsqlBinaryComparison> {
BiFunction<DataType, DataType, Boolean> typesCompatible;

SimplifyComparisonsArithmetics(BiFunction<DataType, DataType, Boolean> typesCompatible) {
super(org.elasticsearch.xpack.ql.optimizer.OptimizerRules.TransformDirection.UP);
this.typesCompatible = typesCompatible;
}

@Override
protected Expression rule(EsqlBinaryComparison bc) {
// optimize only once the expression has a literal on the right side of the binary comparison
if (bc.right() instanceof Literal) {
if (bc.left() instanceof ArithmeticOperation) {
return simplifyBinaryComparison(bc);
}
if (bc.left() instanceof Neg) {
return foldNegation(bc);
}
}
return bc;
}

private Expression simplifyBinaryComparison(EsqlBinaryComparison comparison) {
EsqlArithmeticOperation operation = (EsqlArithmeticOperation) comparison.left();
BinaryArithmeticOperation function = operation.function();
// Modulo can't be simplified.
if (function.equals(MOD)) {
return comparison;
}
OperationSimplifier simplification = null;
if (isMulOrDiv(function)) {
simplification = new MulDivSimplifier(comparison);
} else if (function.equals(ADD) || function.equals(SUB)) {
simplification = new AddSubSimplifier(comparison);
}

return (simplification == null || simplification.isUnsafe(typesCompatible)) ? comparison : simplification.apply();
}

private static boolean isMulOrDiv(BinaryArithmeticOperation op) {
return op.equals(MUL) || op.equals(DIV);
}

private static Expression foldNegation(EsqlBinaryComparison bc) {
Literal bcLiteral = (Literal) bc.right();
Expression literalNeg = tryFolding(new Neg(bcLiteral.source(), bcLiteral));
return literalNeg == null ? bc : bc.reverse().replaceChildren(asList(((Neg) bc.left()).field(), literalNeg));
}

private static Expression tryFolding(Expression expression) {
if (expression.foldable()) {
try {
expression = new Literal(expression.source(), expression.fold(), expression.dataType());
} catch (ArithmeticException | DateTimeException e) {
// null signals that folding would result in an over-/underflow (such as Long.MAX_VALUE+1); the optimisation is skipped.
expression = null;
}
}
return expression;
}

private abstract static class OperationSimplifier {
final EsqlBinaryComparison comparison;
final Literal bcLiteral;
final EsqlArithmeticOperation operation;
final Expression opLeft;
final Expression opRight;
final Literal opLiteral;

OperationSimplifier(EsqlBinaryComparison comparison) {
this.comparison = comparison;
operation = (EsqlArithmeticOperation) comparison.left();
bcLiteral = (Literal) comparison.right();

opLeft = operation.left();
opRight = operation.right();

if (opLeft instanceof Literal) {
opLiteral = (Literal) opLeft;
} else if (opRight instanceof Literal) {
opLiteral = (Literal) opRight;
} else {
opLiteral = null;
}
}

// can it be quickly fast-tracked that the operation can't be reduced?
final boolean isUnsafe(BiFunction<DataType, DataType, Boolean> typesCompatible) {
if (opLiteral == null) {
// one of the arithm. operands must be a literal, otherwise the operation wouldn't simplify anything
return true;
}

// Only operations on fixed point literals are supported, since optimizing float point operations can also change the
// outcome of the filtering:
// x + 1e18 > 1e18::long will yield different results with a field value in [-2^6, 2^6], optimised vs original;
// x * (1 + 1e-15d) > 1 : same with a field value of (1 - 1e-15d)
// so consequently, int fields optimisation requiring FP arithmetic isn't possible either: (x - 1e-15) * (1 + 1e-15) > 1.
if (opLiteral.dataType().isRational() || bcLiteral.dataType().isRational()) {
return true;
}

// the Literal will be moved to the right of the comparison, but only if data-compatible with what's there
if (typesCompatible.apply(bcLiteral.dataType(), opLiteral.dataType()) == false) {
return true;
}

return isOpUnsafe();
}

final Expression apply() {
// force float point folding for FlP field
Literal bcl = operation.dataType().isRational()
// ? Literal.of(bcLiteral, ((Number) bcLiteral.value()).doubleValue())
? new Literal(bcLiteral.source(), ((Number) bcLiteral.value()).doubleValue(), DataTypes.DOUBLE)
: bcLiteral;

Expression bcRightExpression = ((BinaryComparisonInversible) operation).binaryComparisonInverse()
.create(bcl.source(), bcl, opRight);
bcRightExpression = tryFolding(bcRightExpression);
return bcRightExpression != null
? postProcess((EsqlBinaryComparison) comparison.replaceChildren(List.of(opLeft, bcRightExpression)))
: comparison;
}

// operation-specific operations:
// - fast-tracking of simplification unsafety
abstract boolean isOpUnsafe();

// - post optimisation adjustments
Expression postProcess(EsqlBinaryComparison binaryComparison) {
return binaryComparison;
}
}

private static class AddSubSimplifier extends SimplifyComparisonsArithmetics.OperationSimplifier {

AddSubSimplifier(EsqlBinaryComparison comparison) {
super(comparison);
}

@Override
boolean isOpUnsafe() {
// no ADD/SUB with floating fields
if (operation.dataType().isRational()) {
return true;
}

if (operation.function().equals(SUB) && opRight instanceof Literal == false) { // such as: 1 - x > -MAX
// if next simplification step would fail on overflow anyway, skip the optimisation already
return tryFolding(new Sub(EMPTY, opLeft, bcLiteral)) == null;
}

return false;
}
}

private static class MulDivSimplifier extends SimplifyComparisonsArithmetics.OperationSimplifier {

private final boolean isDiv; // and not MUL.
private final int opRightSign; // sign of the right operand in: (left) (op) (right) (comp) (literal)

MulDivSimplifier(EsqlBinaryComparison comparison) {
super(comparison);
isDiv = operation.function().equals(DIV);
opRightSign = sign(opRight);
}

@Override
boolean isOpUnsafe() {
// Integer divisions are not safe to optimise: x / 5 > 1 <=/=> x > 5 for x in [6, 9]; same for the `==` comp
if (operation.dataType().isInteger() && isDiv) {
return true;
}

// If current operation is a multiplication, it's inverse will be a division: safe only if outcome is still integral.
if (isDiv == false && opLeft.dataType().isInteger()) {
long opLiteralValue = ((Number) opLiteral.value()).longValue();
return opLiteralValue == 0 || ((Number) bcLiteral.value()).longValue() % opLiteralValue != 0;
}

// can't move a 0 in Mul/Div comparisons
return opRightSign == 0;
}

@Override
Expression postProcess(EsqlBinaryComparison binaryComparison) {
// negative multiplication/division changes the direction of the comparison
return opRightSign < 0 ? binaryComparison.reverse() : binaryComparison;
}

private static int sign(Object obj) {
int sign = 1;
if (obj instanceof Number) {
sign = (int) signum(((Number) obj).doubleValue());
} else if (obj instanceof Literal) {
sign = sign(((Literal) obj).value());
} else if (obj instanceof Neg) {
sign = -sign(((Neg) obj).field());
} else if (obj instanceof EsqlArithmeticOperation operation) {
if (isMulOrDiv(operation.function())) {
sign = sign(operation.left()) * sign(operation.right());
}
}
return sign;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ public static boolean isRepresentable(DataType t) {
&& isCounterType(t) == false;
}

@Deprecated
public static boolean areCompatible(DataType left, DataType right) {
if (left == right) {
return true;
Expand Down
Loading