From b0bbea479df7e6526383f13d7c0d1645c58ed9ed Mon Sep 17 00:00:00 2001 From: Tim Fox Date: Sat, 2 Nov 2019 08:55:17 -0700 Subject: [PATCH] feat: Implement describe and list functions for UDTFs (#3716) This PR implements describe and list functions for table functions. It also does some refactoring of the existing describe and list functions code to make it simpler. --- .../java/io/confluent/ksql/cli/CliTest.java | 35 ++++++++-- .../ksql/cli/console/ConsoleTest.java | 5 +- .../function/AggregateFunctionFactory.java | 24 ++----- .../ksql/function/FunctionRegistry.java | 14 ++++ .../ksql/function/TableFunctionFactory.java | 37 ++++++----- .../confluent/ksql/function/UdfFactory.java | 24 ++----- .../function/InternalFunctionRegistry.java | 15 +++++ .../confluent/ksql/function/UdtfLoader.java | 8 +-- .../function/UdtfTableFunctionFactory.java | 52 --------------- .../ksql/function/UdfLoaderTest.java | 4 +- .../execution/util/ExpressionTypeManager.java | 4 +- .../codegen/SqlToJavaVisitorTest.java | 3 + .../util/ExpressionTypeManagerTest.java | 10 ++- .../ksql/metastore/MetaStoreImpl.java | 11 ++++ .../execution/DescribeFunctionExecutor.java | 64 +++++++++++++++---- .../execution/ListFunctionsExecutor.java | 22 +++++-- .../ksql/rest/server/TemporaryEngine.java | 50 ++++++++++++++- .../DescribeFunctionExecutorTest.java | 30 ++++++++- .../execution/ListFunctionsExecutorTest.java | 21 ++++-- .../server/resources/KsqlResourceTest.java | 14 ++-- .../ksql/rest/entity/FunctionType.java | 6 +- 21 files changed, 293 insertions(+), 160 deletions(-) delete mode 100644 ksql-engine/src/main/java/io/confluent/ksql/function/UdtfTableFunctionFactory.java diff --git a/ksql-cli/src/test/java/io/confluent/ksql/cli/CliTest.java b/ksql-cli/src/test/java/io/confluent/ksql/cli/CliTest.java index 4f6283130aa7..b33ef95e96ca 100644 --- a/ksql-cli/src/test/java/io/confluent/ksql/cli/CliTest.java +++ b/ksql-cli/src/test/java/io/confluent/ksql/cli/CliTest.java @@ -789,7 +789,7 @@ public void shouldDescribeScalarFunction() { + "Overview : Converts a BIGINT millisecond timestamp value into the string" + " representation of the \n" + " timestamp in the given format.\n" - + "Type : scalar\n" + + "Type : SCALAR\n" + "Jar : internal\n" + "Variations :"; @@ -832,7 +832,7 @@ public void shouldDescribeOverloadedScalarFunction() { + "Overview : Returns a substring of the passed in value.\n" )); assertThat(output, containsString( - "Type : scalar\n" + "Type : SCALAR\n" + "Jar : internal\n" + "Variations :" )); @@ -853,8 +853,8 @@ public void shouldDescribeOverloadedScalarFunction() { public void shouldDescribeAggregateFunction() { final String expectedSummary = "Name : TOPK\n" + - "Author : Confluent\n" + - "Type : aggregate\n" + + "Author : Confluent\n" + + "Type : AGGREGATE\n" + "Jar : internal\n" + "Variations : \n"; @@ -870,6 +870,33 @@ public void shouldDescribeAggregateFunction() { assertThat(output, containsString(expectedVariant)); } + @Test + public void shouldDescribeTableFunction() { + final String expectedOutput = + "Name : EXPLODE\n" + + "Author : Confluent\n" + + "Overview : Explodes an array. This function outputs one value for each element of the array.\n" + + "Type : TABLE\n" + + "Jar : internal\n" + + "Variations : "; + + localCli.handleLine("describe function explode;"); + final String outputString = terminal.getOutputString(); + assertThat(outputString, containsString(expectedOutput)); + + // variations for Udfs are loaded non-deterministically. Don't assume which variation is first + String expectedVariation = + "\tVariation : EXPLODE(list ARRAY)\n" + + "\tReturns : BYTES\n" + + "\tDescription : Explodes an array. This function outputs one value for each element of the array."; + assertThat(outputString, containsString(expectedVariation)); + + expectedVariation = "\tVariation : EXPLODE(input ARRAY)\n" + + "\tReturns : DECIMAL(1, 0)\n" + + "\tDescription : Explodes an array. This function outputs one value for each element of the array."; + assertThat(outputString, containsString(expectedVariation)); + } + @Test public void shouldExplainQueryId() { // Given: diff --git a/ksql-cli/src/test/java/io/confluent/ksql/cli/console/ConsoleTest.java b/ksql-cli/src/test/java/io/confluent/ksql/cli/console/ConsoleTest.java index b3a0f8c910a8..1d5d3a37e612 100644 --- a/ksql-cli/src/test/java/io/confluent/ksql/cli/console/ConsoleTest.java +++ b/ksql-cli/src/test/java/io/confluent/ksql/cli/console/ConsoleTest.java @@ -1226,7 +1226,8 @@ public void shouldPrintFunctionDescription() throws IOException { + "really, really, really, really, really, really, really, really, really, " + "really, really, really, really, really, really, really, really, long\n" + "and contains\n\ttabs and stuff" - )), FunctionType.scalar))); + )), FunctionType.SCALAR + ))); console.printKsqlEntityList(entityList); @@ -1243,7 +1244,7 @@ public void shouldPrintFunctionDescription() throws IOException { + " and containing new lines\n" + " \tAND TABS\n" + " too!\n" - + "Type : scalar\n" + + "Type : SCALAR\n" + "Jar : some.jar\n" + "Variations : \n" + "\n" diff --git a/ksql-common/src/main/java/io/confluent/ksql/function/AggregateFunctionFactory.java b/ksql-common/src/main/java/io/confluent/ksql/function/AggregateFunctionFactory.java index c6b9bffaab23..03f05dc1978a 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/function/AggregateFunctionFactory.java +++ b/ksql-common/src/main/java/io/confluent/ksql/function/AggregateFunctionFactory.java @@ -58,24 +58,12 @@ public AggregateFunctionFactory(final UdfMetadata metadata) { protected abstract List> supportedArgs(); - public String getName() { - return metadata.getName(); - } - - public String getDescription() { - return metadata.getDescription(); - } - - public String getPath() { - return metadata.getPath(); + public UdfMetadata getMetadata() { + return metadata; } - public String getAuthor() { - return metadata.getAuthor(); - } - - public String getVersion() { - return metadata.getVersion(); + public String getName() { + return metadata.getName(); } public void eachFunction(final Consumer> consumer) { @@ -83,10 +71,6 @@ public void eachFunction(final Consumer> consumer consumer.accept(createAggregateFunction(args, getDefaultArguments()))); } - public boolean isInternal() { - return metadata.isInternal(); - } - public AggregateFunctionInitArguments getDefaultArguments() { return AggregateFunctionInitArguments.EMPTY_ARGS; } diff --git a/ksql-common/src/main/java/io/confluent/ksql/function/FunctionRegistry.java b/ksql-common/src/main/java/io/confluent/ksql/function/FunctionRegistry.java index 392eeca23dc9..cfcf6ecf2e70 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/function/FunctionRegistry.java +++ b/ksql-common/src/main/java/io/confluent/ksql/function/FunctionRegistry.java @@ -52,6 +52,15 @@ public interface FunctionRegistry { */ UdfFactory getUdfFactory(String functionName); + /** + * Get the factory for a table function. + * + * @param functionName the name of the function. + * @return the factory. + * @throws KsqlException on unknown table function. + */ + TableFunctionFactory getTableFunctionFactory(String functionName); + /** * Get the factory for a UDAF. * @@ -100,6 +109,11 @@ public interface FunctionRegistry { */ List listFunctions(); + /** + * @return all table function factories. + */ + List listTableFunctions(); + /** * @return all UDAF factories. */ diff --git a/ksql-common/src/main/java/io/confluent/ksql/function/TableFunctionFactory.java b/ksql-common/src/main/java/io/confluent/ksql/function/TableFunctionFactory.java index aa74390724d9..56cba4586391 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/function/TableFunctionFactory.java +++ b/ksql-common/src/main/java/io/confluent/ksql/function/TableFunctionFactory.java @@ -18,41 +18,46 @@ import io.confluent.ksql.function.udf.UdfMetadata; import java.util.List; import java.util.Objects; +import java.util.function.Consumer; +import java.util.stream.Collectors; import org.apache.kafka.connect.data.Schema; -public abstract class TableFunctionFactory { +public class TableFunctionFactory { + + private final UdfIndex udtfIndex; private final UdfMetadata metadata; public TableFunctionFactory(final UdfMetadata metadata) { - this.metadata = Objects.requireNonNull(metadata, "metadata can't be null"); + this.metadata = Objects.requireNonNull(metadata, "metadata"); + this.udtfIndex = new UdfIndex<>(metadata.getName()); } - public abstract KsqlTableFunction createTableFunction(List argTypeList); - - protected abstract List> supportedArgs(); + public UdfMetadata getMetadata() { + return metadata; + } public String getName() { return metadata.getName(); } - public String getDescription() { - return metadata.getDescription(); + public synchronized void eachFunction(final Consumer consumer) { + udtfIndex.values().forEach(consumer); } - public String getPath() { - return metadata.getPath(); + public synchronized KsqlTableFunction createTableFunction(final List argTypeList) { + return udtfIndex.getFunction(argTypeList); } - public String getAuthor() { - return metadata.getAuthor(); + protected synchronized List> supportedArgs() { + return udtfIndex.values() + .stream() + .map(KsqlTableFunction::getArguments) + .collect(Collectors.toList()); } - public String getVersion() { - return metadata.getVersion(); + synchronized void addFunction(final KsqlTableFunction tableFunction) { + udtfIndex.addFunction(tableFunction); } - public boolean isInternal() { - return metadata.isInternal(); - } } diff --git a/ksql-common/src/main/java/io/confluent/ksql/function/UdfFactory.java b/ksql-common/src/main/java/io/confluent/ksql/function/UdfFactory.java index a19a9a716765..76fd3c408edb 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/function/UdfFactory.java +++ b/ksql-common/src/main/java/io/confluent/ksql/function/UdfFactory.java @@ -53,34 +53,18 @@ private void checkCompatible(final KsqlFunction ksqlFunction) { } } - public String getName() { - return metadata.getName(); - } - - public String getAuthor() { - return metadata.getAuthor(); - } - - public String getVersion() { - return metadata.getVersion(); + public UdfMetadata getMetadata() { + return metadata; } - public String getDescription() { - return metadata.getDescription(); + public String getName() { + return metadata.getName(); } public synchronized void eachFunction(final Consumer consumer) { udfIndex.values().forEach(consumer); } - public boolean isInternal() { - return metadata.isInternal(); - } - - public String getPath() { - return metadata.getPath(); - } - public boolean matches(final UdfFactory that) { return this == that || (this.udfClass.equals(that.udfClass) && this.metadata.equals(that.metadata)); diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/InternalFunctionRegistry.java b/ksql-engine/src/main/java/io/confluent/ksql/function/InternalFunctionRegistry.java index a84ca1707050..ebf97d9a6be1 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/InternalFunctionRegistry.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/InternalFunctionRegistry.java @@ -197,11 +197,26 @@ public synchronized AggregateFunctionFactory getAggregateFactory(final String fu return udafFactory; } + @Override + public synchronized TableFunctionFactory getTableFunctionFactory(final String functionName) { + final TableFunctionFactory tableFunctionFactory = udtfs.get(functionName.toUpperCase()); + if (tableFunctionFactory == null) { + throw new KsqlException( + "Can not find any table functions with the name '" + functionName + "'"); + } + return tableFunctionFactory; + } + @Override public synchronized List listAggregateFunctions() { return new ArrayList<>(udafs.values()); } + @Override + public synchronized List listTableFunctions() { + return new ArrayList<>(udtfs.values()); + } + private void validateFunctionName(final String functionName) { if (!functionNameValidator.test(functionName)) { throw new KsqlException(functionName + " is not a valid function name." diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/UdtfLoader.java b/ksql-engine/src/main/java/io/confluent/ksql/function/UdtfLoader.java index ff0cf48cbfe6..3b7e4dfc66ef 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/UdtfLoader.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/UdtfLoader.java @@ -36,7 +36,7 @@ /** * Loads user defined table functions (UDTFs) */ -class UdtfLoader { +public class UdtfLoader { private static final Logger LOGGER = LoggerFactory.getLogger(UdtfLoader.class); @@ -45,7 +45,7 @@ class UdtfLoader { private final SqlTypeParser typeParser; private final boolean throwExceptionOnLoadFailure; - UdtfLoader( + public UdtfLoader( final MutableFunctionRegistry functionRegistry, final Optional metrics, final SqlTypeParser typeParser, @@ -57,7 +57,7 @@ class UdtfLoader { this.throwExceptionOnLoadFailure = throwExceptionOnLoadFailure; } - void loadUdtfFromClass( + public void loadUdtfFromClass( final Class theClass, final String path ) { @@ -79,7 +79,7 @@ void loadUdtfFromClass( false ); - final UdtfTableFunctionFactory udtfFactory = new UdtfTableFunctionFactory(metadata); + final TableFunctionFactory udtfFactory = new TableFunctionFactory(metadata); Arrays.stream(theClass.getMethods()) .filter(method -> method.getAnnotation(Udtf.class) != null) diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/UdtfTableFunctionFactory.java b/ksql-engine/src/main/java/io/confluent/ksql/function/UdtfTableFunctionFactory.java deleted file mode 100644 index d7876a8e02cd..000000000000 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/UdtfTableFunctionFactory.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright 2019 Confluent Inc. - * - * Licensed under the Confluent Community License (the "License"); you may not use - * this file except in compliance with the License. You may obtain a copy of the - * License at - * - * http://www.confluent.io/confluent-community-license - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package io.confluent.ksql.function; - -import io.confluent.ksql.function.udf.UdfMetadata; -import java.util.List; -import java.util.stream.Collectors; -import org.apache.kafka.connect.data.Schema; - -/** - * A table function factory used for creating user defined table functions. - */ -public class UdtfTableFunctionFactory extends TableFunctionFactory { - - private final UdfIndex udtfIndex; - - public UdtfTableFunctionFactory(final UdfMetadata metadata) { - super(metadata); - this.udtfIndex = new UdfIndex<>(metadata.getName()); - } - - @Override - public KsqlTableFunction createTableFunction(final List argTypeList) { - return udtfIndex.getFunction(argTypeList); - } - - @Override - protected List> supportedArgs() { - return udtfIndex.values() - .stream() - .map(KsqlTableFunction::getArguments) - .collect(Collectors.toList()); - } - - void addFunction(final KsqlTableFunction tableFunction) { - udtfIndex.addFunction(tableFunction); - } - -} diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java index a388613d0b9d..4683b6edc5af 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java @@ -403,13 +403,13 @@ public void shouldAllowClassesWithSameFQCNInDifferentUDFJars() throws Exception @Test public void shouldCreateUdfFactoryWithJarPathWhenExternal() { final UdfFactory tostring = FUNC_REG.getUdfFactory("tostring"); - assertThat(tostring.getPath(), equalTo("src/test/resources/udf-example.jar")); + assertThat(tostring.getMetadata().getPath(), equalTo("src/test/resources/udf-example.jar")); } @Test public void shouldCreateUdfFactoryWithInternalPathWhenInternal() { final UdfFactory substring = FUNC_REG.getUdfFactory("substring"); - assertThat(substring.getPath(), equalTo(KsqlFunction.INTERNAL_PATH)); + assertThat(substring.getMetadata().getPath(), equalTo(KsqlFunction.INTERNAL_PATH)); } @Test diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java index 6d4fdcdce94f..07ee31929203 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java @@ -54,6 +54,7 @@ import io.confluent.ksql.function.KsqlFunctionException; import io.confluent.ksql.function.KsqlTableFunction; import io.confluent.ksql.function.UdfFactory; +import io.confluent.ksql.function.udf.UdfMetadata; import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.SchemaConverters; @@ -407,7 +408,8 @@ public Void visitFunctionCall( } final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName().name()); - if (udfFactory.isInternal()) { + final UdfMetadata metadata = udfFactory.getMetadata(); + if (metadata.isInternal()) { // Internal UDFs, e.g. FetchFieldFromStruct, should not be used directly by users: throw new KsqlException( "Can't find any functions with the name '" + node.getName().name() + "'"); diff --git a/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java b/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java index 6e20a1b7e886..8ed04a6d0dbe 100644 --- a/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java +++ b/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java @@ -56,6 +56,7 @@ import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.function.KsqlFunction; import io.confluent.ksql.function.UdfFactory; +import io.confluent.ksql.function.udf.UdfMetadata; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.name.SourceName; @@ -774,5 +775,7 @@ private void givenUdf( when(functionRegistry.getUdfFactory(name)).thenReturn(factory); when(factory.getFunction(anyList())).thenReturn(function); when(function.getReturnType(anyList())).thenReturn(returnType); + UdfMetadata metadata = mock(UdfMetadata.class); + when(factory.getMetadata()).thenReturn(metadata); } } diff --git a/ksql-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java b/ksql-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java index 0fc8a1f84ea2..a05490f96a33 100644 --- a/ksql-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java +++ b/ksql-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java @@ -27,6 +27,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -57,6 +58,7 @@ import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.function.KsqlFunction; import io.confluent.ksql.function.UdfFactory; +import io.confluent.ksql.function.udf.UdfMetadata; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.name.SourceName; @@ -103,9 +105,11 @@ public void init() { expressionTypeManager = new ExpressionTypeManager(SCHEMA, functionRegistry); final UdfFactory internalFactory = mock(UdfFactory.class); - when(internalFactory.isInternal()).thenReturn(true); + final UdfMetadata metadata = mock(UdfMetadata.class); + when(internalFactory.getMetadata()).thenReturn(metadata); + when(metadata.isInternal()).thenReturn(true); - when(functionRegistry.getUdfFactory(FetchFieldFromStruct.FUNCTION_NAME.name())) + when(functionRegistry.getUdfFactory(anyString())) .thenReturn(internalFactory); } @@ -559,5 +563,7 @@ private void givenUdfWithNameAndReturnType( when(functionRegistry.getUdfFactory(name)).thenReturn(factory); when(factory.getFunction(anyList())).thenReturn(function); when(function.getReturnType(anyList())).thenReturn(returnType); + UdfMetadata metadata = mock(UdfMetadata.class); + when(factory.getMetadata()).thenReturn(metadata); } } diff --git a/ksql-metastore/src/main/java/io/confluent/ksql/metastore/MetaStoreImpl.java b/ksql-metastore/src/main/java/io/confluent/ksql/metastore/MetaStoreImpl.java index 54d9a60ac369..4a97ba03590f 100644 --- a/ksql-metastore/src/main/java/io/confluent/ksql/metastore/MetaStoreImpl.java +++ b/ksql-metastore/src/main/java/io/confluent/ksql/metastore/MetaStoreImpl.java @@ -20,6 +20,7 @@ import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.function.KsqlAggregateFunction; import io.confluent.ksql.function.KsqlTableFunction; +import io.confluent.ksql.function.TableFunctionFactory; import io.confluent.ksql.function.UdfFactory; import io.confluent.ksql.metastore.model.DataSource; import io.confluent.ksql.name.SourceName; @@ -240,11 +241,21 @@ public AggregateFunctionFactory getAggregateFactory(final String functionName) { return functionRegistry.getAggregateFactory(functionName); } + @Override + public TableFunctionFactory getTableFunctionFactory(final String functionName) { + return functionRegistry.getTableFunctionFactory(functionName); + } + @Override public List listAggregateFunctions() { return functionRegistry.listAggregateFunctions(); } + @Override + public List listTableFunctions() { + return functionRegistry.listTableFunctions(); + } + private Stream streamSources(final Set sourceNames) { return sourceNames.stream() .map(sourceName -> { diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/DescribeFunctionExecutor.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/DescribeFunctionExecutor.java index 60e78379e2a1..ee6d5c4ab681 100644 --- a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/DescribeFunctionExecutor.java +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/DescribeFunctionExecutor.java @@ -18,7 +18,9 @@ import com.google.common.collect.ImmutableList; import io.confluent.ksql.KsqlExecutionContext; import io.confluent.ksql.function.AggregateFunctionFactory; +import io.confluent.ksql.function.TableFunctionFactory; import io.confluent.ksql.function.UdfFactory; +import io.confluent.ksql.function.udf.UdfMetadata; import io.confluent.ksql.parser.tree.DescribeFunction; import io.confluent.ksql.rest.entity.ArgumentInfo; import io.confluent.ksql.rest.entity.FunctionDescriptionList; @@ -56,6 +58,11 @@ public static Optional execute( describeAggregateFunction(executionContext, functionName, statement.getStatementText())); } + if (executionContext.getMetaStore().isTableFunction(functionName)) { + return Optional.of( + describeTableFunction(executionContext, functionName, statement.getStatementText())); + } + return Optional.of( describeNonAggregateFunction(executionContext, functionName, statement.getStatementText())); } @@ -73,15 +80,31 @@ private static FunctionDescriptionList describeAggregateFunction( aggregateFactory.eachFunction(func -> listBuilder.add( getFunctionInfo(func.getArguments(), func.getReturnType(), func.getDescription(), false))); - return new FunctionDescriptionList( + return createFunctionDescriptionList( + statementText, aggregateFactory.getMetadata(), listBuilder.build(), FunctionType.AGGREGATE); + } + + private static FunctionDescriptionList describeTableFunction( + final KsqlExecutionContext executionContext, + final String functionName, + final String statementText + ) { + final TableFunctionFactory tableFunctionFactory = executionContext.getMetaStore() + .getTableFunctionFactory(functionName); + + final ImmutableList.Builder listBuilder = ImmutableList.builder(); + + tableFunctionFactory.eachFunction(func -> listBuilder.add( + getFunctionInfo( + func.getArguments(), + func.getReturnType(func.getArguments()), func.getDescription(), func.isVariadic() + ))); + + return createFunctionDescriptionList( statementText, - aggregateFactory.getName().toUpperCase(), - aggregateFactory.getDescription(), - aggregateFactory.getAuthor(), - aggregateFactory.getVersion(), - aggregateFactory.getPath(), + tableFunctionFactory.getMetadata(), listBuilder.build(), - FunctionType.aggregate + FunctionType.TABLE ); } @@ -99,15 +122,11 @@ private static FunctionDescriptionList describeNonAggregateFunction( func.getArguments(), func.getReturnType(func.getArguments()), func.getDescription(), func.isVariadic()))); - return new FunctionDescriptionList( + return createFunctionDescriptionList( statementText, - udfFactory.getName().toUpperCase(), - udfFactory.getDescription(), - udfFactory.getAuthor(), - udfFactory.getVersion(), - udfFactory.getPath(), + udfFactory.getMetadata(), listBuilder.build(), - FunctionType.scalar + FunctionType.SCALAR ); } @@ -129,4 +148,21 @@ private static FunctionInfo getFunctionInfo( return new FunctionInfo(args, returnType, description); } + private static FunctionDescriptionList createFunctionDescriptionList( + final String statementText, + final UdfMetadata metadata, final List functionInfos, + final FunctionType functionType + ) { + return new FunctionDescriptionList( + statementText, + metadata.getName().toUpperCase(), + metadata.getDescription(), + metadata.getAuthor(), + metadata.getVersion(), + metadata.getPath(), + functionInfos, + functionType + ); + } + } diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/ListFunctionsExecutor.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/ListFunctionsExecutor.java index a1e9bfbb22e6..ecf6e05e7744 100644 --- a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/ListFunctionsExecutor.java +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/ListFunctionsExecutor.java @@ -42,18 +42,28 @@ public static Optional execute( final FunctionRegistry functionRegistry = executionContext.getMetaStore(); final List all = functionRegistry.listFunctions().stream() - .filter(factory -> !factory.isInternal()) + .filter(factory -> !factory.getMetadata().isInternal()) .map(factory -> new SimpleFunctionInfo( factory.getName().toUpperCase(), - FunctionType.scalar)) + FunctionType.SCALAR + )) .collect(Collectors.toList()); - all.addAll(functionRegistry.listAggregateFunctions().stream() - .filter(factory -> !factory.isInternal()) + functionRegistry.listTableFunctions().stream() + .filter(factory -> !factory.getMetadata().isInternal()) .map(factory -> new SimpleFunctionInfo( factory.getName().toUpperCase(), - FunctionType.aggregate)) - .collect(Collectors.toList())); + FunctionType.TABLE + )) + .forEach(all::add); + + functionRegistry.listAggregateFunctions().stream() + .filter(factory -> !factory.getMetadata().isInternal()) + .map(factory -> new SimpleFunctionInfo( + factory.getName().toUpperCase(), + FunctionType.AGGREGATE + )) + .forEach(all::add); return Optional.of(new FunctionNameList(statement.getStatementText(), all)); } diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/TemporaryEngine.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/TemporaryEngine.java index b177b34ffdba..488e92bd7a92 100644 --- a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/TemporaryEngine.java +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/TemporaryEngine.java @@ -15,14 +15,20 @@ package io.confluent.ksql.rest.server; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.confluent.ksql.KsqlConfigTestUtil; import io.confluent.ksql.engine.KsqlEngine; import io.confluent.ksql.engine.KsqlEngineTestUtil; import io.confluent.ksql.execution.ddl.commands.KsqlTopic; import io.confluent.ksql.function.InternalFunctionRegistry; +import io.confluent.ksql.function.UdtfLoader; +import io.confluent.ksql.function.udf.UdfParameter; +import io.confluent.ksql.function.udtf.Udtf; +import io.confluent.ksql.function.udtf.UdtfDescription; import io.confluent.ksql.metastore.MetaStoreImpl; import io.confluent.ksql.metastore.MutableMetaStore; +import io.confluent.ksql.metastore.TypeRegistry; import io.confluent.ksql.metastore.model.DataSource; import io.confluent.ksql.metastore.model.DataSource.DataSourceType; import io.confluent.ksql.metastore.model.KeyField; @@ -33,6 +39,7 @@ import io.confluent.ksql.parser.DefaultKsqlParser; import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.SqlTypeParser; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.serde.Format; import io.confluent.ksql.serde.FormatInfo; @@ -48,6 +55,8 @@ import io.confluent.rest.RestConfig; import java.util.Collections; import java.util.HashMap; +import java.util.List; +import java.util.Optional; import org.junit.rules.ExternalResource; @SuppressWarnings("OptionalGetWithoutIsPresent") @@ -63,6 +72,7 @@ public class TemporaryEngine extends ExternalResource { .build(); private MutableMetaStore metaStore; + private InternalFunctionRegistry functionRegistry; private KsqlConfig ksqlConfig; private KsqlEngine engine; @@ -70,7 +80,9 @@ public class TemporaryEngine extends ExternalResource { @Override protected void before() { - metaStore = new MetaStoreImpl(new InternalFunctionRegistry()); + functionRegistry = new InternalFunctionRegistry(); + metaStore = new MetaStoreImpl(functionRegistry); + serviceContext = TestServiceContext.create(); engine = (KsqlEngineTestUtil.createKsqlEngine(getServiceContext(), metaStore)); @@ -81,6 +93,13 @@ protected void before() { RestConfig.LISTENERS_CONFIG, "http://localhost:8088" ) ); + + final SqlTypeParser typeParser = SqlTypeParser.create(TypeRegistry.EMPTY); + UdtfLoader udtfLoader = new UdtfLoader(functionRegistry, Optional.empty(), + typeParser, true + ); + udtfLoader.loadUdtfFromClass(TestUdtf1.class, "whatever"); + udtfLoader.loadUdtfFromClass(TestUdtf2.class, "whatever"); } @Override @@ -164,4 +183,33 @@ public KsqlEngine getEngine() { public ServiceContext getServiceContext() { return serviceContext; } + + @UdtfDescription(name = "test_udtf1", description = "test_udtf1 description") + public static class TestUdtf1 { + + @Udtf + public List foo1(@UdfParameter(value = "foo") int foo) { + return ImmutableList.of(1); + } + + @Udtf + public List foo2(@UdfParameter(value = "foo") double foo) { + return ImmutableList.of(1.0d); + } + } + + @UdtfDescription(name = "test_udtf2", description = "test_udtf2 description") + public static class TestUdtf2 { + + @Udtf + public List foo1(@UdfParameter(value = "foo") int foo) { + return ImmutableList.of(1); + } + + @Udtf + public List foo2(@UdfParameter(value = "foo") double foo) { + return ImmutableList.of(1.0d); + } + } + } diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/execution/DescribeFunctionExecutorTest.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/execution/DescribeFunctionExecutorTest.java index fa949f7383f3..212938fa3b08 100644 --- a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/execution/DescribeFunctionExecutorTest.java +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/execution/DescribeFunctionExecutorTest.java @@ -49,7 +49,7 @@ public void shouldDescribeUDF() { @Override protected boolean matchesSafely(FunctionDescriptionList item) { return functionList.getName().equals("CONCAT") - && functionList.getType().equals(FunctionType.scalar); + && functionList.getType().equals(FunctionType.SCALAR); } @Override @@ -75,7 +75,33 @@ public void shouldDescribeUDAF() { @Override protected boolean matchesSafely(FunctionDescriptionList item) { return functionList.getName().equals("MAX") - && functionList.getType().equals(FunctionType.aggregate); + && functionList.getType().equals(FunctionType.AGGREGATE); + } + + @Override + public void describeTo(Description description) { + description.appendText(functionList.getName()); + } + }); + } + + @Test + public void shouldDescribeUDTF() { + // When: + final FunctionDescriptionList functionList = (FunctionDescriptionList) + CustomExecutors.DESCRIBE_FUNCTION.execute( + engine.configure("DESCRIBE FUNCTION TEST_UDTF1;"), + ImmutableMap.of(), + engine.getEngine(), + engine.getServiceContext() + ).orElseThrow(IllegalStateException::new); + + // Then: + assertThat(functionList, new TypeSafeMatcher() { + @Override + protected boolean matchesSafely(FunctionDescriptionList item) { + return functionList.getName().equals("TEST_UDTF1") + && functionList.getType().equals(FunctionType.TABLE); } @Override diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/execution/ListFunctionsExecutorTest.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/execution/ListFunctionsExecutorTest.java index 1a182c0bb552..5b44c87d2b43 100644 --- a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/execution/ListFunctionsExecutorTest.java +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/execution/ListFunctionsExecutorTest.java @@ -25,6 +25,7 @@ import io.confluent.ksql.rest.entity.FunctionType; import io.confluent.ksql.rest.entity.SimpleFunctionInfo; import io.confluent.ksql.rest.server.TemporaryEngine; +import java.util.Collection; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -37,6 +38,7 @@ public class ListFunctionsExecutorTest { @Test public void shouldListFunctions() { + // When: final FunctionNameList functionList = (FunctionNameList) CustomExecutors.LIST_FUNCTIONS.execute( engine.configure("LIST FUNCTIONS;"), @@ -46,15 +48,20 @@ public void shouldListFunctions() { ).orElseThrow(IllegalStateException::new); // Then: - assertThat(functionList.getFunctions(), hasItems( - new SimpleFunctionInfo("EXTRACTJSONFIELD", FunctionType.scalar), - new SimpleFunctionInfo("ARRAYCONTAINS", FunctionType.scalar), - new SimpleFunctionInfo("CONCAT", FunctionType.scalar), - new SimpleFunctionInfo("TOPK", FunctionType.aggregate), - new SimpleFunctionInfo("MAX", FunctionType.aggregate))); + Collection functions = functionList.getFunctions(); + assertThat(functions, hasItems( + new SimpleFunctionInfo("EXTRACTJSONFIELD", FunctionType.SCALAR), + new SimpleFunctionInfo("ARRAYCONTAINS", FunctionType.SCALAR), + new SimpleFunctionInfo("CONCAT", FunctionType.SCALAR), + new SimpleFunctionInfo("TOPK", FunctionType.AGGREGATE), + new SimpleFunctionInfo("MAX", FunctionType.AGGREGATE), + new SimpleFunctionInfo("TEST_UDTF1", FunctionType.TABLE), + new SimpleFunctionInfo("TEST_UDTF2", FunctionType.TABLE) + )); assertThat("shouldn't contain internal functions", functionList.getFunctions(), - not(hasItem(new SimpleFunctionInfo("FETCH_FIELD_FROM_STRUCT", FunctionType.scalar)))); + not(hasItem(new SimpleFunctionInfo("FETCH_FIELD_FROM_STRUCT", FunctionType.SCALAR))) + ); } diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/KsqlResourceTest.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/KsqlResourceTest.java index cd6a41c3d0f1..3f07659bc895 100644 --- a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/KsqlResourceTest.java +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/KsqlResourceTest.java @@ -415,14 +415,16 @@ public void shouldListFunctions() { // Then: assertThat(functionList.getFunctions(), hasItems( - new SimpleFunctionInfo("EXTRACTJSONFIELD", FunctionType.scalar), - new SimpleFunctionInfo("ARRAYCONTAINS", FunctionType.scalar), - new SimpleFunctionInfo("CONCAT", FunctionType.scalar), - new SimpleFunctionInfo("TOPK", FunctionType.aggregate), - new SimpleFunctionInfo("MAX", FunctionType.aggregate))); + new SimpleFunctionInfo("EXTRACTJSONFIELD", FunctionType.SCALAR), + new SimpleFunctionInfo("ARRAYCONTAINS", FunctionType.SCALAR), + new SimpleFunctionInfo("CONCAT", FunctionType.SCALAR), + new SimpleFunctionInfo("TOPK", FunctionType.AGGREGATE), + new SimpleFunctionInfo("MAX", FunctionType.AGGREGATE) + )); assertThat("shouldn't contain internal functions", functionList.getFunctions(), - not(hasItem(new SimpleFunctionInfo("FETCH_FIELD_FROM_STRUCT", FunctionType.scalar)))); + not(hasItem(new SimpleFunctionInfo("FETCH_FIELD_FROM_STRUCT", FunctionType.SCALAR))) + ); } @Test diff --git a/ksql-rest-model/src/main/java/io/confluent/ksql/rest/entity/FunctionType.java b/ksql-rest-model/src/main/java/io/confluent/ksql/rest/entity/FunctionType.java index 8edff36e8f4f..7ed0d68d0039 100644 --- a/ksql-rest-model/src/main/java/io/confluent/ksql/rest/entity/FunctionType.java +++ b/ksql-rest-model/src/main/java/io/confluent/ksql/rest/entity/FunctionType.java @@ -15,4 +15,8 @@ package io.confluent.ksql.rest.entity; -public enum FunctionType { scalar, aggregate } +public enum FunctionType { + SCALAR, + AGGREGATE, + TABLE +}