From 6fc4f72966dc3f1c9c6941967608e47e9cadb168 Mon Sep 17 00:00:00 2001 From: Almog Gavra Date: Wed, 29 Jan 2020 17:48:22 -0800 Subject: [PATCH] feat: support implicit casting in UDFs (#4406) --- .../io/confluent/ksql/function/UdfIndex.java | 30 ++++-- .../ksql/schema/ksql/SchemaConverters.java | 45 +++++++++ .../io/confluent/ksql/util/SchemaUtil.java | 41 ++++++-- .../confluent/ksql/function/UdfIndexTest.java | 33 +++++++ .../confluent/ksql/util/SchemaUtilTest.java | 20 ++++ .../execution/codegen/SqlToJavaVisitor.java | 62 ++++++++++--- .../codegen/SqlToJavaVisitorTest.java | 79 ++++++++++++++++ .../udf-implicit-cast.json | 93 +++++++++++++++++++ 8 files changed, 375 insertions(+), 28 deletions(-) create mode 100644 ksql-functional-tests/src/test/resources/query-validation-tests/udf-implicit-cast.json diff --git a/ksql-common/src/main/java/io/confluent/ksql/function/UdfIndex.java b/ksql-common/src/main/java/io/confluent/ksql/function/UdfIndex.java index b5d961ff0296..f97a8588903c 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/function/UdfIndex.java +++ b/ksql-common/src/main/java/io/confluent/ksql/function/UdfIndex.java @@ -32,6 +32,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Objects; +import java.util.Optional; import java.util.stream.Collectors; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; @@ -141,8 +142,21 @@ void addFunction(final T function) { T getFunction(final List arguments) { final List candidates = new ArrayList<>(); - getCandidates(arguments, 0, root, candidates, new HashMap<>()); + // first try to get the candidates without any implicit casting + getCandidates(arguments, 0, root, candidates, new HashMap<>(), false); + final Optional fun = candidates + .stream() + .max(Node::compare) + .map(node -> node.value); + + if (fun.isPresent()) { + return fun.get(); + } + + // if none were found (candidates is empty) try again with + // implicit casting + getCandidates(arguments, 0, root, candidates, new HashMap<>(), true); return candidates .stream() .max(Node::compare) @@ -155,7 +169,8 @@ private void getCandidates( final int argIndex, final Node current, final List candidates, - final Map reservedGenerics + final Map reservedGenerics, + final boolean allowCasts ) { if (argIndex == arguments.size()) { if (current.value != null) { @@ -167,9 +182,9 @@ private void getCandidates( final SqlType arg = arguments.get(argIndex); for (final Entry candidate : current.children.entrySet()) { final Map reservedCopy = new HashMap<>(reservedGenerics); - if (candidate.getKey().accepts(arg, reservedCopy)) { + if (candidate.getKey().accepts(arg, reservedCopy, allowCasts)) { final Node node = candidate.getValue(); - getCandidates(arguments, argIndex + 1, node, candidates, reservedCopy); + getCandidates(arguments, argIndex + 1, node, candidates, reservedCopy, allowCasts); } } } @@ -324,12 +339,13 @@ public int hashCode() { * @param reservedGenerics a mapping of generics to already reserved types - this map * will be updated if the parameter is generic to point to the * current argument for future checks to accept - * + * @param allowCasts whether or not to accept an implicit cast * @return whether or not this argument can be used as a value for * this parameter */ // CHECKSTYLE_RULES.OFF: BooleanExpressionComplexity - boolean accepts(final SqlType argument, final Map reservedGenerics) { + boolean accepts(final SqlType argument, final Map reservedGenerics, + final boolean allowCasts) { if (argument == null) { return true; } @@ -338,7 +354,7 @@ boolean accepts(final SqlType argument, final Map reserved return reserveGenerics(type, argument, reservedGenerics); } - return SchemaUtil.areCompatible(argument, type); + return SchemaUtil.areCompatible(argument, type, allowCasts); } // CHECKSTYLE_RULES.ON: BooleanExpressionComplexity diff --git a/ksql-common/src/main/java/io/confluent/ksql/schema/ksql/SchemaConverters.java b/ksql-common/src/main/java/io/confluent/ksql/schema/ksql/SchemaConverters.java index 88c3fac43c2e..8e89d95c8d98 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/schema/ksql/SchemaConverters.java +++ b/ksql-common/src/main/java/io/confluent/ksql/schema/ksql/SchemaConverters.java @@ -89,6 +89,8 @@ public final class SchemaConverters { private static final FunctionToSqlConverter FUNCTION_TO_SQL_CONVERTER = new FunctionToSql(); + private static final FunctionToSqlBase FUNCTION_TO_BASE_CONVERTER = new FunctionToSqlBase(); + private SchemaConverters() { } @@ -159,6 +161,10 @@ public interface FunctionToSqlConverter { SqlType toSqlType(ParamType paramType); } + public interface FunctionToSqlBaseConverter { + SqlBaseType toBaseType(ParamType paramType); + } + public static ConnectToSqlTypeConverter connectToSqlConverter() { return CONNECT_TO_SQL_CONVERTER; } @@ -183,6 +189,10 @@ public static FunctionToSqlConverter functionToSqlConverter() { return FUNCTION_TO_SQL_CONVERTER; } + public static FunctionToSqlBaseConverter functionToSqlBaseConverter() { + return FUNCTION_TO_BASE_CONVERTER; + } + private static final class ConnectToSqlConverter implements ConnectToSqlTypeConverter { private static final Map> CONNECT_TO_SQL = ImmutableMap @@ -377,6 +387,41 @@ public SqlType toSqlType(final ParamType paramType) { } } + private static class FunctionToSqlBase implements FunctionToSqlBaseConverter { + + private static final BiMap FUNCTION_TO_BASE = + ImmutableBiMap.builder() + .put(ParamTypes.STRING, SqlBaseType.STRING) + .put(ParamTypes.BOOLEAN, SqlBaseType.BOOLEAN) + .put(ParamTypes.INTEGER, SqlBaseType.INTEGER) + .put(ParamTypes.LONG, SqlBaseType.BIGINT) + .put(ParamTypes.DOUBLE, SqlBaseType.DOUBLE) + .put(ParamTypes.DECIMAL, SqlBaseType.DECIMAL) + .build(); + + @Override + public SqlBaseType toBaseType(final ParamType paramType) { + final SqlBaseType sqlType = FUNCTION_TO_BASE.get(paramType); + if (sqlType != null) { + return sqlType; + } + + if (paramType instanceof MapType) { + return SqlBaseType.MAP; + } + + if (paramType instanceof ArrayType) { + return SqlBaseType.ARRAY; + } + + if (paramType instanceof StructType) { + return SqlBaseType.STRUCT; + } + + throw new KsqlException("Cannot convert param type to sql type: " + paramType); + } + } + private static class SqlToFunction implements SqlToFunctionConverter { @Override diff --git a/ksql-common/src/main/java/io/confluent/ksql/util/SchemaUtil.java b/ksql-common/src/main/java/io/confluent/ksql/util/SchemaUtil.java index 9fd63577c0bc..376c2e1fda26 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/util/SchemaUtil.java +++ b/ksql-common/src/main/java/io/confluent/ksql/util/SchemaUtil.java @@ -15,6 +15,8 @@ package io.confluent.ksql.util; +import static io.confluent.ksql.schema.ksql.SchemaConverters.functionToSqlBaseConverter; + import com.google.common.collect.ImmutableSet; import io.confluent.ksql.function.types.ArrayType; import io.confluent.ksql.function.types.BooleanType; @@ -152,14 +154,26 @@ public static Schema ensureOptional(final Schema schema) { } public static boolean areCompatible(final SqlType actual, final ParamType declared) { + return areCompatible(actual, declared, false); + } + + public static boolean areCompatible( + final SqlType actual, + final ParamType declared, + final boolean allowCast + ) { if (actual.baseType() == SqlBaseType.ARRAY && declared instanceof ArrayType) { - return areCompatible(((SqlArray) actual).getItemType(), ((ArrayType) declared).element()); + return areCompatible( + ((SqlArray) actual).getItemType(), + ((ArrayType) declared).element(), + allowCast); } if (actual.baseType() == SqlBaseType.MAP && declared instanceof MapType) { return areCompatible( ((SqlMap) actual).getValueType(), - ((MapType) declared).value() + ((MapType) declared).value(), + allowCast ); } @@ -167,7 +181,7 @@ public static boolean areCompatible(final SqlType actual, final ParamType declar return isStructCompatible(actual, declared); } - return isPrimitiveMatch(actual, declared); + return isPrimitiveMatch(actual, declared, allowCast); } private static boolean isStructCompatible(final SqlType actual, final ParamType declared) { @@ -181,6 +195,7 @@ private static boolean isStructCompatible(final SqlType actual, final ParamType for (final Entry entry : ((StructType) declared).getSchema().entrySet()) { final String k = entry.getKey(); final Optional field = actualStruct.field(k); + // intentionally do not allow implicit casting within structs if (!field.isPresent() || !areCompatible(field.get().type(), entry.getValue())) { return false; } @@ -189,15 +204,21 @@ private static boolean isStructCompatible(final SqlType actual, final ParamType } // CHECKSTYLE_RULES.OFF: CyclomaticComplexity - private static boolean isPrimitiveMatch(final SqlType actual, final ParamType declared) { + private static boolean isPrimitiveMatch( + final SqlType actual, + final ParamType declared, + final boolean allowCast + ) { // CHECKSTYLE_RULES.ON: CyclomaticComplexity // CHECKSTYLE_RULES.OFF: BooleanExpressionComplexity - return actual.baseType() == SqlBaseType.STRING && declared instanceof StringType - || actual.baseType() == SqlBaseType.INTEGER && declared instanceof IntegerType - || actual.baseType() == SqlBaseType.BIGINT && declared instanceof LongType - || actual.baseType() == SqlBaseType.BOOLEAN && declared instanceof BooleanType - || actual.baseType() == SqlBaseType.DOUBLE && declared instanceof DoubleType - || actual.baseType() == SqlBaseType.DECIMAL && declared instanceof DecimalType; + final SqlBaseType base = actual.baseType(); + return base == SqlBaseType.STRING && declared instanceof StringType + || base == SqlBaseType.INTEGER && declared instanceof IntegerType + || base == SqlBaseType.BIGINT && declared instanceof LongType + || base == SqlBaseType.BOOLEAN && declared instanceof BooleanType + || base == SqlBaseType.DOUBLE && declared instanceof DoubleType + || base == SqlBaseType.DECIMAL && declared instanceof DecimalType + || allowCast && base.canImplicitlyCast(functionToSqlBaseConverter().toBaseType(declared)); // CHECKSTYLE_RULES.ON: BooleanExpressionComplexity } } \ No newline at end of file diff --git a/ksql-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java b/ksql-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java index fac3c7967733..1bc9433263d0 100644 --- a/ksql-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java +++ b/ksql-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java @@ -34,6 +34,8 @@ public class UdfIndexTest { private static final ParamType STRING = ParamTypes.STRING; private static final ParamType DECIMAL = ParamTypes.DECIMAL; private static final ParamType INT = ParamTypes.INTEGER; + private static final ParamType LONG = ParamTypes.LONG; + private static final ParamType DOUBLE = ParamTypes.DOUBLE; private static final ParamType STRUCT1 = StructType.builder().field("a", STRING).build(); private static final ParamType STRUCT2 = StructType.builder().field("b", INT).build(); private static final ParamType MAP1 = MapType.of(STRING); @@ -88,6 +90,37 @@ public void shouldFindOneArg() { assertThat(fun.name(), equalTo(EXPECTED)); } + @Test + public void shouldFindOneArgWithCast() { + // Given: + final KsqlScalarFunction[] functions = new KsqlScalarFunction[]{ + function(EXPECTED, false, LONG)}; + Arrays.stream(functions).forEach(udfIndex::addFunction); + + // When: + final KsqlScalarFunction fun = udfIndex.getFunction(ImmutableList.of(SqlTypes.INTEGER)); + + // Then: + assertThat(fun.name(), equalTo(EXPECTED)); + } + + @Test + public void shouldFindPreferredOneArgWithCast() { + // Given: + final KsqlScalarFunction[] functions = new KsqlScalarFunction[]{ + function(OTHER, false, LONG), + function(EXPECTED, false, INT), + function(OTHER, false, DOUBLE) + }; + Arrays.stream(functions).forEach(udfIndex::addFunction); + + // When: + final KsqlScalarFunction fun = udfIndex.getFunction(ImmutableList.of(SqlTypes.INTEGER)); + + // Then: + assertThat(fun.name(), equalTo(EXPECTED)); + } + @Test public void shouldFindTwoDifferentArgs() { // Given: diff --git a/ksql-common/src/test/java/io/confluent/ksql/util/SchemaUtilTest.java b/ksql-common/src/test/java/io/confluent/ksql/util/SchemaUtilTest.java index 1daba34baeaa..45f56fcc5498 100644 --- a/ksql-common/src/test/java/io/confluent/ksql/util/SchemaUtilTest.java +++ b/ksql-common/src/test/java/io/confluent/ksql/util/SchemaUtilTest.java @@ -302,7 +302,27 @@ public void shouldPassCompatibleSchemas() { SqlTypes.map(SqlTypes.decimal(1, 1)), MapType.of(ParamTypes.DECIMAL)), is(true)); + } + + @Test + public void shouldPassCompatibleSchemasWithImplicitCasting() { + assertThat(SchemaUtil.areCompatible(SqlTypes.INTEGER, ParamTypes.LONG, true), is(true)); + assertThat(SchemaUtil.areCompatible(SqlTypes.INTEGER, ParamTypes.DOUBLE, true), is(true)); + assertThat(SchemaUtil.areCompatible(SqlTypes.INTEGER, ParamTypes.DECIMAL, true), is(true)); + + assertThat(SchemaUtil.areCompatible(SqlTypes.BIGINT, ParamTypes.DOUBLE, true), is(true)); + assertThat(SchemaUtil.areCompatible(SqlTypes.BIGINT, ParamTypes.DECIMAL, true), is(true)); + + assertThat(SchemaUtil.areCompatible(SqlTypes.decimal(2, 1), ParamTypes.DOUBLE, true), is(true)); + } + + @Test + public void shouldNotPassInCompatibleSchemasWithImplicitCasting() { + assertThat(SchemaUtil.areCompatible(SqlTypes.BIGINT, ParamTypes.INTEGER, true), is(false)); + + assertThat(SchemaUtil.areCompatible(SqlTypes.DOUBLE, ParamTypes.LONG, true), is(false)); + assertThat(SchemaUtil.areCompatible(SqlTypes.DOUBLE, ParamTypes.DECIMAL, true), is(false)); } @Test diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java index a7d2d6c99cda..ea1de749638c 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java @@ -15,12 +15,14 @@ package io.confluent.ksql.execution.codegen; +import static io.confluent.ksql.schema.ksql.SchemaConverters.sqlToFunctionConverter; import static java.lang.String.format; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.HashMultiset; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; import com.google.common.collect.Multiset; import io.confluent.ksql.execution.codegen.helpers.ArrayAccess; import io.confluent.ksql.execution.codegen.helpers.ArrayBuilder; @@ -63,8 +65,13 @@ import io.confluent.ksql.execution.expression.tree.WhenClause; import io.confluent.ksql.execution.util.ExpressionTypeManager; import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.function.GenericsUtil; +import io.confluent.ksql.function.KsqlFunction; import io.confluent.ksql.function.KsqlFunctionException; import io.confluent.ksql.function.UdfFactory; +import io.confluent.ksql.function.types.ArrayType; +import io.confluent.ksql.function.types.ParamType; +import io.confluent.ksql.function.types.ParamTypes; import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.schema.Operator; import io.confluent.ksql.schema.ksql.Column; @@ -88,6 +95,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Objects; +import java.util.StringJoiner; import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -352,24 +360,56 @@ public Pair visitFunctionCall(final FunctionCall node, final Vo final String instanceName = funNameToCodeName.apply(functionName); - final SqlType functionReturnSchema = getFunctionReturnSchema(node); + final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName()); + final List argumentSchemas = node.getArguments().stream() + .map(expressionTypeManager::getExpressionSqlType) + .collect(Collectors.toList()); + + final KsqlFunction function = udfFactory.getFunction(argumentSchemas); + + final SqlType functionReturnSchema = function.getReturnType(argumentSchemas); final String javaReturnType = SchemaConverters.sqlToJavaConverter().toJavaType(functionReturnSchema).getSimpleName(); - final String arguments = node.getArguments().stream() - .map(arg -> process(arg, context).getLeft()) - .collect(Collectors.joining(", ")); + + final List arguments = node.getArguments(); + + final StringJoiner joiner = new StringJoiner(", "); + for (int i = 0; i < arguments.size(); i++) { + final Expression arg = arguments.get(i); + final SqlType sqlType = argumentSchemas.get(i); + + final ParamType paramType; + if (i >= function.parameters().size() - 1 && function.isVariadic()) { + paramType = ((ArrayType) Iterables.getLast(function.parameters())).element(); + } else { + paramType = function.parameters().get(i); + } + + joiner.add(process(convertArgument(arg, sqlType, paramType), context).getLeft()); + } + + + final String argumentsString = joiner.toString(); final String codeString = "((" + javaReturnType + ") " + instanceName - + ".evaluate(" + arguments + "))"; + + ".evaluate(" + argumentsString + "))"; return new Pair<>(codeString, functionReturnSchema); } - private SqlType getFunctionReturnSchema(final FunctionCall node) { - final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName()); - final List argumentSchemas = node.getArguments().stream() - .map(expressionTypeManager::getExpressionSqlType) - .collect(Collectors.toList()); + private Expression convertArgument( + final Expression argument, + final SqlType argType, + final ParamType funType + ) { + if (argType == null + || GenericsUtil.hasGenerics(funType) + || sqlToFunctionConverter().toFunctionType(argType).equals(funType)) { + return argument; + } - return udfFactory.getFunction(argumentSchemas).getReturnType(argumentSchemas); + final SqlType target = funType == ParamTypes.DECIMAL + ? DecimalUtil.toSqlDecimal(argType) + : SchemaConverters.functionToSqlConverter().toSqlType(funType); + return new Cast(argument, new Type(target)); } @Override diff --git a/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java b/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java index 03d1c28c2151..e8ddfbb94010 100644 --- a/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java +++ b/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java @@ -42,6 +42,7 @@ import io.confluent.ksql.execution.expression.tree.CreateMapExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression; import io.confluent.ksql.execution.expression.tree.CreateStructExpression.Field; +import io.confluent.ksql.execution.expression.tree.DecimalLiteral; import io.confluent.ksql.execution.expression.tree.DoubleLiteral; import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.FunctionCall; @@ -62,6 +63,9 @@ import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.function.KsqlScalarFunction; import io.confluent.ksql.function.UdfFactory; +import io.confluent.ksql.function.types.ArrayType; +import io.confluent.ksql.function.types.GenericType; +import io.confluent.ksql.function.types.ParamTypes; import io.confluent.ksql.function.udf.UdfMetadata; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.FunctionName; @@ -71,6 +75,7 @@ import io.confluent.ksql.schema.ksql.types.SqlDecimal; import io.confluent.ksql.schema.ksql.types.SqlPrimitiveType; import io.confluent.ksql.schema.ksql.types.SqlTypes; +import java.math.BigDecimal; import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Before; @@ -244,7 +249,11 @@ public void shouldPostfixFunctionInstancesWithUniqueId() { final UdfFactory catFactory = mock(UdfFactory.class); final KsqlScalarFunction catFunction = mock(KsqlScalarFunction.class); givenUdf("SUBSTRING", ssFactory, ssFunction); + when(ssFunction.parameters()) + .thenReturn(ImmutableList.of(ParamTypes.STRING, ParamTypes.INTEGER, ParamTypes.INTEGER)); givenUdf("CONCAT", catFactory, catFunction); + when(catFunction.parameters()) + .thenReturn(ImmutableList.of(ParamTypes.STRING, ParamTypes.STRING)); final FunctionName ssName = FunctionName.of("SUBSTRING"); final FunctionName catName = FunctionName.of("CONCAT"); final FunctionCall substring1 = new FunctionCall( @@ -275,6 +284,76 @@ public void shouldPostfixFunctionInstancesWithUniqueId() { + " ((String) SUBSTRING_3.evaluate(COL1, 4, 5))))))")); } + @Test + public void shouldImplicitlyCastFunctionCallParameters() { + // Given: + final UdfFactory udfFactory = mock(UdfFactory.class); + final KsqlScalarFunction udf = mock(KsqlScalarFunction.class); + givenUdf("FOO", udfFactory, udf); + when(udf.parameters()).thenReturn(ImmutableList.of(ParamTypes.DOUBLE, ParamTypes.LONG)); + + // When: + final String javaExpression = sqlToJavaVisitor.process( + new FunctionCall( + FunctionName.of("FOO"), + ImmutableList.of(new DecimalLiteral(new BigDecimal("1.2")), new IntegerLiteral(1)) + ) + ); + + // Then: + assertThat(javaExpression, is( + "((String) FOO_0.evaluate(((new BigDecimal(\"1.2\")).doubleValue()), (new Integer(1).longValue())))" + )); + } + + @Test + public void shouldImplicitlyCastFunctionCallParametersVariadic() { + // Given: + final UdfFactory udfFactory = mock(UdfFactory.class); + final KsqlScalarFunction udf = mock(KsqlScalarFunction.class); + givenUdf("FOO", udfFactory, udf); + when(udf.parameters()).thenReturn(ImmutableList.of(ParamTypes.DOUBLE, ArrayType.of(ParamTypes.LONG))); + when(udf.isVariadic()).thenReturn(true); + + // When: + final String javaExpression = sqlToJavaVisitor.process( + new FunctionCall( + FunctionName.of("FOO"), + ImmutableList.of( + new DecimalLiteral(new BigDecimal("1.2")), + new IntegerLiteral(1), + new IntegerLiteral(1)) + ) + ); + + // Then: + assertThat(javaExpression, is( + "((String) FOO_0.evaluate(((new BigDecimal(\"1.2\")).doubleValue()), (new Integer(1).longValue()), (new Integer(1).longValue())))" + )); + } + + @Test + public void shouldHandleFunctionCallsWithGenerics() { + // Given: + final UdfFactory udfFactory = mock(UdfFactory.class); + final KsqlScalarFunction udf = mock(KsqlScalarFunction.class); + givenUdf("FOO", udfFactory, udf); + when(udf.parameters()).thenReturn(ImmutableList.of(GenericType.of("T"), GenericType.of("T"))); + + // When: + final String javaExpression = sqlToJavaVisitor.process( + new FunctionCall( + FunctionName.of("FOO"), + ImmutableList.of( + new IntegerLiteral(1), + new IntegerLiteral(1)) + ) + ); + + // Then: + assertThat(javaExpression, is("((String) FOO_0.evaluate(1, 1))")); + } + @Test public void shouldEscapeQuotesInStringLiteral() { // Given: diff --git a/ksql-functional-tests/src/test/resources/query-validation-tests/udf-implicit-cast.json b/ksql-functional-tests/src/test/resources/query-validation-tests/udf-implicit-cast.json new file mode 100644 index 000000000000..b74a34f927a1 --- /dev/null +++ b/ksql-functional-tests/src/test/resources/query-validation-tests/udf-implicit-cast.json @@ -0,0 +1,93 @@ +{ + "tests": [ + { + "name": "int literal -> double", + "statements": [ + "CREATE STREAM TEST (ID bigint, LAT1 double, LON1 double) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM DISTANCE_STREAM AS select ID, geo_distance(LAT1, LON1, 51, 0) as CALCULATED_DISTANCE from test;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"ID": 1, "LAT1": 37.4439, "LON1": -122.1663}} + ], + "outputs": [ + {"topic": "DISTANCE_STREAM", "value": {"ID": 1, "CALCULATED_DISTANCE": 8682.459061368269}} + ] + }, + { + "name": "int field -> double", + "statements": [ + "CREATE STREAM TEST (ID bigint, LAT1 double, LON1 double, LAT2 int) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM DISTANCE_STREAM AS select ID, geo_distance(LAT1, LON1, LAT2, 0) as CALCULATED_DISTANCE from test;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"ID": 1, "LAT1": 37.4439, "LON1": -122.1663, "LAT2": 51}} + ], + "outputs": [ + {"topic": "DISTANCE_STREAM", "value": {"ID": 1, "CALCULATED_DISTANCE": 8682.459061368269}} + ] + }, + { + "name": "long field -> double", + "statements": [ + "CREATE STREAM TEST (ID bigint, LAT1 double, LON1 double, LAT2 bigint) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM DISTANCE_STREAM AS select ID, geo_distance(LAT1, LON1, LAT2, 0) as CALCULATED_DISTANCE from test;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"ID": 1, "LAT1": 37.4439, "LON1": -122.1663, "LAT2": 51}} + ], + "outputs": [ + {"topic": "DISTANCE_STREAM", "value": {"ID": 1, "CALCULATED_DISTANCE": 8682.459061368269}} + ] + }, + { + "name": "decimal literal -> double", + "statements": [ + "CREATE STREAM TEST (ID bigint, LAT1 double, LON1 double) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM DISTANCE_STREAM AS select ID, geo_distance(LAT1, LON1, 51.0, 0) as CALCULATED_DISTANCE from test;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"ID": 1, "LAT1": 37.4439, "LON1": -122.1663}} + ], + "outputs": [ + {"topic": "DISTANCE_STREAM", "value": {"ID": 1, "CALCULATED_DISTANCE": 8682.459061368269}} + ] + }, + { + "name": "decimal field -> double", + "statements": [ + "CREATE STREAM TEST (ID bigint, LAT1 double, LON1 double, LAT2 decimal(3, 1)) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM DISTANCE_STREAM AS select ID, geo_distance(LAT1, LON1, LAT2, 0) as CALCULATED_DISTANCE from test;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"ID": 1, "LAT1": 37.4439, "LON1": -122.1663, "LAT2": 51.0}} + ], + "outputs": [ + {"topic": "DISTANCE_STREAM", "value": {"ID": 1, "CALCULATED_DISTANCE": 8682.459061368269}} + ] + }, + { + "name": "choose the exact match first", + "statements": [ + "CREATE STREAM TEST (ID int) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS select ID, test_udf(ID, 'foo') as foo from test;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"ID": 1}} + ], + "outputs": [ + {"topic": "OUTPUT", "value": {"ID": 1, "FOO": "doStuffIntString"}} + ] + }, + { + "name": "string literal -> double", + "statements": [ + "CREATE STREAM TEST (ID bigint, LAT1 double, LON1 double) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM DISTANCE_STREAM AS select ID, geo_distance(LAT1, LON1, 'foo', 0) as CALCULATED_DISTANCE from test;" + ], + "expectedException": { + "type": "io.confluent.ksql.util.KsqlStatementException", + "message": "Function 'geo_distance' does not accept parameters" + } + } + ] +} \ No newline at end of file