diff --git a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java index 2dcd258626c6..1f9f29f84593 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java @@ -27,6 +27,7 @@ import io.confluent.ksql.execution.expression.tree.ComparisonExpression; import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.plan.SelectExpression; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.metastore.MetaStore; import io.confluent.ksql.metastore.model.DataSource; import io.confluent.ksql.name.ColumnName; @@ -40,7 +41,6 @@ import io.confluent.ksql.parser.tree.GroupingElement; import io.confluent.ksql.parser.tree.Join; import io.confluent.ksql.parser.tree.JoinOn; -import io.confluent.ksql.parser.tree.KsqlWindowExpression; import io.confluent.ksql.parser.tree.Query; import io.confluent.ksql.parser.tree.Select; import io.confluent.ksql.parser.tree.SelectItem; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/engine/rewrite/StatementRewriter.java b/ksql-engine/src/main/java/io/confluent/ksql/engine/rewrite/StatementRewriter.java index e1e8065dbaac..8a139b85c41e 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/engine/rewrite/StatementRewriter.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/engine/rewrite/StatementRewriter.java @@ -31,7 +31,6 @@ import io.confluent.ksql.parser.tree.GroupingElement; import io.confluent.ksql.parser.tree.InsertInto; import io.confluent.ksql.parser.tree.Join; -import io.confluent.ksql.parser.tree.KsqlWindowExpression; import io.confluent.ksql.parser.tree.Query; import io.confluent.ksql.parser.tree.RegisterType; import io.confluent.ksql.parser.tree.Relation; @@ -236,14 +235,8 @@ protected AstNode visitWindowExpression(final WindowExpression node, final C con return new WindowExpression( node.getLocation(), node.getWindowName(), - (KsqlWindowExpression) rewriter.apply(node.getKsqlWindowExpression(), context)); - } - - @Override - protected AstNode visitKsqlWindowExpression( - final KsqlWindowExpression node, - final C context) { - return node; + node.getKsqlWindowExpression() + ); } @Override diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/GeneratedTableAggregateFunction.java b/ksql-engine/src/main/java/io/confluent/ksql/function/GeneratedTableAggregateFunction.java index 95f2f20e15c4..0a9b38e3ae0a 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/GeneratedTableAggregateFunction.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/GeneratedTableAggregateFunction.java @@ -15,6 +15,7 @@ package io.confluent.ksql.function; +import io.confluent.ksql.execution.function.TableAggregationFunction; import io.confluent.ksql.function.udaf.TableUdaf; import java.util.List; import java.util.Optional; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/count/CountKudaf.java b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/count/CountKudaf.java index b93089c5e746..3d1dd8a8713c 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/count/CountKudaf.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/count/CountKudaf.java @@ -15,10 +15,10 @@ package io.confluent.ksql.function.udaf.count; +import io.confluent.ksql.execution.function.TableAggregationFunction; import io.confluent.ksql.function.AggregateFunctionArguments; import io.confluent.ksql.function.BaseAggregateFunction; import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.TableAggregationFunction; import java.util.Collections; import java.util.List; import java.util.function.Function; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/DecimalSumKudaf.java b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/DecimalSumKudaf.java index 2e0768ee4d75..84f37caa6777 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/DecimalSumKudaf.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/DecimalSumKudaf.java @@ -15,10 +15,10 @@ package io.confluent.ksql.function.udaf.sum; +import io.confluent.ksql.execution.function.TableAggregationFunction; import io.confluent.ksql.function.AggregateFunctionArguments; import io.confluent.ksql.function.BaseAggregateFunction; import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.TableAggregationFunction; import io.confluent.ksql.util.DecimalUtil; import java.math.BigDecimal; import java.math.MathContext; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/DoubleSumKudaf.java b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/DoubleSumKudaf.java index 5b806ceb3dd3..261bb774b88a 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/DoubleSumKudaf.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/DoubleSumKudaf.java @@ -15,10 +15,10 @@ package io.confluent.ksql.function.udaf.sum; +import io.confluent.ksql.execution.function.TableAggregationFunction; import io.confluent.ksql.function.AggregateFunctionArguments; import io.confluent.ksql.function.BaseAggregateFunction; import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.TableAggregationFunction; import java.util.Collections; import java.util.function.Function; import org.apache.kafka.connect.data.Schema; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/IntegerSumKudaf.java b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/IntegerSumKudaf.java index d0d37d080c9e..82c972d25f8f 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/IntegerSumKudaf.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/IntegerSumKudaf.java @@ -15,10 +15,10 @@ package io.confluent.ksql.function.udaf.sum; +import io.confluent.ksql.execution.function.TableAggregationFunction; import io.confluent.ksql.function.AggregateFunctionArguments; import io.confluent.ksql.function.BaseAggregateFunction; import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.TableAggregationFunction; import java.util.Collections; import java.util.function.Function; import org.apache.kafka.connect.data.Schema; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/LongSumKudaf.java b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/LongSumKudaf.java index a1627a627c63..1a263c2dd67d 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/LongSumKudaf.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/sum/LongSumKudaf.java @@ -15,10 +15,10 @@ package io.confluent.ksql.function.udaf.sum; +import io.confluent.ksql.execution.function.TableAggregationFunction; import io.confluent.ksql.function.AggregateFunctionArguments; import io.confluent.ksql.function.BaseAggregateFunction; import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.TableAggregationFunction; import java.util.Collections; import java.util.function.Function; import org.apache.kafka.connect.data.Schema; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowEndKudaf.java b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowEndKudaf.java index b76eccb29226..482abe7d8d8f 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowEndKudaf.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowEndKudaf.java @@ -15,6 +15,7 @@ package io.confluent.ksql.function.udaf.window; +import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper; import io.confluent.ksql.function.udaf.TableUdaf; import io.confluent.ksql.function.udaf.UdafDescription; import io.confluent.ksql.function.udaf.UdafFactory; @@ -37,7 +38,7 @@ private WindowEndKudaf() { } static String getFunctionName() { - return "WindowEnd"; + return WindowSelectMapper.WINDOW_END_NAME; } @UdafFactory(description = "Extracts the window end time") diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowStartKudaf.java b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowStartKudaf.java index 278cb7dfffac..c8765548373b 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowStartKudaf.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowStartKudaf.java @@ -15,6 +15,7 @@ package io.confluent.ksql.function.udaf.window; +import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper; import io.confluent.ksql.function.udaf.TableUdaf; import io.confluent.ksql.function.udaf.UdafDescription; import io.confluent.ksql.function.udaf.UdafFactory; @@ -37,7 +38,7 @@ private WindowStartKudaf() { } static String getFunctionName() { - return "WindowStart"; + return WindowSelectMapper.WINDOW_START_NAME; } @UdafFactory(description = "Extracts the window start time") diff --git a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java index 8a21d14ecb94..5e653eb0709d 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java @@ -19,7 +19,6 @@ import static java.util.Objects.requireNonNull; import com.google.common.collect.ImmutableList; -import io.confluent.ksql.GenericRow; import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter; import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; @@ -29,12 +28,10 @@ import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.expression.tree.Literal; import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor; +import io.confluent.ksql.execution.function.UdafUtil; import io.confluent.ksql.execution.plan.SelectExpression; -import io.confluent.ksql.execution.util.ExpressionTypeManager; -import io.confluent.ksql.function.AggregateFunctionArguments; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.udaf.KudafInitializer; import io.confluent.ksql.materialization.MaterializationInfo; import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.name.ColumnName; @@ -42,11 +39,9 @@ import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.schema.ksql.LogicalSchema; -import io.confluent.ksql.schema.ksql.PhysicalSchema; import io.confluent.ksql.schema.ksql.SchemaConverters; import io.confluent.ksql.schema.ksql.SchemaConverters.ConnectToSqlTypeConverter; import io.confluent.ksql.schema.ksql.types.SqlType; -import io.confluent.ksql.serde.SerdeOption; import io.confluent.ksql.serde.ValueFormat; import io.confluent.ksql.services.KafkaTopicClient; import io.confluent.ksql.structured.SchemaKGroupedStream; @@ -66,8 +61,6 @@ import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; -import org.apache.kafka.common.serialization.Serde; -import org.apache.kafka.connect.data.Schema; public class AggregateNode extends PlanNode { @@ -234,54 +227,43 @@ public SchemaKStream buildStream(final KsqlQueryBuilder builder) { builder ); - // Aggregate computations - final KudafInitializer initializer = new KudafInitializer(requiredColumns.size()); - - final Map aggValToFunctionMap = createAggValToFunctionMap( - aggregateArgExpanded, - initializer, - requiredColumns.size(), - builder.getFunctionRegistry(), - internalSchema - ); + final List functionsWithInternalIdentifiers = functionList.stream() + .map( + fc -> new FunctionCall( + fc.getName(), + internalSchema.getInternalArgsExpressionList(fc.getArguments()) + ) + ) + .collect(Collectors.toList()); // This is the schema of the aggregation change log topic and associated state store. // It contains all columns from prepareSchema and columns for any aggregating functions // It uses internal column names, e.g. KSQL_INTERNAL_COL_0 and KSQL_AGG_VARIABLE_0 final LogicalSchema aggregationSchema = buildLogicalSchema( prepareSchema, - aggValToFunctionMap, + functionsWithInternalIdentifiers, + builder.getFunctionRegistry(), true ); final QueryContext.Stacker aggregationContext = contextStacker.push(AGGREGATION_OP_NAME); - final Serde aggValueGenericRowSerde = builder.buildValueSerde( - valueFormat.getFormatInfo(), - PhysicalSchema.from(aggregationSchema, SerdeOption.none()), - aggregationContext.getQueryContext() - ); - - final List functionsWithInternalIdentifiers = functionList.stream() - .map(internalSchema::resolveToInternal) - .map(FunctionCall.class::cast) - .collect(Collectors.toList()); - final LogicalSchema outputSchema = buildLogicalSchema( prepareSchema, - aggValToFunctionMap, - false); + functionsWithInternalIdentifiers, + builder.getFunctionRegistry(), + false + ); SchemaKTable aggregated = schemaKGroupedStream.aggregate( + aggregationSchema, outputSchema, - initializer, requiredColumns.size(), functionsWithInternalIdentifiers, - aggValToFunctionMap, windowExpression, valueFormat, - aggValueGenericRowSerde, - aggregationContext + aggregationContext, + builder ); final Optional havingExpression = Optional.ofNullable(havingExpressions) @@ -316,61 +298,12 @@ protected int getPartitions(final KafkaTopicClient kafkaTopicClient) { return source.getPartitions(kafkaTopicClient); } - private Map createAggValToFunctionMap( - final SchemaKStream aggregateArgExpanded, - final KudafInitializer initializer, - final int initialUdafIndex, - final FunctionRegistry functionRegistry, - final InternalSchema internalSchema - ) { - int udafIndexInAggSchema = initialUdafIndex; - final Map aggValToAggFunctionMap = new HashMap<>(); - for (final FunctionCall functionCall : functionList) { - final KsqlAggregateFunction aggregateFunction = getAggregateFunction( - functionRegistry, - internalSchema, - functionCall, aggregateArgExpanded.getSchema()); - - aggValToAggFunctionMap.put(udafIndexInAggSchema++, aggregateFunction); - initializer.addAggregateIntializer(aggregateFunction.getInitialValueSupplier()); - } - return aggValToAggFunctionMap; - } - - @SuppressWarnings("deprecation") // Need to migrate away from Connect Schema use. - private static KsqlAggregateFunction getAggregateFunction( - final FunctionRegistry functionRegistry, - final InternalSchema internalSchema, - final FunctionCall functionCall, - final LogicalSchema schema - ) { - try { - final ExpressionTypeManager expressionTypeManager = - new ExpressionTypeManager(schema, functionRegistry); - final List functionArgs = internalSchema.getInternalArgsExpressionList( - functionCall.getArguments()); - final Schema expressionType = expressionTypeManager.getExpressionSchema(functionArgs.get(0)); - final KsqlAggregateFunction aggregateFunctionInfo = functionRegistry - .getAggregate(functionCall.getName().name(), expressionType); - - final List args = functionArgs.stream() - .map(Expression::toString) - .collect(Collectors.toList()); - - final int udafIndex = Integer - .parseInt(args.get(0).substring(INTERNAL_COLUMN_NAME_PREFIX.length())); - - return aggregateFunctionInfo.getInstance(new AggregateFunctionArguments(udafIndex, args)); - } catch (final Exception e) { - throw new KsqlException("Failed to create aggregate function: " + functionCall, e); - } - } - private LogicalSchema buildLogicalSchema( final LogicalSchema inputSchema, - final Map aggregateFunctions, - final boolean useAggregate) { - + final List aggregations, + final FunctionRegistry functionRegistry, + final boolean useAggregate + ) { final LogicalSchema.Builder schemaBuilder = LogicalSchema.builder(); final List cols = inputSchema.value(); @@ -382,18 +315,13 @@ private LogicalSchema buildLogicalSchema( final ConnectToSqlTypeConverter converter = SchemaConverters.connectToSqlConverter(); - for (int idx = 0; idx < aggregateFunctions.size(); idx++) { - - final KsqlAggregateFunction aggregateFunction = aggregateFunctions - .get(requiredColumns.size() + idx); - - final ColumnName colName = ColumnName.aggregate(idx); - SqlType fieldType = null; - if (useAggregate) { - fieldType = converter.toSqlType(aggregateFunction.getAggregateType()); - } else { - fieldType = converter.toSqlType(aggregateFunction.getReturnType()); - } + for (int i = 0; i < aggregations.size(); i++) { + final KsqlAggregateFunction aggregateFunction = + UdafUtil.resolveAggregateFunction(functionRegistry, aggregations.get(i), inputSchema); + final ColumnName colName = ColumnName.aggregate(i); + final SqlType fieldType = converter.toSqlType( + useAggregate ? aggregateFunction.getAggregateType() : aggregateFunction.getReturnType() + ); schemaBuilder.valueColumn(colName, fieldType); } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedStream.java b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedStream.java index 6f1593186e73..b84252d60250 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedStream.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedStream.java @@ -16,21 +16,19 @@ package io.confluent.ksql.structured; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.plan.ExecutionStep; import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.StreamAggregate; +import io.confluent.ksql.execution.plan.StreamWindowedAggregate; import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.execution.streams.MaterializedFactory; -import io.confluent.ksql.execution.streams.StreamsUtil; +import io.confluent.ksql.execution.streams.StreamAggregateBuilder; import io.confluent.ksql.function.FunctionRegistry; -import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.UdafAggregator; -import io.confluent.ksql.function.udaf.KudafAggregator; -import io.confluent.ksql.function.udaf.window.WindowSelectMapper; import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.model.WindowType; -import io.confluent.ksql.parser.tree.KsqlWindowExpression; import io.confluent.ksql.parser.tree.WindowExpression; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.serde.Format; @@ -43,15 +41,11 @@ import io.confluent.ksql.util.KsqlConfig; import java.time.Duration; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.Optional; -import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; import org.apache.kafka.streams.kstream.KGroupedStream; import org.apache.kafka.streams.kstream.KTable; -import org.apache.kafka.streams.kstream.Materialized; import org.apache.kafka.streams.kstream.Windowed; public class SchemaKGroupedStream { @@ -121,18 +115,18 @@ public ExecutionStep> getSourceStep() { @SuppressWarnings("unchecked") public SchemaKTable aggregate( + final LogicalSchema aggregateSchema, final LogicalSchema outputSchema, - final Initializer initializer, final int nonFuncColumnCount, final List aggregations, - final Map aggValToFunctionMap, final Optional windowExpression, final ValueFormat valueFormat, - final Serde topicValueSerDe, - final QueryContext.Stacker contextStacker + final QueryContext.Stacker contextStacker, + final KsqlQueryBuilder queryBuilder ) { - throwOnValueFieldCountMismatch(outputSchema, nonFuncColumnCount, aggValToFunctionMap); + throwOnValueFieldCountMismatch(outputSchema, nonFuncColumnCount, aggregations); + final ExecutionStep> step; final KTable table; final KeySerde newKeySerde; final KeyFormat keyFormat; @@ -140,37 +134,44 @@ public SchemaKTable aggregate( if (windowExpression.isPresent()) { keyFormat = getKeyFormat(windowExpression.get()); newKeySerde = getKeySerde(windowExpression.get()); - - table = aggregateWindowed( - initializer, + final StreamWindowedAggregate aggregate = ExecutionStepFactory.streamWindowedAggregate( + contextStacker, + sourceStep, + outputSchema, + Formats.of(keyFormat, valueFormat, SerdeOption.none()), nonFuncColumnCount, - aggValToFunctionMap, - windowExpression.get(), - topicValueSerDe, - contextStacker + aggregations, + aggregateSchema, + windowExpression.get().getKsqlWindowExpression() + ); + step = aggregate; + table = StreamAggregateBuilder.build( + kgroupedStream, + aggregate, + queryBuilder, + materializedFactory ); } else { keyFormat = this.keyFormat; newKeySerde = keySerde; - - table = aggregateNonWindowed( - initializer, + final StreamAggregate aggregate = ExecutionStepFactory.streamAggregate( + contextStacker, + sourceStep, + outputSchema, + Formats.of(keyFormat, valueFormat, SerdeOption.none()), nonFuncColumnCount, - aggValToFunctionMap, - topicValueSerDe, - contextStacker + aggregations, + aggregateSchema + ); + step = aggregate; + table = StreamAggregateBuilder.build( + kgroupedStream, + aggregate, + queryBuilder, + materializedFactory ); } - final ExecutionStep step = ExecutionStepFactory.streamAggregate( - contextStacker, - sourceStep, - outputSchema, - Formats.of(keyFormat, valueFormat, SerdeOption.none()), - nonFuncColumnCount, - aggregations - ); - return new SchemaKTable( table, step, @@ -184,61 +185,6 @@ public SchemaKTable aggregate( ); } - @SuppressWarnings("unchecked") - private KTable aggregateNonWindowed( - final Initializer initializer, - final int nonFuncColumnCount, - final Map indexToFunctionMap, - final Serde topicValueSerDe, - final QueryContext.Stacker contextStacker - ) { - final UdafAggregator aggregator = new KudafAggregator(nonFuncColumnCount, indexToFunctionMap); - - final Materialized materialized = materializedFactory.create( - keySerde, - topicValueSerDe, - StreamsUtil.buildOpName(contextStacker.getQueryContext()) - ); - - final KTable aggTable = kgroupedStream.aggregate(initializer, aggregator, materialized); - - return getAggregationResult(aggTable, aggregator); - } - - @SuppressWarnings("unchecked") - private KTable aggregateWindowed( - final Initializer initializer, - final int nonFuncColumnCount, - final Map indexToFunctionMap, - final WindowExpression windowExpression, - final Serde topicValueSerDe, - final QueryContext.Stacker contextStacker - ) { - final UdafAggregator aggregator = new KudafAggregator(nonFuncColumnCount, indexToFunctionMap); - - final KsqlWindowExpression ksqlWindowExpression = windowExpression.getKsqlWindowExpression(); - - final Materialized materialized = materializedFactory.create( - keySerde, - topicValueSerDe, - StreamsUtil.buildOpName(contextStacker.getQueryContext()) - ); - - final KTable, GenericRow> aggKtable = ksqlWindowExpression.applyAggregate( - kgroupedStream, initializer, aggregator, materialized); - - // Apply the mapper before window_start and window_end functions that return null if a - // record is not part of the window. - final KTable reducedTable = getAggregationResult(aggKtable, aggregator); - - final WindowSelectMapper windowSelectMapper = new WindowSelectMapper(indexToFunctionMap); - if (!windowSelectMapper.hasSelects()) { - return reducedTable; - } - - return reducedTable.mapValues(windowSelectMapper); - } - private KeyFormat getKeyFormat(final WindowExpression windowExpression) { if (ksqlConfig.getBoolean(KsqlConfig.KSQL_WINDOWED_SESSION_KEY_LEGACY_CONFIG)) { return KeyFormat.windowed( @@ -255,11 +201,6 @@ private KeyFormat getKeyFormat(final WindowExpression windowExpression) { ); } - @SuppressWarnings("unchecked") - private KTable getAggregationResult(final KTable table, final UdafAggregator aggregator) { - return table.mapValues(aggregator.getResultMapper()); - } - private KeySerde> getKeySerde(final WindowExpression windowExpression) { if (ksqlConfig.getBoolean(KsqlConfig.KSQL_WINDOWED_SESSION_KEY_LEGACY_CONFIG)) { return keySerde.rebind(WindowInfo.of( @@ -274,10 +215,9 @@ private KeySerde> getKeySerde(final WindowExpression windowExpr static void throwOnValueFieldCountMismatch( final LogicalSchema aggregateSchema, final int nonFuncColumnCount, - final Map aggValToFunctionMap + final List aggregateFunctions ) { - final int nonAggColumnCount = aggValToFunctionMap.size(); - final int totalColumnCount = nonAggColumnCount + nonFuncColumnCount; + final int totalColumnCount = aggregateFunctions.size() + nonFuncColumnCount; final int valueColumnCount = aggregateSchema.value().size(); if (valueColumnCount != totalColumnCount) { diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedTable.java b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedTable.java index 75693817de03..de0571c317f4 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedTable.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedTable.java @@ -16,18 +16,19 @@ package io.confluent.ksql.structured; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.function.TableAggregationFunction; +import io.confluent.ksql.execution.function.UdafUtil; import io.confluent.ksql.execution.plan.ExecutionStep; import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.TableAggregate; import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.execution.streams.MaterializedFactory; -import io.confluent.ksql.execution.streams.StreamsUtil; +import io.confluent.ksql.execution.streams.TableAggregateBuilder; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.TableAggregationFunction; -import io.confluent.ksql.function.udaf.KudafAggregator; -import io.confluent.ksql.function.udaf.KudafUndoAggregator; import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.parser.tree.WindowExpression; import io.confluent.ksql.schema.ksql.LogicalSchema; @@ -38,16 +39,11 @@ import io.confluent.ksql.util.KsqlConfig; import io.confluent.ksql.util.KsqlException; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; -import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; import org.apache.kafka.streams.kstream.KGroupedTable; -import org.apache.kafka.streams.kstream.KTable; -import org.apache.kafka.streams.kstream.Materialized; public class SchemaKGroupedTable extends SchemaKGroupedStream { private final KGroupedTable kgroupedTable; @@ -109,25 +105,25 @@ public ExecutionStep> getSourceTableStep() { @SuppressWarnings("unchecked") @Override public SchemaKTable aggregate( + final LogicalSchema aggregateSchema, final LogicalSchema outputSchema, - final Initializer initializer, final int nonFuncColumnCount, final List aggregations, - final Map aggValToFunctionMap, final Optional windowExpression, final ValueFormat valueFormat, - final Serde topicValueSerDe, - final QueryContext.Stacker contextStacker + final QueryContext.Stacker contextStacker, + final KsqlQueryBuilder queryBuilder ) { if (windowExpression.isPresent()) { throw new KsqlException("Windowing not supported for table aggregations."); } - throwOnValueFieldCountMismatch(outputSchema, nonFuncColumnCount, aggValToFunctionMap); + throwOnValueFieldCountMismatch(outputSchema, nonFuncColumnCount, aggregations); - final List unsupportedFunctionNames = aggValToFunctionMap.values() - .stream() - .filter(function -> !(function instanceof TableAggregationFunction)) + final List unsupportedFunctionNames = aggregations.stream() + .map(call -> UdafUtil.resolveAggregateFunction( + queryBuilder.getFunctionRegistry(), call, sourceTableStep.getSchema()) + ).filter(function -> !(function instanceof TableAggregationFunction)) .map(KsqlAggregateFunction::getFunctionName) .collect(Collectors.toList()); if (!unsupportedFunctionNames.isEmpty()) { @@ -137,46 +133,23 @@ public SchemaKTable aggregate( String.join(", ", unsupportedFunctionNames))); } - final KudafAggregator aggregator = new KudafAggregator( - nonFuncColumnCount, aggValToFunctionMap); - - final Map aggValToUndoFunctionMap = - aggValToFunctionMap.keySet() - .stream() - .collect( - Collectors.toMap( - k -> k, - k -> ((TableAggregationFunction) aggValToFunctionMap.get(k)))); - - final KudafUndoAggregator subtractor = new KudafUndoAggregator( - nonFuncColumnCount, aggValToUndoFunctionMap); - - final Materialized materialized = materializedFactory.create( - keySerde, - topicValueSerDe, - StreamsUtil.buildOpName(contextStacker.getQueryContext()) - ); - - final KTable aggKtable = kgroupedTable.aggregate( - initializer, - aggregator, - subtractor, - materialized); - - final ExecutionStep step = ExecutionStepFactory.tableAggregate( + final TableAggregate step = ExecutionStepFactory.tableAggregate( contextStacker, sourceTableStep, outputSchema, Formats.of(keyFormat, valueFormat, SerdeOption.none()), nonFuncColumnCount, - aggregations + aggregations, + aggregateSchema ); - final KTable outputTable = aggKtable.mapValues( - aggregator.getResultMapper()); - return new SchemaKTable<>( - outputTable, + TableAggregateBuilder.build( + kgroupedTable, + step, + queryBuilder, + materializedFactory + ), step, keyFormat, keySerde, diff --git a/ksql-engine/src/test/java/io/confluent/ksql/engine/rewrite/StatementRewriterTest.java b/ksql-engine/src/test/java/io/confluent/ksql/engine/rewrite/StatementRewriterTest.java index b0a6c2efbaa3..fe591db0609c 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/engine/rewrite/StatementRewriterTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/engine/rewrite/StatementRewriterTest.java @@ -42,7 +42,7 @@ import io.confluent.ksql.parser.tree.Join; import io.confluent.ksql.parser.tree.Join.Type; import io.confluent.ksql.parser.tree.JoinCriteria; -import io.confluent.ksql.parser.tree.KsqlWindowExpression; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.parser.tree.Query; import io.confluent.ksql.parser.tree.Relation; import io.confluent.ksql.parser.tree.ResultMaterialization; @@ -376,11 +376,8 @@ public void shouldRewriteJoinWithWindowExpression() { public void shouldRewriteWindowExpression() { // Given: final KsqlWindowExpression ksqlWindowExpression = mock(KsqlWindowExpression.class); - final KsqlWindowExpression rewrittenKsqlWindowExpression = mock(KsqlWindowExpression.class); final WindowExpression windowExpression = new WindowExpression(location, "name", ksqlWindowExpression); - when(mockRewriter.apply(ksqlWindowExpression, context)) - .thenReturn(rewrittenKsqlWindowExpression); // When: final AstNode rewritten = rewriter.rewrite(windowExpression, context); @@ -388,7 +385,7 @@ public void shouldRewriteWindowExpression() { // Then: assertThat( rewritten, - equalTo(new WindowExpression(location, "name", rewrittenKsqlWindowExpression)) + equalTo(new WindowExpression(location, "name", ksqlWindowExpression)) ); } diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/KudafUndoAggregatorTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/KudafUndoAggregatorTest.java index 5433c574f11d..96a4dd5008d2 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/KudafUndoAggregatorTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/function/KudafUndoAggregatorTest.java @@ -20,7 +20,8 @@ import static org.junit.Assert.assertThat; import io.confluent.ksql.GenericRow; -import io.confluent.ksql.function.udaf.KudafUndoAggregator; +import io.confluent.ksql.execution.function.TableAggregationFunction; +import io.confluent.ksql.execution.function.udaf.KudafUndoAggregator; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/UdfCompilerTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/UdfCompilerTest.java index 423eaa93bd18..0b1fbca1a1db 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/UdfCompilerTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/function/UdfCompilerTest.java @@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableList; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import io.confluent.ksql.execution.function.TableAggregationFunction; import io.confluent.ksql.function.udaf.TestUdaf; import io.confluent.ksql.function.udaf.Udaf; import io.confluent.ksql.util.KsqlException; diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/udaf/sum/BaseSumKudafTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/udaf/sum/BaseSumKudafTest.java index cc8373211434..aef5c8ff2810 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/udaf/sum/BaseSumKudafTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/function/udaf/sum/BaseSumKudafTest.java @@ -18,7 +18,7 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.junit.Assert.assertThat; -import io.confluent.ksql.function.TableAggregationFunction; +import io.confluent.ksql.execution.function.TableAggregationFunction; import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedStreamTest.java b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedStreamTest.java index 4cf4bea2f6ee..d3533a609eaa 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedStreamTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedStreamTest.java @@ -15,39 +15,35 @@ package io.confluent.ksql.structured; -import static java.util.Collections.emptyList; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.sameInstance; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.plan.ExecutionStep; import io.confluent.ksql.execution.plan.Formats; import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.execution.streams.MaterializedFactory; -import io.confluent.ksql.execution.streams.StreamsUtil; +import io.confluent.ksql.execution.windows.SessionWindowExpression; import io.confluent.ksql.function.FunctionRegistry; -import io.confluent.ksql.function.KsqlAggregateFunction; +import io.confluent.ksql.function.InternalFunctionRegistry; import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.model.WindowType; -import io.confluent.ksql.parser.tree.KsqlWindowExpression; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; +import io.confluent.ksql.name.ColumnName; +import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.parser.tree.WindowExpression; import io.confluent.ksql.query.QueryId; -import io.confluent.ksql.schema.ksql.Column; +import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.serde.Format; import io.confluent.ksql.serde.FormatInfo; import io.confluent.ksql.serde.KeyFormat; @@ -56,21 +52,17 @@ import io.confluent.ksql.serde.ValueFormat; import io.confluent.ksql.serde.WindowInfo; import io.confluent.ksql.util.KsqlConfig; -import java.time.Duration; -import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Optional; -import java.util.stream.Collectors; -import java.util.stream.IntStream; +import java.util.concurrent.TimeUnit; import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; import org.apache.kafka.streams.kstream.KGroupedStream; import org.apache.kafka.streams.kstream.KTable; import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.SessionWindowedKStream; +import org.apache.kafka.streams.kstream.SessionWindows; import org.apache.kafka.streams.kstream.ValueMapper; -import org.apache.kafka.streams.kstream.ValueMapperWithKey; import org.apache.kafka.streams.kstream.Windowed; import org.junit.Before; import org.junit.Test; @@ -81,40 +73,45 @@ @SuppressWarnings("unchecked") @RunWith(MockitoJUnitRunner.class) public class SchemaKGroupedStreamTest { + private static final LogicalSchema IN_SCHEMA = LogicalSchema.builder() + .valueColumn(ColumnName.of("IN0"), SqlTypes.STRING) + .valueColumn(ColumnName.of("IN1"), SqlTypes.INTEGER) + .build(); + private static final LogicalSchema AGG_SCHEMA = LogicalSchema.builder() + .valueColumn(ColumnName.of("IN0"), SqlTypes.STRING) + .valueColumn(ColumnName.of("AGG0"), SqlTypes.BIGINT) + .build(); + private static final LogicalSchema OUT_SCHEMA = LogicalSchema.builder() + .valueColumn(ColumnName.of("IN0"), SqlTypes.STRING) + .valueColumn(ColumnName.of("OUT0"), SqlTypes.STRING) + .build(); + private static final FunctionCall AGG = new FunctionCall( + FunctionName.of("SUM"), + ImmutableList.of(new ColumnReferenceExp(ColumnRef.of(ColumnName.of("IN1")))) + ); + private static final KsqlWindowExpression KSQL_WINDOW_EXP = new SessionWindowExpression( + 100, TimeUnit.SECONDS + ); - @Mock - private LogicalSchema aggregateSchema; @Mock private KGroupedStream groupedStream; @Mock + private SessionWindowedKStream sessionWindowedStream; + @Mock private KeyField keyField; @Mock private List sourceStreams; @Mock private KsqlConfig config; @Mock - private FunctionRegistry funcRegistry; - @Mock - private Initializer initializer; - @Mock private Serde topicValueSerDe; @Mock - private KsqlAggregateFunction windowStartFunc; - @Mock - private KsqlAggregateFunction windowEndFunc; - @Mock - private KsqlAggregateFunction otherFunc; - @Mock private FunctionCall aggCall; @Mock private KTable table; @Mock - private KTable table2; - @Mock private WindowExpression windowExp; @Mock - private KsqlWindowExpression ksqlWindowExp; - @Mock private MaterializedFactory materializedFactory; @Mock private Materialized materialized; @@ -123,17 +120,19 @@ public class SchemaKGroupedStreamTest { @Mock private KeySerde> windowedKeySerde; @Mock - private Column field; - @Mock private ExecutionStep sourceStep; @Mock private KeyFormat keyFormat; @Mock private ValueFormat valueFormat; + @Mock + private KsqlQueryBuilder builder; + + private final FunctionRegistry functionRegistry = new InternalFunctionRegistry(); private final QueryContext.Stacker queryContext = new QueryContext.Stacker(new QueryId("query")).push("node"); + private SchemaKGroupedStream schemaGroupedStream; - private Map someUdfs; @Before public void setUp() { @@ -145,330 +144,56 @@ public void setUp() { keyField, sourceStreams, config, - funcRegistry, + functionRegistry, materializedFactory ); - - when(windowStartFunc.getFunctionName()).thenReturn("WindowStart"); - when(windowEndFunc.getFunctionName()).thenReturn("WindowEnd"); - when(otherFunc.getFunctionName()).thenReturn("NotWindowStartFunc"); - when(windowExp.getKsqlWindowExpression()).thenReturn(ksqlWindowExp); + when(sourceStep.getSchema()).thenReturn(IN_SCHEMA); + when(windowExp.getKsqlWindowExpression()).thenReturn(KSQL_WINDOW_EXP); when(config.getBoolean(KsqlConfig.KSQL_WINDOWED_SESSION_KEY_LEGACY_CONFIG)).thenReturn(false); when(materializedFactory.create(any(), any(), any())).thenReturn(materialized); - - when(ksqlWindowExp.getWindowInfo()) - .thenReturn(WindowInfo.of(WindowType.SESSION, Optional.empty())); - + when(builder.buildKeySerde(any(), any(), any())).thenReturn(keySerde); + when(builder.buildValueSerde(any(), any(), any())).thenReturn(topicValueSerDe); + when(builder.getFunctionRegistry()).thenReturn(functionRegistry); when(keySerde.rebind(any(WindowInfo.class))).thenReturn(windowedKeySerde); - - when(aggregateSchema.value()).thenReturn(ImmutableList.of(mock(Column.class))); - - when(ksqlWindowExp.applyAggregate(any(), any(), any(), any())).thenReturn(table); when(table.mapValues(any(ValueMapper.class))).thenReturn(table); - - someUdfs = ImmutableMap.of(0, otherFunc); - } - - @Test - public void shouldNoUseSelectMapperForNonWindowed() { - // Given: - final Map invalidWindowFuncs = ImmutableMap.of( - 0, windowStartFunc, 1, windowEndFunc); - - // When: - assertDoesNotInstallWindowSelectMapper(null, invalidWindowFuncs); - } - - @Test - public void shouldNotUseSelectMapperForWindowedWithoutWindowSelects() { - // Given: - final Map nonWindowFuncs = ImmutableMap.of(0, otherFunc); - - // When: - assertDoesNotInstallWindowSelectMapper(windowExp, nonWindowFuncs); - } - - @Test - public void shouldUseSelectMapperForWindowedWithWindowStart() { - // Given: - Map funcMapWithWindowStart = ImmutableMap.of( - 0, otherFunc, 1, windowStartFunc); - - // Then: - assertDoesInstallWindowSelectMapper(funcMapWithWindowStart); - } - - @Test - public void shouldUseSelectMapperForWindowedWithWindowEnd() { - // Given: - Map funcMapWithWindowEnd = ImmutableMap.of( - 0, windowEndFunc, 1, otherFunc); - - // Then: - assertDoesInstallWindowSelectMapper(funcMapWithWindowEnd); - } - - @Test - public void shouldSupportSessionWindowedKey() { - // Given: - final WindowInfo windowInfo = WindowInfo.of(WindowType.SESSION, Optional.empty()); - when(ksqlWindowExp.getWindowInfo()).thenReturn(windowInfo); - - // When: - final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - someUdfs, - Optional.of(windowExp), - valueFormat, - topicValueSerDe, - queryContext - ); - - // Then: - verify(keySerde).rebind(windowInfo); - assertThat(result.getKeySerde(), is(windowedKeySerde)); - } - - @Test - public void shouldSupportHoppingWindowedKey() { - // Given: - final WindowInfo windowInfo = WindowInfo - .of(WindowType.HOPPING, Optional.of(Duration.ofMillis(10))); - when(ksqlWindowExp.getWindowInfo()).thenReturn(windowInfo); - - // When: - final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - someUdfs, - Optional.of(windowExp), - valueFormat, - topicValueSerDe, - queryContext - ); - - // Then: - verify(keySerde).rebind(windowInfo); - assertThat(result.getKeySerde(), is(windowedKeySerde)); - } - - @Test - public void shouldSupportTumblingWindowedKey() { - // Given: - final WindowInfo windowInfo = WindowInfo - .of(WindowType.TUMBLING, Optional.of(Duration.ofMillis(10))); - - when(ksqlWindowExp.getWindowInfo()).thenReturn(windowInfo); - // When: - final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - someUdfs, - Optional.of(windowExp), - valueFormat, - topicValueSerDe, - queryContext - ); - - // Then: - verify(keySerde).rebind(windowInfo); - assertThat(result.getKeySerde(), is(windowedKeySerde)); } @Test - public void shouldUseTimeWindowKeySerdeForWindowedIfLegacyConfig() { - // Given: - when(config.getBoolean(KsqlConfig.KSQL_WINDOWED_SESSION_KEY_LEGACY_CONFIG)) - .thenReturn(true); - - // When: - final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - someUdfs, - Optional.of(windowExp), - valueFormat, - topicValueSerDe, - queryContext - ); - - // Then: - verify(keySerde) - .rebind(WindowInfo.of(WindowType.TUMBLING, Optional.of(Duration.ofMillis(Long.MAX_VALUE)))); - assertThat(result.getKeySerde(), is(windowedKeySerde)); - } - - private void assertDoesNotInstallWindowSelectMapper( - final WindowExpression windowExp, - final Map funcMap) { - - // Given: - if (windowExp != null) { - when(ksqlWindowExp.applyAggregate(any(), any(), any(), any())) - .thenReturn(table); - } else { - when(groupedStream.aggregate(any(), any(), any())) - .thenReturn(table); - } - givenAggregateSchemaFieldCount(funcMap.size()); - - // When: - final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - funcMap, - Optional.ofNullable(windowExp), - valueFormat, - topicValueSerDe, - queryContext - ); - - // Then: - assertThat(result.getKtable(), is(sameInstance(table))); - verify(table, never()).mapValues(any(ValueMapperWithKey.class)); - } - - private void assertDoesInstallWindowSelectMapper( - final Map funcMap) { - + public void shouldReturnKTableWithOutputSchema() { // Given: - when(table.mapValues(any(ValueMapperWithKey.class))).thenReturn(table2); - givenAggregateSchemaFieldCount(funcMap.size()); + when(groupedStream.aggregate(any(), any(), any())).thenReturn(table); // When: final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - funcMap, - Optional.of(windowExp), - valueFormat, - topicValueSerDe, - queryContext - ); - - // Then: - assertThat(result.getKtable(), is(sameInstance(table2))); - verify(table, times(1)).mapValues(any(ValueMapperWithKey.class)); - } - - @SuppressWarnings("unchecked") - private Materialized whenMaterializedFactoryCreates() { - final Materialized materialized = mock(Materialized.class); - when(materializedFactory.create(any(), any(), any())).thenReturn(materialized); - return materialized; - } - - @SuppressWarnings("unchecked") - @Test - public void shouldUseMaterializedFactoryForStateStore() { - // Given: - final Materialized materialized = whenMaterializedFactoryCreates(); - final KTable mockKTable = mock(KTable.class); - when(groupedStream.aggregate(any(), any(), same(materialized))).thenReturn(mockKTable); - - // When: - schemaGroupedStream.aggregate( - aggregateSchema, - () -> null, - 0, - emptyList(), - someUdfs, + AGG_SCHEMA, + OUT_SCHEMA, + 1, + ImmutableList.of(AGG), Optional.empty(), valueFormat, - topicValueSerDe, - queryContext - ); - - // Then: - verify(materializedFactory) - .create( - same(keySerde), - same(topicValueSerDe), - eq(StreamsUtil.buildOpName(queryContext.getQueryContext()))); - verify(groupedStream, times(1)).aggregate(any(), any(), same(materialized)); - } - - @SuppressWarnings("unchecked") - @Test - public void shouldUseMaterializedFactoryWindowedStateStore() { - // Given: - final Materialized materialized = whenMaterializedFactoryCreates(); - when(ksqlWindowExp.applyAggregate(any(), any(), any(), same(materialized))) - .thenReturn(table); - - // When: - schemaGroupedStream.aggregate( - aggregateSchema, - () -> null, - 0, - emptyList(), - someUdfs, - Optional.of(windowExp), - valueFormat, - topicValueSerDe, - queryContext); - - // Then: - verify(materializedFactory) - .create( - same(keySerde), - same(topicValueSerDe), - eq(StreamsUtil.buildOpName(queryContext.getQueryContext()))); - verify(ksqlWindowExp, times(1)).applyAggregate(any(), any(), any(), same(materialized)); - } - - @Test - public void shouldReturnKTableWithAggregateSchema() { - // When: - final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - someUdfs, - Optional.of(windowExp), - valueFormat, - topicValueSerDe, - queryContext + queryContext, + builder ); // Then: - assertThat(result.getSchema(), is(aggregateSchema)); + assertThat(result.getSchema(), is(OUT_SCHEMA)); } @Test public void shouldBuildStepForAggregate() { // Given: - final Map functions = ImmutableMap.of(1, otherFunc); - when(aggregateSchema.value()) - .thenReturn(ImmutableList.of(mock(Column.class), mock(Column.class))); - when(groupedStream.aggregate(any(), any(), any())) - .thenReturn(table); + when(groupedStream.aggregate(any(), any(), any())).thenReturn(table); // When: final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, + AGG_SCHEMA, + OUT_SCHEMA, 1, - ImmutableList.of(aggCall), - functions, + ImmutableList.of(AGG), Optional.empty(), valueFormat, - topicValueSerDe, - queryContext + queryContext, + builder ); // Then: @@ -478,28 +203,34 @@ public void shouldBuildStepForAggregate() { ExecutionStepFactory.streamAggregate( queryContext, schemaGroupedStream.getSourceStep(), - aggregateSchema, + OUT_SCHEMA, Formats.of(keyFormat, valueFormat, SerdeOption.none()), 1, - ImmutableList.of(aggCall) + ImmutableList.of(AGG), + AGG_SCHEMA ) ) ); + assertThat(result.getKtable(), is(table)); } @Test - public void shouldBuildStepKeyFormatForWindowedAggregate() { + public void shouldBuildStepForWindowedAggregate() { + // Given: + when(groupedStream.windowedBy(any(SessionWindows.class))).thenReturn(sessionWindowedStream); + when(sessionWindowedStream.aggregate(any(), any(), any(), any())).thenReturn(table); + when(table.mapValues(any(ValueMapper.class))).thenReturn(table); + // When: final SchemaKTable result = schemaGroupedStream.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - someUdfs, + AGG_SCHEMA, + OUT_SCHEMA, + 1, + ImmutableList.of(AGG), Optional.of(windowExp), valueFormat, - topicValueSerDe, - queryContext + queryContext, + builder ); // Then: @@ -510,47 +241,33 @@ public void shouldBuildStepKeyFormatForWindowedAggregate() { assertThat( result.getSourceTableStep(), equalTo( - ExecutionStepFactory.streamAggregate( + ExecutionStepFactory.streamWindowedAggregate( queryContext, schemaGroupedStream.getSourceStep(), - aggregateSchema, + OUT_SCHEMA, Formats.of(expected, valueFormat, SerdeOption.none()), - 0, - Collections.emptyList() + 1, + ImmutableList.of(AGG), + AGG_SCHEMA, + KSQL_WINDOW_EXP ) ) ); + assertThat(result.getKtable(), is(table)); } @Test(expected = IllegalArgumentException.class) public void shouldThrowOnColumnCountMismatch() { - // Given: - // Agg schema has 2 fields: - givenAggregateSchemaFieldCount(2); - - // Where as params have 1 nonAgg and 2 agg fields: - final Map aggColumns = ImmutableMap.of(2, otherFunc); - // When: schemaGroupedStream.aggregate( - aggregateSchema, - initializer, + AGG_SCHEMA, + OUT_SCHEMA, 2, ImmutableList.of(aggCall), - aggColumns, Optional.of(windowExp), valueFormat, - topicValueSerDe, - queryContext + queryContext, + builder ); } - - private void givenAggregateSchemaFieldCount(final int count) { - final List valueFields = IntStream - .range(0, count) - .mapToObj(i -> field) - .collect(Collectors.toList()); - - when(aggregateSchema.value()).thenReturn(valueFields); - } } diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java index 7342c5c03f91..846f75bbd5e1 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java @@ -15,84 +15,45 @@ package io.confluent.ksql.structured; -import static java.util.Collections.emptyList; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import io.confluent.kafka.schemaregistry.client.MockSchemaRegistryClient; -import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; -import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.name.ColumnName; -import io.confluent.ksql.name.SourceName; +import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; -import io.confluent.ksql.execution.plan.DefaultExecutionStepProperties; import io.confluent.ksql.execution.plan.ExecutionStep; import io.confluent.ksql.execution.plan.Formats; import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.execution.streams.MaterializedFactory; -import io.confluent.ksql.execution.streams.StreamsUtil; import io.confluent.ksql.function.InternalFunctionRegistry; -import io.confluent.ksql.function.KsqlAggregateFunction; -import io.confluent.ksql.function.TableAggregationFunction; -import io.confluent.ksql.function.udaf.KudafInitializer; -import io.confluent.ksql.logging.processing.NoopProcessingLogContext; -import io.confluent.ksql.logging.processing.ProcessingLogContext; -import io.confluent.ksql.metastore.MetaStore; import io.confluent.ksql.metastore.model.KeyField; -import io.confluent.ksql.metastore.model.KsqlTable; import io.confluent.ksql.parser.tree.WindowExpression; -import io.confluent.ksql.planner.plan.PlanNode; import io.confluent.ksql.query.QueryId; -import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.LogicalSchema; -import io.confluent.ksql.schema.ksql.PersistenceSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.serde.Format; import io.confluent.ksql.serde.FormatInfo; -import io.confluent.ksql.serde.GenericRowSerDe; import io.confluent.ksql.serde.KeyFormat; import io.confluent.ksql.serde.KeySerde; import io.confluent.ksql.serde.SerdeOption; import io.confluent.ksql.serde.ValueFormat; -import io.confluent.ksql.testutils.AnalysisTestUtil; import io.confluent.ksql.util.KsqlConfig; import io.confluent.ksql.util.KsqlException; -import io.confluent.ksql.util.MetaStoreFixture; -import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; import java.util.Optional; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import org.apache.kafka.common.serialization.Serde; -import org.apache.kafka.common.serialization.Serdes; -import org.apache.kafka.connect.data.Schema; import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.StreamsBuilder; -import org.apache.kafka.streams.kstream.Consumed; -import org.apache.kafka.streams.kstream.Initializer; import org.apache.kafka.streams.kstream.KGroupedTable; import org.apache.kafka.streams.kstream.KTable; -import org.apache.kafka.streams.kstream.Materialized; import org.apache.kafka.streams.kstream.ValueMapper; -import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -104,16 +65,33 @@ @SuppressWarnings("unchecked") @RunWith(MockitoJUnitRunner.class) public class SchemaKGroupedTableTest { + private static final LogicalSchema IN_SCHEMA = LogicalSchema.builder() + .valueColumn(ColumnName.of("IN0"), SqlTypes.STRING) + .valueColumn(ColumnName.of("IN1"), SqlTypes.INTEGER) + .build(); + private static final LogicalSchema AGG_SCHEMA = LogicalSchema.builder() + .valueColumn(ColumnName.of("IN0"), SqlTypes.STRING) + .valueColumn(ColumnName.of("AGG0"), SqlTypes.BIGINT) + .valueColumn(ColumnName.of("AGG1"), SqlTypes.BIGINT) + .build(); + private static final LogicalSchema OUT_SCHEMA = LogicalSchema.builder() + .valueColumn(ColumnName.of("IN0"), SqlTypes.STRING) + .valueColumn(ColumnName.of("OUT0"), SqlTypes.STRING) + .valueColumn(ColumnName.of("OUT1"), SqlTypes.STRING) + .build(); + private static final FunctionCall MIN = udaf("MIN"); + private static final FunctionCall MAX = udaf("MAX"); + private static final FunctionCall SUM = udaf("SUM"); + private static final FunctionCall COUNT = udaf("COUNT"); + private final KsqlConfig ksqlConfig = new KsqlConfig(Collections.emptyMap()); private final InternalFunctionRegistry functionRegistry = new InternalFunctionRegistry(); - private final ProcessingLogContext processingLogContext = ProcessingLogContext.create(); private final KGroupedTable mockKGroupedTable = mock(KGroupedTable.class); private final LogicalSchema schema = LogicalSchema.builder() .valueColumn(ColumnName.of("GROUPING_COLUMN"), SqlTypes.STRING) .valueColumn(ColumnName.of("AGG_VALUE"), SqlTypes.INTEGER) .build(); private final MaterializedFactory materializedFactory = mock(MaterializedFactory.class); - private final MetaStore metaStore = MetaStoreFixture.getNewMetaStore(new InternalFunctionRegistry()); private final QueryContext.Stacker queryContext = new QueryContext.Stacker(new QueryId("query")).push("node"); private final ValueFormat valueFormat = ValueFormat.of(FormatInfo.of(Format.JSON)); @@ -125,102 +103,23 @@ public class SchemaKGroupedTableTest { @Mock private KeySerde keySerde; @Mock - private LogicalSchema aggregateSchema; - @Mock - private Initializer initializer; - @Mock - private Serde topicValueSerDe; - @Mock - private FunctionCall aggCall1; - @Mock - private FunctionCall aggCall2; - @Mock - private Column field; - @Mock - private KsqlAggregateFunction otherFunc; - @Mock - private TableAggregationFunction tableFunc; - @Mock private KsqlQueryBuilder queryBuilder; @Mock private KTable table; - private KTable kTable; - private KsqlTable ksqlTable; - private Map someUdfs; - @Before public void init() { - ksqlTable = (KsqlTable) metaStore.getSource(SourceName.of("TEST2")); - final StreamsBuilder builder = new StreamsBuilder(); - - final Serde rowSerde = GenericRowSerDe.from( - ksqlTable.getKsqlTopic().getValueFormat().getFormatInfo(), - PersistenceSchema.from(ksqlTable.getSchema().valueConnectSchema(), false), - new KsqlConfig(Collections.emptyMap()), - MockSchemaRegistryClient::new, - "", - NoopProcessingLogContext.INSTANCE - ); - - kTable = builder.table( - ksqlTable.getKsqlTopic().getKafkaTopicName(), - Consumed.with(Serdes.String(), rowSerde) - ); - when(queryBuilder.getFunctionRegistry()).thenReturn(functionRegistry); - when(queryBuilder.getKsqlConfig()).thenReturn(ksqlConfig); - - when(aggregateSchema.findValueColumn(ColumnName.of("GROUPING_COLUMN"))) - .thenReturn(Optional.of(Column.of(ColumnName.of("GROUPING_COLUMN"), SqlTypes.STRING))); - - when(aggregateSchema.value()).thenReturn(ImmutableList.of(mock(Column.class))); - when(mockKGroupedTable.aggregate(any(), any(), any(), any())).thenReturn(table); when(table.mapValues(any(ValueMapper.class))).thenReturn(table); - - someUdfs = ImmutableMap.of(0, tableFunc); } private ExecutionStep buildSourceTableStep(final LogicalSchema schema) { final ExecutionStep step = mock(ExecutionStep.class); - when(step.getProperties()).thenReturn( - new DefaultExecutionStepProperties(schema, queryContext.getQueryContext()) - ); when(step.getSchema()).thenReturn(schema); return step; } - private SchemaKGroupedTable buildSchemaKGroupedTableFromQuery( - final String query, - final String...groupByColumns - ) { - when(keySerde.rebind(any(PersistenceSchema.class))).thenReturn(keySerde); - - final PlanNode logicalPlan = AnalysisTestUtil.buildLogicalPlan(ksqlConfig, query, metaStore); - - final SchemaKTable initialSchemaKTable = new SchemaKTable( - kTable, - buildSourceTableStep(logicalPlan.getTheSourceNode().getSchema()), - keyFormat, - keySerde, - logicalPlan.getTheSourceNode().getKeyField(), - new ArrayList<>(), - SchemaKStream.Type.SOURCE, - ksqlConfig, - functionRegistry); - - final List groupByExpressions = - Arrays.stream(groupByColumns) - .map(c -> new ColumnReferenceExp(ColumnRef.of(SourceName.of("TEST1"), ColumnName.of(c)))) - .collect(Collectors.toList()); - - final SchemaKGroupedStream groupedSchemaKTable = initialSchemaKTable.groupBy( - valueFormat, groupByExpressions, queryContext, queryBuilder); - Assert.assertThat(groupedSchemaKTable, instanceOf(SchemaKGroupedTable.class)); - return (SchemaKGroupedTable)groupedSchemaKTable; - } - @Test public void shouldFailWindowedTableAggregation() { // Given: @@ -235,56 +134,39 @@ public void shouldFailWindowedTableAggregation() { // When: groupedTable.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - someUdfs, + AGG_SCHEMA, + OUT_SCHEMA, + 1, + ImmutableList.of(SUM, COUNT), Optional.of(windowExp), valueFormat, - topicValueSerDe, - queryContext + queryContext, + queryBuilder ); } @Test public void shouldFailUnsupportedAggregateFunction() { - final SchemaKGroupedTable kGroupedTable = buildSchemaKGroupedTableFromQuery( - "SELECT col0, col1, col2 FROM test1 EMIT CHANGES;", "COL1", "COL2"); - final InternalFunctionRegistry functionRegistry = new InternalFunctionRegistry(); - try { - final Map aggValToFunctionMap = new HashMap<>(); - aggValToFunctionMap.put( - 0, functionRegistry.getAggregate("MAX", Schema.OPTIONAL_INT64_SCHEMA)); - aggValToFunctionMap.put( - 1, functionRegistry.getAggregate("MIN", Schema.OPTIONAL_INT64_SCHEMA)); + // Given: + final SchemaKGroupedTable kGroupedTable = + buildSchemaKGroupedTable(mockKGroupedTable, materializedFactory); - givenAggregateSchemaFieldCount(aggValToFunctionMap.size() + 1); + // Then: + expectedException.expect(KsqlException.class); + expectedException.expectMessage( + "The aggregation function(s) (MIN, MAX) cannot be applied to a table."); - kGroupedTable.aggregate( - aggregateSchema, - new KudafInitializer(1), - 1, - ImmutableList.of(aggCall1, aggCall2), - aggValToFunctionMap, - Optional.empty(), - valueFormat, - GenericRowSerDe.from( - FormatInfo.of(Format.JSON, Optional.empty()), - PersistenceSchema.from(ksqlTable.getSchema().valueConnectSchema(), false), - ksqlConfig, - () -> null, - "test", - processingLogContext), - queryContext - ); - Assert.fail("Should fail to build topology for aggregation with unsupported function"); - } catch(final KsqlException e) { - Assert.assertThat( - e.getMessage(), - equalTo( - "The aggregation function(s) (MAX, MIN) cannot be applied to a table.")); - } + // When: + kGroupedTable.aggregate( + AGG_SCHEMA, + OUT_SCHEMA, + 1, + ImmutableList.of(MIN, MAX), + Optional.empty(), + valueFormat, + queryContext, + queryBuilder + ); } private SchemaKGroupedTable buildSchemaKGroupedTable( @@ -293,80 +175,31 @@ private SchemaKGroupedTable buildSchemaKGroupedTable( ) { return new SchemaKGroupedTable( kGroupedTable, - buildSourceTableStep(schema), + buildSourceTableStep(IN_SCHEMA), keyFormat, keySerde, - KeyField.of(schema.value().get(0).name(), schema.value().get(0)), + KeyField.of(IN_SCHEMA.value().get(0).name(), IN_SCHEMA.value().get(0)), Collections.emptyList(), ksqlConfig, functionRegistry, materializedFactory); } - @Test - public void shouldUseMaterializedFactoryForStateStore() { - // Given: - final Serde valueSerde = mock(Serde.class); - final Materialized materialized = MaterializedFactory.create(ksqlConfig).create( - Serdes.String(), - valueSerde, - StreamsUtil.buildOpName(queryContext.getQueryContext())); - - when(materializedFactory.create(any(), any(), any())).thenReturn(materialized); - - final KTable mockKTable = mock(KTable.class); - when(mockKGroupedTable.aggregate(any(), any(), any(), any())).thenReturn(mockKTable); - - final SchemaKGroupedTable groupedTable = - buildSchemaKGroupedTable(mockKGroupedTable, materializedFactory); - - // When: - groupedTable.aggregate( - aggregateSchema, - () -> null, - 0, - emptyList(), - someUdfs, - Optional.empty(), - valueFormat, - valueSerde, - queryContext); - - // Then: - verify(materializedFactory).create( - eq(keySerde), - same(valueSerde), - eq(StreamsUtil.buildOpName(queryContext.getQueryContext())) - ); - - verify(mockKGroupedTable).aggregate( - any(), - any(), - any(), - same(materialized) - ); - } - @Test public void shouldBuildStepForAggregate() { // Given: - final Map functions = ImmutableMap.of(1, tableFunc); - final SchemaKGroupedTable groupedTable = + final SchemaKGroupedTable kGroupedTable = buildSchemaKGroupedTable(mockKGroupedTable, materializedFactory); - when(aggregateSchema.value()).thenReturn( - ImmutableList.of(mock(Column.class), mock(Column.class))); - // When: - final SchemaKTable result = groupedTable.aggregate( - aggregateSchema, - initializer, + final SchemaKTable result = kGroupedTable.aggregate( + AGG_SCHEMA, + OUT_SCHEMA, 1, - ImmutableList.of(aggCall1), - functions, + ImmutableList.of(SUM, COUNT), Optional.empty(), valueFormat, - topicValueSerDe, - queryContext + queryContext, + queryBuilder ); // Then: @@ -375,37 +208,37 @@ public void shouldBuildStepForAggregate() { equalTo( ExecutionStepFactory.tableAggregate( queryContext, - groupedTable.getSourceTableStep(), - aggregateSchema, + kGroupedTable.getSourceTableStep(), + OUT_SCHEMA, Formats.of(keyFormat, valueFormat, SerdeOption.none()), 1, - ImmutableList.of(aggCall1) + ImmutableList.of(SUM, COUNT), + AGG_SCHEMA ) ) ); } @Test - public void shouldReturnKTableWithAggregateSchema() { + public void shouldReturnKTableWithOutputSchema() { // Given: final SchemaKGroupedTable groupedTable = buildSchemaKGroupedTable(mockKGroupedTable, materializedFactory); // When: final SchemaKTable result = groupedTable.aggregate( - aggregateSchema, - initializer, - 0, - emptyList(), - someUdfs, + AGG_SCHEMA, + OUT_SCHEMA, + 1, + ImmutableList.of(SUM, COUNT), Optional.empty(), valueFormat, - topicValueSerDe, - queryContext + queryContext, + queryBuilder ); // Then: - assertThat(result.getSchema(), is(aggregateSchema)); + assertThat(result.getSchema(), is(OUT_SCHEMA)); } @Test(expected = IllegalArgumentException.class) @@ -414,32 +247,23 @@ public void shouldThrowOnColumnCountMismatch() { final SchemaKGroupedTable groupedTable = buildSchemaKGroupedTable(mockKGroupedTable, materializedFactory); - // Agg schema has 2 fields: - givenAggregateSchemaFieldCount(2); - - // Where as params have 1 nonAgg and 2 agg fields: - final Map aggColumns = ImmutableMap.of(2, otherFunc); - // When: groupedTable.aggregate( - aggregateSchema, - initializer, + AGG_SCHEMA, + OUT_SCHEMA, 2, - ImmutableList.of(aggCall1), - aggColumns, + ImmutableList.of(SUM, COUNT), Optional.empty(), valueFormat, - topicValueSerDe, - queryContext + queryContext, + queryBuilder ); } - private void givenAggregateSchemaFieldCount(final int count) { - final List valueFields = IntStream - .range(0, count) - .mapToObj(i -> field) - .collect(Collectors.toList()); - - when(aggregateSchema.value()).thenReturn(valueFields); + private static FunctionCall udaf(final String name) { + return new FunctionCall( + FunctionName.of(name), + ImmutableList.of(new ColumnReferenceExp(ColumnRef.of("IN1"))) + ); } } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/TableAggregationFunction.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/TableAggregationFunction.java similarity index 87% rename from ksql-engine/src/main/java/io/confluent/ksql/function/TableAggregationFunction.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/function/TableAggregationFunction.java index fbef04cfe2cc..9249aa63522c 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/TableAggregationFunction.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/TableAggregationFunction.java @@ -13,7 +13,9 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.function; +package io.confluent.ksql.execution.function; + +import io.confluent.ksql.function.KsqlAggregateFunction; public interface TableAggregationFunction extends KsqlAggregateFunction { A undo(I valueToUndo, A aggregateValue); diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/function/UdafUtil.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/UdafUtil.java new file mode 100644 index 000000000000..bc7a58559042 --- /dev/null +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/UdafUtil.java @@ -0,0 +1,61 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.function; + +import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.util.ExpressionTypeManager; +import io.confluent.ksql.function.AggregateFunctionArguments; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.function.KsqlAggregateFunction; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.util.KsqlException; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.kafka.connect.data.Schema; + +public final class UdafUtil { + private UdafUtil() { + } + + @SuppressWarnings("deprecation") // Need to migrate away from Connect Schema use. + public static KsqlAggregateFunction resolveAggregateFunction( + final FunctionRegistry functionRegistry, + final FunctionCall functionCall, + final LogicalSchema schema + ) { + try { + final ExpressionTypeManager expressionTypeManager = + new ExpressionTypeManager(schema, functionRegistry); + final List functionArgs = functionCall.getArguments(); + final Schema expressionType = expressionTypeManager.getExpressionSchema(functionArgs.get(0)); + final KsqlAggregateFunction aggregateFunctionInfo = functionRegistry.getAggregate( + functionCall.getName().name(), + expressionType + ); + + final List args = functionArgs.stream() + .map(Expression::toString) + .collect(Collectors.toList()); + + final int udafIndex = schema.valueColumnIndex(args.get(0)).getAsInt(); + + return aggregateFunctionInfo.getInstance(new AggregateFunctionArguments(udafIndex, args)); + } catch (final Exception e) { + throw new KsqlException("Failed to create aggregate function: " + functionCall, e); + } + } +} diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafAggregator.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafAggregator.java similarity index 96% rename from ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafAggregator.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafAggregator.java index 267abb225b77..4510deaf2023 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafAggregator.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafAggregator.java @@ -13,7 +13,7 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.function.udaf; +package io.confluent.ksql.execution.function.udaf; import static java.util.Objects.requireNonNull; @@ -151,4 +151,12 @@ private static List validateAggregates( } return builder.build(); } + + public int getNonFuncColumnCount() { + return nonFuncColumnCount; + } + + public List getAggValToAggFunctionMap() { + return aggregateFunctions; + } } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafInitializer.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafInitializer.java similarity index 70% rename from ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafInitializer.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafInitializer.java index e59f41da3e4f..40fc553abb7c 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafInitializer.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafInitializer.java @@ -13,11 +13,12 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.function.udaf; +package io.confluent.ksql.execution.function.udaf; +import com.google.common.collect.ImmutableList; import io.confluent.ksql.GenericRow; -import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -25,11 +26,14 @@ public class KudafInitializer implements Initializer { - private final List aggValueSuppliers = new ArrayList<>(); + private final List initialValueSuppliers; private final int nonAggValSize; - public KudafInitializer(final int nonAggValSize) { + public KudafInitializer(final int nonAggValSize, final List initialValueSuppliers) { this.nonAggValSize = nonAggValSize; + this.initialValueSuppliers = ImmutableList.copyOf( + Objects.requireNonNull(initialValueSuppliers, "initialValueSuppliers") + ); } @Override @@ -38,11 +42,7 @@ public GenericRow apply() { .mapToObj(value -> null) .collect(Collectors.toList()); - aggValueSuppliers.forEach(supplier -> values.add(supplier.get())); + initialValueSuppliers.forEach(supplier -> values.add(supplier.get())); return new GenericRow(values); } - - public void addAggregateIntializer(final Supplier intialValueSupplier) { - aggValueSuppliers.add(intialValueSupplier); - } } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafUndoAggregator.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafUndoAggregator.java similarity index 86% rename from ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafUndoAggregator.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafUndoAggregator.java index e5b6989e93de..82cabe5098e1 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/KudafUndoAggregator.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/KudafUndoAggregator.java @@ -13,11 +13,11 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.function.udaf; +package io.confluent.ksql.execution.function.udaf; import com.google.common.collect.ImmutableMap; import io.confluent.ksql.GenericRow; -import io.confluent.ksql.function.TableAggregationFunction; +import io.confluent.ksql.execution.function.TableAggregationFunction; import java.util.Map; import java.util.Objects; import org.apache.kafka.connect.data.Struct; @@ -53,4 +53,12 @@ public GenericRow apply(final Struct k, final GenericRow rowValue, final Generic aggRowValue.getColumns().get(aggRowIndex)))); return aggRowValue; } + + public int getNonFuncColumnCount() { + return nonFuncColumnCount; + } + + public Map getAggValToAggFunctionMap() { + return aggValToAggFunctionMap; + } } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowSelectMapper.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/window/WindowSelectMapper.java similarity index 86% rename from ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowSelectMapper.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/window/WindowSelectMapper.java index 76954479148d..09860e299ab4 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udaf/window/WindowSelectMapper.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udaf/window/WindowSelectMapper.java @@ -13,7 +13,7 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.function.udaf.window; +package io.confluent.ksql.execution.function.udaf.window; import com.google.common.collect.ImmutableMap; import io.confluent.ksql.GenericRow; @@ -26,14 +26,17 @@ import org.apache.kafka.streams.kstream.Windowed; /** - * Used to handle the special cased {@link WindowStartKudaf} and {@link WindowEndKudaf}. + * Used to handle the special cased {WindowStart} and {WindowEnd}. */ public final class WindowSelectMapper implements ValueMapperWithKey, GenericRow, GenericRow> { + public static final String WINDOW_START_NAME = "WindowStart"; + public static final String WINDOW_END_NAME = "WindowEnd"; + private static final Map WINDOW_FUNCTION_NAMES = ImmutableMap.of( - WindowStartKudaf.getFunctionName().toUpperCase(), Type.StartTime, - WindowEndKudaf.getFunctionName().toUpperCase(), Type.EndTime + WINDOW_START_NAME.toUpperCase(), Type.StartTime, + WINDOW_END_NAME.toUpperCase(), Type.EndTime ); private final Map windowSelects; diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamAggregate.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamAggregate.java index 1a47b2ffe9ab..5fa29b514af8 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamAggregate.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamAggregate.java @@ -15,31 +15,39 @@ package io.confluent.ksql.execution.plan; import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.schema.ksql.LogicalSchema; import java.util.Collections; import java.util.List; import java.util.Objects; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KTable; @Immutable -public class StreamAggregate implements ExecutionStep { +public class StreamAggregate implements ExecutionStep> { private final ExecutionStepProperties properties; - private final ExecutionStep source; + private final ExecutionStep> source; private final Formats formats; private final int nonFuncColumnCount; private final List aggregations; + private final LogicalSchema aggregationSchema; public StreamAggregate( final ExecutionStepProperties properties, - final ExecutionStep source, + final ExecutionStep> source, final Formats formats, final int nonFuncColumnCount, - final List aggregations) { + final List aggregations, + final LogicalSchema aggregationSchema) { this.properties = Objects.requireNonNull(properties, "properties"); this.source = Objects.requireNonNull(source, "source"); this.formats = Objects.requireNonNull(formats, "formats"); this.nonFuncColumnCount = nonFuncColumnCount; - this.aggregations = Objects.requireNonNull(aggregations); + this.aggregations = Objects.requireNonNull(aggregations, "aggregations"); + this.aggregationSchema = Objects.requireNonNull(aggregationSchema, "aggregationSchema"); } @Override @@ -52,8 +60,24 @@ public List> getSources() { return Collections.singletonList(source); } + public int getNonFuncColumnCount() { + return nonFuncColumnCount; + } + + public List getAggregations() { + return aggregations; + } + + public Formats getFormats() { + return formats; + } + + public LogicalSchema getAggregationSchema() { + return aggregationSchema; + } + @Override - public T build(final KsqlQueryBuilder streamsBuilder) { + public KTable build(final KsqlQueryBuilder streamsBuilder) { throw new UnsupportedOperationException(); } @@ -65,17 +89,25 @@ public boolean equals(final Object o) { if (o == null || getClass() != o.getClass()) { return false; } - final StreamAggregate that = (StreamAggregate) o; + final StreamAggregate that = (StreamAggregate) o; return Objects.equals(properties, that.properties) && Objects.equals(source, that.source) && Objects.equals(formats, that.formats) && Objects.equals(aggregations, that.aggregations) - && nonFuncColumnCount == that.nonFuncColumnCount; + && nonFuncColumnCount == that.nonFuncColumnCount + && aggregationSchema.equals(that.aggregationSchema); } @Override public int hashCode() { - return Objects.hash(properties, source, formats, aggregations, nonFuncColumnCount); + return Objects.hash( + properties, + source, + formats, + aggregations, + nonFuncColumnCount, + aggregationSchema + ); } } diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamWindowedAggregate.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamWindowedAggregate.java new file mode 100644 index 000000000000..37eb3703cbc3 --- /dev/null +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamWindowedAggregate.java @@ -0,0 +1,124 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.plan; + +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Windowed; + +public class StreamWindowedAggregate + implements ExecutionStep, GenericRow>> { + private final ExecutionStepProperties properties; + private final ExecutionStep> source; + private final Formats formats; + private final int nonFuncColumnCount; + private final List aggregations; + private final LogicalSchema aggregationSchema; + private final KsqlWindowExpression windowExpression; + + public StreamWindowedAggregate( + final ExecutionStepProperties properties, + final ExecutionStep> source, + final Formats formats, + final int nonFuncColumnCount, + final List aggregations, + final LogicalSchema aggregationSchema, + final KsqlWindowExpression windowExpression) { + this.properties = Objects.requireNonNull(properties, "properties"); + this.source = Objects.requireNonNull(source, "source"); + this.formats = Objects.requireNonNull(formats, "formats"); + this.nonFuncColumnCount = nonFuncColumnCount; + this.aggregations = Objects.requireNonNull(aggregations, "aggregations"); + this.aggregationSchema = Objects.requireNonNull(aggregationSchema, "aggregationSchema"); + this.windowExpression = Objects.requireNonNull(windowExpression, "windowExpression"); + } + + @Override + public ExecutionStepProperties getProperties() { + return properties; + } + + @Override + public List> getSources() { + return Collections.singletonList(source); + } + + public int getNonFuncColumnCount() { + return nonFuncColumnCount; + } + + public List getAggregations() { + return aggregations; + } + + public Formats getFormats() { + return formats; + } + + public LogicalSchema getAggregationSchema() { + return aggregationSchema; + } + + public KsqlWindowExpression getWindowExpression() { + return windowExpression; + } + + @Override + public KTable, GenericRow> build(final KsqlQueryBuilder streamsBuilder) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final StreamWindowedAggregate that = (StreamWindowedAggregate) o; + return Objects.equals(properties, that.properties) + && Objects.equals(source, that.source) + && Objects.equals(formats, that.formats) + && Objects.equals(aggregations, that.aggregations) + && nonFuncColumnCount == that.nonFuncColumnCount + && Objects.equals(aggregationSchema, that.aggregationSchema) + && Objects.equals(windowExpression, that.windowExpression); + } + + @Override + public int hashCode() { + + return Objects.hash( + properties, + source, + formats, + aggregations, + nonFuncColumnCount, + aggregationSchema, + windowExpression + ); + } +} diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableAggregate.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableAggregate.java index 768d38d7948b..51c5fe5d4090 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableAggregate.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableAggregate.java @@ -15,31 +15,39 @@ package io.confluent.ksql.execution.plan; import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.schema.ksql.LogicalSchema; import java.util.Collections; import java.util.List; import java.util.Objects; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.KGroupedTable; +import org.apache.kafka.streams.kstream.KTable; @Immutable -public class TableAggregate implements ExecutionStep { +public class TableAggregate implements ExecutionStep> { private final ExecutionStepProperties properties; - private final ExecutionStep source; + private final ExecutionStep> source; private final Formats formats; private final int nonFuncColumnCount; private final List aggregations; + private final LogicalSchema aggregationSchema; public TableAggregate( final ExecutionStepProperties properties, - final ExecutionStep source, + final ExecutionStep> source, final Formats formats, final int nonFuncColumnCount, - final List aggregations) { + final List aggregations, + final LogicalSchema aggregationSchema) { this.properties = Objects.requireNonNull(properties, "properties"); this.source = Objects.requireNonNull(source, "source"); this.formats = Objects.requireNonNull(formats, "formats"); this.nonFuncColumnCount = nonFuncColumnCount; this.aggregations = Objects.requireNonNull(aggregations, "aggValToFunctionMap"); + this.aggregationSchema = Objects.requireNonNull(aggregationSchema, "aggregationSchema"); } @Override @@ -52,8 +60,24 @@ public List> getSources() { return Collections.singletonList(source); } + public Formats getFormats() { + return formats; + } + + public List getAggregations() { + return aggregations; + } + + public int getNonFuncColumnCount() { + return nonFuncColumnCount; + } + + public LogicalSchema getAggregationSchema() { + return aggregationSchema; + } + @Override - public T build(final KsqlQueryBuilder builder) { + public KTable build(final KsqlQueryBuilder builder) { throw new UnsupportedOperationException(); } @@ -65,12 +89,13 @@ public boolean equals(final Object o) { if (o == null || getClass() != o.getClass()) { return false; } - final TableAggregate that = (TableAggregate) o; + final TableAggregate that = (TableAggregate) o; return Objects.equals(properties, that.properties) && Objects.equals(source, that.source) && Objects.equals(formats, that.formats) && nonFuncColumnCount == that.nonFuncColumnCount - && Objects.equals(aggregations, that.aggregations); + && Objects.equals(aggregations, that.aggregations) + && Objects.equals(aggregationSchema, that.aggregationSchema); } @Override diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/HoppingWindowExpression.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/HoppingWindowExpression.java similarity index 72% rename from ksql-parser/src/main/java/io/confluent/ksql/parser/tree/HoppingWindowExpression.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/windows/HoppingWindowExpression.java index 95f33723120d..01f4ce062884 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/HoppingWindowExpression.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/HoppingWindowExpression.java @@ -13,13 +13,11 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.parser.tree; +package io.confluent.ksql.execution.windows; import static java.util.Objects.requireNonNull; import com.google.errorprone.annotations.Immutable; -import io.confluent.ksql.GenericRow; -import io.confluent.ksql.function.UdafAggregator; import io.confluent.ksql.model.WindowType; import io.confluent.ksql.parser.NodeLocation; import io.confluent.ksql.serde.WindowInfo; @@ -27,12 +25,6 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.TimeUnit; -import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; -import org.apache.kafka.streams.kstream.KGroupedStream; -import org.apache.kafka.streams.kstream.KTable; -import org.apache.kafka.streams.kstream.Materialized; -import org.apache.kafka.streams.kstream.TimeWindows; @Immutable public class HoppingWindowExpression extends KsqlWindowExpression { @@ -73,8 +65,24 @@ public WindowInfo getWindowInfo() { ); } + public TimeUnit getSizeUnit() { + return sizeUnit; + } + + public long getSize() { + return size; + } + + public TimeUnit getAdvanceByUnit() { + return advanceByUnit; + } + + public long getAdvanceBy() { + return advanceBy; + } + @Override - public R accept(final AstVisitor visitor, final C context) { + public R accept(final WindowVisitor visitor, final C context) { return visitor.visitHoppingWindowExpression(this, context); } @@ -102,21 +110,4 @@ public boolean equals(final Object o) { && hoppingWindowExpression.advanceBy == advanceBy && hoppingWindowExpression .advanceByUnit == advanceByUnit; } - - @SuppressWarnings("unchecked") - @Override - public KTable applyAggregate( - final KGroupedStream groupedStream, - final Initializer initializer, - final UdafAggregator aggregator, - final Materialized materialized - ) { - final TimeWindows windows = TimeWindows - .of(Duration.ofMillis(sizeUnit.toMillis(size))) - .advanceBy(Duration.ofMillis(advanceByUnit.toMillis(advanceBy))); - - return groupedStream - .windowedBy(windows) - .aggregate(initializer, aggregator, materialized); - } } diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/KsqlWindowExpression.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/KsqlWindowExpression.java new file mode 100644 index 000000000000..f15717187065 --- /dev/null +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/KsqlWindowExpression.java @@ -0,0 +1,34 @@ +/* + * Copyright 2018 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.windows; + +import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.parser.Node; +import io.confluent.ksql.parser.NodeLocation; +import io.confluent.ksql.serde.WindowInfo; +import java.util.Optional; + +@Immutable +public abstract class KsqlWindowExpression extends Node { + + KsqlWindowExpression(final Optional nodeLocation) { + super(nodeLocation); + } + + public abstract WindowInfo getWindowInfo(); + + public abstract R accept(WindowVisitor visitor, C context); +} diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/SessionWindowExpression.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/SessionWindowExpression.java similarity index 66% rename from ksql-parser/src/main/java/io/confluent/ksql/parser/tree/SessionWindowExpression.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/windows/SessionWindowExpression.java index f8c4154f5b01..ed22f3f4d06e 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/SessionWindowExpression.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/SessionWindowExpression.java @@ -13,26 +13,17 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.parser.tree; +package io.confluent.ksql.execution.windows; import static java.util.Objects.requireNonNull; import com.google.errorprone.annotations.Immutable; -import io.confluent.ksql.GenericRow; -import io.confluent.ksql.function.UdafAggregator; import io.confluent.ksql.model.WindowType; import io.confluent.ksql.parser.NodeLocation; import io.confluent.ksql.serde.WindowInfo; -import java.time.Duration; import java.util.Objects; import java.util.Optional; import java.util.concurrent.TimeUnit; -import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; -import org.apache.kafka.streams.kstream.KGroupedStream; -import org.apache.kafka.streams.kstream.KTable; -import org.apache.kafka.streams.kstream.Materialized; -import org.apache.kafka.streams.kstream.SessionWindows; @Immutable public class SessionWindowExpression extends KsqlWindowExpression { @@ -54,13 +45,21 @@ public SessionWindowExpression( this.sizeUnit = requireNonNull(sizeUnit, "sizeUnit"); } + public TimeUnit getSizeUnit() { + return sizeUnit; + } + + public long getGap() { + return gap; + } + @Override public WindowInfo getWindowInfo() { return WindowInfo.of(WindowType.SESSION, Optional.empty()); } @Override - public R accept(final AstVisitor visitor, final C context) { + public R accept(final WindowVisitor visitor, final C context) { return visitor.visitSessionWindowExpression(this, context); } @@ -85,18 +84,4 @@ public boolean equals(final Object o) { final SessionWindowExpression sessionWindowExpression = (SessionWindowExpression) o; return sessionWindowExpression.gap == gap && sessionWindowExpression.sizeUnit == sizeUnit; } - - @SuppressWarnings("unchecked") - @Override - public KTable applyAggregate(final KGroupedStream groupedStream, - final Initializer initializer, - final UdafAggregator aggregator, - final Materialized materialized) { - - final SessionWindows windows = SessionWindows.with(Duration.ofMillis(sizeUnit.toMillis(gap))); - - return groupedStream - .windowedBy(windows) - .aggregate(initializer, aggregator, aggregator.getMerger(), materialized); - } } diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/TumblingWindowExpression.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/TumblingWindowExpression.java similarity index 70% rename from ksql-parser/src/main/java/io/confluent/ksql/parser/tree/TumblingWindowExpression.java rename to ksql-execution/src/main/java/io/confluent/ksql/execution/windows/TumblingWindowExpression.java index 6db9121cbfe5..45b5602dcad2 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/TumblingWindowExpression.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/TumblingWindowExpression.java @@ -13,13 +13,11 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.parser.tree; +package io.confluent.ksql.execution.windows; import static java.util.Objects.requireNonNull; import com.google.errorprone.annotations.Immutable; -import io.confluent.ksql.GenericRow; -import io.confluent.ksql.function.UdafAggregator; import io.confluent.ksql.model.WindowType; import io.confluent.ksql.parser.NodeLocation; import io.confluent.ksql.serde.WindowInfo; @@ -27,12 +25,6 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.TimeUnit; -import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; -import org.apache.kafka.streams.kstream.KGroupedStream; -import org.apache.kafka.streams.kstream.KTable; -import org.apache.kafka.streams.kstream.Materialized; -import org.apache.kafka.streams.kstream.TimeWindows; @Immutable public class TumblingWindowExpression extends KsqlWindowExpression { @@ -62,8 +54,16 @@ public WindowInfo getWindowInfo() { ); } + public TimeUnit getSizeUnit() { + return sizeUnit; + } + + public long getSize() { + return size; + } + @Override - public R accept(final AstVisitor visitor, final C context) { + public R accept(final WindowVisitor visitor, final C context) { return visitor.visitTumblingWindowExpression(this, context); } @@ -88,19 +88,4 @@ public boolean equals(final Object o) { final TumblingWindowExpression tumblingWindowExpression = (TumblingWindowExpression) o; return tumblingWindowExpression.size == size && tumblingWindowExpression.sizeUnit == sizeUnit; } - - @SuppressWarnings("unchecked") - @Override - public KTable applyAggregate(final KGroupedStream groupedStream, - final Initializer initializer, - final UdafAggregator aggregator, - final Materialized materialized) { - - final TimeWindows windows = TimeWindows.of(Duration.ofMillis(sizeUnit.toMillis(size))); - - return groupedStream - .windowedBy(windows) - .aggregate(initializer, aggregator, materialized); - - } } diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/WindowVisitor.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/WindowVisitor.java new file mode 100644 index 000000000000..f32ffb2d088f --- /dev/null +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/windows/WindowVisitor.java @@ -0,0 +1,24 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.windows; + +public interface WindowVisitor { + R visitHoppingWindowExpression(HoppingWindowExpression hoppingWindowExpression, C ctx); + + R visitSessionWindowExpression(SessionWindowExpression sessionWindowExpression, C ctx); + + R visitTumblingWindowExpression(TumblingWindowExpression tumblingWindowExpression, C ctx); +} diff --git a/ksql-execution/src/test/java/io/confluent/ksql/execution/function/UdafUtilTest.java b/ksql-execution/src/test/java/io/confluent/ksql/execution/function/UdafUtilTest.java new file mode 100644 index 000000000000..9c03a1b44f4a --- /dev/null +++ b/ksql-execution/src/test/java/io/confluent/ksql/execution/function/UdafUtilTest.java @@ -0,0 +1,111 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.function; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.function.AggregateFunctionArguments; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.function.KsqlAggregateFunction; +import io.confluent.ksql.name.ColumnName; +import io.confluent.ksql.name.FunctionName; +import io.confluent.ksql.schema.ksql.ColumnRef; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import org.apache.kafka.connect.data.Schema; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class UdafUtilTest { + private static final LogicalSchema SCHEMA = LogicalSchema.builder() + .valueColumn(ColumnName.of("FOO"), SqlTypes.INTEGER) + .valueColumn(ColumnName.of("BAR"), SqlTypes.BIGINT) + .build(); + private static final FunctionCall FUNCTION_CALL = new FunctionCall( + FunctionName.of("AGG"), + ImmutableList.of(new ColumnReferenceExp(ColumnRef.of("BAR"))) + ); + + @Mock + private FunctionRegistry functionRegistry; + @Mock + private KsqlAggregateFunction function; + @Mock + private KsqlAggregateFunction resolved; + @Captor + private ArgumentCaptor argumentsCaptor; + + @Before + @SuppressWarnings("unchecked") + public void init() { + when(functionRegistry.getAggregate(any(), any())).thenReturn(function); + when(function.getInstance(any())).thenReturn(resolved); + } + + @Test + public void shouldResolveUDAF() { + // When: + final KsqlAggregateFunction returned = + UdafUtil.resolveAggregateFunction(functionRegistry, FUNCTION_CALL, SCHEMA); + + // Then: + assertThat(returned, is(resolved)); + } + + @Test + public void shouldGetAggregateWithCorrectName() { + // When: + UdafUtil.resolveAggregateFunction(functionRegistry, FUNCTION_CALL, SCHEMA); + + // Then: + verify(functionRegistry).getAggregate(eq("AGG"), any()); + } + + @Test + public void shouldGetAggregateWithCorrectType() { + // When: + UdafUtil.resolveAggregateFunction(functionRegistry, FUNCTION_CALL, SCHEMA); + + // Then: + verify(functionRegistry).getAggregate(any(), eq(Schema.OPTIONAL_INT64_SCHEMA)); + } + + @Test + public void shouldResolveWithCorrectArgs() { + // When: + UdafUtil.resolveAggregateFunction(functionRegistry, FUNCTION_CALL, SCHEMA); + + // Then: + verify(function).getInstance(argumentsCaptor.capture()); + final AggregateFunctionArguments arguments = argumentsCaptor.getValue(); + assertThat(arguments.udafIndex(), equalTo(1)); + } +} \ No newline at end of file diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/udaf/window/WindowSelectMapperTest.java b/ksql-execution/src/test/java/io/confluent/ksql/execution/function/udaf/window/WindowSelectMapperTest.java similarity index 96% rename from ksql-engine/src/test/java/io/confluent/ksql/function/udaf/window/WindowSelectMapperTest.java rename to ksql-execution/src/test/java/io/confluent/ksql/execution/function/udaf/window/WindowSelectMapperTest.java index dba6ef59cf22..1b960b296e56 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/udaf/window/WindowSelectMapperTest.java +++ b/ksql-execution/src/test/java/io/confluent/ksql/execution/function/udaf/window/WindowSelectMapperTest.java @@ -13,7 +13,7 @@ * specific language governing permissions and limitations under the License. */ -package io.confluent.ksql.function.udaf.window; +package io.confluent.ksql.execution.function.udaf.window; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper; import io.confluent.ksql.function.KsqlAggregateFunction; import java.util.ArrayList; import java.util.Arrays; diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java b/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java index 019e73e1ccdf..4c207ab858eb 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java +++ b/ksql-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java @@ -51,6 +51,9 @@ import io.confluent.ksql.execution.expression.tree.TimeLiteral; import io.confluent.ksql.execution.expression.tree.TimestampLiteral; import io.confluent.ksql.execution.expression.tree.WhenClause; +import io.confluent.ksql.execution.windows.HoppingWindowExpression; +import io.confluent.ksql.execution.windows.SessionWindowExpression; +import io.confluent.ksql.execution.windows.TumblingWindowExpression; import io.confluent.ksql.metastore.MetaStore; import io.confluent.ksql.metastore.model.DataSource; import io.confluent.ksql.name.ColumnName; @@ -88,7 +91,6 @@ import io.confluent.ksql.parser.tree.Explain; import io.confluent.ksql.parser.tree.GroupBy; import io.confluent.ksql.parser.tree.GroupingElement; -import io.confluent.ksql.parser.tree.HoppingWindowExpression; import io.confluent.ksql.parser.tree.InsertInto; import io.confluent.ksql.parser.tree.InsertValues; import io.confluent.ksql.parser.tree.Join; @@ -111,7 +113,6 @@ import io.confluent.ksql.parser.tree.RunScript; import io.confluent.ksql.parser.tree.Select; import io.confluent.ksql.parser.tree.SelectItem; -import io.confluent.ksql.parser.tree.SessionWindowExpression; import io.confluent.ksql.parser.tree.SetProperty; import io.confluent.ksql.parser.tree.ShowColumns; import io.confluent.ksql.parser.tree.SimpleGroupBy; @@ -123,7 +124,6 @@ import io.confluent.ksql.parser.tree.TableElement.Namespace; import io.confluent.ksql.parser.tree.TableElements; import io.confluent.ksql.parser.tree.TerminateQuery; -import io.confluent.ksql.parser.tree.TumblingWindowExpression; import io.confluent.ksql.parser.tree.UnsetProperty; import io.confluent.ksql.parser.tree.WindowExpression; import io.confluent.ksql.parser.tree.WithinExpression; diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/AstVisitor.java b/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/AstVisitor.java index f232dda4adc3..64ef262434b5 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/AstVisitor.java +++ b/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/AstVisitor.java @@ -92,22 +92,6 @@ protected R visitWindowExpression(final WindowExpression node, final C context) return visitNode(node, context); } - protected R visitKsqlWindowExpression(final KsqlWindowExpression node, final C context) { - return visitNode(node, context); - } - - protected R visitTumblingWindowExpression(final TumblingWindowExpression node, final C context) { - return visitKsqlWindowExpression(node, context); - } - - protected R visitHoppingWindowExpression(final HoppingWindowExpression node, final C context) { - return visitKsqlWindowExpression(node, context); - } - - protected R visitSessionWindowExpression(final SessionWindowExpression node, final C context) { - return visitKsqlWindowExpression(node, context); - } - protected R visitTableElement(final TableElement node, final C context) { return visitNode(node, context); } diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/KsqlWindowExpression.java b/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/KsqlWindowExpression.java deleted file mode 100644 index c7a48b6a7211..000000000000 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/KsqlWindowExpression.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright 2018 Confluent Inc. - * - * Licensed under the Confluent Community License (the "License"); you may not use - * this file except in compliance with the License. You may obtain a copy of the - * License at - * - * http://www.confluent.io/confluent-community-license - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package io.confluent.ksql.parser.tree; - -import com.google.errorprone.annotations.Immutable; -import io.confluent.ksql.GenericRow; -import io.confluent.ksql.function.UdafAggregator; -import io.confluent.ksql.parser.NodeLocation; -import io.confluent.ksql.serde.WindowInfo; -import java.util.Optional; -import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; -import org.apache.kafka.streams.kstream.KGroupedStream; -import org.apache.kafka.streams.kstream.KTable; -import org.apache.kafka.streams.kstream.Materialized; - -@Immutable -public abstract class KsqlWindowExpression extends AstNode { - - KsqlWindowExpression(final Optional location) { - super(location); - } - - public abstract KTable applyAggregate(KGroupedStream groupedStream, - Initializer initializer, - UdafAggregator aggregator, - Materialized materialized); - - public abstract WindowInfo getWindowInfo(); - - @Override - public R accept(final AstVisitor visitor, final C context) { - return visitor.visitKsqlWindowExpression(this, context); - } -} diff --git a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/WindowExpression.java b/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/WindowExpression.java index 8d889157ca3d..32bdaae39da2 100644 --- a/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/WindowExpression.java +++ b/ksql-parser/src/main/java/io/confluent/ksql/parser/tree/WindowExpression.java @@ -18,6 +18,7 @@ import static java.util.Objects.requireNonNull; import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.parser.NodeLocation; import java.util.Objects; import java.util.Optional; diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/HoppingWindowExpressionTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/HoppingWindowExpressionTest.java index edf41face930..a9112715447b 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/HoppingWindowExpressionTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/HoppingWindowExpressionTest.java @@ -27,6 +27,7 @@ import com.google.common.testing.EqualsTester; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.windows.HoppingWindowExpression; import io.confluent.ksql.function.UdafAggregator; import io.confluent.ksql.model.WindowType; import io.confluent.ksql.parser.NodeLocation; @@ -54,27 +55,6 @@ public class HoppingWindowExpressionTest { public static final NodeLocation SOME_LOCATION = new NodeLocation(0, 0); public static final NodeLocation OTHER_LOCATION = new NodeLocation(1, 0); - @Mock - private KGroupedStream stream; - @Mock - private TimeWindowedKStream windowedKStream; - @Mock - private UdafAggregator aggregator; - @Mock - private Initializer initializer; - @Mock - private Materialized> store; - private HoppingWindowExpression windowExpression; - - @Before - public void setUp() { - windowExpression = new HoppingWindowExpression(10, SECONDS, 4, TimeUnit.MILLISECONDS); - - when(stream - .windowedBy(any(TimeWindows.class))) - .thenReturn(windowedKStream); - } - @Test public void shouldImplementHashCodeAndEqualsProperty() { new EqualsTester() @@ -100,18 +80,6 @@ public void shouldImplementHashCodeAndEqualsProperty() { .testEquals(); } - @Test - public void shouldCreateHoppingWindowAggregate() { - // When: - windowExpression.applyAggregate(stream, initializer, aggregator, store); - - // Then: - verify(stream) - .windowedBy(TimeWindows.of(Duration.ofSeconds(10)).advanceBy(Duration.ofMillis(4L))); - - verify(windowedKStream).aggregate(initializer, aggregator, store); - } - @Test public void shouldReturnWindowInfo() { assertThat(new HoppingWindowExpression(10, SECONDS, 20, MINUTES).getWindowInfo(), diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/ParserModelTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/ParserModelTest.java index e492ec627977..f9ab9730daae 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/ParserModelTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/ParserModelTest.java @@ -32,6 +32,8 @@ import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.execution.expression.tree.StringLiteral; import io.confluent.ksql.execution.expression.tree.Type; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; +import io.confluent.ksql.execution.windows.TumblingWindowExpression; import io.confluent.ksql.parser.properties.with.CreateSourceAsProperties; import io.confluent.ksql.parser.properties.with.CreateSourceProperties; import io.confluent.ksql.properties.with.CommonCreateConfigs; diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/SessionWindowExpressionTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/SessionWindowExpressionTest.java index 68fee27e3d6a..0740e2d0f0fb 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/SessionWindowExpressionTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/SessionWindowExpressionTest.java @@ -22,6 +22,7 @@ import static org.mockito.Mockito.when; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.windows.SessionWindowExpression; import io.confluent.ksql.function.UdafAggregator; import io.confluent.ksql.model.WindowType; import io.confluent.ksql.serde.WindowInfo; @@ -46,39 +47,11 @@ @RunWith(MockitoJUnitRunner.class) public class SessionWindowExpressionTest { - @Mock - private KGroupedStream stream; - @Mock - private SessionWindowedKStream windowedKStream; - @Mock - private UdafAggregator aggregator; - @Mock - private Initializer initializer; - @Mock - private Materialized> store; - @Mock - private Merger merger; private SessionWindowExpression windowExpression; @Before public void setUp() { windowExpression = new SessionWindowExpression(5, TimeUnit.SECONDS); - - when(stream - .windowedBy(any(SessionWindows.class))) - .thenReturn(windowedKStream); - - when(aggregator.getMerger()).thenReturn(merger); - } - - @Test - public void shouldCreateSessionWindowed() { - // When: - windowExpression.applyAggregate(stream, initializer, aggregator, store); - - // Then: - verify(stream).windowedBy(SessionWindows.with(Duration.ofSeconds(5))); - verify(windowedKStream).aggregate(initializer, aggregator, merger, store); } @Test diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/TumblingWindowExpressionTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/TumblingWindowExpressionTest.java index 8c587e13b0c0..158f343a57e4 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/TumblingWindowExpressionTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/TumblingWindowExpressionTest.java @@ -18,65 +18,17 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import io.confluent.ksql.GenericRow; -import io.confluent.ksql.function.UdafAggregator; +import io.confluent.ksql.execution.windows.TumblingWindowExpression; import io.confluent.ksql.model.WindowType; import io.confluent.ksql.serde.WindowInfo; import java.time.Duration; import java.util.Optional; -import java.util.concurrent.TimeUnit; -import org.apache.kafka.common.utils.Bytes; -import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; -import org.apache.kafka.streams.kstream.KGroupedStream; -import org.apache.kafka.streams.kstream.Materialized; -import org.apache.kafka.streams.kstream.TimeWindowedKStream; -import org.apache.kafka.streams.kstream.TimeWindows; -import org.apache.kafka.streams.state.WindowStore; -import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; @RunWith(MockitoJUnitRunner.class) public class TumblingWindowExpressionTest { - - @Mock - private KGroupedStream stream; - @Mock - private TimeWindowedKStream windowedKStream; - @Mock - private UdafAggregator aggregator; - @Mock - private Initializer initializer; - @Mock - private Materialized> store; - private TumblingWindowExpression windowExpression; - - @Before - public void setUp() { - windowExpression = new TumblingWindowExpression(10, TimeUnit.SECONDS); - - when(stream - .windowedBy(any(TimeWindows.class))) - .thenReturn(windowedKStream); - } - - @Test - public void shouldCreateTumblingWindowAggregate() { - // When: - windowExpression.applyAggregate(stream, initializer, aggregator, store); - - // Then: - verify(stream).windowedBy(TimeWindows.of(Duration.ofSeconds(10))); - verify(windowedKStream).aggregate(initializer, aggregator, store); - } - @Test public void shouldReturnWindowInfo() { assertThat(new TumblingWindowExpression(11, SECONDS).getWindowInfo(), diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/WindowExpressionTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/WindowExpressionTest.java index 0bf814abed79..f745415fbe7a 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/WindowExpressionTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/WindowExpressionTest.java @@ -18,6 +18,7 @@ import static org.mockito.Mockito.mock; import com.google.common.testing.EqualsTester; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.parser.NodeLocation; import java.util.Optional; import org.junit.Test; diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateBuilderUtils.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateBuilderUtils.java new file mode 100644 index 000000000000..5d0c5d4a534e --- /dev/null +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateBuilderUtils.java @@ -0,0 +1,57 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.streams; + +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; +import io.confluent.ksql.serde.KeySerde; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.state.KeyValueStore; + +public final class AggregateBuilderUtils { + private AggregateBuilderUtils() { + } + + static Materialized> buildMaterialized( + final QueryContext queryContext, + final LogicalSchema aggregateSchema, + final Formats formats, + final KsqlQueryBuilder queryBuilder, + final MaterializedFactory materializedFactory) { + final PhysicalSchema physicalAggregationSchema = PhysicalSchema.from( + aggregateSchema, + formats.getOptions() + ); + final KeySerde keySerde = queryBuilder.buildKeySerde( + formats.getKeyFormat().getFormatInfo(), + physicalAggregationSchema, + queryContext + ); + final Serde valueSerde = queryBuilder.buildValueSerde( + formats.getValueFormat().getFormatInfo(), + physicalAggregationSchema, + queryContext + ); + return materializedFactory.create(keySerde, valueSerde, StreamsUtil.buildOpName(queryContext)); + } +} diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateParams.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateParams.java new file mode 100644 index 000000000000..145dcdf1c18d --- /dev/null +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateParams.java @@ -0,0 +1,96 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.streams; + +import com.google.common.collect.ImmutableMap; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.function.TableAggregationFunction; +import io.confluent.ksql.execution.function.UdafUtil; +import io.confluent.ksql.execution.function.udaf.KudafAggregator; +import io.confluent.ksql.execution.function.udaf.KudafInitializer; +import io.confluent.ksql.execution.function.udaf.KudafUndoAggregator; +import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.function.KsqlAggregateFunction; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +public final class AggregateParams { + private final KudafInitializer initializer; + private final int initialUdafIndex; + private final Map indexToFunction; + + AggregateParams( + final LogicalSchema internalSchema, + final int initialUdafIndex, + final FunctionRegistry functionRegistry, + final List functionList + ) { + final List initialValueSuppliers = new LinkedList<>(); + int udafIndexInAggSchema = initialUdafIndex; + final Map indexToFunction = new HashMap<>(); + for (final FunctionCall functionCall : functionList) { + final KsqlAggregateFunction aggregateFunction = UdafUtil.resolveAggregateFunction( + functionRegistry, + functionCall, + internalSchema + ); + + indexToFunction.put(udafIndexInAggSchema++, aggregateFunction); + initialValueSuppliers.add(aggregateFunction.getInitialValueSupplier()); + } + this.initialUdafIndex = initialUdafIndex; + this.initializer = new KudafInitializer(initialUdafIndex, initialValueSuppliers); + this.indexToFunction = ImmutableMap.copyOf(indexToFunction); + } + + public KudafInitializer getInitializer() { + return initializer; + } + + public KudafAggregator getAggregator() { + return new KudafAggregator(initialUdafIndex, indexToFunction); + } + + public KudafUndoAggregator getUndoAggregator() { + final Map indexToUndo = + indexToFunction.keySet() + .stream() + .collect( + Collectors.toMap( + k -> k, + k -> ((TableAggregationFunction) indexToFunction.get(k)))); + return new KudafUndoAggregator(initialUdafIndex, indexToUndo); + } + + public WindowSelectMapper getWindowSelectMapper() { + return new WindowSelectMapper(indexToFunction); + } + + public interface Factory { + AggregateParams create( + LogicalSchema internalSchema, + int initialUdafIndex, + FunctionRegistry functionRegistry, + List functionList + ); + } +} diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java index 9a74adcff112..88b96aa41255 100644 --- a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java @@ -36,12 +36,14 @@ import io.confluent.ksql.execution.plan.StreamStreamJoin; import io.confluent.ksql.execution.plan.StreamTableJoin; import io.confluent.ksql.execution.plan.StreamToTable; +import io.confluent.ksql.execution.plan.StreamWindowedAggregate; import io.confluent.ksql.execution.plan.TableAggregate; import io.confluent.ksql.execution.plan.TableFilter; import io.confluent.ksql.execution.plan.TableGroupBy; import io.confluent.ksql.execution.plan.TableMapValues; import io.confluent.ksql.execution.plan.TableSink; import io.confluent.ksql.execution.plan.TableTableJoin; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.util.timestamp.TimestampExtractionPolicy; @@ -302,22 +304,45 @@ public static TableTableJoin tableTableJoin( ); } - public static StreamAggregate, KGroupedStream> - streamAggregate( - final QueryContext.Stacker stacker, - final ExecutionStep> sourceStep, - final LogicalSchema resultSchema, - final Formats formats, - final int nonFuncColumnCount, - final List aggregations + public static StreamAggregate streamAggregate( + final QueryContext.Stacker stacker, + final ExecutionStep> sourceStep, + final LogicalSchema resultSchema, + final Formats formats, + final int nonFuncColumnCount, + final List aggregations, + final LogicalSchema aggregateSchema + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new StreamAggregate( + new DefaultExecutionStepProperties(resultSchema, queryContext), + sourceStep, + formats, + nonFuncColumnCount, + aggregations, + aggregateSchema + ); + } + + public static StreamWindowedAggregate streamWindowedAggregate( + final QueryContext.Stacker stacker, + final ExecutionStep> sourceStep, + final LogicalSchema resultSchema, + final Formats formats, + final int nonFuncColumnCount, + final List aggregations, + final LogicalSchema aggregateSchema, + final KsqlWindowExpression window ) { final QueryContext queryContext = stacker.getQueryContext(); - return new StreamAggregate<>( + return new StreamWindowedAggregate( new DefaultExecutionStepProperties(resultSchema, queryContext), sourceStep, formats, nonFuncColumnCount, - aggregations + aggregations, + aggregateSchema, + window ); } @@ -349,22 +374,23 @@ public static StreamGroupByKey streamGroupByKey( ); } - public static TableAggregate, KGroupedTable> - tableAggregate( - final QueryContext.Stacker stacker, - final ExecutionStep> sourceStep, - final LogicalSchema resultSchema, - final Formats formats, - final int nonFuncColumnCount, - final List aggregations + public static TableAggregate tableAggregate( + final QueryContext.Stacker stacker, + final ExecutionStep> sourceStep, + final LogicalSchema resultSchema, + final Formats formats, + final int nonFuncColumnCount, + final List aggregations, + final LogicalSchema aggregateSchema ) { final QueryContext queryContext = stacker.getQueryContext(); - return new TableAggregate<>( + return new TableAggregate( new DefaultExecutionStepProperties(resultSchema, queryContext), sourceStep, formats, nonFuncColumnCount, - aggregations + aggregations, + aggregateSchema ); } diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamAggregateBuilder.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamAggregateBuilder.java new file mode 100644 index 000000000000..56ac16b5c07c --- /dev/null +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/StreamAggregateBuilder.java @@ -0,0 +1,227 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.streams; + +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.StreamAggregate; +import io.confluent.ksql.execution.plan.StreamWindowedAggregate; +import io.confluent.ksql.execution.windows.HoppingWindowExpression; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; +import io.confluent.ksql.execution.windows.SessionWindowExpression; +import io.confluent.ksql.execution.windows.TumblingWindowExpression; +import io.confluent.ksql.execution.windows.WindowVisitor; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; +import io.confluent.ksql.serde.KeySerde; +import java.time.Duration; +import java.util.Objects; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.KeyValueStore; + +public final class StreamAggregateBuilder { + private StreamAggregateBuilder() { + } + + public static KTable build( + final KGroupedStream groupedStream, + final StreamAggregate aggregate, + final KsqlQueryBuilder queryBuilder, + final MaterializedFactory materializedFactory) { + return build(groupedStream, aggregate, queryBuilder, materializedFactory, AggregateParams::new); + } + + static KTable build( + final KGroupedStream kgroupedStream, + final StreamAggregate aggregate, + final KsqlQueryBuilder queryBuilder, + final MaterializedFactory materializedFactory, + final AggregateParams.Factory aggregateParamsFactory) { + final LogicalSchema sourceSchema = aggregate.getSources().get(0).getSchema(); + final int nonFuncColumns = aggregate.getNonFuncColumnCount(); + final AggregateParams aggregateParams = aggregateParamsFactory.create( + sourceSchema, + nonFuncColumns, + queryBuilder.getFunctionRegistry(), + aggregate.getAggregations() + ); + final Materialized> materialized = + AggregateBuilderUtils.buildMaterialized( + aggregate.getProperties().getQueryContext(), + aggregate.getAggregationSchema(), + aggregate.getFormats(), + queryBuilder, + materializedFactory + ); + final KTable aggregated = kgroupedStream.aggregate( + aggregateParams.getInitializer(), + aggregateParams.getAggregator(), + materialized + ); + return aggregated.mapValues(aggregateParams.getAggregator().getResultMapper()); + } + + public static KTable, GenericRow> build( + final KGroupedStream groupedStream, + final StreamWindowedAggregate aggregate, + final KsqlQueryBuilder queryBuilder, + final MaterializedFactory materializedFactory + ) { + return build(groupedStream, aggregate, queryBuilder, materializedFactory, AggregateParams::new); + } + + static KTable, GenericRow> build( + final KGroupedStream groupedStream, + final StreamWindowedAggregate aggregate, + final KsqlQueryBuilder queryBuilder, + final MaterializedFactory materializedFactory, + final AggregateParams.Factory aggregateParamsFactory + ) { + final LogicalSchema sourceSchema = aggregate.getSources().get(0).getSchema(); + final int nonFuncColumns = aggregate.getNonFuncColumnCount(); + final AggregateParams aggregateParams = aggregateParamsFactory.create( + sourceSchema, + nonFuncColumns, + queryBuilder.getFunctionRegistry(), + aggregate.getAggregations() + ); + final KsqlWindowExpression ksqlWindowExpression = aggregate.getWindowExpression(); + final KTable, GenericRow> aggregated = ksqlWindowExpression.accept( + new WindowedAggregator( + groupedStream, + aggregate, + queryBuilder, + materializedFactory, + aggregateParams + ), + null + ); + final KTable, GenericRow> reduced = aggregated.mapValues( + aggregateParams.getAggregator().getResultMapper() + ); + final WindowSelectMapper windowSelectMapper = aggregateParams.getWindowSelectMapper(); + if (!windowSelectMapper.hasSelects()) { + return reduced; + } + return reduced.mapValues(windowSelectMapper); + } + + private static class WindowedAggregator + implements WindowVisitor, GenericRow>, Void> { + final QueryContext queryContext; + final Formats formats; + final KGroupedStream groupedStream; + final KsqlQueryBuilder queryBuilder; + final MaterializedFactory materializedFactory; + final KeySerde keySerde; + final Serde valueSerde; + final AggregateParams aggregateParams; + + WindowedAggregator( + final KGroupedStream groupedStream, + final StreamWindowedAggregate aggregate, + final KsqlQueryBuilder queryBuilder, + final MaterializedFactory materializedFactory, + final AggregateParams aggregateParams) { + Objects.requireNonNull(aggregate, "aggregate"); + this.groupedStream = Objects.requireNonNull(groupedStream, "groupedStream"); + this.queryBuilder = Objects.requireNonNull(queryBuilder, "queryBuilder"); + this.materializedFactory = Objects.requireNonNull(materializedFactory, "materializedFactory"); + this.aggregateParams = Objects.requireNonNull(aggregateParams, "aggregateParams"); + this.queryContext = aggregate.getProperties().getQueryContext(); + this.formats = aggregate.getFormats(); + final PhysicalSchema physicalSchema = PhysicalSchema.from( + aggregate.getAggregationSchema(), + formats.getOptions() + ); + keySerde = queryBuilder.buildKeySerde( + formats.getKeyFormat().getFormatInfo(), + physicalSchema, + queryContext + ); + valueSerde = queryBuilder.buildValueSerde( + formats.getValueFormat().getFormatInfo(), + physicalSchema, + queryContext + ); + } + + @Override + public KTable, GenericRow> visitHoppingWindowExpression( + final HoppingWindowExpression window, + final Void ctx) { + final TimeWindows windows = TimeWindows + .of(Duration.ofMillis(window.getSizeUnit().toMillis(window.getSize()))) + .advanceBy( + Duration.ofMillis(window.getAdvanceByUnit().toMillis(window.getAdvanceBy())) + ); + + return groupedStream + .windowedBy(windows) + .aggregate( + aggregateParams.getInitializer(), + aggregateParams.getAggregator(), + materializedFactory.create( + keySerde, valueSerde, StreamsUtil.buildOpName(queryContext)) + ); + } + + @Override + public KTable, GenericRow> visitSessionWindowExpression( + final SessionWindowExpression window, + final Void ctx) { + final SessionWindows windows = SessionWindows.with( + Duration.ofMillis(window.getSizeUnit().toMillis(window.getGap())) + ); + return groupedStream + .windowedBy(windows) + .aggregate( + aggregateParams.getInitializer(), + aggregateParams.getAggregator(), + aggregateParams.getAggregator().getMerger(), + materializedFactory.create( + keySerde, valueSerde, StreamsUtil.buildOpName(queryContext)) + ); + } + + @Override + public KTable, GenericRow> visitTumblingWindowExpression( + final TumblingWindowExpression window, + final Void ctx) { + final TimeWindows windows = TimeWindows.of( + Duration.ofMillis(window.getSizeUnit().toMillis(window.getSize()))); + return groupedStream + .windowedBy(windows) + .aggregate( + aggregateParams.getInitializer(), + aggregateParams.getAggregator(), + materializedFactory.create( + keySerde, valueSerde, StreamsUtil.buildOpName(queryContext)) + ); + } + } +} diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/TableAggregateBuilder.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/TableAggregateBuilder.java new file mode 100644 index 000000000000..89d5b8c45832 --- /dev/null +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/TableAggregateBuilder.java @@ -0,0 +1,76 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.streams; + +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.plan.TableAggregate; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.KGroupedTable; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.state.KeyValueStore; + +public final class TableAggregateBuilder { + private TableAggregateBuilder() { + } + + public static KTable build( + final KGroupedTable kgroupedTable, + final TableAggregate aggregate, + final KsqlQueryBuilder queryBuilder, + final MaterializedFactory materializedFactory) { + return build( + kgroupedTable, + aggregate, + queryBuilder, + materializedFactory, + AggregateParams::new + ); + } + + public static KTable build( + final KGroupedTable kgroupedTable, + final TableAggregate aggregate, + final KsqlQueryBuilder queryBuilder, + final MaterializedFactory materializedFactory, + final AggregateParams.Factory aggregateParamsFactory) { + final LogicalSchema sourceSchema = aggregate.getSources().get(0).getSchema(); + final int nonFuncColumns = aggregate.getNonFuncColumnCount(); + final AggregateParams aggregateParams = aggregateParamsFactory.create( + sourceSchema, + nonFuncColumns, + queryBuilder.getFunctionRegistry(), + aggregate.getAggregations() + ); + final Materialized> materialized = + AggregateBuilderUtils.buildMaterialized( + aggregate.getProperties().getQueryContext(), + aggregate.getAggregationSchema(), + aggregate.getFormats(), + queryBuilder, + materializedFactory + ); + return kgroupedTable.aggregate( + aggregateParams.getInitializer(), + aggregateParams.getAggregator(), + aggregateParams.getUndoAggregator(), + materialized + ).mapValues(aggregateParams.getAggregator().getResultMapper()); + } +} \ No newline at end of file diff --git a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/AggregateParamsTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/AggregateParamsTest.java new file mode 100644 index 000000000000..ffbd3a84add0 --- /dev/null +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/AggregateParamsTest.java @@ -0,0 +1,180 @@ +package io.confluent.ksql.execution.streams; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.function.TableAggregationFunction; +import io.confluent.ksql.execution.function.udaf.KudafAggregator; +import io.confluent.ksql.execution.function.udaf.KudafInitializer; +import io.confluent.ksql.execution.function.udaf.KudafUndoAggregator; +import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.function.KsqlAggregateFunction; +import io.confluent.ksql.name.ColumnName; +import io.confluent.ksql.name.FunctionName; +import io.confluent.ksql.schema.ksql.ColumnRef; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import java.util.List; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class AggregateParamsTest { + private static final LogicalSchema INPUT_SCHEMA = LogicalSchema.builder() + .valueColumn(ColumnName.of("REQUIRED0"), SqlTypes.BIGINT) + .valueColumn(ColumnName.of("REQUIRED1"), SqlTypes.STRING) + .valueColumn(ColumnName.of("ARGUMENT0"), SqlTypes.INTEGER) + .valueColumn(ColumnName.of("ARGUMENT1"), SqlTypes.DOUBLE) + .build(); + private static final FunctionCall AGG0 = new FunctionCall( + FunctionName.of("AGG0"), + ImmutableList.of(new ColumnReferenceExp(ColumnRef.of("ARGUMENT0"))) + ); + private static final long INITIAL_VALUE0 = 123; + private static final FunctionCall AGG1 = new FunctionCall( + FunctionName.of("AGG1"), + ImmutableList.of(new ColumnReferenceExp(ColumnRef.of("ARGUMENT1"))) + ); + private static final FunctionCall TABLE_AGG = new FunctionCall( + FunctionName.of("TABLE_AGG"), + ImmutableList.of(new ColumnReferenceExp(ColumnRef.of("ARGUMENT0"))) + ); + private static final FunctionCall WINDOW_START = new FunctionCall( + FunctionName.of("WindowStart"), + ImmutableList.of(new ColumnReferenceExp(ColumnRef.of("ARGUMENT0"))) + ); + private static final String INITIAL_VALUE1 = "initial"; + private static final List FUNCTIONS = ImmutableList.of(AGG0, AGG1); + + @Mock + private FunctionRegistry functionRegistry; + @Mock + private KsqlAggregateFunction agg0; + @Mock + private KsqlAggregateFunction agg0Resolved; + @Mock + private KsqlAggregateFunction agg1; + @Mock + private KsqlAggregateFunction agg1Resolved; + @Mock + private TableAggregationFunction tableAgg; + @Mock + private KsqlAggregateFunction windowStart; + + private AggregateParams aggregateParams; + + @Before + @SuppressWarnings("unchecked") + public void init() { + when(functionRegistry.getAggregate(same(AGG0.getName().name()), any())).thenReturn(agg0); + when(agg0Resolved.getInitialValueSupplier()).thenReturn(() -> INITIAL_VALUE0); + when(agg0Resolved.getFunctionName()).thenReturn(AGG0.getName().name()); + when(agg0.getInstance(any())).thenReturn(agg0Resolved); + when(functionRegistry.getAggregate(same(AGG1.getName().name()), any())).thenReturn(agg1); + when(agg1Resolved.getInitialValueSupplier()).thenReturn(() -> INITIAL_VALUE1); + when(agg1Resolved.getFunctionName()).thenReturn(AGG1.getName().name()); + when(agg1.getInstance(any())).thenReturn(agg1Resolved); + when(functionRegistry.getAggregate(same(TABLE_AGG.getName().name()), any())) + .thenReturn(tableAgg); + when(tableAgg.getInitialValueSupplier()).thenReturn(() -> INITIAL_VALUE0); + when(tableAgg.getInstance(any())).thenReturn(tableAgg); + when(functionRegistry.getAggregate(same(WINDOW_START.getName().name()), any())) + .thenReturn(windowStart); + when(windowStart.getInitialValueSupplier()).thenReturn(() -> INITIAL_VALUE0); + when(windowStart.getInstance(any())).thenReturn(windowStart); + when(windowStart.getFunctionName()).thenReturn(WINDOW_START.getName().name()); + aggregateParams = new AggregateParams( + INPUT_SCHEMA, + 2, + functionRegistry, + FUNCTIONS + ); + } + + @Test + public void shouldReturnCorrectAggregator() { + // When: + final KudafAggregator aggregator = aggregateParams.getAggregator(); + + // Then: + assertThat(aggregator.getNonFuncColumnCount(), equalTo(2)); + assertThat( + aggregator.getAggValToAggFunctionMap(), + equalTo(ImmutableList.of(agg0Resolved, agg1Resolved)) + ); + } + + @Test + public void shouldReturnCorrectInitializer() { + // When: + final KudafInitializer initializer = aggregateParams.getInitializer(); + + // Then: + assertThat( + initializer.apply(), + equalTo(new GenericRow(null, null, INITIAL_VALUE0, INITIAL_VALUE1)) + ); + } + + @Test + public void shouldReturnUndoAggregator() { + // Given: + aggregateParams = + new AggregateParams(INPUT_SCHEMA, 2, functionRegistry, ImmutableList.of(TABLE_AGG)); + + // When: + final KudafUndoAggregator undoAggregator = aggregateParams.getUndoAggregator(); + + // Then: + assertThat(undoAggregator.getNonFuncColumnCount(), equalTo(2)); + assertThat( + undoAggregator.getAggValToAggFunctionMap(), + equalTo(ImmutableMap.of(2, tableAgg)) + ); + } + + @Test + public void shouldReturnCorrectWindowSelectMapperForNonWindowSelections() { + // When: + final WindowSelectMapper windowSelectMapper = aggregateParams.getWindowSelectMapper(); + + // Then: + assertThat(windowSelectMapper.hasSelects(), is(false)); + } + + @Test + public void shouldReturnCorrectWindowSelectMapperForWindowSelections() { + // Given: + aggregateParams = new AggregateParams( + INPUT_SCHEMA, + 2, + functionRegistry, + ImmutableList.of(WINDOW_START) + ); + + // When: + final WindowSelectMapper windowSelectMapper = aggregateParams.getWindowSelectMapper(); + + // Then: + final Windowed window = new Windowed<>(null, new TimeWindow(10, 20)); + assertThat( + windowSelectMapper.apply(window, new GenericRow("fiz", "baz", null)), + equalTo(new GenericRow("fiz", "baz", 10)) + ); + } +} diff --git a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamAggregateBuilderTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamAggregateBuilderTest.java new file mode 100644 index 000000000000..2e16d738ee97 --- /dev/null +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamAggregateBuilderTest.java @@ -0,0 +1,621 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.streams; + +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.function.udaf.KudafAggregator; +import io.confluent.ksql.execution.function.udaf.KudafInitializer; +import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper; +import io.confluent.ksql.execution.plan.DefaultExecutionStepProperties; +import io.confluent.ksql.execution.plan.ExecutionStep; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.StreamAggregate; +import io.confluent.ksql.execution.plan.StreamWindowedAggregate; +import io.confluent.ksql.execution.windows.HoppingWindowExpression; +import io.confluent.ksql.execution.windows.SessionWindowExpression; +import io.confluent.ksql.execution.windows.TumblingWindowExpression; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.name.ColumnName; +import io.confluent.ksql.name.FunctionName; +import io.confluent.ksql.query.QueryId; +import io.confluent.ksql.schema.ksql.ColumnRef; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.serde.Format; +import io.confluent.ksql.serde.FormatInfo; +import io.confluent.ksql.serde.KeyFormat; +import io.confluent.ksql.serde.KeySerde; +import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.serde.ValueFormat; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Merger; +import org.apache.kafka.streams.kstream.SessionWindowedKStream; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.TimeWindowedKStream; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.ValueMapper; +import org.apache.kafka.streams.kstream.ValueMapperWithKey; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.Windows; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.WindowStore; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class StreamAggregateBuilderTest { + private static final LogicalSchema INPUT_SCHEMA = LogicalSchema.builder() + .valueColumn(ColumnName.of("REQUIRED0"), SqlTypes.BIGINT) + .valueColumn(ColumnName.of("REQUIRED1"), SqlTypes.STRING) + .valueColumn(ColumnName.of("ARGUMENT0"), SqlTypes.INTEGER) + .valueColumn(ColumnName.of("ARGUMENT1"), SqlTypes.DOUBLE) + .build(); + private static final LogicalSchema AGGREGATE_SCHEMA = LogicalSchema.builder() + .valueColumn(ColumnName.of("REQUIRED0"), SqlTypes.BIGINT) + .valueColumn(ColumnName.of("REQUIRED1"), SqlTypes.STRING) + .valueColumn(ColumnName.of("RESULT0"), SqlTypes.BIGINT) + .valueColumn(ColumnName.of("RESULT1"), SqlTypes.STRING) + .build(); + private static final PhysicalSchema PHYSICAL_AGGREGATE_SCHEMA = PhysicalSchema.from( + AGGREGATE_SCHEMA, + SerdeOption.none() + ); + private static final FunctionCall AGG0 = new FunctionCall( + FunctionName.of("AGG0"), + ImmutableList.of(new ColumnReferenceExp(ColumnRef.of("ARGUMENT0"))) + ); + private static final FunctionCall AGG1 = new FunctionCall( + FunctionName.of("AGG1"), + ImmutableList.of(new ColumnReferenceExp(ColumnRef.of("ARGUMENT1"))) + ); + private static final List FUNCTIONS = ImmutableList.of(AGG0, AGG1); + private static final QueryContext CTX = + new QueryContext.Stacker(new QueryId("qid")).push("agg").push("regate").getQueryContext(); + private static final KeyFormat KEY_FORMAT = KeyFormat.nonWindowed(FormatInfo.of(Format.KAFKA)); + private static final ValueFormat VALUE_FORMAT = ValueFormat.of(FormatInfo.of(Format.JSON)); + private static final Duration WINDOW = Duration.ofMillis(30000); + private static final Duration HOP = Duration.ofMillis(10000); + + @Mock + private KGroupedStream groupedStream; + @Mock + private KTable aggregated; + @Mock + private KTable aggregatedWithResults; + @Mock + private TimeWindowedKStream timeWindowedStream; + @Mock + private SessionWindowedKStream sessionWindowedStream; + @Mock + private KTable, GenericRow> windowed; + @Mock + private KTable, GenericRow> windowedWithResults; + @Mock + private KTable, GenericRow> windowedWithWindowBoundaries; + @Mock + private KsqlQueryBuilder queryBuilder; + @Mock + private FunctionRegistry functionRegistry; + @Mock + private AggregateParams.Factory aggregateParamsFactory; + @Mock + private AggregateParams aggregateParams; + @Mock + private KudafInitializer initializer; + @Mock + private KudafAggregator aggregator; + @Mock + private ValueMapper resultMapper; + @Mock + private WindowSelectMapper windowSelectMapper; + @Mock + private Merger merger; + @Mock + private MaterializedFactory materializedFactory; + @Mock + private KeySerde keySerde; + @Mock + private Serde valueSerde; + @Mock + private Materialized> materialized; + @Mock + private Materialized> timeWindowMaterialized; + @Mock + private Materialized> sessionWindowMaterialized; + @Mock + private ExecutionStep> sourceStep; + + private StreamAggregate aggregate; + private StreamWindowedAggregate windowedAggregate; + + @Before + @SuppressWarnings("unchecked") + public void init() { + when(sourceStep.getSchema()).thenReturn(INPUT_SCHEMA); + when(queryBuilder.buildKeySerde(any(), any(), any())).thenReturn(keySerde); + when(queryBuilder.buildValueSerde(any(), any(), any())).thenReturn(valueSerde); + when(queryBuilder.getFunctionRegistry()).thenReturn(functionRegistry); + when(aggregateParamsFactory.create(any(), anyInt(), any(), any())).thenReturn(aggregateParams); + when(aggregateParams.getAggregator()).thenReturn(aggregator); + when(aggregator.getMerger()).thenReturn(merger); + when(aggregator.getResultMapper()).thenReturn(resultMapper); + when(aggregateParams.getInitializer()).thenReturn(initializer); + when(aggregateParams.getWindowSelectMapper()).thenReturn(windowSelectMapper); + when(windowSelectMapper.hasSelects()).thenReturn(false); + } + + @SuppressWarnings("unchecked") + private void givenUnwindowedAggregate() { + when(materializedFactory.>create(any(), any(), any())) + .thenReturn(materialized); + when(groupedStream.aggregate(any(), any(), any(Materialized.class))).thenReturn(aggregated); + when(aggregated.mapValues(any(ValueMapper.class))).thenReturn(aggregatedWithResults); + aggregate = new StreamAggregate( + new DefaultExecutionStepProperties(INPUT_SCHEMA, CTX), + sourceStep, + Formats.of(KEY_FORMAT, VALUE_FORMAT, SerdeOption.none()), + 2, + FUNCTIONS, + AGGREGATE_SCHEMA + ); + } + + @SuppressWarnings("unchecked") + private void givenTimeWindowedAggregate() { + when(materializedFactory.>create(any(), any(), any())) + .thenReturn(timeWindowMaterialized); + when(groupedStream.windowedBy(any(Windows.class))).thenReturn(timeWindowedStream); + when(timeWindowedStream.aggregate(any(), any(), any(Materialized.class))) + .thenReturn(windowed); + when(windowed.mapValues(any(ValueMapper.class))).thenReturn(windowedWithResults); + } + + private void givenTumblingWindowedAggregate() { + givenTimeWindowedAggregate(); + windowedAggregate = new StreamWindowedAggregate( + new DefaultExecutionStepProperties(INPUT_SCHEMA, CTX), + sourceStep, + Formats.of(KEY_FORMAT, VALUE_FORMAT, SerdeOption.none()), + 2, + FUNCTIONS, + AGGREGATE_SCHEMA, + new TumblingWindowExpression(WINDOW.getSeconds(), TimeUnit.SECONDS) + ); + } + + private void givenHoppingWindowedAggregate() { + givenTimeWindowedAggregate(); + windowedAggregate = new StreamWindowedAggregate( + new DefaultExecutionStepProperties(INPUT_SCHEMA, CTX), + sourceStep, + Formats.of(KEY_FORMAT, VALUE_FORMAT, SerdeOption.none()), + 2, + FUNCTIONS, + AGGREGATE_SCHEMA, + new HoppingWindowExpression( + WINDOW.getSeconds(), + TimeUnit.SECONDS, + HOP.getSeconds(), + TimeUnit.SECONDS + ) + ); + } + + @SuppressWarnings("unchecked") + private void givenSessionWindowedAggregate() { + when(materializedFactory.>create(any(), any(), any())) + .thenReturn(sessionWindowMaterialized); + when(groupedStream.windowedBy(any(SessionWindows.class))).thenReturn(sessionWindowedStream); + when(sessionWindowedStream.aggregate(any(), any(), any(), any(Materialized.class))) + .thenReturn(windowed); + when(windowed.mapValues(any(ValueMapper.class))).thenReturn(windowedWithResults); + windowedAggregate = new StreamWindowedAggregate( + new DefaultExecutionStepProperties(INPUT_SCHEMA, CTX), + sourceStep, + Formats.of(KEY_FORMAT, VALUE_FORMAT, SerdeOption.none()), + 2, + FUNCTIONS, + AGGREGATE_SCHEMA, + new SessionWindowExpression(WINDOW.getSeconds(), TimeUnit.SECONDS) + ); + } + + @Test + public void shouldBuildUnwindowedAggregateCorrectly() { + // Given: + givenUnwindowedAggregate(); + + // When: + final KTable result = StreamAggregateBuilder.build( + groupedStream, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + assertThat(result, is(aggregatedWithResults)); + final InOrder inOrder = Mockito.inOrder(groupedStream, aggregated, aggregatedWithResults); + inOrder.verify(groupedStream).aggregate(initializer, aggregator, materialized); + inOrder.verify(aggregated).mapValues(resultMapper); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void shouldBuildMaterializedWithCorrectSerdesForUnwindowedAggregate() { + // Given: + givenUnwindowedAggregate(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(materializedFactory).create(same(keySerde), same(valueSerde), any()); + } + + @Test + public void shouldBuildMaterializedWithCorrectNameForUnwindowedAggregate() { + // Given: + givenUnwindowedAggregate(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(materializedFactory).create(any(), any(), eq("agg-regate")); + } + + @Test + public void shouldBuildKeySerdeCorrectlyForUnwindowedAggregate() { + // Given: + givenUnwindowedAggregate(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(queryBuilder).buildKeySerde(KEY_FORMAT.getFormatInfo(), PHYSICAL_AGGREGATE_SCHEMA, CTX); + } + + @Test + public void shouldBuildValueSerdeCorrectlyForUnwindowedAggregate() { + // Given: + givenUnwindowedAggregate(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(queryBuilder).buildValueSerde( + VALUE_FORMAT.getFormatInfo(), + PHYSICAL_AGGREGATE_SCHEMA, + CTX + ); + } + + @Test + public void shouldBuildAggregatorParamsCorrectlyForUnwindowedAggregate() { + // Given: + givenUnwindowedAggregate(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(aggregateParamsFactory).create(INPUT_SCHEMA, 2, functionRegistry, FUNCTIONS); + } + + @Test + public void shouldBuildTumblingWindowedAggregateCorrectly() { + // Given: + givenTumblingWindowedAggregate(); + + // When: + final KTable, GenericRow> result = StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + assertThat(result, is(windowedWithResults)); + final InOrder inOrder = Mockito.inOrder( + groupedStream, + timeWindowedStream, + windowed, + windowedWithResults + ); + inOrder.verify(groupedStream).windowedBy(TimeWindows.of(WINDOW)); + inOrder.verify(timeWindowedStream).aggregate(initializer, aggregator, timeWindowMaterialized); + inOrder.verify(windowed).mapValues(resultMapper); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void shouldBuildHoppingWindowedAggregateCorrectly() { + // Given: + givenHoppingWindowedAggregate(); + + // When: + final KTable, GenericRow> result = StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + assertThat(result, is(windowedWithResults)); + final InOrder inOrder = Mockito.inOrder( + groupedStream, + timeWindowedStream, + windowed, + windowedWithResults + ); + inOrder.verify(groupedStream).windowedBy(TimeWindows.of(WINDOW).advanceBy(HOP)); + inOrder.verify(timeWindowedStream).aggregate(initializer, aggregator, timeWindowMaterialized); + inOrder.verify(windowed).mapValues(resultMapper); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void shouldBuildSessionWindowedAggregateCorrectly() { + // Given: + givenSessionWindowedAggregate(); + + // When: + final KTable, GenericRow> result = StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + assertThat(result, is(windowedWithResults)); + final InOrder inOrder = Mockito.inOrder( + groupedStream, + sessionWindowedStream, + windowed, + windowedWithResults + ); + inOrder.verify(groupedStream).windowedBy(SessionWindows.with(WINDOW)); + inOrder.verify(sessionWindowedStream).aggregate( + initializer, + aggregator, + merger, + sessionWindowMaterialized + ); + inOrder.verify(windowed).mapValues(resultMapper); + inOrder.verifyNoMoreInteractions(); + } + + private List given() { + return ImmutableList.of( + this::givenHoppingWindowedAggregate, + this::givenTumblingWindowedAggregate, + this::givenSessionWindowedAggregate + ); + } + + @Test + public void shouldBuildMaterializedWithCorrectSerdesForWindowedAggregate() { + for (final Runnable given : given()) { + // Given: + reset(groupedStream, timeWindowedStream, sessionWindowedStream, aggregated, materializedFactory); + given.run(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(materializedFactory).create(same(keySerde), same(valueSerde), any()); + } + } + + @Test + public void shouldBuildMaterializedWithCorrectNameForWindowedAggregate() { + for (final Runnable given : given()) { + // Given: + reset(groupedStream, timeWindowedStream, sessionWindowedStream, aggregated, materializedFactory); + given.run(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(materializedFactory).create(any(), any(), eq("agg-regate")); + } + } + + @Test + public void shouldBuildKeySerdeCorrectlyForWindowedAggregate() { + for (final Runnable given : given()) { + // Given: + reset(groupedStream, timeWindowedStream, sessionWindowedStream, aggregated, queryBuilder); + given.run(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(queryBuilder) + .buildKeySerde(KEY_FORMAT.getFormatInfo(), PHYSICAL_AGGREGATE_SCHEMA, CTX); + } + } + + @Test + public void shouldBuildValueSerdeCorrectlyForWindowedAggregate() { + for (final Runnable given : given()) { + // Given: + reset(groupedStream, timeWindowedStream, sessionWindowedStream, aggregated, queryBuilder); + given.run(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(queryBuilder) + .buildValueSerde(VALUE_FORMAT.getFormatInfo(), PHYSICAL_AGGREGATE_SCHEMA, CTX); + } + } + + @Test + public void shouldBuildAggregatorParamsCorrectlyForWindowedAggregate() { + for (final Runnable given : given()) { + // Given: + reset( + groupedStream, + timeWindowedStream, + sessionWindowedStream, + aggregated, + aggregateParamsFactory + ); + when(aggregateParamsFactory.create(any(), anyInt(), any(), any())) + .thenReturn(aggregateParams); + given.run(); + + // When: + StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(aggregateParamsFactory) + .create(INPUT_SCHEMA, 2, functionRegistry, FUNCTIONS); + } + } + + @Test + @SuppressWarnings("unchecked") + public void shouldAddWindowBoundariesIfSpecified() { + for (final Runnable given : given()) { + // Given: + reset( + groupedStream, timeWindowedStream, sessionWindowedStream, windowed, windowedWithResults); + when(windowSelectMapper.hasSelects()).thenReturn(true); + when(windowedWithResults.mapValues(any(ValueMapperWithKey.class))).thenReturn( + windowedWithWindowBoundaries); + given.run(); + + // When: + final KTable, GenericRow> result = StreamAggregateBuilder.build( + groupedStream, + windowedAggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + assertThat(result, is(windowedWithWindowBoundaries)); + verify(windowedWithResults).mapValues(windowSelectMapper); + } + } +} \ No newline at end of file diff --git a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableAggregateBuilderTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableAggregateBuilderTest.java new file mode 100644 index 000000000000..1830d61c7754 --- /dev/null +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableAggregateBuilderTest.java @@ -0,0 +1,262 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.streams; + +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.function.udaf.KudafAggregator; +import io.confluent.ksql.execution.function.udaf.KudafInitializer; +import io.confluent.ksql.execution.function.udaf.KudafUndoAggregator; +import io.confluent.ksql.execution.plan.DefaultExecutionStepProperties; +import io.confluent.ksql.execution.plan.ExecutionStep; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.TableAggregate; +import io.confluent.ksql.function.FunctionRegistry; +import io.confluent.ksql.name.ColumnName; +import io.confluent.ksql.name.FunctionName; +import io.confluent.ksql.query.QueryId; +import io.confluent.ksql.schema.ksql.ColumnRef; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.PhysicalSchema; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.serde.Format; +import io.confluent.ksql.serde.FormatInfo; +import io.confluent.ksql.serde.KeyFormat; +import io.confluent.ksql.serde.KeySerde; +import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.serde.ValueFormat; +import java.util.List; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.streams.kstream.KGroupedTable; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.ValueMapper; +import org.apache.kafka.streams.state.KeyValueStore; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class TableAggregateBuilderTest { + private static final LogicalSchema INPUT_SCHEMA = LogicalSchema.builder() + .valueColumn(ColumnName.of("REQUIRED0"), SqlTypes.BIGINT) + .valueColumn(ColumnName.of("REQUIRED1"), SqlTypes.STRING) + .valueColumn(ColumnName.of("ARGUMENT0"), SqlTypes.INTEGER) + .valueColumn(ColumnName.of("ARGUMENT1"), SqlTypes.DOUBLE) + .build(); + private static final LogicalSchema AGGREGATE_SCHEMA = LogicalSchema.builder() + .valueColumn(ColumnName.of("REQUIRED0"), SqlTypes.BIGINT) + .valueColumn(ColumnName.of("REQUIRED1"), SqlTypes.STRING) + .valueColumn(ColumnName.of("RESULT0"), SqlTypes.BIGINT) + .valueColumn(ColumnName.of("RESULT1"), SqlTypes.STRING) + .build(); + private static final PhysicalSchema PHYSICAL_AGGREGATE_SCHEMA = PhysicalSchema.from( + AGGREGATE_SCHEMA, + SerdeOption.none() + ); + private static final FunctionCall AGG0 = new FunctionCall( + FunctionName.of("AGG0"), + ImmutableList.of(new ColumnReferenceExp(ColumnRef.of("ARGUMENT0"))) + ); + private static final FunctionCall AGG1 = new FunctionCall( + FunctionName.of("AGG1"), + ImmutableList.of(new ColumnReferenceExp(ColumnRef.of("ARGUMENT1"))) + ); + private static final List FUNCTIONS = ImmutableList.of(AGG0, AGG1); + private static final QueryContext CTX = + new QueryContext.Stacker(new QueryId("qid")).push("agg").push("regate").getQueryContext(); + private static final KeyFormat KEY_FORMAT = KeyFormat.nonWindowed(FormatInfo.of(Format.KAFKA)); + private static final ValueFormat VALUE_FORMAT = ValueFormat.of(FormatInfo.of(Format.JSON)); + + @Mock + private KGroupedTable groupedTable; + @Mock + private KTable aggregated; + @Mock + private KTable aggregatedWithResults; + @Mock + private KsqlQueryBuilder queryBuilder; + @Mock + private FunctionRegistry functionRegistry; + @Mock + private AggregateParams.Factory aggregateParamsFactory; + @Mock + private AggregateParams aggregateParams; + @Mock + private KudafInitializer initializer; + @Mock + private KudafAggregator aggregator; + @Mock + private ValueMapper resultMapper; + @Mock + private KudafUndoAggregator undoAggregator; + @Mock + private MaterializedFactory materializedFactory; + @Mock + private KeySerde keySerde; + @Mock + private Serde valueSerde; + @Mock + private Materialized> materialized; + @Mock + private ExecutionStep> sourceStep; + + private TableAggregate aggregate; + + @Before + @SuppressWarnings("unchecked") + public void init() { + when(sourceStep.getSchema()).thenReturn(INPUT_SCHEMA); + when(queryBuilder.buildKeySerde(any(), any(), any())).thenReturn(keySerde); + when(queryBuilder.buildValueSerde(any(), any(), any())).thenReturn(valueSerde); + when(queryBuilder.getFunctionRegistry()).thenReturn(functionRegistry); + when(aggregateParamsFactory.create(any(), anyInt(), any(), any())).thenReturn(aggregateParams); + when(aggregateParams.getAggregator()).thenReturn(aggregator); + when(aggregateParams.getUndoAggregator()).thenReturn(undoAggregator); + when(aggregateParams.getInitializer()).thenReturn(initializer); + when(aggregator.getResultMapper()).thenReturn(resultMapper); + when(materializedFactory.>create(any(), any(), any())) + .thenReturn(materialized); + when(groupedTable.aggregate(any(), any(), any(), any(Materialized.class))).thenReturn( + aggregated); + when(aggregated.mapValues(any(ValueMapper.class))).thenReturn(aggregatedWithResults); + aggregate = new TableAggregate( + new DefaultExecutionStepProperties(INPUT_SCHEMA, CTX), + sourceStep, + Formats.of(KEY_FORMAT, VALUE_FORMAT, SerdeOption.none()), + 2, + FUNCTIONS, + AGGREGATE_SCHEMA + ); + } + + @Test + public void shouldBuildAggregateCorrectly() { + // When: + final KTable result = TableAggregateBuilder.build( + groupedTable, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + assertThat(result, is(aggregatedWithResults)); + final InOrder inOrder = Mockito.inOrder(groupedTable, aggregated, aggregatedWithResults); + inOrder.verify(groupedTable).aggregate(initializer, aggregator, undoAggregator, materialized); + inOrder.verify(aggregated).mapValues(resultMapper); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void shouldBuildMaterializedWithCorrectSerdesForAggregate() { + // When: + TableAggregateBuilder.build( + groupedTable, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(materializedFactory).create(same(keySerde), same(valueSerde), any()); + } + + @Test + public void shouldBuildMaterializedWithCorrectNameForAggregate() { + // When: + TableAggregateBuilder.build( + groupedTable, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(materializedFactory).create(any(), any(), eq("agg-regate")); + } + + @Test + public void shouldBuildKeySerdeCorrectlyForAggregate() { + // When: + TableAggregateBuilder.build( + groupedTable, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(queryBuilder).buildKeySerde(KEY_FORMAT.getFormatInfo(), PHYSICAL_AGGREGATE_SCHEMA, CTX); + } + + @Test + public void shouldBuildValueSerdeCorrectlyForAggregate() { + // When: + TableAggregateBuilder.build( + groupedTable, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(queryBuilder).buildValueSerde( + VALUE_FORMAT.getFormatInfo(), + PHYSICAL_AGGREGATE_SCHEMA, + CTX + ); + } + + @Test + public void shouldBuildAggregatorParamsCorrectlyForAggregate() { + // When: + TableAggregateBuilder.build( + groupedTable, + aggregate, + queryBuilder, + materializedFactory, + aggregateParamsFactory + ); + + // Then: + verify(aggregateParamsFactory).create(INPUT_SCHEMA, 2, functionRegistry, FUNCTIONS); + } +} \ No newline at end of file diff --git a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableTableJoinBuilderTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableTableJoinBuilderTest.java index 39634811e36a..e87816a54662 100644 --- a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableTableJoinBuilderTest.java +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableTableJoinBuilderTest.java @@ -84,7 +84,6 @@ public void init() { new DefaultExecutionStepProperties(LEFT_SCHEMA, SRC_CTX)); when(right.getProperties()).thenReturn( new DefaultExecutionStepProperties(RIGHT_SCHEMA, SRC_CTX)); - when(joinedFactory.create(any(Serde.class), any(), any(), any())).thenReturn(joined); } @SuppressWarnings("unchecked")