Skip to content

Commit

Permalink
Fix ulong tests
Browse files Browse the repository at this point in the history
  • Loading branch information
not-napoleon committed Nov 30, 2023
1 parent 04b5211 commit 12981c3
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import org.elasticsearch.xpack.ql.type.DataType;
import org.elasticsearch.xpack.ql.type.DataTypes;
import org.elasticsearch.xpack.ql.type.EsField;
import org.elasticsearch.xpack.ql.util.NumericUtils;
import org.elasticsearch.xpack.ql.util.StringUtils;
import org.elasticsearch.xpack.versionfield.Version;
import org.hamcrest.Matcher;
Expand Down Expand Up @@ -88,6 +89,7 @@
import static org.elasticsearch.xpack.esql.SerializationTestUtils.assertSerialization;
import static org.hamcrest.Matchers.either;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.sameInstance;
Expand Down Expand Up @@ -234,7 +236,7 @@ private void testEvaluate(boolean readFloating) {
Object result;
try (ExpressionEvaluator evaluator = evaluator(expression).get(driverContext())) {
try (Block block = evaluator.eval(row(testCase.getDataValues()))) {
result = toJavaObject(block, 0);
result = toJavaObjectBigIntegerAware(block, 0, testCase.expectedType);
}
}
assertThat(result, not(equalTo(Double.NaN)));
Expand All @@ -248,6 +250,16 @@ private void testEvaluate(boolean readFloating) {
}
}

private static Object toJavaObjectBigIntegerAware(Block block, int position, DataType expectedType) {
Object result;
result = toJavaObject(block, position);
if (expectedType == DataTypes.UNSIGNED_LONG) {
assertThat(result, instanceOf(Long.class));
result = NumericUtils.unsignedLongAsBigInteger((Long) result);
}
return result;
}

