Skip to content

Commit

Permalink
add argument compatibility checking
Browse files Browse the repository at this point in the history
  • Loading branch information
not-napoleon committed Feb 6, 2024
1 parent 25e112d commit 92ee423
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,23 @@
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Cast;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation;
import org.elasticsearch.xpack.esql.type.EsqlDataTypeRegistry;
import org.elasticsearch.xpack.esql.type.EsqlDataTypes;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.TypeResolutions;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparisonProcessor;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataType;
import org.elasticsearch.xpack.ql.type.DataTypes;

import java.time.ZoneId;
import java.util.Map;
import java.util.function.Function;

import static org.elasticsearch.xpack.ql.type.DataTypes.UNSIGNED_LONG;

public abstract class EsqlBinaryComparison extends BinaryComparison implements EvaluatorMapper {

private final Map<DataType, BinaryEvaluator> evaluatorMap;
Expand All @@ -38,6 +44,7 @@ protected EsqlBinaryComparison(
) {
this(source, left, right, operation, null, evaluatorMap);
}

protected EsqlBinaryComparison(
Source source,
Expression left,
Expand Down Expand Up @@ -71,6 +78,51 @@ public Boolean fold() {
return (Boolean) EvaluatorMapper.super.fold();
}

@Override
protected TypeResolution resolveType() {
TypeResolution typeResolution = super.resolveType();
if (typeResolution.unresolved()) {
return typeResolution;
}

return checkCompatibility();
}

@Override
protected TypeResolution resolveInputType(Expression e, TypeResolutions.ParamOrdinal paramOrdinal) {
return TypeResolutions.isType(
e,
evaluatorMap::containsKey,
sourceText(),
paramOrdinal,
evaluatorMap.keySet().stream().map(DataType::typeName).toArray(String[]::new)
);
}

/**
* Check if the two input types are compatible for this operation
*
* @return TypeResolution.TYPE_RESOLVED iff the types are compatible. Otherwise, an appropriate type resolution error.
*/
protected TypeResolution checkCompatibility() {
DataType leftType = left().dataType();
DataType rightType = right().dataType();

// Unsigned long is only interoperable with other unsigned longs
if ((rightType == UNSIGNED_LONG && (false == (leftType == UNSIGNED_LONG || leftType == DataTypes.NULL)))
|| (leftType == UNSIGNED_LONG && (false == (rightType == UNSIGNED_LONG || rightType == DataTypes.NULL)))) {
return new TypeResolution(EsqlArithmeticOperation.formatIncompatibleTypesMessage(symbol(), leftType, rightType));
}

if ((leftType.isNumeric() && rightType.isNumeric())
|| leftType.equals(rightType)
|| DataTypes.isNull(leftType)
|| DataTypes.isNull(rightType)) {
return TypeResolution.TYPE_RESOLVED;
}
return new TypeResolution(EsqlArithmeticOperation.formatIncompatibleTypesMessage(symbol(), leftType, rightType));
}

// NOCOMMIT: This is the same as EsqlArithmeticOperation#ArithmeticEvaluator, and they should be refactored to the same place
@FunctionalInterface
interface BinaryEvaluator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import static org.elasticsearch.xpack.ql.type.DataTypes.LONG;
import static org.elasticsearch.xpack.ql.type.DataTypes.UNSIGNED_LONG;

abstract class EsqlArithmeticOperation extends ArithmeticOperation implements EvaluatorMapper {
public abstract class EsqlArithmeticOperation extends ArithmeticOperation implements EvaluatorMapper {

/**
* The only role of this enum is to fit the super constructor that expects a BinaryOperation which is
Expand Down Expand Up @@ -139,7 +139,7 @@ protected TypeResolution checkCompatibility() {
return TypeResolution.TYPE_RESOLVED;
}

static String formatIncompatibleTypesMessage(String symbol, DataType leftType, DataType rightType) {
public static String formatIncompatibleTypesMessage(String symbol, DataType leftType, DataType rightType) {
return format(null, "[{}] has arguments with incompatible types [{}] and [{}]", symbol, leftType.typeName(), rightType.typeName());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.function.Supplier;

public class EqualsTests extends AbstractFunctionTestCase {
Expand Down Expand Up @@ -113,7 +114,7 @@ public static Iterable<Object[]> parameters() {
)
);
// Datetime
// TODO: I'm surprised this passes. Shouldn't there be a cast from DateTime to Long?
// TODO: I'm surprised this passes. Shouldn't there be a cast from DateTime to Long?
suppliers.addAll(
TestCaseSupplier.forBinaryNotCasting(
"EqualsLongsEvaluator",
Expand Down Expand Up @@ -165,8 +166,46 @@ public static Iterable<Object[]> parameters() {
List.of()
)
);
suppliers.addAll(
TestCaseSupplier.forBinaryNotCasting(
"EqualsGeometriesEvaluator",
"lhs",
"rhs",
Object::equals,
DataTypes.BOOLEAN,
TestCaseSupplier.cartesianPointCases(),
TestCaseSupplier.cartesianPointCases(),
List.of()
)
);

return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(anyNullIsNull(true, suppliers)));
suppliers.addAll(
TestCaseSupplier.forBinaryNotCasting(
"EqualsGeometriesEvaluator",
"lhs",
"rhs",
Object::equals,
DataTypes.BOOLEAN,
TestCaseSupplier.cartesianShapeCases(),
TestCaseSupplier.cartesianShapeCases(),
List.of()
)
);


return parameterSuppliersFromTypedData(
errorsForCasesWithoutExamples(anyNullIsNull(true, suppliers), EqualsTests::errorMessageString)
);
}

private static String errorMessageString(boolean includeOrdinal, List<Set<DataType>> validPerPosition, List<DataType> types) {
try {
return typeErrorMessage(includeOrdinal, validPerPosition, types);
} catch (IllegalStateException e) {
// This means all the positional args were okay, so the expected error is from the combination
return "[==] has arguments with incompatible types [" + types.get(0).typeName() + "] and [" + types.get(1).typeName() + "]";

}
}

@Override
Expand Down

0 comments on commit 92ee423

Please sign in to comment.