Skip to content

Commit

Permalink
refactored the type cross product logic
Browse files Browse the repository at this point in the history
  • Loading branch information
not-napoleon committed Dec 4, 2023
1 parent f90a713 commit 066b900
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ private void testEvaluate(boolean readFloating) {
assertFalse("expected resolved", expression.typeResolved().unresolved());
expression = new FoldNull().rule(expression);
assertThat(expression.dataType(), equalTo(testCase.expectedType));
// TODO should we convert unsigned_long into BigDecimal so it's easier to assert?
logger.info("Result type: " + expression.dataType());

Object result;
try (ExpressionEvaluator evaluator = evaluator(expression).get(driverContext())) {
try (Block block = evaluator.eval(row(testCase.getDataValues()))) {
Expand All @@ -253,7 +254,7 @@ private void testEvaluate(boolean readFloating) {
private static Object toJavaObjectBigIntegerAware(Block block, int position, DataType expectedType) {
Object result;
result = toJavaObject(block, position);
if (expectedType == DataTypes.UNSIGNED_LONG) {
if (result != null && expectedType == DataTypes.UNSIGNED_LONG) {
assertThat(result, instanceOf(Long.class));
result = NumericUtils.unsignedLongAsBigInteger((Long) result);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,111 @@ public static List<TestCaseSupplier> forBinaryNumericNotCasting(
);
}

public record StuffForNumericType(
Number min,
Number max,
BinaryOperator<Number> expected,
String evaluatorName
) {}

public record AllTheTypeSpecificSettings(
StuffForNumericType intStuff,
StuffForNumericType longStuff,
StuffForNumericType ulongStuff,
StuffForNumericType doubleStuff
) {
public StuffForNumericType get(DataType type) {
if (type == DataTypes.INTEGER) {
return intStuff;
}
if (type == DataTypes.LONG) {
return longStuff;
}
if (type == DataTypes.UNSIGNED_LONG) {
return ulongStuff;
}
if (type == DataTypes.DOUBLE) {
return doubleStuff;
}
throw new IllegalArgumentException("bogus numeric type [" + type + "]");
}
}

private static DataType widen(DataType lhs, DataType rhs) {
if (lhs == rhs) {
return lhs;
}
if (lhs == DataTypes.DOUBLE || rhs == DataTypes.DOUBLE) {
return DataTypes.DOUBLE;
}
if (lhs == DataTypes.UNSIGNED_LONG || rhs == DataTypes.UNSIGNED_LONG) {
return DataTypes.UNSIGNED_LONG;
}
if (lhs == DataTypes.LONG || rhs == DataTypes.LONG) {
return DataTypes.LONG;
}
throw new IllegalArgumentException("Invalid numeric widening lhs: [" + lhs + "] rhs: [" + rhs + "]");
}

private static List<TypedDataSupplier> getSuppliersForNumericType(DataType type, Number min, Number max) {
if (type == DataTypes.INTEGER) {
return intCases(min.intValue(), max.intValue());
}
if (type == DataTypes.LONG) {
return longCases(min.longValue(), max.longValue());
}
if (type == DataTypes.UNSIGNED_LONG) {
return ulongCases(
min instanceof BigInteger ? (BigInteger) min : BigInteger.valueOf(Math.max(min.longValue(), 0L)),
max instanceof BigInteger ? (BigInteger) max : BigInteger.valueOf(Math.max(max.longValue(), 0L))
);
}
if (type == DataTypes.DOUBLE) {
return doubleCases(min.doubleValue(), max.doubleValue());
}
throw new IllegalArgumentException("bogus numeric type [" + type + "]");
}

public static List<TestCaseSupplier> forBinaryWithWidening(
AllTheTypeSpecificSettings typeStuff,
String lhsName,
String rhsName,
List<String> warnings
) {
List<TestCaseSupplier> suppliers = new ArrayList<>();
// TODO: Surely this list exists elsewhere already NOCOMMIT
List<DataType> numericTypes = List.of(DataTypes.INTEGER, DataTypes.LONG, DataTypes.UNSIGNED_LONG, DataTypes.DOUBLE);

for (DataType lhsType : numericTypes) {
for (DataType rhsType : numericTypes) {
DataType expected = widen(lhsType, rhsType);
StuffForNumericType expectedTypeStuff = typeStuff.get(expected);
String evaluator = expectedTypeStuff.evaluatorName()
+ "["
+ lhsName
+ "="
+ getCastEvaluator("Attribute[channel=0]", lhsType, expected)
+ ", "
+ rhsName
+ "="
+ getCastEvaluator("Attribute[channel=1]", rhsType, expected)
+ "]";
casesCrossProduct(
(l, r) -> expectedTypeStuff.expected().apply((Number) l, (Number) r),
getSuppliersForNumericType(lhsType, expectedTypeStuff.min(), expectedTypeStuff.max()),
getSuppliersForNumericType(rhsType, expectedTypeStuff.min(), expectedTypeStuff.max()),
// TODO: This doesn't really need to be a function
(lt, rt) -> evaluator,
warnings,
suppliers,
expected
);
}
}

return suppliers;
}

public static List<TestCaseSupplier> forBinaryNotCasting(
String name,
String lhsName,
Expand All @@ -279,18 +384,6 @@ public static List<TestCaseSupplier> forBinaryNotCasting(
suppliers,
expectedType
);
if (symetric) {
// reverse lhs and rhs suppliers
casesCrossProduct(
expected,
rhsSuppliers,
lhsSuppliers,
(lhsType, rhsType) -> name + "[" + lhsName + "=Attribute[channel=0], " + rhsName + "=Attribute[channel=1]]",
warnings,
suppliers,
expectedType
);
}
return suppliers;
}

Expand Down Expand Up @@ -810,6 +903,54 @@ public static List<TypedDataSupplier> versionCases(String prefix) {
);
}

private static String getCastEvaluator(String original, DataType current, DataType target) {
if (current == target) {
return original;
}
if (target == DataTypes.LONG) {
return castToLongEvaluator(original, current);
}
if (target == DataTypes.UNSIGNED_LONG) {
return castToUnsignedLongEvaluator(original, current);
}
if (target == DataTypes.DOUBLE) {
return castToDoubleEvaluator(original, current);
}
throw new IllegalArgumentException("Invalid numeric cast to [" + target + "]");
}

private static String castToLongEvaluator(String original, DataType current) {
if (current == DataTypes.LONG) {
return original;
}
if (current == DataTypes.INTEGER) {
return "CastIntToLongEvaluator[v=" + original + "]";
}
if (current == DataTypes.DOUBLE) {
return "CastDoubleToLongEvaluator[v=" + original + "]";
}
if (current == DataTypes.UNSIGNED_LONG) {
return "CastUnsignedLongToLong[v=" + original + "]";
}
throw new UnsupportedOperationException();
}

private static String castToUnsignedLongEvaluator(String original, DataType current) {
if (current == DataTypes.UNSIGNED_LONG) {
return original;
}
if (current == DataTypes.INTEGER) {
return "CastIntToUnsignedLongEvaluator[v=" + original + "]";
}
if (current == DataTypes.LONG) {
return "CastLongToUnsignedLongEvaluator[v=" + original + "]";
}
if (current == DataTypes.DOUBLE) {
return "CastDoubleToUnsignedLongEvaluator[v=" + original + "]";
}
throw new UnsupportedOperationException();
}

private static String castToDoubleEvaluator(String original, DataType current) {
if (current == DataTypes.DOUBLE) {
return original;
Expand Down Expand Up @@ -954,6 +1095,9 @@ public TypedData(Object data, String name) {

@Override
public String toString() {
if (type == DataTypes.UNSIGNED_LONG && data != null) {
return type.toString() + "(" + NumericUtils.unsignedLongAsBigInteger((Long) data).toString() + ")";
}
return type.toString() + "(" + (data == null ? "null" : data.toString()) + ")";
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,55 +44,44 @@ public AddTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSu
public static Iterable<Object[]> parameters() {
List<TestCaseSupplier> suppliers = new ArrayList<>();
suppliers.addAll(
TestCaseSupplier.forBinaryNumericNotCasting(
"AddIntsEvaluator",
"lhs",
"rhs",
(l, r) -> l.intValue() + r.intValue(),
DataTypes.INTEGER,
TestCaseSupplier.intCases((Integer.MIN_VALUE >> 1) - 1, (Integer.MAX_VALUE >> 1) - 1),
TestCaseSupplier.intCases((Integer.MIN_VALUE >> 1) - 1, (Integer.MAX_VALUE >> 1) - 1),
List.of(),
false
)
);
suppliers.addAll(
TestCaseSupplier.forBinaryNumericNotCasting(
"AddLongsEvaluator",
"lhs",
"rhs",
(l, r) -> l.longValue() + r.longValue(),
DataTypes.LONG,
TestCaseSupplier.longCases((Long.MIN_VALUE >> 1) - 1, (Long.MAX_VALUE >> 1) - 1),
TestCaseSupplier.longCases((Long.MIN_VALUE >> 1) - 1, (Long.MAX_VALUE >> 1) - 1),
List.of(),
false
)
);
suppliers.addAll(
TestCaseSupplier.forBinaryNumericNotCasting(
"AddDoublesEvaluator",
TestCaseSupplier.forBinaryWithWidening(
new TestCaseSupplier.AllTheTypeSpecificSettings(
new TestCaseSupplier.StuffForNumericType(
(Integer.MIN_VALUE >> 1) - 1,
(Integer.MAX_VALUE >> 1) - 1,
(l, r) -> l.intValue() + r.intValue(),
"AddIntsEvaluator"
),
new TestCaseSupplier.StuffForNumericType(
(Long.MIN_VALUE >> 1) - 1,
(Long.MAX_VALUE >> 1) - 1,
(l, r) -> l.longValue() + r.longValue(),
"AddLongsEvaluator"
),
new TestCaseSupplier.StuffForNumericType(
BigInteger.ONE,
BigInteger.valueOf(Long.MAX_VALUE),
(l, r) -> {
BigInteger bigL = l instanceof BigInteger ? (BigInteger) l : BigInteger.valueOf(l.longValue());
BigInteger bigR = r instanceof BigInteger ? (BigInteger) r : BigInteger.valueOf(r.longValue());
return bigL.add(bigR);
},
"AddUnsignedLongsEvaluator"
),
new TestCaseSupplier.StuffForNumericType(
Double.NEGATIVE_INFINITY,
Double.POSITIVE_INFINITY,
(l, r) -> l.doubleValue() + r.doubleValue(),
"AddDoublesEvaluator"
)
),
"lhs",
"rhs",
(l, r) -> l.doubleValue() + r.doubleValue(),
DataTypes.DOUBLE,
TestCaseSupplier.doubleCases(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY),
TestCaseSupplier.doubleCases(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY),
List.of(),
false
List.of()
)
);
suppliers.addAll(TestCaseSupplier.forBinaryNumericNotCasting("AddUnsignedLongsEvaluator", "lhs", "rhs", (l, r) -> {
assert l instanceof BigInteger;
assert r instanceof BigInteger;
return ((BigInteger) l).add((BigInteger) r);
},
DataTypes.UNSIGNED_LONG,
TestCaseSupplier.ulongCases(BigInteger.ONE, BigInteger.valueOf(Long.MAX_VALUE)),
TestCaseSupplier.ulongCases(BigInteger.ONE, BigInteger.valueOf(Long.MAX_VALUE)),
List.of(),
false
));

// Datetime Cases
suppliers.addAll(
TestCaseSupplier.forBinaryNotCasting(
// TODO: There is an evaluator for Datetime + Period, so it should be tested. Similarly below.
Expand Down Expand Up @@ -171,6 +160,8 @@ public static Iterable<Object[]> parameters() {
false
)
);

// Cases that should generate warnings
suppliers.addAll(List.of(new TestCaseSupplier("MV", () -> {
// Ensure we don't have an overflow
int rhs = randomIntBetween((Integer.MIN_VALUE >> 1) - 1, (Integer.MAX_VALUE >> 1) - 1);
Expand Down

0 comments on commit 066b900

Please sign in to comment.