/**
* Evaluates a {@link Block} of values, all copied from the input pattern, read directly from the page.
* <p>
Expand Down Expand Up @@ -391,7 +403,7 @@ private void testEvaluateBlock(BlockFactory inputBlockFactory, DriverContext con
assertThat(toJavaObject(block, p), allNullsMatcher());
continue;
}
assertThat(toJavaObject(block, p), testCase.getMatcher());
assertThat(toJavaObjectBigIntegerAware(block, p, testCase.expectedType), testCase.getMatcher());
}
assertThat(
"evaluates to tracked block",
Expand Down Expand Up @@ -456,7 +468,7 @@ public final void testEvaluateInManyThreads() throws ExecutionException, Interru
try (EvalOperator.ExpressionEvaluator eval = evalSupplier.get(driverContext())) {
for (int c = 0; c < count; c++) {
try (Block block = eval.eval(page)) {
assertThat(toJavaObject(block, 0), testCase.getMatcher());
assertThat(toJavaObjectBigIntegerAware(block, 0, testCase.expectedType), testCase.getMatcher());
}
}
}
Expand Down Expand Up @@ -497,7 +509,11 @@ public final void testFold() {
expression = new FoldNull().rule(expression);
assertThat(expression.dataType(), equalTo(testCase.expectedType));
assertTrue(expression.foldable());
assertThat(expression.fold(), testCase.getMatcher());
Object result = expression.fold();
if (testCase.expectedType == DataTypes.UNSIGNED_LONG) {
result = NumericUtils.unsignedLongAsBigInteger((Long) result);
}
assertThat(result, testCase.getMatcher());
if (testCase.getExpectedWarnings() != null) {
assertWarnings(testCase.getExpectedWarnings());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ private static void casesCrossProduct(
List.of(lhsTyped, rhsTyped),
evaluatorToString.apply(lhsSupplier.type(), rhsSupplier.type()),
expectedType,
equalTo(expected.apply(lhs.doubleValue(), rhs.doubleValue()))
equalTo(expected.apply(lhs, rhs))
);
for (String warning : warnings) {
testCase = testCase.withWarning(warning);
Expand Down Expand Up @@ -240,8 +240,8 @@ public static List<TestCaseSupplier> forBinaryNumericNotCasting(
DataType expectedType,
List<TypedDataSupplier> lhsSuppliers,
List<TypedDataSupplier> rhsSuppliers,
List<String> warnings

List<String> warnings,
boolean symetric
) {
List<TestCaseSupplier> suppliers = new ArrayList<>();
casesCrossProduct(
Expand All @@ -253,6 +253,18 @@ public static List<TestCaseSupplier> forBinaryNumericNotCasting(
suppliers,
expectedType
);
if (symetric) {
// reverse lhs and rhs suppliers
casesCrossProduct(
expected,
rhsSuppliers,
lhsSuppliers,
(lhsType, rhsType) -> name + "[" + lhsName + "=Attribute[channel=0], " + rhsName + "=Attribute[channel=1]]",
warnings,
suppliers,
expectedType
);
}
return suppliers;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,35 +53,51 @@ public static Iterable<Object[]> parameters() {
DataTypes.INTEGER,
TestCaseSupplier.intCases((Integer.MIN_VALUE >> 1) - 1, (Integer.MAX_VALUE >> 1) - 1),
TestCaseSupplier.intCases((Integer.MIN_VALUE >> 1) - 1, (Integer.MAX_VALUE >> 1) - 1),
List.of()
)
List.of(),
false)
);
suppliers.addAll(List.of(new TestCaseSupplier("Long + Long", () -> {
// Ensure we don't have an overflow
long rhs = randomLongBetween((Long.MIN_VALUE >> 1) - 1, (Long.MAX_VALUE >> 1) - 1);
long lhs = randomLongBetween((Long.MIN_VALUE >> 1) - 1, (Long.MAX_VALUE >> 1) - 1);
return new TestCaseSupplier.TestCase(
List.of(
new TestCaseSupplier.TypedData(lhs, DataTypes.LONG, "lhs"),
new TestCaseSupplier.TypedData(rhs, DataTypes.LONG, "rhs")
),
"AddLongsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]",
suppliers.addAll(
TestCaseSupplier.forBinaryNumericNotCasting(
"AddLongsEvaluator",
"lhs",
"rhs",
(l, r) -> l.longValue() + r.longValue(),
DataTypes.LONG,
equalTo(lhs + rhs)
);
}), new TestCaseSupplier("Double + Double", () -> {
double rhs = randomDouble();
double lhs = randomDouble();
return new TestCaseSupplier.TestCase(
List.of(
new TestCaseSupplier.TypedData(lhs, DataTypes.DOUBLE, "lhs"),
new TestCaseSupplier.TypedData(rhs, DataTypes.DOUBLE, "rhs")
),
"AddDoublesEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]",
TestCaseSupplier.longCases((Long.MIN_VALUE >> 1) - 1, (Long.MAX_VALUE >> 1) - 1),
TestCaseSupplier.longCases((Long.MIN_VALUE >> 1) - 1, (Long.MAX_VALUE >> 1) - 1),
List.of(),
false)
);
suppliers.addAll(
TestCaseSupplier.forBinaryNumericNotCasting(
"AddDoublesEvaluator",
"lhs",
"rhs",
(l, r) -> l.doubleValue() + r.doubleValue(),
DataTypes.DOUBLE,
equalTo(lhs + rhs)
);
})/*, new TestCaseSupplier("ULong + ULong", () -> {
TestCaseSupplier.doubleCases(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY),
TestCaseSupplier.doubleCases(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY),
List.of(),
false)
);
suppliers.addAll(
TestCaseSupplier.forBinaryNumericNotCasting(
"AddUnsignedLongsEvaluator",
"lhs",
"rhs",
(l, r) -> {
assert l instanceof BigInteger;
assert r instanceof BigInteger;
return ((BigInteger)l).add((BigInteger) r);
},
DataTypes.UNSIGNED_LONG,
// TODO: we should be able to test values over Long.MAX_VALUE too...
TestCaseSupplier.ulongCases(BigInteger.ONE, BigInteger.valueOf(Long.MAX_VALUE)),
TestCaseSupplier.ulongCases(BigInteger.ONE, BigInteger.valueOf(Long.MAX_VALUE)),
List.of(),
false)
);
suppliers.addAll(List.of(/*, new TestCaseSupplier("ULong + ULong", () -> {
// Ensure we don't have an overflow
// TODO: we should be able to test values over Long.MAX_VALUE too...
long rhs = randomLongBetween(0, (Long.MAX_VALUE >> 1) - 1);
Expand All @@ -94,7 +110,7 @@ public static Iterable<Object[]> parameters() {
"AddUnsignedLongsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]",
equalTo(asLongUnsigned(lhsBI.add(rhsBI).longValue()))
);
}) */, new TestCaseSupplier("Datetime + Period", () -> {
}) */ new TestCaseSupplier("Datetime + Period", () -> {
long lhs = (Long) randomLiteral(DataTypes.DATETIME).value();
Period rhs = (Period) randomLiteral(EsqlDataTypes.DATE_PERIOD).value();
return new TestCaseSupplier.TestCase(
Expand Down

0 comments on commit 12981c3

Please sign in to comment.