Skip to content

Commit

Permalink
Implement EvaluatorMapper for Equals
Browse files Browse the repository at this point in the history
  • Loading branch information
not-napoleon committed Feb 5, 2024
1 parent 24079b1 commit 698c041
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,27 @@

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.ann.Evaluator;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
import org.elasticsearch.xpack.esql.expression.EsqlTypeResolutions;
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Cast;
import org.elasticsearch.xpack.esql.type.EsqlDataTypeRegistry;
import org.elasticsearch.xpack.esql.type.EsqlDataTypes;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.TypeResolutions;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.ql.tree.NodeInfo;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataType;
import org.elasticsearch.xpack.ql.type.DataTypes;

import java.time.ZoneId;
import java.util.function.Function;

import static org.elasticsearch.xpack.ql.expression.TypeResolutions.ParamOrdinal.DEFAULT;

public class Equals extends org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.Equals {
public class Equals extends org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.Equals implements EvaluatorMapper {
public Equals(Source source, Expression left, Expression right) {
super(source, left, right);
}
Expand Down Expand Up @@ -82,4 +91,38 @@ static boolean processBools(boolean lhs, boolean rhs) {
static boolean processGeometries(BytesRef lhs, BytesRef rhs) {
return lhs.equals(rhs);
}

@Override
public EvalOperator.ExpressionEvaluator.Factory toEvaluator(Function<Expression, EvalOperator.ExpressionEvaluator.Factory> toEvaluator) {
// Our type is always boolean, so figure out the evaluator type from the inputs
DataType commonType = EsqlDataTypeRegistry.INSTANCE.commonType(left().dataType(), right().dataType());
var lhs = Cast.cast(source(), left().dataType(), commonType, toEvaluator.apply(left()));
var rhs = Cast.cast(source(), right().dataType(), commonType, toEvaluator.apply(right()));
if (DataTypes.isDateTime(commonType)) {
return new EqualsLongsEvaluator.Factory(source(), lhs, rhs);
}
if (EsqlDataTypes.isSpatial(commonType)) {
return new EqualsGeometriesEvaluator.Factory(source(), lhs, rhs);
}
if (commonType.equals(DataTypes.INTEGER)) {
return new EqualsIntsEvaluator.Factory(source(), lhs, rhs);
}
if (commonType.equals(DataTypes.LONG) || commonType.equals(DataTypes.UNSIGNED_LONG)) {
return new EqualsLongsEvaluator.Factory(source(), lhs, rhs);
}
if (commonType.equals(DataTypes.DOUBLE)) {
return new EqualsDoublesEvaluator.Factory(source(), lhs, rhs);
}
if (commonType.equals(DataTypes.BOOLEAN)) {
return new EqualsBoolsEvaluator.Factory(source(), lhs, rhs);
}
if (DataTypes.isString(commonType)) {
return new EqualsKeywordsEvaluator.Factory(source(), lhs, rhs);
}
throw new EsqlIllegalArgumentException("Unsupported type " + left().dataType());
}
@Override
public Boolean fold() {
return (Boolean) EvaluatorMapper.super.fold();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,20 @@
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase;
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataType;
import org.elasticsearch.xpack.ql.type.DataTypes;
import org.hamcrest.Matcher;

import java.math.BigInteger;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;

import static org.hamcrest.Matchers.equalTo;

public class EqualsTests extends AbstractBinaryComparisonTestCase {
public class EqualsTests extends AbstractFunctionTestCase {
public EqualsTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
this.testCase = testCaseSupplier.get();
}
Expand Down Expand Up @@ -79,36 +76,8 @@ public static Iterable<Object[]> parameters() {
List.of()
)
);
// Datetime, Period/Duration Cases

/*
For some reason, DatePeriods aren't working. Will investigate after fixing double folding.
suppliers.addAll(
TestCaseSupplier.forBinaryNotCasting(
"No evaluator, the tests only trigger the folding code since Period is not representable",
"lhs",
"rhs",
Object::equals,
DataTypes.BOOLEAN,
TestCaseSupplier.datePeriodCases(),
TestCaseSupplier.datePeriodCases(),
List.of()
)
);
*/
suppliers.addAll(
TestCaseSupplier.forBinaryNotCasting(
"No evaluator, the tests only trigger the folding code since Duration is not representable",
"lhs",
"rhs",
Object::equals,
DataTypes.BOOLEAN,
TestCaseSupplier.timeDurationCases(),
TestCaseSupplier.timeDurationCases(),
List.of()
)
);

// Datetime
// TODO: I'm surprised this passes. Shouldn't there be a cast from DateTime to Long?
suppliers.addAll(
TestCaseSupplier.forBinaryNotCasting(
"EqualsLongsEvaluator",
Expand Down Expand Up @@ -165,17 +134,7 @@ public static Iterable<Object[]> parameters() {
}

@Override
protected <T extends Comparable<T>> Matcher<Object> resultMatcher(T lhs, T rhs) {
return equalTo(lhs.equals(rhs));
}

@Override
protected BinaryComparison build(Source source, Expression lhs, Expression rhs) {
return new Equals(source, lhs, rhs);
}

@Override
protected boolean isEquality() {
return true;
protected Expression build(Source source, List<Expression> args) {
return new Equals(source, args.get(0), args.get(1));
}
}

0 comments on commit 698c041

Please sign in to comment.