Skip to content

Commit

Permalink
feat: decimal math with other numbers (#3001)
Browse files Browse the repository at this point in the history
  • Loading branch information
agavra authored Jun 24, 2019
1 parent 9a634b8 commit 14d2bb7
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 6 deletions.
22 changes: 22 additions & 0 deletions ksql-common/src/main/java/io/confluent/ksql/util/DecimalUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,28 @@ public static BigDecimal ensureFit(final BigDecimal value, final Schema schema)
}
}

/**
* Converts a schema to a decimal schema with set precision/scale without losing
* scale or precision.
*
* @param schema the schema
* @return the decimal schema
* @throws KsqlException if the schema cannot safely be converted to decimal
*/
public static Schema toDecimal(final Schema schema) {
switch (schema.type()) {
case BYTES:
requireDecimal(schema);
return schema;
case INT32:
return builder(10, 0).build();
case INT64:
return builder(19, 0).build();
default:
throw new KsqlException("Cannot convert schema of type " + schema.type() + " to decimal.");
}
}

public static BigDecimal cast(final long value, final int precision, final int scale) {
validateParameters(precision, scale);
final BigDecimal decimal = new BigDecimal(value, new MathContext(precision));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,12 @@ public static Schema resolveBinaryOperatorResultType(
return Schema.OPTIONAL_STRING_SCHEMA;
}

if (DecimalUtil.isDecimal(left) && DecimalUtil.isDecimal(right)) {
return resolveDecimalOperatorResultType(left, right, operator);
if (DecimalUtil.isDecimal(left) || DecimalUtil.isDecimal(right)) {
if (left.type() != Schema.Type.FLOAT64 && right.type() != Schema.Type.FLOAT64) {
return resolveDecimalOperatorResultType(
DecimalUtil.toDecimal(left), DecimalUtil.toDecimal(right), operator);
}
return Schema.OPTIONAL_FLOAT64_SCHEMA;
}

if (!TYPE_TO_SCHEMA.containsKey(left.type()) || !TYPE_TO_SCHEMA.containsKey(right.type())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.math.BigDecimal;
import org.apache.kafka.connect.data.Decimal;
import org.apache.kafka.connect.data.Schema;
import org.hamcrest.Matchers;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
Expand Down Expand Up @@ -235,6 +236,46 @@ public void shouldCastStringRoundUp() {
assertThat(decimal, is(new BigDecimal("1.2")));
}

@Test
public void shouldConvertInteger() {
// When:
final Schema decimal = DecimalUtil.toDecimal(Schema.OPTIONAL_INT32_SCHEMA);

// Then:
assertThat(decimal, Matchers.is(DecimalUtil.builder(10, 0).build()));
}

@Test
public void shouldConvertLong() {
// When:
final Schema decimal = DecimalUtil.toDecimal(Schema.OPTIONAL_INT64_SCHEMA);

// Then:
assertThat(decimal, Matchers.is(DecimalUtil.builder(19, 0).build()));
}

@Test
public void shouldConvertDecimal() {
// Given:
final Schema given = DecimalUtil.builder(2, 2);

// When:
final Schema decimal = DecimalUtil.toDecimal(given);

// Then:
assertThat(decimal, sameInstance(given));
}

@Test
public void shouldThrowIfConvertString() {
// Expect:
expectedException.expect(KsqlException.class);
expectedException.expectMessage("Cannot convert schema of type STRING to decimal");

// When:
DecimalUtil.toDecimal(Schema.OPTIONAL_STRING_SCHEMA);
}

@Test
public void shouldEnsureFitIfExactMatch() {
// No Exception When:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,42 @@ public void shouldResolveDecimalMod() {
});
}

@Test
public void shouldResolveDecimalLongAdd() {
final Map<PrecisionScale, PrecisionScale> inputToExpected =
ImmutableMap.<PrecisionScale, PrecisionScale>builder()
.put(PrecisionScale.of(2, 1), PrecisionScale.of(21, 1))
.put(PrecisionScale.of(3, 3), PrecisionScale.of(23, 3))
.put(PrecisionScale.of(23, 0), PrecisionScale.of(24, 0))
.build();

inputToExpected.forEach((in, out) -> {
// Given:
final Schema d1 = DecimalUtil.builder(in.precision, in.scale).build();
final Schema d2 = Schema.OPTIONAL_INT64_SCHEMA;

// When:
final Schema result = SchemaUtil.resolveBinaryOperatorResultType(d1, d2, Operator.ADD);

// Then:
assertThat(String.format("precision: %s", in), DecimalUtil.precision(result), is(out.precision));
assertThat(String.format("scale: %s", in), DecimalUtil.scale(result), is(out.scale));
});
}

@Test
public void shouldResolveDecimalDoubleMath() {
// Given:
final Schema d1 = DecimalUtil.builder(15, 10).build();
final Schema d2 = Schema.OPTIONAL_FLOAT64_SCHEMA;

// When:
final Schema result = SchemaUtil.resolveBinaryOperatorResultType(d1, d2, Operator.ADD);

// Then:
assertThat(result, is(Schema.OPTIONAL_FLOAT64_SCHEMA));
}

private static class PrecisionScale {
final int precision;
final int scale;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import io.confluent.ksql.parser.tree.Node;
import io.confluent.ksql.parser.tree.NotExpression;
import io.confluent.ksql.parser.tree.NullLiteral;
import io.confluent.ksql.parser.tree.PrimitiveType;
import io.confluent.ksql.parser.tree.QualifiedName;
import io.confluent.ksql.parser.tree.QualifiedNameReference;
import io.confluent.ksql.parser.tree.SearchedCaseExpression;
Expand Down Expand Up @@ -502,23 +503,39 @@ protected Pair<String, Schema> visitArithmeticBinary(
left.getRight(), right.getRight(), node.getOperator());

if (DecimalUtil.isDecimal(schema)) {
final String leftExpr = CastVisitor.getCast(
left,
Decimal.of(DecimalUtil.toDecimal(left.right))).getLeft();
final String rightExpr = CastVisitor.getCast(
right,
Decimal.of(DecimalUtil.toDecimal(right.right))).getLeft();

return new Pair<>(
String.format(
"(%s.%s(%s, new MathContext(%d, RoundingMode.UNNECESSARY)).setScale(%d))",
left.getLeft(),
leftExpr,
DECIMAL_OPERATOR_NAME.get(node.getOperator()),
right.getLeft(),
rightExpr,
DecimalUtil.precision(schema),
DecimalUtil.scale(schema)),
schema
);
} else {
final String leftExpr =
DecimalUtil.isDecimal(left.getRight())
? CastVisitor.getCast(left, PrimitiveType.of(SqlType.DOUBLE)).getLeft()
: left.getLeft();
final String rightExpr =
DecimalUtil.isDecimal(right.getRight())
? CastVisitor.getCast(right, PrimitiveType.of(SqlType.DOUBLE)).getLeft()
: right.getLeft();

return new Pair<>(
String.format(
"(%s %s %s)",
left.getLeft(),
leftExpr,
node.getOperator().getSymbol(),
right.getLeft()),
rightExpr),
schema
);
}
Expand Down Expand Up @@ -826,6 +843,10 @@ private static Pair<String, Schema> castDecimal(
throw new KsqlException("Expected decimal type: " + type);
}

if (DecimalUtil.isDecimal(expr.right) && Decimal.of(expr.right).equals(type)) {
return expr;
}

return new Pair<>(
getDecimalCastString(expr.getRight(), expr.getLeft(), (Decimal) type),
returnType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,38 @@ public void shouldGenerateCorrectCodeForDecimalAdd() {
assertThat(java, is("(TEST1_COL8.add(TEST1_COL8, new MathContext(3, RoundingMode.UNNECESSARY)).setScale(1))"));
}

@Test
public void shouldGenerateCastLongToDecimalInBinaryExpression() {
// Given:
final ArithmeticBinaryExpression binExp = new ArithmeticBinaryExpression(
Operator.ADD,
new QualifiedNameReference(QualifiedName.of("TEST1.COL8")),
new QualifiedNameReference(QualifiedName.of("TEST1.COL0"))
);

// When:
final String java = sqlToJavaVisitor.process(binExp);

// Then:
assertThat(java, containsString("DecimalUtil.cast(TEST1_COL0, 19, 0)"));
}

@Test
public void shouldGenerateCastDecimalToDoubleInBinaryExpression() {
// Given:
final ArithmeticBinaryExpression binExp = new ArithmeticBinaryExpression(
Operator.ADD,
new QualifiedNameReference(QualifiedName.of("TEST1.COL8")),
new QualifiedNameReference(QualifiedName.of("TEST1.COL3"))
);

// When:
final String java = sqlToJavaVisitor.process(binExp);

// Then:
assertThat(java, containsString("(TEST1_COL8).doubleValue()"));
}

@Test
public void shouldGenerateCorrectCodeForDecimalSubtract() {
// Given:
Expand Down Expand Up @@ -515,6 +547,21 @@ public void shouldGenerateCorrectCodeForDecimalCast() {
assertThat(java, is("(DecimalUtil.cast(TEST1_COL3, 2, 1))"));
}

@Test
public void shouldGenerateCorrectCodeForDecimalCastNoOp() {
// Given:
final Cast cast = new Cast(
new QualifiedNameReference(QualifiedName.of("TEST1.COL8")),
Decimal.of(2, 1)
);

// When:
final String java = sqlToJavaVisitor.process(cast);

// Then:
assertThat(java, is("TEST1_COL8"));
}

@Test
public void shouldGenerateCorrectCodeForDecimalToIntCast() {
// Given:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,40 @@
{"topic": "TEST2", "key": 0, "value": {"RESULT": "10.01"}}
]
},
{
"name": "addition with double",
"statements": [
"CREATE STREAM TEST (a DECIMAL(4,2), b DOUBLE) WITH (kafka_topic='test', value_format='AVRO');",
"CREATE STREAM TEST2 AS SELECT (a + b) AS RESULT FROM TEST;"
],
"inputs": [
{"topic": "test", "key": 0, "value": {"A": "10.01", "B": 5.1}},
{"topic": "test", "key": 0, "value": {"A": "10.01", "B": -5.0}},
{"topic": "test", "key": 0, "value": {"A": "10.01", "B": 0.0}}
],
"outputs": [
{"topic": "TEST2", "key": 0, "value": {"RESULT": 15.11}},
{"topic": "TEST2", "key": 0, "value": {"RESULT": 5.01}},
{"topic": "TEST2", "key": 0, "value": {"RESULT": 10.01}}
]
},
{
"name": "addition with int",
"statements": [
"CREATE STREAM TEST (a DECIMAL(4,2), b INT) WITH (kafka_topic='test', value_format='AVRO');",
"CREATE STREAM TEST2 AS SELECT (a + b) AS RESULT FROM TEST;"
],
"inputs": [
{"topic": "test", "key": 0, "value": {"A": "10.01", "B": 5}},
{"topic": "test", "key": 0, "value": {"A": "10.01", "B": -5}},
{"topic": "test", "key": 0, "value": {"A": "10.01", "B": 0}}
],
"outputs": [
{"topic": "TEST2", "key": 0, "value": {"RESULT": "15.01"}},
{"topic": "TEST2", "key": 0, "value": {"RESULT": "5.01"}},
{"topic": "TEST2", "key": 0, "value": {"RESULT": "10.01"}}
]
},
{
"name": "addition 3 columns",
"statements": [
Expand Down

0 comments on commit 14d2bb7

Please sign in to comment.