diff --git a/docs/developer-guide/udf.rst b/docs/developer-guide/udf.rst index 5b95f3a335e6..6739a28784ff 100644 --- a/docs/developer-guide/udf.rst +++ b/docs/developer-guide/udf.rst @@ -52,10 +52,10 @@ Follow these steps to create your custom functions: For a detailed walkthrough on creating a UDF, see :ref:`implement-a-udf`. ====================== -Creating UDF and UDAFs +Creating UDFs and UDAFs ====================== -KSQL supports creating User Defined Scalar Functions (UDFs) and User Defined Aggregate Functions (UDAF) via custom jars that are +KSQL supports creating User Defined Scalar Functions (UDFs) and User Defined Aggregate Functions (UDAFs) via custom jars that are uploaded to the ``ext/`` directory of the KSQL installation. At start up time KSQL scans the jars in the directory looking for any classes that annotated with ``@UdfDescription`` (UDF) or ``@UdafDescription`` (UDAF). @@ -104,6 +104,20 @@ The KSQL server will check the value being passed to each parameter and report a log for any null values being passed to a primitive type. The associated column in the output row will be ``null``. + +Dynamic return type +~~~~~~~~~~~~~~~~~~~ + +UDFs support dynamic return types that are resolved at runtime. This is useful if you want to +implement a UDF with a non-deterministic return type. A UDF which returns ``BigDecimal``, +for example, may vary the precision and scale of the output based on the input schema. + +To use this functionality, you need to specify a method with signature +``public Schema (final List params)`` and annotate it with ``@SchemaProvider``. +Also, you need to link it to the corresponding UDF by using the ``schemaProvider=`` +parameter of the ``@Udf`` annotation. + + Generics in UDFS ~~~~~~~~~~~~~~~~ diff --git a/ksql-common/src/main/java/io/confluent/ksql/function/KsqlFunction.java b/ksql-common/src/main/java/io/confluent/ksql/function/KsqlFunction.java index 842c86d4e51c..c3b5928f2701 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/function/KsqlFunction.java +++ b/ksql-common/src/main/java/io/confluent/ksql/function/KsqlFunction.java @@ -20,6 +20,7 @@ import io.confluent.ksql.function.udf.Kudf; import io.confluent.ksql.util.KsqlConfig; import io.confluent.ksql.util.KsqlException; +import io.confluent.ksql.util.SchemaUtil; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -36,7 +37,8 @@ public final class KsqlFunction implements IndexedFunction { static final String INTERNAL_PATH = "internal"; - private final Schema returnType; + private final Function,Schema> returnSchemaProvider; + private final Schema javaReturnType; private final List parameters; private final String functionName; private final Class kudfClass; @@ -44,7 +46,43 @@ public final class KsqlFunction implements IndexedFunction { private final String description; private final String pathLoadedFrom; private final boolean isVariadic; - private final boolean hasGenerics; + + private KsqlFunction( + final Function,Schema> returnSchemaProvider, + final Schema javaReturnType, + final List arguments, + final String functionName, + final Class kudfClass, + final Function udfFactory, + final String description, + final String pathLoadedFrom, + final boolean isVariadic) { + + this.returnSchemaProvider = Objects.requireNonNull(returnSchemaProvider, "schemaProvider"); + this.javaReturnType = Objects.requireNonNull(javaReturnType, "javaReturnType"); + this.parameters = ImmutableList.copyOf(Objects.requireNonNull(arguments, "arguments")); + this.functionName = Objects.requireNonNull(functionName, "functionName"); + this.kudfClass = Objects.requireNonNull(kudfClass, "kudfClass"); + this.udfFactory = Objects.requireNonNull(udfFactory, "udfFactory"); + this.description = Objects.requireNonNull(description, "description"); + this.pathLoadedFrom = Objects.requireNonNull(pathLoadedFrom, "pathLoadedFrom"); + this.isVariadic = isVariadic; + + + if (arguments.stream().anyMatch(Objects::isNull)) { + throw new IllegalArgumentException("KSQL Function can't have null argument types"); + } + if (isVariadic) { + if (arguments.isEmpty()) { + throw new IllegalArgumentException( + "KSQL variadic functions must have at least one parameter"); + } + if (!Iterables.getLast(arguments).type().equals(Type.ARRAY)) { + throw new IllegalArgumentException( + "KSQL variadic functions must have ARRAY type as their last parameter"); + } + } + } /** * Create built in / legacy function. @@ -66,7 +104,8 @@ public static KsqlFunction createLegacyBuiltIn( }; return create( - returnType, arguments, functionName, kudfClass, udfFactory, "", INTERNAL_PATH, false); + ignored -> returnType, returnType, arguments, functionName, kudfClass, udfFactory, "", + INTERNAL_PATH, false); } /** @@ -75,7 +114,8 @@ public static KsqlFunction createLegacyBuiltIn( *

Can be either built-in UDF or true user-supplied. */ static KsqlFunction create( - final Schema returnType, + final Function,Schema> schemaProvider, + final Schema javaReturnType, final List arguments, final String functionName, final Class kudfClass, @@ -85,7 +125,8 @@ static KsqlFunction create( final boolean isVariadic ) { return new KsqlFunction( - returnType, + schemaProvider, + javaReturnType, arguments, functionName, kudfClass, @@ -95,47 +136,20 @@ static KsqlFunction create( isVariadic); } - private KsqlFunction( - final Schema returnType, - final List arguments, - final String functionName, - final Class kudfClass, - final Function udfFactory, - final String description, - final String pathLoadedFrom, - final boolean isVariadic) { - this.returnType = Objects.requireNonNull(returnType, "returnType"); - this.parameters = ImmutableList.copyOf(Objects.requireNonNull(arguments, "arguments")); - this.functionName = Objects.requireNonNull(functionName, "functionName"); - this.kudfClass = Objects.requireNonNull(kudfClass, "kudfClass"); - this.udfFactory = Objects.requireNonNull(udfFactory, "udfFactory"); - this.description = Objects.requireNonNull(description, "description"); - this.pathLoadedFrom = Objects.requireNonNull(pathLoadedFrom, "pathLoadedFrom"); - this.isVariadic = isVariadic; - this.hasGenerics = GenericsUtil.hasGenerics(returnType); + public Schema getReturnType(final List arguments) { - if (arguments.stream().anyMatch(Objects::isNull)) { - throw new IllegalArgumentException("KSQL Function can't have null argument types"); - } - if (isVariadic) { - if (arguments.isEmpty()) { - throw new IllegalArgumentException( - "KSQL variadic functions must have at least one parameter"); - } - if (!Iterables.getLast(arguments).type().equals(Type.ARRAY)) { - throw new IllegalArgumentException( - "KSQL variadic functions must have ARRAY type as their last parameter"); - } + final Schema returnType = returnSchemaProvider.apply(arguments); + + if (returnType == null) { + throw new KsqlException(String.format("Return type of UDF %s cannot be null.", functionName)); } if (!returnType.isOptional()) { throw new IllegalArgumentException("KSQL only supports optional field types"); } - } - - public Schema getReturnType(final List arguments) { - if (!hasGenerics) { + if (!GenericsUtil.hasGenerics(returnType)) { + checkMatchingReturnTypes(returnType, javaReturnType); return returnType; } @@ -152,7 +166,21 @@ public Schema getReturnType(final List arguments) { genericMapping.putAll(GenericsUtil.resolveGenerics(schema, instance)); } - return GenericsUtil.applyResolved(returnType, genericMapping); + final Schema genericSchema = GenericsUtil.applyResolved(returnType, genericMapping); + final Schema genericJavaSchema = GenericsUtil.applyResolved(javaReturnType, genericMapping); + checkMatchingReturnTypes(genericSchema, genericJavaSchema); + + return genericSchema; + } + + private void checkMatchingReturnTypes(final Schema s1, final Schema s2) { + if (!SchemaUtil.areCompatible(s1, s2)) { + throw new KsqlException(String.format("Return type %s of UDF %s does not match the declared " + + "return type %s.", + s1.toString(), + functionName, + s2.toString())); + } } public List getArguments() { @@ -188,7 +216,7 @@ public boolean equals(final Object o) { return false; } final KsqlFunction that = (KsqlFunction) o; - return Objects.equals(returnType, that.returnType) + return Objects.equals(javaReturnType, that.javaReturnType) && Objects.equals(parameters, that.parameters) && Objects.equals(functionName, that.functionName) && Objects.equals(kudfClass, that.kudfClass) @@ -199,13 +227,13 @@ public boolean equals(final Object o) { @Override public int hashCode() { return Objects.hash( - returnType, parameters, functionName, kudfClass, pathLoadedFrom, isVariadic); + returnSchemaProvider, parameters, functionName, kudfClass, pathLoadedFrom, isVariadic); } @Override public String toString() { return "KsqlFunction{" - + "returnType=" + returnType + + "returnType=" + javaReturnType + ", arguments=" + parameters.stream().map(Schema::type).collect(Collectors.toList()) + ", functionName='" + functionName + '\'' + ", kudfClass=" + kudfClass diff --git a/ksql-common/src/main/java/io/confluent/ksql/function/UdfIndex.java b/ksql-common/src/main/java/io/confluent/ksql/function/UdfIndex.java index 11a0d7db6b05..a2c78fa3638e 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/function/UdfIndex.java +++ b/ksql-common/src/main/java/io/confluent/ksql/function/UdfIndex.java @@ -16,11 +16,11 @@ package io.confluent.ksql.function; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import io.confluent.ksql.schema.connect.SqlSchemaFormatter; import io.confluent.ksql.util.DecimalUtil; import io.confluent.ksql.util.KsqlException; +import io.confluent.ksql.util.SchemaUtil; import java.util.ArrayList; import java.util.Collection; import java.util.Comparator; @@ -29,11 +29,9 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Objects; -import java.util.function.BiPredicate; import java.util.stream.Collectors; import org.apache.commons.lang3.StringUtils; import org.apache.kafka.connect.data.Schema; -import org.apache.kafka.connect.data.Schema.Type; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -258,14 +256,6 @@ int compare(final Node other) { */ static final class Parameter { - private static final Map> CUSTOM_SCHEMA_EQ = - ImmutableMap.>builder() - .put(Type.MAP, Parameter::mapEquals) - .put(Type.ARRAY, Parameter::arrayEquals) - .put(Type.STRUCT, Parameter::structEquals) - .put(Type.BYTES, Parameter::bytesEquals) - .build(); - private final Schema schema; private final boolean isVararg; @@ -311,14 +301,7 @@ boolean accepts(final Schema argument, final Map reservedGeneric return reserveGenerics(schema, argument, reservedGenerics); } - final Schema.Type type = schema.type(); - - // we require a custom equals method that ignores certain values (e.g. - // whether or not the schema is optional, and the documentation) - return Objects.equals(type, argument.type()) - && CUSTOM_SCHEMA_EQ.getOrDefault(type, (a, b) -> true).test(schema, argument) - && Objects.equals(schema.version(), argument.version()) - && Objects.deepEquals(schema.defaultValue(), argument.defaultValue()); + return SchemaUtil.areCompatible(schema, argument); } // CHECKSTYLE_RULES.ON: BooleanExpressionComplexity diff --git a/ksql-common/src/main/java/io/confluent/ksql/util/SchemaUtil.java b/ksql-common/src/main/java/io/confluent/ksql/util/SchemaUtil.java index 37d060f617c3..d130efc3bd60 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/util/SchemaUtil.java +++ b/ksql-common/src/main/java/io/confluent/ksql/util/SchemaUtil.java @@ -31,12 +31,15 @@ import java.util.List; import java.util.Map; import java.util.NavigableMap; +import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.function.BiPredicate; import org.apache.avro.LogicalTypes; import org.apache.avro.SchemaBuilder.FieldAssembler; import org.apache.kafka.connect.data.Field; import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.Schema.Type; import org.apache.kafka.connect.data.SchemaBuilder; import org.apache.kafka.connect.data.Struct; @@ -97,6 +100,15 @@ public final class SchemaUtil { .put(Schema.Type.BOOLEAN, "(Boolean)") .build(); + private static final Map> CUSTOM_SCHEMA_EQ = + ImmutableMap.>builder() + .put(Type.MAP, SchemaUtil::mapEquals) + .put(Type.ARRAY, SchemaUtil::arrayEquals) + .put(Type.STRUCT, SchemaUtil::structEquals) + .put(Type.BYTES, SchemaUtil::bytesEquals) + .build(); + + private SchemaUtil() { } @@ -363,4 +375,40 @@ public static Schema ensureOptional(final Schema schema) { .build(); } + + public static boolean areCompatible(final Schema arg1, final Schema arg2) { + if (arg2 == null) { + return arg1.isOptional(); + } + + // we require a custom equals method that ignores certain values (e.g. + // whether or not the schema is optional, and the documentation) + return Objects.equals(arg1.type(), arg2.type()) + && CUSTOM_SCHEMA_EQ.getOrDefault(arg1.type(), (a, b) -> true).test(arg1, arg2) + && Objects.equals(arg1.version(), arg2.version()) + && Objects.deepEquals(arg1.defaultValue(), arg2.defaultValue()); + } + + private static boolean mapEquals(final Schema mapA, final Schema mapB) { + return Objects.equals(mapA.keySchema(), mapB.keySchema()) + && Objects.equals(mapA.valueSchema(), mapB.valueSchema()); + } + + private static boolean arrayEquals(final Schema arrayA, final Schema arrayB) { + return Objects.equals(arrayA.valueSchema(), arrayB.valueSchema()); + } + + private static boolean structEquals(final Schema structA, final Schema structB) { + return structA.fields().isEmpty() + || structB.fields().isEmpty() + || Objects.equals(structA.fields(), structB.fields()); + } + + private static boolean bytesEquals(final Schema bytesA, final Schema bytesB) { + // from a Java schema perspective, all decimals are the same + // since they can all be cast to BigDecimal - other bytes types + // are not supported in UDFs + return DecimalUtil.isDecimal(bytesA) && DecimalUtil.isDecimal(bytesB); + } + } \ No newline at end of file diff --git a/ksql-common/src/test/java/io/confluent/ksql/function/KsqlFunctionTest.java b/ksql-common/src/test/java/io/confluent/ksql/function/KsqlFunctionTest.java index a82b8721d228..e50d6541eab3 100644 --- a/ksql-common/src/test/java/io/confluent/ksql/function/KsqlFunctionTest.java +++ b/ksql-common/src/test/java/io/confluent/ksql/function/KsqlFunctionTest.java @@ -20,6 +20,7 @@ import com.google.common.collect.ImmutableList; import io.confluent.ksql.function.udf.Kudf; +import io.confluent.ksql.util.DecimalUtil; import io.confluent.ksql.util.KsqlConfig; import java.util.List; import java.util.function.Function; @@ -131,7 +132,35 @@ public void shouldThrowOnNonOptionalReturnType() { expectedException.expectMessage("KSQL only supports optional field types"); // When: - createFunction(Schema.INT32_SCHEMA, ImmutableList.of()); + final KsqlFunction function = createFunction(Schema.INT32_SCHEMA, ImmutableList.of()); + function.getReturnType(ImmutableList.of()); + + } + + @Test + public void shouldResolveSchemaProvider() { + // Given: + final Schema decimalSchema = DecimalUtil.builder(2,1).build(); + final Function, Schema> schemaProviderFunction = args -> { + return decimalSchema; + }; + + final KsqlFunction udf = KsqlFunction.create( + schemaProviderFunction, + decimalSchema, + ImmutableList.of(Schema.INT32_SCHEMA), + "funcName", + MyUdf.class, + udfFactory, + "the description", + "path/udf/loaded/from.jar", + false); + + // When: + final Schema returnType = udf.getReturnType(ImmutableList.of(Schema.INT32_SCHEMA)); + + // Then: + assertThat(returnType, is(decimalSchema)); } private KsqlFunction createFunction(final Schema returnSchema, final List args) { @@ -144,6 +173,7 @@ private KsqlFunction createFunction( final boolean isVariadic ) { return KsqlFunction.create( + ignored -> returnSchema, returnSchema, args, "funcName", diff --git a/ksql-common/src/test/java/io/confluent/ksql/function/UdfFactoryTest.java b/ksql-common/src/test/java/io/confluent/ksql/function/UdfFactoryTest.java index 0d6866144982..088435ebc625 100644 --- a/ksql-common/src/test/java/io/confluent/ksql/function/UdfFactoryTest.java +++ b/ksql-common/src/test/java/io/confluent/ksql/function/UdfFactoryTest.java @@ -47,6 +47,7 @@ public void shouldThrowExceptionIfAddingFunctionWithDifferentPath() { expectedException.expect(KafkaException.class); expectedException.expectMessage("as a function with the same name has been loaded from a different jar"); factory.addFunction(KsqlFunction.create( + ignored -> Schema.OPTIONAL_STRING_SCHEMA, Schema.OPTIONAL_STRING_SCHEMA, Collections.emptyList(), "TestFunc", diff --git a/ksql-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java b/ksql-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java index addf7329ca8c..1548b944a76b 100644 --- a/ksql-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java +++ b/ksql-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java @@ -742,6 +742,7 @@ private static KsqlFunction function( }; return KsqlFunction.create( + ignored -> Schema.OPTIONAL_STRING_SCHEMA, Schema.OPTIONAL_STRING_SCHEMA, Arrays.asList(args), name, diff --git a/ksql-common/src/test/java/io/confluent/ksql/util/SchemaUtilTest.java b/ksql-common/src/test/java/io/confluent/ksql/util/SchemaUtilTest.java index 3879916d69ec..0cb1066dafec 100644 --- a/ksql-common/src/test/java/io/confluent/ksql/util/SchemaUtilTest.java +++ b/ksql-common/src/test/java/io/confluent/ksql/util/SchemaUtilTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2018 Confluent Inc. + * Copyright 2019 Confluent Inc. * * Licensed under the Confluent Community License (the "License"); you may not use * this file except in compliance with the License. You may obtain a copy of the @@ -21,6 +21,7 @@ import static org.hamcrest.Matchers.notNullValue; import com.google.common.collect.ImmutableMap; +import io.confluent.ksql.function.GenericsUtil; import io.confluent.ksql.schema.Operator; import io.confluent.ksql.schema.ksql.PersistenceSchema; import java.math.BigDecimal; @@ -30,6 +31,7 @@ import org.apache.kafka.connect.data.ConnectSchema; import org.apache.kafka.connect.data.Field; import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.Schema.Type; import org.apache.kafka.connect.data.SchemaBuilder; import org.junit.Assert; import org.junit.Test; @@ -813,6 +815,32 @@ public void shouldFailIsNumberForString() { assertThat(SchemaUtil.isNumber(Schema.STRING_SCHEMA), is(false)); } + @Test + public void shouldFailINonCompatibleSchemas() { + assertThat(SchemaUtil.areCompatible(Schema.STRING_SCHEMA, Schema.INT32_SCHEMA), is(false)); + + assertThat(SchemaUtil.areCompatible(DecimalUtil.builder(1,1).build(), + Schema.BYTES_SCHEMA), is(false)); + + assertThat(SchemaUtil.areCompatible(GenericsUtil.generic("a").build(), + GenericsUtil.generic("b").build()), is(false)); + + assertThat(SchemaUtil.areCompatible(GenericsUtil.array("a").build(), + GenericsUtil.array("b").build()), is(false)); + } + + @Test + public void shouldPassCompatibleSchemas() { + assertThat(SchemaUtil.areCompatible(Schema.STRING_SCHEMA, Schema.OPTIONAL_STRING_SCHEMA), + is(true)); + + assertThat(SchemaUtil.areCompatible(DecimalUtil.builder(2,2), + DecimalUtil.builder(1,1)), is(true)); + + assertThat(SchemaUtil.areCompatible(GenericsUtil.generic("a").build(), + GenericsUtil.generic("a").build()), is(false)); + } + @Test public void shouldBuildAliasedFieldName() { // When: diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/InternalFunctionRegistry.java b/ksql-engine/src/main/java/io/confluent/ksql/function/InternalFunctionRegistry.java index 19980ed2bf7d..7bc87c807b9f 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/InternalFunctionRegistry.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/InternalFunctionRegistry.java @@ -25,7 +25,6 @@ import io.confluent.ksql.function.udf.UdfMetadata; import io.confluent.ksql.function.udf.json.ArrayContainsKudf; import io.confluent.ksql.function.udf.json.JsonExtractStringKudf; -import io.confluent.ksql.function.udf.math.AbsKudf; import io.confluent.ksql.function.udf.math.CeilKudf; import io.confluent.ksql.function.udf.math.RandomKudf; import io.confluent.ksql.function.udf.math.RoundKudf; @@ -219,18 +218,6 @@ private void addStringFunctions() { private void addMathFunctions() { - addBuiltInFunction(KsqlFunction.createLegacyBuiltIn( - Schema.OPTIONAL_FLOAT64_SCHEMA, - ImmutableList.of(Schema.OPTIONAL_FLOAT64_SCHEMA), - AbsKudf.NAME, - AbsKudf.class)); - - addBuiltInFunction(KsqlFunction.createLegacyBuiltIn( - Schema.OPTIONAL_FLOAT64_SCHEMA, - Collections.singletonList(Schema.OPTIONAL_INT64_SCHEMA), - AbsKudf.NAME, - AbsKudf.class)); - addBuiltInFunction(KsqlFunction.createLegacyBuiltIn( Schema.OPTIONAL_FLOAT64_SCHEMA, Collections.singletonList(Schema.OPTIONAL_FLOAT64_SCHEMA), diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/UdfLoader.java b/ksql-engine/src/main/java/io/confluent/ksql/function/UdfLoader.java index 6ba859c6eeb8..aaf4211532a4 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/UdfLoader.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/UdfLoader.java @@ -23,10 +23,12 @@ import io.confluent.ksql.function.udf.UdfDescription; import io.confluent.ksql.function.udf.UdfMetadata; import io.confluent.ksql.function.udf.UdfParameter; +import io.confluent.ksql.function.udf.UdfSchemaProvider; import io.confluent.ksql.metrics.MetricCollectors; import io.confluent.ksql.schema.ksql.SchemaConverters; import io.confluent.ksql.schema.ksql.TypeContextUtil; import io.confluent.ksql.security.ExtensionSecurityManager; +import io.confluent.ksql.util.DecimalUtil; import io.confluent.ksql.util.KsqlConfig; import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.SchemaUtil; @@ -35,6 +37,7 @@ import io.github.lukehutch.fastclasspathscanner.matchprocessor.MethodAnnotationMatchProcessor; import java.io.File; import java.io.IOException; +import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.lang.reflect.Parameter; @@ -46,6 +49,7 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.TimeUnit; +import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -140,7 +144,6 @@ public void loadUdfFromClass(final Class ... udfClass) { } } - @SuppressWarnings("OptionalUsedAsFieldOrParameterType") private void loadUdfs(final ClassLoader loader, final Optional path) { final String pathLoadedFrom @@ -259,11 +262,12 @@ private void handleUdfAnnotation(final Class theClass, LOGGER.info("Adding UDF name='{}' from class={}", annotation.name(), theClass); final UdfInvoker udf = UdfCompiler.compile(method, classLoader); - addFunction(annotation, method, udf, path); + addFunction(theClass, annotation, method, udf, path); } - private void addFunction(final UdfDescription classLevelAnnotation, + private void addFunction(final Class theClass, + final UdfDescription classLevelAnnotation, final Method method, final UdfInvoker udf, final String path) { @@ -302,9 +306,9 @@ private void addFunction(final UdfDescription classLevelAnnotation, if (name.trim().isEmpty()) { throw new KsqlFunctionException( - String.format("Cannot resolve parameter name for param at index %d for udf %s:%s. " - + "Please specify a name in @UdfParameter or compile your JAR with -parameters " - + "to infer the name from the parameter name.", + String.format("Cannot resolve parameter name for param at index %d for UDF %s:%s. " + + "Please specify a name in @UdfParameter or compile your JAR with -parameters " + + "to infer the name from the parameter name.", idx, classLevelAnnotation.name(), method.getName())); } @@ -320,10 +324,15 @@ private void addFunction(final UdfDescription classLevelAnnotation, return UdfUtil.getSchemaFromType(type, name, doc); }).collect(Collectors.toList()); - final Schema returnType = getReturnType(method, udfAnnotation); + final Schema javaReturnSchema = getReturnType(method, udfAnnotation); functionRegistry.addFunction(KsqlFunction.create( - returnType, + handleUdfReturnSchema(theClass, + method, + javaReturnSchema, + udfAnnotation, + classLevelAnnotation), + javaReturnSchema, parameters, functionName, udfClass, @@ -342,8 +351,9 @@ private void addFunction(final UdfDescription classLevelAnnotation, method.isVarArgs())); } + private static Object instantiateUdfClass(final Method method, - final UdfDescription annotation) { + final UdfDescription annotation) { try { return method.getDeclaringClass().newInstance(); } catch (final Exception e) { @@ -354,6 +364,79 @@ private static Object instantiateUdfClass(final Method method, } } + private static Object instantiateUdfClass(final Class udfClass, + final UdfDescription annotation) { + try { + return udfClass.newInstance(); + } catch (final Exception e) { + throw new KsqlException("Failed to create instance for UDF=" + + annotation.name(), e); + } + } + + private Function,Schema> handleUdfReturnSchema(final Class theClass, + final Method method, + final Schema javaReturnSchema, + final Udf udfAnnotation, + final UdfDescription descAnnotation) { + + final String schemaProviderName = udfAnnotation.schemaProvider(); + + if (!schemaProviderName.equals("")) { + return handleUdfSchemaProviderAnnotation(schemaProviderName, theClass, method, + descAnnotation); + } else if (DecimalUtil.isDecimal(javaReturnSchema)) { + throw new KsqlException(String.format("Cannot load UDF %s. BigDecimal return type " + + "is not supported without a schema provider method.", descAnnotation.name())); + } + + return ignored -> javaReturnSchema; + } + + private Function,Schema> handleUdfSchemaProviderAnnotation( + final String schemaProviderName, + final Class theClass, + final Method method, + final UdfDescription annotation) { + + // throws exception if cannot find method + final Method m = findSchemaProvider(theClass, schemaProviderName); + final Object instance = instantiateUdfClass(theClass, annotation); + + return parameterTypes -> { + return invokeSchemaProviderMethod(instance, m, parameterTypes, annotation); + }; + } + + private Method findSchemaProvider(final Class theClass, + final String schemaProviderName) { + try { + final Method m = theClass.getDeclaredMethod(schemaProviderName, List.class); + if (!m.isAnnotationPresent(UdfSchemaProvider.class)) { + throw new KsqlException(String.format( + "Method %s should be annotated with @UdfSchemaProvider.", + schemaProviderName)); + } + return m; + } catch (NoSuchMethodException e) { + throw new KsqlException(String.format( + "Cannot find schema provider method with name %s and parameter List in class %s.", + schemaProviderName,theClass.getName()),e); + } + } + + private Schema invokeSchemaProviderMethod(final Object instance, final Method m, + final List args, final UdfDescription annotation) { + try { + return (Schema) m.invoke(instance, args); + } catch (IllegalAccessException + | InvocationTargetException e) { + throw new KsqlException(String.format("Cannot invoke the schema provider " + + "method %s for UDF %s. ", + m.getName(), annotation.name()), e); + } + } + private void addSensor(final String sensorName, final String udfName) { metrics.ifPresent(metrics -> { if (metrics.getSensor(sensorName) == null) { @@ -377,8 +460,7 @@ private void addSensor(final String sensorName, final String udfName) { public static UdfLoader newInstance(final KsqlConfig config, final MutableFunctionRegistry metaStore, - final String ksqlInstallDir - ) { + final String ksqlInstallDir) { final Boolean loadCustomerUdfs = config.getBoolean(KsqlConfig.KSQL_ENABLE_UDFS); final Boolean collectMetrics = config.getBoolean(KsqlConfig.KSQL_COLLECT_UDF_METRICS); final String extDirName = config.getString(KsqlConfig.KSQL_EXT_DIR); diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udf/math/Abs.java b/ksql-engine/src/main/java/io/confluent/ksql/function/udf/math/Abs.java new file mode 100644 index 000000000000..5f0fff13a646 --- /dev/null +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/udf/math/Abs.java @@ -0,0 +1,68 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"; you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.function.udf.math; + +import io.confluent.ksql.function.udf.Udf; +import io.confluent.ksql.function.udf.UdfDescription; +import io.confluent.ksql.function.udf.UdfParameter; +import io.confluent.ksql.function.udf.UdfSchemaProvider; +import io.confluent.ksql.util.DecimalUtil; +import io.confluent.ksql.util.KsqlException; +import java.math.BigDecimal; +import java.util.List; +import org.apache.kafka.connect.data.Schema; + +@UdfDescription(name = "Abs", description = Abs.DESCRIPTION) +public class Abs { + + static final String DESCRIPTION = "Returns the absolute value of its argument. If the argument " + + "is not negative, the argument is returned. If the argument is negative, the negation of " + + "the argument is returned."; + + + @Udf + public Double abs(@UdfParameter final Integer val) { + return (val == null) ? null : (double)Math.abs(val); + } + + @Udf + public Double abs(@UdfParameter final Long val) { + return (val == null) ? null : (double)Math.abs(val); + } + + @Udf + public Double abs(@UdfParameter final Double val) { + return (val == null) ? null : Math.abs(val); + } + + @Udf(schemaProvider = "provideSchema") + public BigDecimal abs(@UdfParameter final BigDecimal val) { + return (val == null) ? null : val.abs(); + } + + @UdfSchemaProvider + public Schema provideSchema(final List params) { + if (params.size() != 1) { + throw new KsqlException("Abs udf accepts one parameter"); + } + final Schema s = params.get(0); + if (!DecimalUtil.isDecimal(s)) { + throw new KsqlException("The schema provider method for Abs expects a BigDecimal parameter" + + "type"); + } + return s; + } +} \ No newline at end of file diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/InternalFunctionRegistryTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/InternalFunctionRegistryTest.java index 8a528d12e684..3643f0831fa6 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/InternalFunctionRegistryTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/function/InternalFunctionRegistryTest.java @@ -363,7 +363,7 @@ public void shouldHaveBuiltInUDFRegistered() { // String UDF "LCASE", "UCASE", "CONCAT", "TRIM", "IFNULL", "LEN", // Math UDF - "ABS", "CEIL", "ROUND", "RANDOM", + "CEIL", "ROUND", "RANDOM", // JSON UDF "EXTRACTJSONFIELD", "ARRAYCONTAINS", // Struct UDF diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java index 6d22382f4ebb..da7e7957aef5 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java @@ -35,11 +35,14 @@ import io.confluent.ksql.function.udf.Udf; import io.confluent.ksql.function.udf.UdfDescription; import io.confluent.ksql.function.udf.UdfParameter; +import io.confluent.ksql.function.udf.UdfSchemaProvider; import io.confluent.ksql.util.DecimalUtil; import io.confluent.ksql.util.KsqlConfig; import io.confluent.ksql.util.KsqlException; import java.io.File; import java.lang.reflect.Field; +import java.math.BigDecimal; +import java.nio.file.Path; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -219,6 +222,120 @@ public void shouldLoadFunctionWithStructReturnType() { ); } + @Test + public void shouldLoadFunctionWithSchemaProvider() { + // Given: + final UdfFactory returnDecimal = FUNC_REG.getUdfFactory("returndecimal"); + + // When: + final Schema decimal = DecimalUtil.builder(2, 1).build(); + final List args = Collections.singletonList(decimal); + final KsqlFunction function = returnDecimal.getFunction(args); + + // Then: + assertThat(function.getReturnType(args), equalTo(decimal)); + } + + @Test + public void shouldThrowOnReturnTypeMismatch() { + // Given: + final UdfFactory returnIncompatible = FUNC_REG.getUdfFactory("returnincompatible"); + final Schema decimal = DecimalUtil.builder(2, 1).build(); + final List args = Collections.singletonList(decimal); + final KsqlFunction function = returnIncompatible.getFunction(args); + + // Expect: + expectedException.expect(KsqlException.class); + expectedException.expectMessage(is("Return type Schema{org.apache.kafka.connect.data." + + "Decimal:BYTES} of UDF ReturnIncompatible does not " + + "match the declared return type Schema{STRING}.")); + + // When: + function.getReturnType(args); + } + + @Test + public void shouldThrowOnMissingAnnotation() throws ClassNotFoundException { + // Given: + final MutableFunctionRegistry functionRegistry = new InternalFunctionRegistry(); + final Path udfJar = new File("src/test/resources/udf-failing-tests.jar").toPath(); + final UdfClassLoader udfClassLoader = UdfClassLoader.newClassLoader(udfJar, + PARENT_CLASS_LOADER, + resourceName -> false); + Class clazz = udfClassLoader.loadClass("org.damian.ksql.udf.MissingAnnotationUdf"); + final UdfLoader udfLoader = new UdfLoader(functionRegistry, + new File("src/test/resources/udf-failing-tests.jar"), + udfClassLoader, + value -> false, + COMPILER, + Optional.empty(), + true); + + // Expect: + expectedException.expect(KsqlException.class); + expectedException.expectMessage(is("Cannot load UDF MissingAnnotation. BigDecimal return type " + + "is not supported without a schema provider method.")); + + // When: + udfLoader.loadUdfFromClass(clazz); + + } + + @Test + public void shouldThrowOnMissingSchemaProvider() throws ClassNotFoundException { + // Given: + final MutableFunctionRegistry functionRegistry = new InternalFunctionRegistry(); + final Path udfJar = new File("src/test/resources/udf-failing-tests.jar").toPath(); + final UdfClassLoader udfClassLoader = UdfClassLoader.newClassLoader(udfJar, + PARENT_CLASS_LOADER, + resourceName -> false); + Class clazz = udfClassLoader.loadClass("org.damian.ksql.udf.MissingSchemaProviderUdf"); + final UdfLoader udfLoader = new UdfLoader(functionRegistry, + new File("src/test/resources/udf-failing-tests.jar"), + udfClassLoader, + value -> false, + COMPILER, + Optional.empty(), + true); + + // Expect: + expectedException.expect(KsqlException.class); + expectedException.expectMessage(is("Cannot find schema provider method with name provideSchema " + + "and parameter List in class org.damian.ksql.udf." + + "MissingSchemaProviderUdf.")); + + /// When: + udfLoader.loadUdfFromClass(clazz); + } + + @Test + public void shouldThrowOnReturnDecimalWithoutSchemaProvider() throws ClassNotFoundException { + // Given: + final MutableFunctionRegistry functionRegistry = new InternalFunctionRegistry(); + final Path udfJar = new File("src/test/resources/udf-failing-tests.jar").toPath(); + final UdfClassLoader udfClassLoader = UdfClassLoader.newClassLoader(udfJar, + PARENT_CLASS_LOADER, + resourceName -> false); + Class clazz = udfClassLoader.loadClass("org.damian.ksql.udf." + + "ReturnDecimalWithoutSchemaProviderUdf"); + final UdfLoader udfLoader = new UdfLoader(functionRegistry, + new File("src/test/resources/udf-failing-tests.jar"), + udfClassLoader, + value -> false, + COMPILER, + Optional.empty(), + true); + + // Expect: + expectedException.expect(KsqlException.class); + expectedException.expectMessage(is("Cannot load UDF ReturnDecimalWithoutSchemaProvider. " + + "BigDecimal return type is not supported without a " + + "schema provider method.")); + + /// When: + udfLoader.loadUdfFromClass(clazz); + } + @Test public void shouldPutJarUdfsInClassLoaderForJar() throws Exception { final UdfFactory toString = FUNC_REG.getUdfFactory("tostring"); @@ -313,7 +430,7 @@ public void shouldNotLoadInternalUdfs() { expectedException.expectMessage(is("Can't find any functions with the name 'substring'")); // When: - functionRegistry.getUdfFactory("substring"); + functionRegistry.getUdfFactory("substring"); } @Test @@ -473,7 +590,7 @@ private static UdfLoader createUdfLoader( final Optional metrics ) { return new UdfLoader(functionRegistry, - new File("src/test/resources"), + new File("src/test/resources/udf-example.jar"), PARENT_CLASS_LOADER, value -> false, COMPILER, @@ -540,4 +657,40 @@ public Object foo(final String noValue) { return 0; } } + + @SuppressWarnings({"unused", "MethodMayBeStatic"}) // Invoked via reflection in test. + @UdfDescription( + name = "ReturnDecimal", + description = "A test-only UDF for testing 'SchemaProvider'") + + public static class ReturnDecimalUdf { + + @Udf(schemaProvider = "provideSchema") + public BigDecimal foo(@UdfParameter("justValue") final BigDecimal p) { + return p; + } + + @UdfSchemaProvider + public Schema provideSchema(List params) { + return DecimalUtil.builder(2, 1).build(); + } + } + + @SuppressWarnings({"unused", "MethodMayBeStatic"}) // Invoked via reflection in test. + @UdfDescription( + name = "ReturnIncompatible", + description = "A test-only UDF for testing 'SchemaProvider'") + + public static class ReturnIncompatibleUdf { + + @Udf(schemaProvider = "provideSchema") + public String foo(@UdfParameter("justValue") final BigDecimal p) { + return "lala"; + } + + @UdfSchemaProvider + public Schema provideSchema(List params) { + return DecimalUtil.builder(2, 1).build(); + } + } } \ No newline at end of file diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/udf/math/AbsKudfTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/udf/math/AbsTest.java similarity index 53% rename from ksql-engine/src/test/java/io/confluent/ksql/function/udf/math/AbsKudfTest.java rename to ksql-engine/src/test/java/io/confluent/ksql/function/udf/math/AbsTest.java index 15599b71422a..18e50ad404cd 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/udf/math/AbsKudfTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/function/udf/math/AbsTest.java @@ -20,33 +20,40 @@ import static org.junit.Assert.assertThat; import io.confluent.ksql.function.udf.KudfTester; +import java.math.BigDecimal; import org.junit.Before; import org.junit.Test; -public class AbsKudfTest { +public class AbsTest { - private AbsKudf udf; + private Abs udf; @Before public void setUp() { - udf = new AbsKudf(); + udf = new Abs(); } @Test - public void shouldBeWellBehavedUdf() { - new KudfTester(AbsKudf::new) - .withArgumentTypes(Number.class) - .test(); + public void shouldHandleNull() { + assertThat(udf.abs((Integer) null), is(nullValue())); + assertThat(udf.abs((Long)null), is(nullValue())); + assertThat(udf.abs((Double)null), is(nullValue())); + assertThat(udf.abs((BigDecimal) null), is(nullValue())); } @Test - public void shouldReturnNullWhenArgNull() { - assertThat(udf.evaluate((Object)null), is(nullValue())); + public void shouldHandleNegative() { + assertThat(udf.abs(-1), is(1.0)); + assertThat(udf.abs(-1L), is(1.0)); + assertThat(udf.abs(-1.0), is(1.0)); + assertThat(udf.abs(new BigDecimal(-1)), is(new BigDecimal(-1).abs())); } @Test - public void shouldAbs() { - assertThat(udf.evaluate(-1.234), is(1.234)); - assertThat(udf.evaluate(5567), is(5567.0)); + public void shouldHandlePositive() { + assertThat(udf.abs(1), is(1.0)); + assertThat(udf.abs(1L), is(1.0)); + assertThat(udf.abs(1.0), is(1.0)); + assertThat(udf.abs(new BigDecimal(1)), is(new BigDecimal(1).abs())); } } \ No newline at end of file diff --git a/ksql-engine/src/test/resources/udf-failing-tests.jar b/ksql-engine/src/test/resources/udf-failing-tests.jar new file mode 100644 index 000000000000..be374a0909c9 Binary files /dev/null and b/ksql-engine/src/test/resources/udf-failing-tests.jar differ diff --git a/ksql-engine/src/test/resources/udf-failing-tests/pom.xml b/ksql-engine/src/test/resources/udf-failing-tests/pom.xml new file mode 100644 index 000000000000..d0fdc83306fb --- /dev/null +++ b/ksql-engine/src/test/resources/udf-failing-tests/pom.xml @@ -0,0 +1,54 @@ + + + 4.0.0 + org.damian.ksql.udf + udf-failing-tests + 1 + + + 5.4.0-SNAPSHOT + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.7.0 + + 8 + 8 + UTF-8 + + -parameters + + + + + + + + + + + io.confluent.ksql + ksql-udf + ${ksql.version} + + + + io.confluent.ksql + ksql-common + ${ksql.version} + + + + org.apache.kafka + connect-api + 2.2.0 + + + + + \ No newline at end of file diff --git a/ksql-engine/src/test/resources/udf-failing-tests/src/main/java/org/damian/ksql/udf/MissingAnnotationUdf.java b/ksql-engine/src/test/resources/udf-failing-tests/src/main/java/org/damian/ksql/udf/MissingAnnotationUdf.java new file mode 100644 index 000000000000..d6455098cb00 --- /dev/null +++ b/ksql-engine/src/test/resources/udf-failing-tests/src/main/java/org/damian/ksql/udf/MissingAnnotationUdf.java @@ -0,0 +1,47 @@ +/* + * Copyright 2018 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package org.damian.ksql.udf; + +import io.confluent.ksql.function.udf.Udf; +import io.confluent.ksql.function.udf.UdfDescription; +import io.confluent.ksql.function.udf.UdfParameter; +import io.confluent.ksql.function.udf.UdfSchemaProvider; +import io.confluent.ksql.util.DecimalUtil; +import java.math.BigDecimal; +import java.util.List; +import org.apache.kafka.connect.data.Schema; + +/** + * Class used to test the loading of UDFs. This is packaged in udf-failing-tests.jar + * Attention: This test crashes the UdfLoader. + */ + +@UdfDescription( + name = "MissingAnnotation", + description = "A test-only UDF for testing 'SchemaProvider'") + +public class MissingAnnotationUdf { + + @Udf + public BigDecimal foo(@UdfParameter("justValue") final BigDecimal p) { + return p; + } + + @UdfSchemaProvider + public Schema provideSchema(List params) { + return DecimalUtil.builder(2, 1).build(); + } +} diff --git a/ksql-engine/src/test/resources/udf-failing-tests/src/main/java/org/damian/ksql/udf/MissingSchemaProviderUdf.java b/ksql-engine/src/test/resources/udf-failing-tests/src/main/java/org/damian/ksql/udf/MissingSchemaProviderUdf.java new file mode 100644 index 000000000000..3fa715240c86 --- /dev/null +++ b/ksql-engine/src/test/resources/udf-failing-tests/src/main/java/org/damian/ksql/udf/MissingSchemaProviderUdf.java @@ -0,0 +1,40 @@ +/* + * Copyright 2018 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package org.damian.ksql.udf; + +import io.confluent.ksql.function.udf.Udf; +import io.confluent.ksql.function.udf.UdfDescription; +import io.confluent.ksql.function.udf.UdfParameter; +import io.confluent.ksql.function.udf.UdfSchemaProvider; +import io.confluent.ksql.util.DecimalUtil; +import java.math.BigDecimal; + +/** + * Class used to test the loading of UDFs. This is packaged in udf-failing-tests.jar + * Attention: This test crashes the UdfLoader. + */ + +@UdfDescription( + name = "MissingSchemaProviderMethod", + description = "A test-only UDF for testing 'SchemaProvider'") + +public class MissingSchemaProviderUdf { + + @Udf(schemaProvider = "provideSchema") + public BigDecimal foo(@UdfParameter("justValue") final BigDecimal p) { + return p; + } +} diff --git a/ksql-engine/src/test/resources/udf-failing-tests/src/main/java/org/damian/ksql/udf/ReturnDecimalWithoutSchemaProviderUdf.java b/ksql-engine/src/test/resources/udf-failing-tests/src/main/java/org/damian/ksql/udf/ReturnDecimalWithoutSchemaProviderUdf.java new file mode 100644 index 000000000000..17442b8f394f --- /dev/null +++ b/ksql-engine/src/test/resources/udf-failing-tests/src/main/java/org/damian/ksql/udf/ReturnDecimalWithoutSchemaProviderUdf.java @@ -0,0 +1,37 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package org.damian.ksql.udf; + +import io.confluent.ksql.function.udf.Udf; +import io.confluent.ksql.function.udf.UdfDescription; +import java.math.BigDecimal; + +/** + * Class used to test the loading of UDFs. This is packaged in udf-failing-tests.jar + * Attention: This test crashes the UdfLoader. + */ + +@UdfDescription( + name = "ReturnDecimalWithoutSchemaProvider", + description = "A test-only UDF for testing 'SchemaProvider'") + +public class ReturnDecimalWithoutSchemaProviderUdf { + + @Udf + public BigDecimal foo(final BigDecimal p) { + return new BigDecimal(1); + } +} diff --git a/ksql-engine/src/test/resources/udf-failing-tests/src/main/resources/log4j.properties b/ksql-engine/src/test/resources/udf-failing-tests/src/main/resources/log4j.properties new file mode 100644 index 000000000000..06ce81399847 --- /dev/null +++ b/ksql-engine/src/test/resources/udf-failing-tests/src/main/resources/log4j.properties @@ -0,0 +1,22 @@ +# +# Copyright 2018 Confluent Inc. +# +# Licensed under the Confluent Community License (the "License"); you may not use +# this file except in compliance with the License. You may obtain a copy of the +# License at +# +# http://www.confluent.io/confluent-community-license +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. +# + +log4j.rootLogger=WARN,stdout + +log4j.logger.io.confluent.ksql=DEBUG +log4j.logger.io.confluent.ksql.integration=DEBUG +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c:%L)%n diff --git a/ksql-engine/src/test/resources/udf-failing-tests/src/main/resources/resource-blacklist.txt b/ksql-engine/src/test/resources/udf-failing-tests/src/main/resources/resource-blacklist.txt new file mode 100644 index 000000000000..a664ade0610d --- /dev/null +++ b/ksql-engine/src/test/resources/udf-failing-tests/src/main/resources/resource-blacklist.txt @@ -0,0 +1 @@ +java.lang.Runtime$ diff --git a/ksql-functional-tests/src/test/resources/query-validation-tests/math.json b/ksql-functional-tests/src/test/resources/query-validation-tests/math.json index efce183cb9e1..ec50d307067b 100644 --- a/ksql-functional-tests/src/test/resources/query-validation-tests/math.json +++ b/ksql-functional-tests/src/test/resources/query-validation-tests/math.json @@ -97,6 +97,25 @@ {"topic": "OUTPUT", "value": {"I": 0.0, "L": 0.0, "D": 0.0, "B": 0.0}}, {"topic": "OUTPUT", "value": {"I": 1.0, "L": 2.0, "D": 3.0, "B": 3.0}} ] + }, + { + "name": "abs", + "statements": [ + "CREATE STREAM INPUT (i INT, l BIGINT, d DOUBLE, b DECIMAL(2,1)) WITH (kafka_topic='input', value_format='AVRO');", + "CREATE STREAM OUTPUT AS SELECT abs(i) i, abs(l) l, abs(d) d, abs(b) b FROM INPUT;" + ], + "inputs": [ + {"topic": "input", "value": {"i": null, "l": null, "d": null}}, + {"topic": "input", "value": {"i": -1, "l": -2, "d": -3.1, "b": "-3.2"}}, + {"topic": "input", "value": {"i": 0, "l": 0, "d": 0.0, "b": "0.0"}}, + {"topic": "input", "value": {"i": 1, "l": 2, "d": 3.3, "b": "3.4"}} + ], + "outputs": [ + {"topic": "OUTPUT", "value": {"I": null, "L": null, "D": null, "B": null}}, + {"topic": "OUTPUT", "value": {"I": 1.0, "L": 2.0, "D": 3.1, "B": "3.2"}}, + {"topic": "OUTPUT", "value": {"I": 0.0, "L": 0.0, "D": 0.0, "B": "0.0"}}, + {"topic": "OUTPUT", "value": {"I": 1.0, "L": 2.0, "D": 3.3, "B": "3.4"}} + ] } ] } \ No newline at end of file diff --git a/ksql-udf/src/main/java/io/confluent/ksql/function/udf/Udf.java b/ksql-udf/src/main/java/io/confluent/ksql/function/udf/Udf.java index 3d66ce1267b5..d0d2488ac603 100644 --- a/ksql-udf/src/main/java/io/confluent/ksql/function/udf/Udf.java +++ b/ksql-udf/src/main/java/io/confluent/ksql/function/udf/Udf.java @@ -46,4 +46,10 @@ * this is required and will fail if not supplied. */ String schema() default ""; + + /** + * The name of the method that provides the return type of the UDF. + * @return the name of the other method + */ + String schemaProvider() default ""; } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udf/math/AbsKudf.java b/ksql-udf/src/main/java/io/confluent/ksql/function/udf/UdfSchemaProvider.java similarity index 52% rename from ksql-engine/src/main/java/io/confluent/ksql/function/udf/math/AbsKudf.java rename to ksql-udf/src/main/java/io/confluent/ksql/function/udf/UdfSchemaProvider.java index 4f8326cc978d..06c1a2a3f690 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udf/math/AbsKudf.java +++ b/ksql-udf/src/main/java/io/confluent/ksql/function/udf/UdfSchemaProvider.java @@ -13,20 +13,21 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.function.udf.math; +package io.confluent.ksql.function.udf; -import io.confluent.ksql.function.UdfUtil; -import io.confluent.ksql.function.udf.Kudf; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; -public class AbsKudf implements Kudf { - public static final String NAME = "ABS"; +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD}) +/** + * The {@code UdfSchemaProvider} annotation on a method tells KSQL to use this method to resolve + * the return type of the udf at runtime. + * + *

The corresponding udf annotation must have the {@code schemaProvider} specified. + */ +public @interface UdfSchemaProvider { - @Override - public Object evaluate(final Object... args) { - UdfUtil.ensureCorrectArgs(NAME, args, Number.class); - if (args[0] == null) { - return null; - } - return Math.abs(((Number) args[0]).doubleValue()); - } -} +} \ No newline at end of file