diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/Equals.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/Equals.java index 9fb899b8e36df..011b70c46c3a7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/Equals.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/predicate/operator/comparison/Equals.java @@ -8,18 +8,27 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.ann.Evaluator; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; +import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; import org.elasticsearch.xpack.esql.expression.EsqlTypeResolutions; +import org.elasticsearch.xpack.esql.expression.function.scalar.math.Cast; +import org.elasticsearch.xpack.esql.type.EsqlDataTypeRegistry; +import org.elasticsearch.xpack.esql.type.EsqlDataTypes; import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.expression.TypeResolutions; import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.ql.tree.NodeInfo; import org.elasticsearch.xpack.ql.tree.Source; +import org.elasticsearch.xpack.ql.type.DataType; +import org.elasticsearch.xpack.ql.type.DataTypes; import java.time.ZoneId; +import java.util.function.Function; import static org.elasticsearch.xpack.ql.expression.TypeResolutions.ParamOrdinal.DEFAULT; -public class Equals extends org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.Equals { +public class Equals extends org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.Equals implements EvaluatorMapper { public Equals(Source source, Expression left, Expression right) { super(source, left, right); } @@ -82,4 +91,38 @@ static boolean processBools(boolean lhs, boolean rhs) { static boolean processGeometries(BytesRef lhs, BytesRef rhs) { return lhs.equals(rhs); } + + @Override + public EvalOperator.ExpressionEvaluator.Factory toEvaluator(Function toEvaluator) { + // Our type is always boolean, so figure out the evaluator type from the inputs + DataType commonType = EsqlDataTypeRegistry.INSTANCE.commonType(left().dataType(), right().dataType()); + var lhs = Cast.cast(source(), left().dataType(), commonType, toEvaluator.apply(left())); + var rhs = Cast.cast(source(), right().dataType(), commonType, toEvaluator.apply(right())); + if (DataTypes.isDateTime(commonType)) { + return new EqualsLongsEvaluator.Factory(source(), lhs, rhs); + } + if (EsqlDataTypes.isSpatial(commonType)) { + return new EqualsGeometriesEvaluator.Factory(source(), lhs, rhs); + } + if (commonType.equals(DataTypes.INTEGER)) { + return new EqualsIntsEvaluator.Factory(source(), lhs, rhs); + } + if (commonType.equals(DataTypes.LONG) || commonType.equals(DataTypes.UNSIGNED_LONG)) { + return new EqualsLongsEvaluator.Factory(source(), lhs, rhs); + } + if (commonType.equals(DataTypes.DOUBLE)) { + return new EqualsDoublesEvaluator.Factory(source(), lhs, rhs); + } + if (commonType.equals(DataTypes.BOOLEAN)) { + return new EqualsBoolsEvaluator.Factory(source(), lhs, rhs); + } + if (DataTypes.isString(commonType)) { + return new EqualsKeywordsEvaluator.Factory(source(), lhs, rhs); + } + throw new EsqlIllegalArgumentException("Unsupported type " + left().dataType()); + } + @Override + public Boolean fold() { + return (Boolean) EvaluatorMapper.super.fold(); + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EqualsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EqualsTests.java index d62041b61c983..6c222032cb94f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EqualsTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EqualsTests.java @@ -11,23 +11,20 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction; import org.elasticsearch.xpack.ql.expression.Expression; -import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.ql.tree.Source; import org.elasticsearch.xpack.ql.type.DataType; import org.elasticsearch.xpack.ql.type.DataTypes; -import org.hamcrest.Matcher; import java.math.BigInteger; import java.util.ArrayList; import java.util.List; import java.util.function.Supplier; -import static org.hamcrest.Matchers.equalTo; - -public class EqualsTests extends AbstractBinaryComparisonTestCase { +public class EqualsTests extends AbstractFunctionTestCase { public EqualsTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @@ -79,36 +76,8 @@ public static Iterable parameters() { List.of() ) ); - // Datetime, Period/Duration Cases - - /* - For some reason, DatePeriods aren't working. Will investigate after fixing double folding. - suppliers.addAll( - TestCaseSupplier.forBinaryNotCasting( - "No evaluator, the tests only trigger the folding code since Period is not representable", - "lhs", - "rhs", - Object::equals, - DataTypes.BOOLEAN, - TestCaseSupplier.datePeriodCases(), - TestCaseSupplier.datePeriodCases(), - List.of() - ) - ); - */ - suppliers.addAll( - TestCaseSupplier.forBinaryNotCasting( - "No evaluator, the tests only trigger the folding code since Duration is not representable", - "lhs", - "rhs", - Object::equals, - DataTypes.BOOLEAN, - TestCaseSupplier.timeDurationCases(), - TestCaseSupplier.timeDurationCases(), - List.of() - ) - ); - + // Datetime + // TODO: I'm surprised this passes. Shouldn't there be a cast from DateTime to Long? suppliers.addAll( TestCaseSupplier.forBinaryNotCasting( "EqualsLongsEvaluator", @@ -165,17 +134,7 @@ public static Iterable parameters() { } @Override - protected > Matcher resultMatcher(T lhs, T rhs) { - return equalTo(lhs.equals(rhs)); - } - - @Override - protected BinaryComparison build(Source source, Expression lhs, Expression rhs) { - return new Equals(source, lhs, rhs); - } - - @Override - protected boolean isEquality() { - return true; + protected Expression build(Source source, List args) { + return new Equals(source, args.get(0), args.get(1)); } }