diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/UdfCompiler.java b/ksql-engine/src/main/java/io/confluent/ksql/function/UdfCompiler.java index ea9b04d6275b..dff5adcde909 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/UdfCompiler.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/UdfCompiler.java @@ -88,7 +88,7 @@ public class UdfCompiler { private final SqlTypeParser typeParser; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") - public UdfCompiler(final Optional metrics) { + UdfCompiler(final Optional metrics) { this.metrics = Objects.requireNonNull(metrics, "metrics can't be null"); this.typeParser = SqlTypeParser.create(TypeRegistry.EMPTY); } @@ -195,13 +195,13 @@ private class UdafTypes { } private void validateTypes(final Type t) { - if (!isTypeSupported((Class)getRawType(t), SUPPORTED_TYPES)) { + if (isUnsupportedType((Class) getRawType(t))) { throw new KsqlException(String.format(invalidClassErrorMsg, t)); } } Schema getInputSchema(final String inSchema) { - validateStructAnnotation(inputType, inSchema, ""); + validateStructAnnotation(inputType, inSchema, "paramSchema"); final Schema inputSchema = getSchemaFromType(inputType, inSchema); //Currently, aggregate functions cannot have reified types as input parameters. if (!GenericsUtil.constituentGenerics(inputSchema).isEmpty()) { @@ -212,19 +212,18 @@ Schema getInputSchema(final String inSchema) { } Schema getAggregateSchema(final String aggSchema) { - validateStructAnnotation(aggregateType, aggSchema, ""); + validateStructAnnotation(aggregateType, aggSchema, "aggregateSchema"); return getSchemaFromType(aggregateType, aggSchema); } Schema getOutputSchema(final String outSchema) { - validateStructAnnotation(outputType, outSchema, ""); + validateStructAnnotation(outputType, outSchema, "returnSchema"); return getSchemaFromType(outputType, outSchema); } private void validateStructAnnotation(final Type type, final String schema, final String msg) { if (type.equals(Struct.class) && schema.isEmpty()) { - throw new KsqlException(String.format("Must specify '%s' for STRUCT parameter in " - + "@UdafFactory.", msg)); + throw new KsqlException("Must specify '" + msg + "' for STRUCT parameter in @UdafFactory."); } } @@ -283,7 +282,7 @@ private static String generateUdafClass( ) { validateMethodSignature(method); Arrays.stream(method.getParameterTypes()) - .filter(type -> !UdfCompiler.isTypeSupported(type, SUPPORTED_TYPES)) + .filter(UdfCompiler::isUnsupportedType) .findFirst() .ifPresent(type -> { throw new KsqlException( @@ -319,7 +318,7 @@ private static String generateCode(final Method method) { continue; } - if (!UdfCompiler.isTypeSupported(type, SUPPORTED_TYPES)) { + if (isUnsupportedType(type)) { throw new KsqlException( String.format( "Type %s is not supported by UDF methods. " @@ -347,11 +346,10 @@ private static void validateMethodSignature(final Method method) { } } - @SuppressWarnings("BooleanMethodIsAlwaysInverted") - private static boolean isTypeSupported(final Class type, final Set> supportedTypes) { - return supportedTypes.contains(type) - || type.isArray() && supportedTypes.contains(type.getComponentType()) - || supportedTypes.stream().anyMatch(supported -> supported.isAssignableFrom(type)); + private static boolean isUnsupportedType(final Class type) { + return !SUPPORTED_TYPES.contains(type) + && (!type.isArray() || !SUPPORTED_TYPES.contains(type.getComponentType())) + && SUPPORTED_TYPES.stream().noneMatch(supported -> supported.isAssignableFrom(type)); } private static IScriptEvaluator createScriptEvaluator( diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/UdfCompilerTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/UdfCompilerTest.java index 2fdc8737c991..423eaa93bd18 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/UdfCompilerTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/function/UdfCompilerTest.java @@ -41,6 +41,7 @@ import org.junit.Test; import org.junit.rules.ExpectedException; +@SuppressWarnings({"MethodMayBeStatic", "WeakerAccess", "unused"}) // UDFs not static / private public class UdfCompilerTest { private static final Schema STRUCT_SCHEMA = @@ -272,37 +273,59 @@ public void shouldThrowIfUnsupportedInputType() throws Exception { ""); } - @Test(expected = KsqlException.class) + @Test public void shouldThrowIfMissingInputTypeSchema() throws Exception { - udfCompiler.compileAggregate(UdfCompilerTest.class.getMethod("missingInputSchemaAnnotationUdaf"), - classLoader, - "test", - "desc", - "", - "", - ""); + // Then: + expectedException.expect(KsqlException.class); + expectedException.expectMessage( + "Must specify 'paramSchema' for STRUCT parameter in @UdafFactory."); + + // When: + udfCompiler.compileAggregate( + UdfCompilerTest.class.getMethod("missingInputSchemaAnnotationUdaf"), + classLoader, + "test", + "desc", + "", + "", + ""); } - @Test(expected = KsqlException.class) + @Test public void shouldThrowIfMissingAggregateTypeSchema() throws Exception { - udfCompiler.compileAggregate(UdfCompilerTest.class.getMethod("missingAggregateSchemaAnnotationUdaf"), - classLoader, - "test", - "desc", - "", - "", - ""); + // Then: + expectedException.expect(KsqlException.class); + expectedException.expectMessage( + "Must specify 'aggregateSchema' for STRUCT parameter in @UdafFactory."); + + // When: + udfCompiler.compileAggregate( + UdfCompilerTest.class.getMethod("missingAggregateSchemaAnnotationUdaf"), + classLoader, + "test", + "desc", + "", + "", + ""); } - @Test(expected = KsqlException.class) + @Test public void shouldThrowIfMissingOutputTypeSchema() throws Exception { - udfCompiler.compileAggregate(UdfCompilerTest.class.getMethod("missingOutputSchemaAnnotationUdaf"), - classLoader, - "test", - "desc", - "", - "", - ""); + // Then: + expectedException.expect(KsqlException.class); + expectedException.expectMessage( + "Must specify 'returnSchema' for STRUCT parameter in @UdafFactory."); + + // When: + udfCompiler.compileAggregate( + UdfCompilerTest.class.getMethod("missingOutputSchemaAnnotationUdaf"), + classLoader, + "test", + "desc", + "", + "", + "" + ); } @Test diff --git a/ksql-udf/src/main/java/io/confluent/ksql/function/udaf/Udaf.java b/ksql-udf/src/main/java/io/confluent/ksql/function/udaf/Udaf.java index f84d01ec0ec1..cf7ddb467b32 100644 --- a/ksql-udf/src/main/java/io/confluent/ksql/function/udaf/Udaf.java +++ b/ksql-udf/src/main/java/io/confluent/ksql/function/udaf/Udaf.java @@ -18,11 +18,21 @@ /** * {@code Udaf} represents a custom UDAF (User Defined Aggregate Function) * that can be used to perform aggregations on KSQL Streams. - * Type support is presently limited to: int, Integer, long, Long, boolean, Boolean, double, + * + *

Type support is presently limited to: int, Integer, long, Long, boolean, Boolean, double, * Double, String, Map, and List. * - * @param value type - * @param aggregate type + *

Sequence of calls is: + *

    + *
  1. {@code initialize()}: to get the initial value for the aggregate
  2. + *
  3. {@code aggregate(value, aggregate)}: adds {@code value} to the {@code aggregate}.
  4. + *
  5. {@code merge(agg1, agg2)}: merges to aggregates together, e.g. on session merges.
  6. + *
  7. {@code map(agg)}: reduces the intermediate state to the final output type.
  8. + *
+ * + * @param the input type + * @param the intermediate aggregate type + * @param the final output type */ public interface Udaf { /** @@ -39,13 +49,6 @@ public interface Udaf { */ A aggregate(I current, A aggregate); - /** - * Map the intermediate aggregate value into the actual returned value. - * @param agg aggregate value of current record - * @return new value of current record - */ - O map(A agg); - /** * Merge two aggregates * @param aggOne first aggregate @@ -53,4 +56,11 @@ public interface Udaf { * @return new aggregate */ A merge(A aggOne, A aggTwo); + + /** + * Map the intermediate aggregate value into the actual returned value. + * @param agg aggregate value of current record + * @return new value of current record + */ + O map(A agg); }