From 0d0b1c3db762a707896693b65e76fb168218ccfe Mon Sep 17 00:00:00 2001 From: Rohan Date: Fri, 6 Sep 2019 13:26:58 -0700 Subject: [PATCH] feat: build execution plan from structured package (#3285) This patch updates the classes in the structured package (SchemaKStream, SchemaKTable, SchemaKGroupedStream, and SchemaKGroupedTable) to build the query execution plan internally. The plan nodes themselves are constructed by a factory class (ExecutionStepFactory). Note that this currently requires some ugly passing around of redundant parameters. For example, we have to pass around serdes _and_ value/key formats for building those serdes. Once we move the actual streams calls to the ksql-streams layer, this will get cleaned up. --- .../ksql/physical/PhysicalPlanBuilder.java | 23 +- .../ksql/physical/TransientQueryQueue.java | 4 +- .../ksql/planner/plan/AggregateNode.java | 8 +- .../ksql/planner/plan/DataSourceNode.java | 21 +- .../confluent/ksql/planner/plan/JoinNode.java | 21 +- .../ksql/planner/plan/KsqlBareOutputNode.java | 9 +- .../plan/KsqlStructuredDataOutputNode.java | 12 +- .../ksql/structured/QueuedSchemaKStream.java | 100 ---- .../ksql/structured/SchemaKGroupedStream.java | 56 ++- .../ksql/structured/SchemaKGroupedTable.java | 51 +- .../ksql/structured/SchemaKStream.java | 334 +++++++++---- .../ksql/structured/SchemaKTable.java | 147 ++++-- .../physical/PhysicalPlanBuilderTest.java | 15 +- .../physical/TransientQueryQueueTest.java | 4 +- .../ksql/planner/plan/DataSourceNodeTest.java | 16 +- .../ksql/planner/plan/JoinNodeTest.java | 17 +- .../KsqlStructuredDataOutputNodeTest.java | 68 ++- .../structured/SchemaKGroupedStreamTest.java | 130 ++++- .../structured/SchemaKGroupedTableTest.java | 89 +++- .../ksql/structured/SchemaKStreamTest.java | 444 ++++++++++++++++-- .../ksql/structured/SchemaKTableTest.java | 265 ++++++++--- .../plan/DefaultExecutionStepProperties.java | 16 +- .../plan/ExecutionStepProperties.java | 2 + .../ksql/execution/plan/StreamAggregate.java | 21 +- .../ksql/execution/plan/StreamSelectKey.java | 3 + .../ksql/execution/plan/StreamStreamJoin.java | 39 +- .../ksql/execution/plan/TableAggregate.java | 21 +- .../ksql/execution/plan/TableTableJoin.java | 6 +- .../streams/ExecutionStepFactory.java | 282 ++++++++++- .../streams/StreamSourceBuilderTest.java | 6 +- 30 files changed, 1713 insertions(+), 517 deletions(-) delete mode 100644 ksql-engine/src/main/java/io/confluent/ksql/structured/QueuedSchemaKStream.java diff --git a/ksql-engine/src/main/java/io/confluent/ksql/physical/PhysicalPlanBuilder.java b/ksql-engine/src/main/java/io/confluent/ksql/physical/PhysicalPlanBuilder.java index cf5426550efc..7249244ea396 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/physical/PhysicalPlanBuilder.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/physical/PhysicalPlanBuilder.java @@ -39,7 +39,6 @@ import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.PhysicalSchema; import io.confluent.ksql.services.ServiceContext; -import io.confluent.ksql.structured.QueuedSchemaKStream; import io.confluent.ksql.structured.SchemaKStream; import io.confluent.ksql.structured.SchemaKTable; import io.confluent.ksql.util.KsqlConfig; @@ -124,20 +123,11 @@ public QueryMetadata buildPhysicalPlan(final LogicalPlanNode logicalPlanNode) { final SchemaKStream resultStream = outputNode.buildStream(ksqlQueryBuilder); if (outputNode instanceof KsqlBareOutputNode) { - if (!(resultStream instanceof QueuedSchemaKStream)) { - throw new KsqlException(String.format( - "Mismatch between logical and physical output; " - + "expected a QueuedSchemaKStream based on logical " - + "KsqlBareOutputNode, found a %s instead", - resultStream.getClass().getCanonicalName() - )); - } - final String transientQueryPrefix = ksqlConfig.getString(KsqlConfig.KSQL_TRANSIENT_QUERY_NAME_PREFIX_CONFIG); return buildPlanForBareQuery( - (QueuedSchemaKStream) resultStream, + resultStream, (KsqlBareOutputNode) outputNode, getServiceId(), transientQueryPrefix, @@ -147,15 +137,6 @@ public QueryMetadata buildPhysicalPlan(final LogicalPlanNode logicalPlanNode) { } if (outputNode instanceof KsqlStructuredDataOutputNode) { - if (resultStream instanceof QueuedSchemaKStream) { - throw new KsqlException(String.format( - "Mismatch between logical and physical output; " - + "expected a SchemaKStream based on logical " - + "QueuedSchemaKStream, found a %s instead", - resultStream.getClass().getCanonicalName() - )); - } - final KsqlStructuredDataOutputNode ksqlStructuredDataOutputNode = (KsqlStructuredDataOutputNode) outputNode; @@ -177,7 +158,7 @@ public QueryMetadata buildPhysicalPlan(final LogicalPlanNode logicalPlanNode) { } private QueryMetadata buildPlanForBareQuery( - final QueuedSchemaKStream schemaKStream, + final SchemaKStream schemaKStream, final KsqlBareOutputNode bareOutputNode, final String serviceId, final String transientQueryPrefix, diff --git a/ksql-engine/src/main/java/io/confluent/ksql/physical/TransientQueryQueue.java b/ksql-engine/src/main/java/io/confluent/ksql/physical/TransientQueryQueue.java index 5cb256e78e38..be7193a1e58a 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/physical/TransientQueryQueue.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/physical/TransientQueryQueue.java @@ -16,7 +16,7 @@ package io.confluent.ksql.physical; import io.confluent.ksql.GenericRow; -import io.confluent.ksql.structured.QueuedSchemaKStream; +import io.confluent.ksql.structured.SchemaKStream; import io.confluent.ksql.util.KsqlException; import java.util.Objects; import java.util.OptionalInt; @@ -37,7 +37,7 @@ class TransientQueryQueue { private final BlockingQueue> rowQueue = new LinkedBlockingQueue<>(100); - TransientQueryQueue(final QueuedSchemaKStream schemaKStream, final OptionalInt limit) { + TransientQueryQueue(final SchemaKStream schemaKStream, final OptionalInt limit) { this.callback = limit.isPresent() ? new LimitedQueueCallback(limit.getAsInt()) : new UnlimitedQueueCallback(); diff --git a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java index 37c37e8af6e5..52bbee798ff0 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java @@ -217,6 +217,7 @@ public SchemaKStream buildStream(final KsqlQueryBuilder builder) { getGroupByExpressions()); final SchemaKGroupedStream schemaKGroupedStream = aggregateArgExpanded.groupBy( + valueFormat, genericRowSerde, internalGroupByColumns, groupByContext @@ -246,15 +247,20 @@ public SchemaKStream buildStream(final KsqlQueryBuilder builder) { aggregationContext.getQueryContext() ); + final List functionsWithInternalIdentifiers = functionList.stream() + .map(internalSchema::resolveToInternal) + .map(FunctionCall.class::cast) + .collect(Collectors.toList()); SchemaKTable aggregated = schemaKGroupedStream.aggregate( aggStageSchema, initializer, requiredColumns.size(), + functionsWithInternalIdentifiers, aggValToFunctionMap, getWindowExpression(), + valueFormat, aggValueGenericRowSerde, aggregationContext - ); if (havingExpressions != null) { diff --git a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/DataSourceNode.java b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/DataSourceNode.java index faffcf3b224a..4991d576ee61 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/DataSourceNode.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/DataSourceNode.java @@ -17,7 +17,6 @@ import static java.util.Objects.requireNonNull; -import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.context.QueryContext.Stacker; @@ -37,7 +36,6 @@ import javax.annotation.concurrent.Immutable; import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.common.config.ConfigException; -import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.streams.Topology; import org.apache.kafka.streams.Topology.AutoOffsetReset; @@ -134,7 +132,7 @@ public SchemaKStream buildStream(final KsqlQueryBuilder builder) { builder, dataSource, schema, - contextStacker.push(SOURCE_OP_NAME).getQueryContext(), + contextStacker.push(SOURCE_OP_NAME), timestampIndex(), getAutoOffsetReset(builder.getKsqlConfig().getKsqlStreamConfigProps()), keyField @@ -143,12 +141,15 @@ public SchemaKStream buildStream(final KsqlQueryBuilder builder) { return schemaKStream; } final Stacker reduceContextStacker = contextStacker.push(REDUCE_OP_NAME); - final Serde tableSerde = builder.buildValueSerde( - dataSource.getKsqlTopic().getValueFormat().getFormatInfo(), - PhysicalSchema.from(getSchema(), SerdeOption.none()), - reduceContextStacker.getQueryContext() - ); - return schemaKStream.toTable(tableSerde, reduceContextStacker); + return schemaKStream.toTable( + dataSource.getKsqlTopic().getKeyFormat(), + dataSource.getKsqlTopic().getValueFormat(), + builder.buildValueSerde( + dataSource.getKsqlTopic().getValueFormat().getFormatInfo(), + PhysicalSchema.from(getSchema(), SerdeOption.none()), + reduceContextStacker.getQueryContext() + ), + reduceContextStacker); } interface SchemaKStreamFactory { @@ -156,7 +157,7 @@ SchemaKStream create( KsqlQueryBuilder builder, DataSource dataSource, LogicalSchemaWithMetaAndKeyFields schemaWithMetaAndKeyFields, - QueryContext queryContext, + QueryContext.Stacker contextStacker, int timestampIndex, Optional offsetReset, KeyField keyField diff --git a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/JoinNode.java b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/JoinNode.java index 7006396d9879..6c5fdf07d75d 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/JoinNode.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/JoinNode.java @@ -19,7 +19,6 @@ import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; -import io.confluent.ksql.metastore.model.DataSource; import io.confluent.ksql.metastore.model.DataSource.DataSourceType; import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.metastore.model.KeyField.LegacyField; @@ -280,15 +279,17 @@ static SchemaKStream maybeRePartitionByKey( return stream.selectKey(joinFieldName, true, contextStacker); } + static ValueFormat getFormatForSource(final DataSourceNode sourceNode) { + return sourceNode.getDataSource() + .getKsqlTopic() + .getValueFormat(); + } + Serde getSerDeForSource( final DataSourceNode sourceNode, final QueryContext.Stacker contextStacker ) { - final DataSource dataSource = sourceNode.getDataSource(); - - final ValueFormat valueFormat = dataSource - .getKsqlTopic() - .getValueFormat(); + final ValueFormat valueFormat = getFormatForSource(sourceNode); final LogicalSchema logicalSchema = sourceNode.getSchema() .withoutAlias(); @@ -367,6 +368,8 @@ public SchemaKStream join() { joinNode.schema, getJoinedKeyField(joinNode.left.getAlias(), leftStream.getKeyField()), joinNode.withinExpression.get().joinWindow(), + getFormatForSource(joinNode.left), + getFormatForSource(joinNode.right), getSerDeForSource(joinNode.left, contextStacker.push(LEFT_SERDE_CONTEXT_NAME)), getSerDeForSource(joinNode.right, contextStacker.push(RIGHT_SERDE_CONTEXT_NAME)), contextStacker); @@ -376,6 +379,8 @@ public SchemaKStream join() { joinNode.schema, getOuterJoinedKeyField(joinNode.left.getAlias(), leftStream.getKeyField()), joinNode.withinExpression.get().joinWindow(), + getFormatForSource(joinNode.left), + getFormatForSource(joinNode.right), getSerDeForSource(joinNode.left, contextStacker.push(LEFT_SERDE_CONTEXT_NAME)), getSerDeForSource(joinNode.right, contextStacker.push(RIGHT_SERDE_CONTEXT_NAME)), contextStacker); @@ -385,6 +390,8 @@ public SchemaKStream join() { joinNode.schema, getJoinedKeyField(joinNode.left.getAlias(), leftStream.getKeyField()), joinNode.withinExpression.get().joinWindow(), + getFormatForSource(joinNode.left), + getFormatForSource(joinNode.right), getSerDeForSource(joinNode.left, contextStacker.push(LEFT_SERDE_CONTEXT_NAME)), getSerDeForSource(joinNode.right, contextStacker.push(RIGHT_SERDE_CONTEXT_NAME)), contextStacker); @@ -424,6 +431,7 @@ public SchemaKStream join() { rightTable, joinNode.schema, getJoinedKeyField(joinNode.left.getAlias(), leftStream.getKeyField()), + getFormatForSource(joinNode.left), getSerDeForSource(joinNode.left, contextStacker.push(LEFT_SERDE_CONTEXT_NAME)), contextStacker); @@ -432,6 +440,7 @@ public SchemaKStream join() { rightTable, joinNode.schema, getJoinedKeyField(joinNode.left.getAlias(), leftStream.getKeyField()), + getFormatForSource(joinNode.left), getSerDeForSource(joinNode.left, contextStacker.push(LEFT_SERDE_CONTEXT_NAME)), contextStacker); case OUTER: diff --git a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/KsqlBareOutputNode.java b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/KsqlBareOutputNode.java index 8b724cbae56d..d4d7aaf137fd 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/KsqlBareOutputNode.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/KsqlBareOutputNode.java @@ -19,7 +19,6 @@ import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.query.QueryId; import io.confluent.ksql.schema.ksql.LogicalSchema; -import io.confluent.ksql.structured.QueuedSchemaKStream; import io.confluent.ksql.structured.SchemaKStream; import io.confluent.ksql.util.QueryIdGenerator; import io.confluent.ksql.util.timestamp.TimestampExtractionPolicy; @@ -56,12 +55,6 @@ public KeyField getKeyField() { @Override public SchemaKStream buildStream(final KsqlQueryBuilder builder) { - final SchemaKStream schemaKStream = getSource() - .buildStream(builder); - - return new QueuedSchemaKStream<>( - schemaKStream, - builder.buildNodeContext(getId().toString()).getQueryContext() - ); + return getSource().buildStream(builder); } } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/KsqlStructuredDataOutputNode.java b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/KsqlStructuredDataOutputNode.java index 43a12f1070ba..a70cd51bb320 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/KsqlStructuredDataOutputNode.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/KsqlStructuredDataOutputNode.java @@ -140,13 +140,15 @@ public SchemaKStream buildStream(final KsqlQueryBuilder builder) { contextStacker.getQueryContext() ); - result.into( + return result.into( getKsqlTopic().getKafkaTopicName(), outputRowSerde, - implicitAndKeyFieldIndexes + getSchema(), + getKsqlTopic().getValueFormat(), + serdeOptions, + implicitAndKeyFieldIndexes, + contextStacker ); - - return result; } @SuppressWarnings("unchecked") @@ -164,7 +166,7 @@ private SchemaKStream createOutputStream( getKeyField().legacy() ); - final SchemaKStream result = schemaKStream.sink(resultKeyField, contextStacker); + final SchemaKStream result = schemaKStream.withKeyField(resultKeyField); if (!partitionByField.isPresent()) { return result; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/QueuedSchemaKStream.java b/ksql-engine/src/main/java/io/confluent/ksql/structured/QueuedSchemaKStream.java deleted file mode 100644 index 1aba9b71fb32..000000000000 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/QueuedSchemaKStream.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Copyright 2018 Confluent Inc. - * - * Licensed under the Confluent Community License (the "License"); you may not use - * this file except in compliance with the License. You may obtain a copy of the - * License at - * - * http://www.confluent.io/confluent-community-license - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package io.confluent.ksql.structured; - -import io.confluent.ksql.GenericRow; -import io.confluent.ksql.execution.context.QueryContext; -import io.confluent.ksql.execution.expression.tree.Expression; -import io.confluent.ksql.execution.plan.SelectExpression; -import io.confluent.ksql.logging.processing.ProcessingLogContext; -import io.confluent.ksql.metastore.model.KeyField; -import io.confluent.ksql.schema.ksql.LogicalSchema; -import java.util.List; -import java.util.Set; -import org.apache.kafka.common.serialization.Serde; -import org.apache.kafka.connect.data.Struct; - -public class QueuedSchemaKStream extends SchemaKStream { - - public QueuedSchemaKStream( - final SchemaKStream schemaKStream, - final QueryContext queryContext - ) { - super( - schemaKStream.getKstream(), - schemaKStream.schema, - schemaKStream.keySerde, - schemaKStream.keyField, - schemaKStream.sourceSchemaKStreams, - Type.SINK, - schemaKStream.ksqlConfig, - schemaKStream.functionRegistry, - queryContext - ); - } - - @Override - public SchemaKStream into( - final String kafkaTopicName, - final Serde topicValueSerDe, - final Set rowkeyIndexes - ) { - throw new UnsupportedOperationException(); - } - - @Override - public SchemaKStream filter( - final Expression filterExpression, - final QueryContext.Stacker contextStacker, - final ProcessingLogContext processingLogContext) { - throw new UnsupportedOperationException(); - } - - @Override - public SchemaKStream select( - final List expressions, - final QueryContext.Stacker contextStacker, - final ProcessingLogContext processingLogContext) { - throw new UnsupportedOperationException(); - } - - @Override - public SchemaKStream leftJoin( - final SchemaKTable schemaKTable, - final LogicalSchema joinSchema, - final KeyField keyField, - final Serde joinSerde, - final QueryContext.Stacker contextStacker - ) { - throw new UnsupportedOperationException(); - } - - @Override - public SchemaKStream selectKey( - final String fieldName, - final boolean updateRowKey, - final QueryContext.Stacker contextStacker) { - throw new UnsupportedOperationException(); - } - - @Override - public SchemaKGroupedStream groupBy( - final Serde valSerde, - final List groupByExpressions, - final QueryContext.Stacker contextStacker) { - throw new UnsupportedOperationException(); - } -} diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedStream.java b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedStream.java index 1ff321de8610..18dba021456a 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedStream.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedStream.java @@ -17,6 +17,10 @@ import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.plan.ExecutionStep; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.function.KsqlAggregateFunction; import io.confluent.ksql.function.UdafAggregator; @@ -27,7 +31,12 @@ import io.confluent.ksql.parser.tree.KsqlWindowExpression; import io.confluent.ksql.parser.tree.WindowExpression; import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.serde.Format; +import io.confluent.ksql.serde.FormatInfo; +import io.confluent.ksql.serde.KeyFormat; import io.confluent.ksql.serde.KeySerde; +import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.serde.ValueFormat; import io.confluent.ksql.serde.WindowInfo; import io.confluent.ksql.streams.MaterializedFactory; import io.confluent.ksql.streams.StreamsUtil; @@ -48,7 +57,8 @@ public class SchemaKGroupedStream { final KGroupedStream kgroupedStream; - final LogicalSchema schema; + final ExecutionStep> sourceStep; + final KeyFormat keyFormat; final KeySerde keySerde; final KeyField keyField; final List sourceSchemaKStreams; @@ -58,7 +68,8 @@ public class SchemaKGroupedStream { SchemaKGroupedStream( final KGroupedStream kgroupedStream, - final LogicalSchema schema, + final ExecutionStep> sourceStep, + final KeyFormat keyFormat, final KeySerde keySerde, final KeyField keyField, final List sourceSchemaKStreams, @@ -66,7 +77,9 @@ public class SchemaKGroupedStream { final FunctionRegistry functionRegistry ) { this( - kgroupedStream, schema, + kgroupedStream, + sourceStep, + keyFormat, keySerde, keyField, sourceSchemaKStreams, @@ -78,7 +91,8 @@ public class SchemaKGroupedStream { SchemaKGroupedStream( final KGroupedStream kgroupedStream, - final LogicalSchema schema, + final ExecutionStep> sourceStep, + final KeyFormat keyFormat, final KeySerde keySerde, final KeyField keyField, final List sourceSchemaKStreams, @@ -87,7 +101,8 @@ public class SchemaKGroupedStream { final MaterializedFactory materializedFactory ) { this.kgroupedStream = kgroupedStream; - this.schema = schema; + this.sourceStep = sourceStep; + this.keyFormat = Objects.requireNonNull(keyFormat, "keyFormat"); this.keySerde = Objects.requireNonNull(keySerde, "keySerde"); this.keyField = keyField; this.sourceSchemaKStreams = sourceSchemaKStreams; @@ -100,13 +115,19 @@ public KeyField getKeyField() { return keyField; } + public ExecutionStep> getSourceStep() { + return sourceStep; + } + @SuppressWarnings("unchecked") public SchemaKTable aggregate( final LogicalSchema aggregateSchema, final Initializer initializer, final int nonFuncColumnCount, + final List aggregations, final Map aggValToFunctionMap, final WindowExpression windowExpression, + final ValueFormat valueFormat, final Serde topicValueSerDe, final QueryContext.Stacker contextStacker ) { @@ -114,7 +135,9 @@ public SchemaKTable aggregate( final KTable table; final KeySerde newKeySerde; + final KeyFormat keyFormat; if (windowExpression != null) { + keyFormat = getKeyFormat(windowExpression); newKeySerde = getKeySerde(windowExpression); table = aggregateWindowed( @@ -126,6 +149,7 @@ public SchemaKTable aggregate( contextStacker ); } else { + keyFormat = this.keyFormat; newKeySerde = keySerde; table = aggregateNonWindowed( @@ -137,16 +161,25 @@ public SchemaKTable aggregate( ); } + final ExecutionStep step = ExecutionStepFactory.streamAggregate( + contextStacker, + sourceStep, + aggregateSchema, + Formats.of(keyFormat, valueFormat, SerdeOption.none()), + nonFuncColumnCount, + aggregations + ); return new SchemaKTable( table, - aggregateSchema, + step, + keyFormat, newKeySerde, keyField, sourceSchemaKStreams, SchemaKStream.Type.AGGREGATE, ksqlConfig, - functionRegistry, - contextStacker.getQueryContext()); + functionRegistry + ); } @SuppressWarnings("unchecked") @@ -198,6 +231,13 @@ private KTable aggregateWindowed( return aggKtable.mapValues(windowSelectMapper); } + private KeyFormat getKeyFormat(final WindowExpression windowExpression) { + return KeyFormat.windowed( + FormatInfo.of(Format.KAFKA), + windowExpression.getKsqlWindowExpression().getWindowInfo() + ); + } + private KeySerde> getKeySerde(final WindowExpression windowExpression) { if (ksqlConfig.getBoolean(KsqlConfig.KSQL_WINDOWED_SESSION_KEY_LEGACY_CONFIG)) { return keySerde.rebind(WindowInfo.of( diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedTable.java b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedTable.java index d4b58abaa59e..a81dc4909eac 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedTable.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKGroupedTable.java @@ -17,6 +17,10 @@ import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.plan.ExecutionStep; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.function.KsqlAggregateFunction; import io.confluent.ksql.function.TableAggregationFunction; @@ -25,7 +29,10 @@ import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.parser.tree.WindowExpression; import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.serde.KeyFormat; import io.confluent.ksql.serde.KeySerde; +import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.serde.ValueFormat; import io.confluent.ksql.streams.MaterializedFactory; import io.confluent.ksql.streams.StreamsUtil; import io.confluent.ksql.util.KsqlConfig; @@ -43,10 +50,12 @@ public class SchemaKGroupedTable extends SchemaKGroupedStream { private final KGroupedTable kgroupedTable; + private final ExecutionStep> sourceTableStep; SchemaKGroupedTable( final KGroupedTable kgroupedTable, - final LogicalSchema schema, + final ExecutionStep> sourceTableStep, + final KeyFormat keyFormat, final KeySerde keySerde, final KeyField keyField, final List sourceSchemaKStreams, @@ -55,7 +64,8 @@ public class SchemaKGroupedTable extends SchemaKGroupedStream { ) { this( kgroupedTable, - schema, + sourceTableStep, + keyFormat, keySerde, keyField, sourceSchemaKStreams, @@ -66,7 +76,8 @@ public class SchemaKGroupedTable extends SchemaKGroupedStream { SchemaKGroupedTable( final KGroupedTable kgroupedTable, - final LogicalSchema schema, + final ExecutionStep> sourceTableStep, + final KeyFormat keyFormat, final KeySerde keySerde, final KeyField keyField, final List sourceSchemaKStreams, @@ -74,10 +85,24 @@ public class SchemaKGroupedTable extends SchemaKGroupedStream { final FunctionRegistry functionRegistry, final MaterializedFactory materializedFactory ) { - super(null, schema, keySerde, keyField, sourceSchemaKStreams, - ksqlConfig, functionRegistry, materializedFactory); + super( + null, + null, + keyFormat, + keySerde, + keyField, + sourceSchemaKStreams, + ksqlConfig, + functionRegistry, + materializedFactory + ); this.kgroupedTable = Objects.requireNonNull(kgroupedTable, "kgroupedTable"); + this.sourceTableStep = Objects.requireNonNull(sourceTableStep, "sourceTableStep"); + } + + public ExecutionStep> getSourceTableStep() { + return sourceTableStep; } @SuppressWarnings("unchecked") @@ -86,8 +111,10 @@ public SchemaKTable aggregate( final LogicalSchema aggregateSchema, final Initializer initializer, final int nonFuncColumnCount, + final List aggregations, final Map aggValToFunctionMap, final WindowExpression windowExpression, + final ValueFormat valueFormat, final Serde topicValueSerDe, final QueryContext.Stacker contextStacker ) { @@ -135,16 +162,24 @@ public SchemaKTable aggregate( subtractor, materialized); + final ExecutionStep step = ExecutionStepFactory.tableAggregate( + contextStacker, + sourceTableStep, + aggregateSchema, + Formats.of(keyFormat, valueFormat, SerdeOption.none()), + nonFuncColumnCount, + aggregations + ); return new SchemaKTable<>( aggKtable, - aggregateSchema, + step, + keyFormat, keySerde, keyField, sourceSchemaKStreams, SchemaKStream.Type.AGGREGATE, ksqlConfig, - functionRegistry, - contextStacker.getQueryContext() + functionRegistry ); } } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java index 303906cf25a8..adcb1c96168a 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java @@ -19,6 +19,7 @@ import static io.confluent.ksql.execution.streams.ExecutionStepFactory.streamSourceWindowed; import static java.util.Objects.requireNonNull; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import io.confluent.ksql.GenericRow; import io.confluent.ksql.codegen.CodeGenRunner; @@ -28,10 +29,14 @@ import io.confluent.ksql.execution.expression.tree.DereferenceExpression; import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.QualifiedNameReference; +import io.confluent.ksql.execution.plan.ExecutionStep; +import io.confluent.ksql.execution.plan.ExecutionStepProperties; import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.JoinType; import io.confluent.ksql.execution.plan.LogicalSchemaWithMetaAndKeyFields; import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.execution.plan.StreamSource; +import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.execution.streams.StreamSourceBuilder; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.logging.processing.ProcessingLogContext; @@ -44,7 +49,10 @@ import io.confluent.ksql.schema.ksql.FormatOptions; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.serde.KeyFormat; import io.confluent.ksql.serde.KeySerde; +import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.serde.ValueFormat; import io.confluent.ksql.streams.StreamsFactories; import io.confluent.ksql.streams.StreamsUtil; import io.confluent.ksql.structured.SelectValueMapper.SelectInfo; @@ -84,7 +92,7 @@ public class SchemaKStream { public enum Type { SOURCE, PROJECT, FILTER, AGGREGATE, SINK, REKEY, JOIN } final KStream kstream; - final LogicalSchema schema; + final KeyFormat keyFormat; final KeySerde keySerde; final KeyField keyField; final List sourceSchemaKStreams; @@ -92,25 +100,26 @@ public enum Type { SOURCE, PROJECT, FILTER, AGGREGATE, SINK, REKEY, JOIN } final KsqlConfig ksqlConfig; final FunctionRegistry functionRegistry; final StreamsFactories streamsFactories; - private final QueryContext queryContext; + private final ExecutionStep> sourceStep; + private final ExecutionStepProperties sourceProperties; private static SchemaKStream forSource( final KsqlQueryBuilder builder, + final KeyFormat keyFormat, final KeySerde keySerde, final StreamSource> streamSource, - final KeyField keyField, - final QueryContext queryContext) { + final KeyField keyField) { final KStream kstream = streamSource.build(builder); return new SchemaKStream<>( kstream, - streamSource.getProperties().getSchema(), + streamSource, + keyFormat, keySerde, keyField, ImmutableList.of(), SchemaKStream.Type.SOURCE, builder.getKsqlConfig(), - builder.getFunctionRegistry(), - queryContext + builder.getFunctionRegistry() ); } @@ -118,7 +127,7 @@ public static SchemaKStream forSource( final KsqlQueryBuilder builder, final DataSource dataSource, final LogicalSchemaWithMetaAndKeyFields schemaWithMetaAndKeyFields, - final QueryContext queryContext, + final QueryContext.Stacker contextStacker, final int timestampIndex, final Optional offsetReset, final KeyField keyField @@ -126,7 +135,7 @@ public static SchemaKStream forSource( final KsqlTopic topic = dataSource.getKsqlTopic(); if (topic.getKeyFormat().isWindowed()) { final StreamSource, GenericRow>> step = streamSourceWindowed( - queryContext, + contextStacker, schemaWithMetaAndKeyFields, topic.getKafkaTopicName(), Formats.of(topic.getKeyFormat(), topic.getValueFormat(), dataSource.getSerdeOptions()), @@ -136,13 +145,13 @@ public static SchemaKStream forSource( ); return forSource( builder, + topic.getKeyFormat(), StreamSourceBuilder.getWindowedKeySerde(builder, step), step, - keyField, - queryContext); + keyField); } else { final StreamSource> step = streamSource( - queryContext, + contextStacker, schemaWithMetaAndKeyFields, topic.getKafkaTopicName(), Formats.of(topic.getKeyFormat(), topic.getValueFormat(), dataSource.getSerdeOptions()), @@ -152,62 +161,71 @@ public static SchemaKStream forSource( ); return forSource( builder, + topic.getKeyFormat(), StreamSourceBuilder.getKeySerde(builder, step), step, - keyField, - queryContext); + keyField); } } - protected SchemaKStream( + @VisibleForTesting + SchemaKStream( final KStream kstream, - final LogicalSchema schema, + final ExecutionStep> sourceStep, + final KeyFormat keyFormat, final KeySerde keySerde, final KeyField keyField, final List sourceSchemaKStreams, final Type type, final KsqlConfig ksqlConfig, - final FunctionRegistry functionRegistry, - final QueryContext queryContext + final FunctionRegistry functionRegistry ) { this( - kstream, schema, + kstream, + requireNonNull(sourceStep, "sourceStep"), + sourceStep.getProperties(), + keyFormat, keySerde, keyField, sourceSchemaKStreams, type, ksqlConfig, functionRegistry, - StreamsFactories.create(ksqlConfig), - queryContext); + StreamsFactories.create(ksqlConfig) + ); } + // CHECKSTYLE_RULES.OFF: ParameterNumber SchemaKStream( final KStream kstream, - final LogicalSchema schema, + final ExecutionStep> sourceStep, + final ExecutionStepProperties sourceProperties, + final KeyFormat keyFormat, final KeySerde keySerde, final KeyField keyField, final List sourceSchemaKStreams, final Type type, final KsqlConfig ksqlConfig, final FunctionRegistry functionRegistry, - final StreamsFactories streamsFactories, - final QueryContext queryContext + final StreamsFactories streamsFactories ) { - this.schema = requireNonNull(schema, "schema"); + // CHECKSTYLE_RULES.ON: ParameterNumber + this.keyFormat = requireNonNull(keyFormat, "keyFormat"); this.keySerde = requireNonNull(keySerde, "keySerde"); + this.sourceStep = sourceStep; + this.sourceProperties = Objects.requireNonNull(sourceProperties, "sourceProperties"); this.kstream = kstream; - this.keyField = requireNonNull(keyField, "keyField") - .validateKeyExistsIn(schema); + this.keyField = requireNonNull(keyField, "keyField").validateKeyExistsIn(getSchema()); this.sourceSchemaKStreams = requireNonNull(sourceSchemaKStreams, "sourceSchemaKStreams"); this.type = requireNonNull(type, "type"); this.ksqlConfig = requireNonNull(ksqlConfig, "ksqlConfig"); this.functionRegistry = requireNonNull(functionRegistry, "functionRegistry"); this.streamsFactories = requireNonNull(streamsFactories); - this.queryContext = requireNonNull(queryContext); } public SchemaKTable toTable( + final KeyFormat keyFormat, + final ValueFormat valueFormat, final Serde valueSerde, final QueryContext.Stacker contextStacker ) { @@ -233,39 +251,46 @@ public SchemaKTable toTable( () -> null, (k, value, oldValue) -> value.orElse(null), materialized); + final ExecutionStep> step = ExecutionStepFactory.streamToTable( + contextStacker, + Formats.of(keyFormat, valueFormat, Collections.emptySet()), + sourceStep + ); return new SchemaKTable<>( ktable, - schema, + step, + keyFormat, keySerde, keyField, Collections.singletonList(this), type, ksqlConfig, - functionRegistry, - contextStacker.getQueryContext() + functionRegistry ); } - public SchemaKStream sink( - final KeyField resultKeyField, - final QueryContext.Stacker contextStacker) { + public SchemaKStream withKeyField(final KeyField resultKeyField) { return new SchemaKStream<>( kstream, - schema, + sourceStep, + keyFormat, keySerde, resultKeyField, - Collections.singletonList(this), - SchemaKStream.Type.SINK, + sourceSchemaKStreams, + type, ksqlConfig, - functionRegistry, - contextStacker.getQueryContext() + functionRegistry ); } - public SchemaKStream into( + public SchemaKStream into( final String kafkaTopicName, final Serde topicValueSerDe, - final Set rowkeyIndexes + final LogicalSchema outputSchema, + final ValueFormat valueFormat, + final Set options, + final Set rowkeyIndexes, + final QueryContext.Stacker contextStacker ) { kstream .mapValues(row -> { @@ -280,7 +305,24 @@ public SchemaKStream into( } return new GenericRow(columns); }).to(kafkaTopicName, Produced.with(keySerde, topicValueSerDe)); - return this; + final ExecutionStep> step = ExecutionStepFactory.streamSink( + contextStacker, + outputSchema, + Formats.of(keyFormat, valueFormat, options), + sourceStep, + kafkaTopicName + ); + return new SchemaKStream<>( + kstream, + step, + keyFormat, + keySerde, + keyField, + Collections.singletonList(this), + SchemaKStream.Type.SINK, + ksqlConfig, + functionRegistry + ); } public SchemaKStream filter( @@ -290,7 +332,7 @@ public SchemaKStream filter( ) { final SqlPredicate predicate = new SqlPredicate( filterExpression, - schema, + getSchema(), ksqlConfig, functionRegistry, processingLogContext.getLoggerFactory().getLogger( @@ -300,16 +342,21 @@ public SchemaKStream filter( ); final KStream filteredKStream = kstream.filter(predicate.getPredicate()); + final ExecutionStep> step = ExecutionStepFactory.streamFilter( + contextStacker, + sourceStep, + filterExpression + ); return new SchemaKStream<>( filteredKStream, - schema, + step, + keyFormat, keySerde, keyField, Collections.singletonList(this), Type.FILTER, ksqlConfig, - functionRegistry, - contextStacker.getQueryContext() + functionRegistry ); } @@ -323,16 +370,22 @@ public SchemaKStream select( QueryLoggerUtil.queryLoggerName( contextStacker.push(Type.PROJECT.name()).getQueryContext())) ); + final ExecutionStep> step = ExecutionStepFactory.streamMapValues( + contextStacker, + sourceStep, + selectExpressions, + selection.getProjectedSchema() + ); return new SchemaKStream<>( kstream.mapValues(selection.getSelectValueMapper()), - selection.getProjectedSchema(), + step, + keyFormat, keySerde, selection.getKey(), Collections.singletonList(this), Type.PROJECT, ksqlConfig, - functionRegistry, - contextStacker.getQueryContext() + functionRegistry ); } @@ -349,7 +402,7 @@ class Selection { this.key = findKeyField(selectExpressions); this.selectValueMapper = SelectValueMapperFactory.create( selectExpressions, - SchemaKStream.this.schema, + SchemaKStream.this.getSchema(), ksqlConfig, functionRegistry, processingLogger @@ -439,9 +492,9 @@ private LogicalSchema buildSchema( ) { final LogicalSchema.Builder schemaBuilder = LogicalSchema.builder(); - final List keyFields = SchemaKStream.this.schema.isAliased() - ? SchemaKStream.this.schema.withoutAlias().keyFields() - : SchemaKStream.this.schema.keyFields(); + final List keyFields = SchemaKStream.this.getSchema().isAliased() + ? SchemaKStream.this.getSchema().withoutAlias().keyFields() + : SchemaKStream.this.getSchema().keyFields(); schemaBuilder.keyFields(keyFields); @@ -470,10 +523,10 @@ public SchemaKStream leftJoin( final SchemaKTable schemaKTable, final LogicalSchema joinSchema, final KeyField keyField, + final ValueFormat valueFormat, final Serde leftValueSerDe, final QueryContext.Stacker contextStacker ) { - final KStream joinedKStream = kstream.leftJoin( schemaKTable.getKtable(), @@ -484,17 +537,24 @@ public SchemaKStream leftJoin( null, StreamsUtil.buildOpName(contextStacker.getQueryContext())) ); - + final ExecutionStep> step = ExecutionStepFactory.streamTableJoin( + contextStacker, + JoinType.LEFT, + Formats.of(keyFormat, valueFormat, SerdeOption.none()), + sourceStep, + schemaKTable.getSourceTableStep(), + joinSchema + ); return new SchemaKStream<>( joinedKStream, - joinSchema, + step, + keyFormat, keySerde, keyField, ImmutableList.of(this, schemaKTable), Type.JOIN, ksqlConfig, - functionRegistry, - contextStacker.getQueryContext() + functionRegistry ); } @@ -503,6 +563,8 @@ public SchemaKStream leftJoin( final LogicalSchema joinSchema, final KeyField keyField, final JoinWindows joinWindows, + final ValueFormat leftFormat, + final ValueFormat rightFormat, final Serde leftSerde, final Serde rightSerde, final QueryContext.Stacker contextStacker) { @@ -519,17 +581,26 @@ public SchemaKStream leftJoin( rightSerde, StreamsUtil.buildOpName(contextStacker.getQueryContext())) ); - + final ExecutionStep> step = ExecutionStepFactory.streamStreamJoin( + contextStacker, + JoinType.LEFT, + Formats.of(keyFormat, leftFormat, SerdeOption.none()), + Formats.of(keyFormat, rightFormat, SerdeOption.none()), + sourceStep, + otherSchemaKStream.sourceStep, + joinSchema, + joinWindows + ); return new SchemaKStream<>( joinStream, - joinSchema, + step, + keyFormat, keySerde, keyField, ImmutableList.of(this, otherSchemaKStream), Type.JOIN, ksqlConfig, - functionRegistry, - contextStacker.getQueryContext() + functionRegistry ); } @@ -538,6 +609,7 @@ public SchemaKStream join( final SchemaKTable schemaKTable, final LogicalSchema joinSchema, final KeyField keyField, + final ValueFormat valueFormat, final Serde joinSerDe, final QueryContext.Stacker contextStacker ) { @@ -551,17 +623,24 @@ public SchemaKStream join( null, StreamsUtil.buildOpName(contextStacker.getQueryContext())) ); - + final ExecutionStep> step = ExecutionStepFactory.streamTableJoin( + contextStacker, + JoinType.INNER, + Formats.of(keyFormat, valueFormat, SerdeOption.none()), + sourceStep, + schemaKTable.getSourceTableStep(), + joinSchema + ); return new SchemaKStream<>( joinedKStream, - joinSchema, + step, + keyFormat, keySerde, keyField, ImmutableList.of(this, schemaKTable), Type.JOIN, ksqlConfig, - functionRegistry, - contextStacker.getQueryContext() + functionRegistry ); } @@ -570,6 +649,8 @@ public SchemaKStream join( final LogicalSchema joinSchema, final KeyField keyField, final JoinWindows joinWindows, + final ValueFormat leftFormat, + final ValueFormat rightFormat, final Serde leftSerde, final Serde rightSerde, final QueryContext.Stacker contextStacker) { @@ -585,17 +666,26 @@ public SchemaKStream join( rightSerde, StreamsUtil.buildOpName(contextStacker.getQueryContext())) ); - + final ExecutionStep> step = ExecutionStepFactory.streamStreamJoin( + contextStacker, + JoinType.INNER, + Formats.of(keyFormat, leftFormat, SerdeOption.none()), + Formats.of(keyFormat, rightFormat, SerdeOption.none()), + sourceStep, + otherSchemaKStream.sourceStep, + joinSchema, + joinWindows + ); return new SchemaKStream<>( joinStream, - joinSchema, + step, + keyFormat, keySerde, keyField, ImmutableList.of(this, otherSchemaKStream), Type.JOIN, ksqlConfig, - functionRegistry, - contextStacker.getQueryContext() + functionRegistry ); } @@ -604,6 +694,8 @@ public SchemaKStream outerJoin( final LogicalSchema joinSchema, final KeyField keyField, final JoinWindows joinWindows, + final ValueFormat leftFormat, + final ValueFormat rightFormat, final Serde leftSerde, final Serde rightSerde, final QueryContext.Stacker contextStacker) { @@ -618,17 +710,26 @@ public SchemaKStream outerJoin( rightSerde, StreamsUtil.buildOpName(contextStacker.getQueryContext())) ); - + final ExecutionStep> step = ExecutionStepFactory.streamStreamJoin( + contextStacker, + JoinType.OUTER, + Formats.of(keyFormat, leftFormat, SerdeOption.none()), + Formats.of(keyFormat, rightFormat, SerdeOption.none()), + sourceStep, + otherSchemaKStream.sourceStep, + joinSchema, + joinWindows + ); return new SchemaKStream<>( joinStream, - joinSchema, + step, + keyFormat, keySerde, keyField, ImmutableList.of(this, otherSchemaKStream), Type.JOIN, ksqlConfig, - functionRegistry, - contextStacker.getQueryContext() + functionRegistry ); } @@ -642,9 +743,9 @@ public SchemaKStream selectKey( throw new UnsupportedOperationException("Can not selectKey of windowed stream"); } - final Optional existingKey = keyField.resolve(schema, ksqlConfig); + final Optional existingKey = keyField.resolve(getSchema(), ksqlConfig); - final Field proposedKey = schema.findValueField(fieldName) + final Field proposedKey = getSchema().findValueField(fieldName) .orElseThrow(IllegalArgumentException::new); final LegacyField proposedLegacy = LegacyField.of(proposedKey.fullName(), proposedKey.type()); @@ -663,20 +764,20 @@ public SchemaKStream selectKey( final boolean treatAsRowKey = usingNewKeyFields() && isRowKey(proposedKey.name()); if (namesMatch || treatAsRowKey) { - return new SchemaKStream<>( + return (SchemaKStream) new SchemaKStream<>( kstream, - schema, - (KeySerde) keySerde, + sourceStep, + keyFormat, + keySerde, resultantKeyField, sourceSchemaKStreams, type, ksqlConfig, - functionRegistry, - queryContext + functionRegistry ); } - final int keyIndexInValue = schema.valueFieldIndex(proposedKey.fullName()) + final int keyIndexInValue = getSchema().valueFieldIndex(proposedKey.fullName()) .orElseThrow(IllegalStateException::new); final KStream keyedKStream = kstream @@ -691,22 +792,27 @@ public SchemaKStream selectKey( return row; }); - final KeyField newKeyField = schema.isMetaField(fieldName) + final KeyField newKeyField = getSchema().isMetaField(fieldName) ? resultantKeyField.withName(Optional.empty()) : resultantKeyField; final KeySerde selectKeySerde = keySerde.rebind(StructKeyUtil.ROWKEY_SERIALIZED_SCHEMA); - - return new SchemaKStream<>( + final ExecutionStep> step = ExecutionStepFactory.streamSelectKey( + contextStacker, + sourceStep, + fieldName, + updateRowKey + ); + return (SchemaKStream) new SchemaKStream( keyedKStream, - schema, + step, + keyFormat, selectKeySerde, newKeyField, Collections.singletonList(this), Type.REKEY, ksqlConfig, - functionRegistry, - contextStacker.getQueryContext() + functionRegistry ); } @@ -719,9 +825,9 @@ private static boolean isRowKey(final String fieldName) { } private Object extractColumn(final int keyIndexInValue, final GenericRow value) { - if (value.getColumns().size() != schema.valueFields().size()) { + if (value.getColumns().size() != getSchema().valueFields().size()) { throw new IllegalStateException("Field count mismatch. " - + "Schema fields: " + schema + + "Schema fields: " + getSchema() + ", row:" + value); } @@ -747,7 +853,7 @@ private boolean rekeyRequired(final List groupByExpressions) { return true; } - final Optional keyField = getKeyField().resolve(schema, ksqlConfig); + final Optional keyField = getKeyField().resolve(getSchema(), ksqlConfig); if (!keyField.isPresent()) { return true; } @@ -763,11 +869,13 @@ private boolean rekeyRequired(final List groupByExpressions) { @SuppressWarnings("unchecked") public SchemaKGroupedStream groupBy( + final ValueFormat valueFormat, final Serde valSerde, final List groupByExpressions, final QueryContext.Stacker contextStacker ) { final boolean rekey = rekeyRequired(groupByExpressions); + final KeyFormat rekeyedKeyFormat = KeyFormat.nonWindowed(keyFormat.getFormatInfo()); if (!rekey) { final Grouped grouped = streamsFactories.getGroupedFactory() .create( @@ -783,10 +891,17 @@ public SchemaKGroupedStream groupBy( } final KeySerde structKeySerde = (KeySerde) keySerde; - + final ExecutionStep> step = + ExecutionStepFactory.streamGroupBy( + contextStacker, + sourceStep, + Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none()), + groupByExpressions + ); return new SchemaKGroupedStream( kgroupedStream, - schema, + step, + keyFormat, structKeySerde, keyField, Collections.singletonList(this), @@ -814,12 +929,19 @@ public SchemaKGroupedStream groupBy( final LegacyField legacyKeyField = LegacyField .notInSchema(groupBy.aggregateKeyName, SqlTypes.STRING); - final Optional newKeyField = schema.findValueField(groupBy.aggregateKeyName) + final Optional newKeyField = getSchema().findValueField(groupBy.aggregateKeyName) .map(Field::name); - + final ExecutionStep> source = + ExecutionStepFactory.streamGroupBy( + contextStacker, + sourceStep, + Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none()), + groupByExpressions + ); return new SchemaKGroupedStream( kgroupedStream, - schema, + source, + rekeyedKeyFormat, groupedKeySerde, KeyField.of(newKeyField, Optional.of(legacyKeyField)), Collections.singletonList(this), @@ -828,12 +950,20 @@ public SchemaKGroupedStream groupBy( ); } + ExecutionStep> getSourceStep() { + return sourceStep; + } + public KeyField getKeyField() { return keyField; } + public QueryContext getQueryContext() { + return sourceProperties.getQueryContext(); + } + public LogicalSchema getSchema() { - return schema; + return sourceProperties.getSchema(); } public KeySerde getKeySerde() { @@ -853,8 +983,8 @@ public String getExecutionPlan(final String indent) { stringBuilder.append(indent) .append(" > [ ") .append(type).append(" ] | Schema: ") - .append(schema.toString(FORMAT_OPTIONS)) - .append(" | Logger: ").append(QueryLoggerUtil.queryLoggerName(queryContext)) + .append(getSchema().toString(FORMAT_OPTIONS)) + .append(" | Logger: ").append(QueryLoggerUtil.queryLoggerName(getQueryContext())) .append("\n"); for (final SchemaKStream schemaKStream : sourceSchemaKStreams) { stringBuilder @@ -869,6 +999,10 @@ public Type getType() { return type; } + public KeyFormat getKeyFormat() { + return keyFormat; + } + public FunctionRegistry getFunctionRegistry() { return functionRegistry; } @@ -880,7 +1014,7 @@ class GroupBy { GroupBy(final List expressions) { final List groupBy = CodeGenRunner.compileExpressions( - expressions.stream(), "Group By", schema, ksqlConfig, functionRegistry); + expressions.stream(), "Group By", getSchema(), ksqlConfig, functionRegistry); this.mapper = new GroupByMapper(groupBy); this.aggregateKeyName = GroupByMapper.keyNameFor(expressions); diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKTable.java b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKTable.java index cbbcbb9241c1..a148323e638c 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKTable.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKTable.java @@ -20,7 +20,11 @@ import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.context.QueryLoggerUtil; import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.plan.ExecutionStep; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.JoinType; import io.confluent.ksql.execution.plan.SelectExpression; +import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.logging.processing.ProcessingLogContext; import io.confluent.ksql.metastore.model.KeyField; @@ -28,13 +32,17 @@ import io.confluent.ksql.schema.ksql.Field; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.serde.KeyFormat; import io.confluent.ksql.serde.KeySerde; +import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.serde.ValueFormat; import io.confluent.ksql.streams.StreamsFactories; import io.confluent.ksql.streams.StreamsUtil; import io.confluent.ksql.util.KsqlConfig; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.Set; import org.apache.kafka.common.serialization.Serde; @@ -48,66 +56,73 @@ // CHECKSTYLE_RULES.OFF: ClassDataAbstractionCoupling public class SchemaKTable extends SchemaKStream { - // CHECKSTYLE_RULES.ON: ClassDataAbstractionCoupling private final KTable ktable; + private final ExecutionStep> sourceTableStep; public SchemaKTable( final KTable ktable, - final LogicalSchema schema, + final ExecutionStep> sourceTableStep, + final KeyFormat keyFormat, final KeySerde keySerde, final KeyField keyField, final List sourceSchemaKStreams, final Type type, final KsqlConfig ksqlConfig, - final FunctionRegistry functionRegistry, - final QueryContext queryContext + final FunctionRegistry functionRegistry ) { this( - ktable, schema, + ktable, + sourceTableStep, + keyFormat, keySerde, keyField, sourceSchemaKStreams, type, ksqlConfig, functionRegistry, - StreamsFactories.create(ksqlConfig), - queryContext + StreamsFactories.create(ksqlConfig) ); } SchemaKTable( final KTable ktable, - final LogicalSchema schema, + final ExecutionStep> sourceTableStep, + final KeyFormat keyFormat, final KeySerde keySerde, final KeyField keyField, final List sourceSchemaKStreams, final Type type, final KsqlConfig ksqlConfig, final FunctionRegistry functionRegistry, - final StreamsFactories streamsFactories, - final QueryContext queryContext + final StreamsFactories streamsFactories ) { super( null, - schema, + null, + Objects.requireNonNull(sourceTableStep, "sourceTableStep").getProperties(), + keyFormat, keySerde, keyField, sourceSchemaKStreams, type, ksqlConfig, functionRegistry, - streamsFactories, - queryContext + streamsFactories ); this.ktable = ktable; + this.sourceTableStep = sourceTableStep; } @Override public SchemaKTable into( final String kafkaTopicName, final Serde topicValueSerDe, - final Set rowkeyIndexes + final LogicalSchema outputSchema, + final ValueFormat valueFormat, + final Set options, + final Set rowkeyIndexes, + final QueryContext.Stacker contextStacker ) { ktable.toStream() @@ -125,7 +140,24 @@ public SchemaKTable into( } ).to(kafkaTopicName, Produced.with(keySerde, topicValueSerDe)); - return this; + final ExecutionStep> step = ExecutionStepFactory.tableSink( + contextStacker, + outputSchema, + sourceTableStep, + Formats.of(keyFormat, valueFormat, options), + kafkaTopicName + ); + return new SchemaKTable<>( + ktable, + step, + keyFormat, + keySerde, + keyField, + sourceSchemaKStreams, + type, + ksqlConfig, + functionRegistry + ); } @SuppressWarnings("unchecked") @@ -137,7 +169,7 @@ public SchemaKTable filter( ) { final SqlPredicate predicate = new SqlPredicate( filterExpression, - schema, + getSchema(), ksqlConfig, functionRegistry, processingLogContext.getLoggerFactory().getLogger( @@ -146,16 +178,21 @@ public SchemaKTable filter( ); final KTable filteredKTable = ktable.filter(predicate.getPredicate()); + final ExecutionStep> step = ExecutionStepFactory.tableFilter( + contextStacker, + sourceTableStep, + filterExpression + ); return new SchemaKTable<>( filteredKTable, - schema, + step, + keyFormat, keySerde, keyField, Collections.singletonList(this), Type.FILTER, ksqlConfig, - functionRegistry, - contextStacker.getQueryContext() + functionRegistry ); } @@ -170,16 +207,22 @@ public SchemaKTable select( QueryLoggerUtil.queryLoggerName( contextStacker.push(Type.PROJECT.name()).getQueryContext())) ); + final ExecutionStep> step = ExecutionStepFactory.tableMapValues( + contextStacker, + sourceTableStep, + selection.getProjectedSchema(), + selectExpressions + ); return new SchemaKTable<>( ktable.mapValues(selection.getSelectValueMapper()), - selection.getProjectedSchema(), + step, + keyFormat, keySerde, selection.getKey(), Collections.singletonList(this), Type.PROJECT, ksqlConfig, - functionRegistry, - contextStacker.getQueryContext() + functionRegistry ); } @@ -193,14 +236,20 @@ public KTable getKtable() { return ktable; } + public ExecutionStep> getSourceTableStep() { + return sourceTableStep; + } + @Override public SchemaKGroupedStream groupBy( + final ValueFormat valueFormat, final Serde valSerde, final List groupByExpressions, final QueryContext.Stacker contextStacker ) { final GroupBy groupBy = new GroupBy(groupByExpressions); + final KeyFormat groupedKeyFormat = KeyFormat.nonWindowed(keyFormat.getFormatInfo()); final KeySerde groupedKeySerde = keySerde .rebind(StructKeyUtil.ROWKEY_SERIALIZED_SCHEMA); @@ -222,12 +271,20 @@ public SchemaKGroupedStream groupBy( final LegacyField legacyKeyField = LegacyField .notInSchema(groupBy.aggregateKeyName, SqlTypes.STRING); - final Optional newKeyField = schema.findValueField(groupBy.aggregateKeyName) + final Optional newKeyField = getSchema().findValueField(groupBy.aggregateKeyName) .map(Field::fullName); + final ExecutionStep> step = + ExecutionStepFactory.tableGroupBy( + contextStacker, + sourceTableStep, + Formats.of(groupedKeyFormat, valueFormat, SerdeOption.none()), + groupByExpressions + ); return new SchemaKGroupedTable( kgroupedTable, - schema, + step, + groupedKeyFormat, groupedKeySerde, KeyField.of(newKeyField, Optional.of(legacyKeyField)), Collections.singletonList(this), @@ -246,17 +303,23 @@ public SchemaKTable join( schemaKTable.getKtable(), new KsqlValueJoiner(this.getSchema(), schemaKTable.getSchema()) ); - + final ExecutionStep> step = ExecutionStepFactory.tableTableJoin( + contextStacker, + JoinType.INNER, + sourceTableStep, + schemaKTable.getSourceTableStep(), + joinSchema + ); return new SchemaKTable<>( joinedKTable, - joinSchema, + step, + keyFormat, keySerde, keyField, ImmutableList.of(this, schemaKTable), Type.JOIN, ksqlConfig, - functionRegistry, - contextStacker.getQueryContext() + functionRegistry ); } @@ -272,17 +335,23 @@ public SchemaKTable leftJoin( schemaKTable.getKtable(), new KsqlValueJoiner(this.getSchema(), schemaKTable.getSchema()) ); - + final ExecutionStep> step = ExecutionStepFactory.tableTableJoin( + contextStacker, + JoinType.LEFT, + sourceTableStep, + schemaKTable.getSourceTableStep(), + joinSchema + ); return new SchemaKTable<>( joinedKTable, - joinSchema, + step, + keyFormat, keySerde, keyField, ImmutableList.of(this, schemaKTable), Type.JOIN, ksqlConfig, - functionRegistry, - contextStacker.getQueryContext() + functionRegistry ); } @@ -298,17 +367,23 @@ public SchemaKTable outerJoin( schemaKTable.getKtable(), new KsqlValueJoiner(this.getSchema(), schemaKTable.getSchema()) ); - + final ExecutionStep> step = ExecutionStepFactory.tableTableJoin( + contextStacker, + JoinType.OUTER, + sourceTableStep, + schemaKTable.getSourceTableStep(), + joinSchema + ); return new SchemaKTable<>( joinedKTable, - joinSchema, + step, + keyFormat, keySerde, keyField, ImmutableList.of(this, schemaKTable), Type.JOIN, ksqlConfig, - functionRegistry, - contextStacker.getQueryContext() + functionRegistry ); } } diff --git a/ksql-engine/src/test/java/io/confluent/ksql/physical/PhysicalPlanBuilderTest.java b/ksql-engine/src/test/java/io/confluent/ksql/physical/PhysicalPlanBuilderTest.java index 178a32e1f7ec..2f849a595782 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/physical/PhysicalPlanBuilderTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/physical/PhysicalPlanBuilderTest.java @@ -325,7 +325,8 @@ public void shouldCreateExecutionPlan() { final String planText = metadata.getExecutionPlan(); final String[] lines = planText.split("\n"); assertThat(lines[0], startsWith( - " > [ SINK ] | Schema: [ROWKEY STRING KEY, COL0 BIGINT, KSQL_COL_1 DOUBLE, KSQL_COL_2 BIGINT] |")); + " > [ PROJECT ] | Schema: [ROWKEY STRING KEY, COL0 BIGINT, KSQL_COL_1 DOUBLE, " + + "KSQL_COL_2 BIGINT] |")); assertThat(lines[1], startsWith( "\t\t > [ AGGREGATE ] | Schema: [ROWKEY STRING KEY, KSQL_INTERNAL_COL_0 BIGINT, " + "KSQL_INTERNAL_COL_1 DOUBLE, KSQL_AGG_VARIABLE_0 DOUBLE, " @@ -444,7 +445,7 @@ public void shouldCreatePlanForInsertIntoStreamFromStream() { final String[] lines = planText.split("\n"); assertThat(lines.length, equalTo(3)); assertThat(lines[0], containsString( - "> [ SINK ] | Schema: [ROWKEY STRING KEY, ROWTIME BIGINT, ROWKEY STRING, COL0 INTEGER]")); + "> [ SINK ] | Schema: [ROWKEY STRING KEY, COL0 INTEGER]")); assertThat(lines[1], containsString( "> [ PROJECT ] | Schema: [ROWKEY STRING KEY, ROWTIME BIGINT, ROWKEY STRING, COL0 INTEGER]")); @@ -477,7 +478,7 @@ public void shouldFailInsertIfTheResultTypesDoNotMatch() { } @Test - public void shouldCheckSinkAndResultKeysDoNotMatch() { + public void shouldRekeyIfPartitionByDoesNotMatchResultKey() { final String csasQuery = "CREATE STREAM s1 AS SELECT col0, col1, col2 FROM test1 PARTITION BY col0;"; final String insertIntoQuery = "INSERT INTO s1 SELECT col0, col1, col2 FROM test1 PARTITION BY col0;"; givenKafkaTopicsExist("test1"); @@ -488,11 +489,11 @@ public void shouldCheckSinkAndResultKeysDoNotMatch() { final String planText = queryMetadataList.get(1).getExecutionPlan(); final String[] lines = planText.split("\n"); assertThat(lines.length, equalTo(4)); - assertThat(lines[0], - equalTo(" > [ REKEY ] | Schema: [ROWKEY STRING KEY, COL0 BIGINT, COL1 STRING, COL2 DOUBLE] " - + "| Logger: InsertQuery_1.S1")); - assertThat(lines[1], equalTo("\t\t > [ SINK ] | Schema: [ROWKEY STRING KEY, COL0 BIGINT, COL1 STRING, COL2 " + assertThat(lines[0], equalTo(" > [ SINK ] | Schema: [ROWKEY STRING KEY, COL0 BIGINT, COL1 STRING, COL2 " + "DOUBLE] | Logger: InsertQuery_1.S1")); + assertThat(lines[1], + equalTo("\t\t > [ REKEY ] | Schema: [ROWKEY STRING KEY, COL0 BIGINT, COL1 STRING, COL2 DOUBLE] " + + "| Logger: InsertQuery_1.S1")); assertThat(lines[2], equalTo("\t\t\t\t > [ PROJECT ] | Schema: [ROWKEY STRING KEY, COL0 BIGINT, COL1 STRING" + ", COL2 DOUBLE] | Logger: InsertQuery_1.Project")); } diff --git a/ksql-engine/src/test/java/io/confluent/ksql/physical/TransientQueryQueueTest.java b/ksql-engine/src/test/java/io/confluent/ksql/physical/TransientQueryQueueTest.java index f61890c304f9..a304caa47af3 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/physical/TransientQueryQueueTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/physical/TransientQueryQueueTest.java @@ -27,7 +27,7 @@ import io.confluent.ksql.GenericRow; import io.confluent.ksql.physical.TransientQueryQueue.QueuePopulator; -import io.confluent.ksql.structured.QueuedSchemaKStream; +import io.confluent.ksql.structured.SchemaKStream; import java.util.OptionalInt; import java.util.Queue; import java.util.stream.IntStream; @@ -54,7 +54,7 @@ public class TransientQueryQueueTest { @Mock private KStream kStreamsApp; @Mock - private QueuedSchemaKStream queuedKStream; + private SchemaKStream queuedKStream; @Captor private ArgumentCaptor> queuePopulatorCaptor; private Queue> queue; diff --git a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/DataSourceNodeTest.java b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/DataSourceNodeTest.java index b66c63b6173a..6d55741d0a99 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/DataSourceNodeTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/DataSourceNodeTest.java @@ -215,7 +215,7 @@ public void before() { when(kGroupedStream.aggregate(any(), any(), any())).thenReturn(kTable); when(schemaKStreamFactory.create(any(), any(), any(), any(), anyInt(), any(), any())) .thenReturn(stream); - when(stream.toTable(any(), any())).thenReturn(table); + when(stream.toTable(any(), any(), any(), any())).thenReturn(table); } @Test @@ -483,13 +483,13 @@ public void shouldBuildSourceStreamWithCorrectParams() { same(ksqlStreamBuilder), same(dataSource), eq(StreamSource.getSchemaWithMetaAndKeyFields("name", REAL_SCHEMA)), - queryContextCaptor.capture(), + stackerCaptor.capture(), eq(3), eq(OFFSET_RESET), same(node.getKeyField()) ); assertThat( - queryContextCaptor.getValue().getContext(), + stackerCaptor.getValue().getQueryContext().getContext(), equalTo(ImmutableList.of("0", "source")) ); } @@ -508,13 +508,13 @@ public void shouldBuildSourceStreamWithCorrectParamsWhenBuildingTable() { same(ksqlStreamBuilder), same(dataSource), eq(StreamSource.getSchemaWithMetaAndKeyFields("name", REAL_SCHEMA)), - queryContextCaptor.capture(), + stackerCaptor.capture(), eq(3), eq(OFFSET_RESET), same(node.getKeyField()) ); assertThat( - queryContextCaptor.getValue().getContext(), + stackerCaptor.getValue().getQueryContext().getContext(), equalTo(ImmutableList.of("0", "source")) ); } @@ -529,7 +529,7 @@ public void shouldBuildTableByConvertingFromStream() { final SchemaKStream returned = node.buildStream(ksqlStreamBuilder); // Then: - verify(stream).toTable(any(), any()); + verify(stream).toTable(any(), any(), any(), any()); assertThat(returned, is(table)); } @@ -552,7 +552,7 @@ public void shouldBuildReduceSerdeCorrectlyWhenBuildingTable() { queryContextCaptor.getValue().getContext(), equalTo(ImmutableList.of("0", "reduce")) ); - verify(stream).toTable(same(rowSerde), any()); + verify(stream).toTable(any(), any(), same(rowSerde), any()); } @Test @@ -565,7 +565,7 @@ public void shouldBuildTableWithCorrectContext() { node.buildStream(ksqlStreamBuilder); // Then: - verify(stream).toTable(any(), stackerCaptor.capture()); + verify(stream).toTable(any(), any(), any(), stackerCaptor.capture()); assertThat( stackerCaptor.getValue().getQueryContext().getContext(), equalTo(ImmutableList.of("0", "reduce"))); diff --git a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/JoinNodeTest.java b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/JoinNodeTest.java index 98e15a0d5e71..6d58f38d7132 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/JoinNodeTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/JoinNodeTest.java @@ -122,6 +122,7 @@ public class JoinNodeTest { private static final Optional NO_KEY_FIELD = Optional.empty(); private static final ValueFormat VALUE_FORMAT = ValueFormat.of(FormatInfo.of(Format.JSON)); + private static final ValueFormat OTHER_FORMAT = ValueFormat.of(FormatInfo.of(Format.DELIMITED)); private final KsqlConfig ksqlConfig = new KsqlConfig(new HashMap<>()); private StreamsBuilder builder; private JoinNode joinNode; @@ -206,8 +207,8 @@ public void setUp() { when(left.getPartitions(mockKafkaTopicClient)).thenReturn(2); when(right.getPartitions(mockKafkaTopicClient)).thenReturn(2); - setUpSource(left, leftSource, "Foobar1"); - setUpSource(right, rightSource, "Foobar2"); + setUpSource(left, VALUE_FORMAT, leftSource, "Foobar1"); + setUpSource(right, OTHER_FORMAT, rightSource, "Foobar2"); when(leftSchemaKStream.getKeyField()).thenReturn(leftJoinField); when(leftSchemaKTable.getKeyField()).thenReturn(leftJoinField); @@ -360,6 +361,8 @@ public void shouldPerformStreamToStreamLeftJoin() { eq(JOIN_SCHEMA), eq(leftJoinField), eq(WITHIN_EXPRESSION.get().joinWindow()), + eq(VALUE_FORMAT), + eq(OTHER_FORMAT), any(), any(), eq(CONTEXT_STACKER)); @@ -390,6 +393,8 @@ public void shouldPerformStreamToStreamInnerJoin() { eq(JOIN_SCHEMA), eq(leftJoinField), eq(WITHIN_EXPRESSION.get().joinWindow()), + eq(VALUE_FORMAT), + eq(OTHER_FORMAT), any(), any(), eq(CONTEXT_STACKER)); @@ -420,6 +425,8 @@ public void shouldPerformStreamToStreamOuterJoin() { eq(JOIN_SCHEMA), eq(leftJoinField.withName(Optional.empty())), eq(WITHIN_EXPRESSION.get().joinWindow()), + eq(VALUE_FORMAT), + eq(OTHER_FORMAT), any(), any(), eq(CONTEXT_STACKER)); @@ -560,6 +567,7 @@ public void shouldHandleJoinIfTableHasNoKeyAndJoinFieldIsRowKey() { eq(rightSchemaKTable), eq(JOIN_SCHEMA), eq(leftJoinField), + eq(VALUE_FORMAT), any(), eq(CONTEXT_STACKER)); } @@ -588,6 +596,7 @@ public void shouldPerformStreamToTableLeftJoin() { eq(rightSchemaKTable), eq(JOIN_SCHEMA), eq(leftJoinField), + eq(VALUE_FORMAT), any(), eq(CONTEXT_STACKER)); } @@ -616,6 +625,7 @@ public void shouldPerformStreamToTableInnerJoin() { eq(rightSchemaKTable), eq(JOIN_SCHEMA), eq(leftJoinField), + eq(VALUE_FORMAT), any(), eq(CONTEXT_STACKER)); } @@ -1108,6 +1118,7 @@ private static String getNonKeyColumn( @SuppressWarnings("unchecked") private static void setUpSource( final DataSourceNode node, + final ValueFormat valueFormat, final DataSource dataSource, final String name ) { @@ -1115,7 +1126,7 @@ private static void setUpSource( when(node.getDataSource()).thenReturn((DataSource)dataSource); final KsqlTopic ksqlTopic = mock(KsqlTopic.class); - when(ksqlTopic.getValueFormat()).thenReturn(VALUE_FORMAT); + when(ksqlTopic.getValueFormat()).thenReturn(valueFormat); when(dataSource.getKsqlTopic()).thenReturn(ksqlTopic); } } diff --git a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/KsqlStructuredDataOutputNodeTest.java b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/KsqlStructuredDataOutputNodeTest.java index 9786c0012010..4cf45a6a4262 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/KsqlStructuredDataOutputNodeTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/planner/plan/KsqlStructuredDataOutputNodeTest.java @@ -22,10 +22,12 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; @@ -85,6 +87,7 @@ public class KsqlStructuredDataOutputNodeTest { private static final KeyField KEY_FIELD = KeyField.of("key", SCHEMA.findValueField("key").get()); private static final PlanNodeId PLAN_NODE_ID = new PlanNodeId("0"); + private static final ValueFormat JSON_FORMAT = ValueFormat.of(FormatInfo.of(Format.JSON)); @Rule public final ExpectedException expectedException = ExpectedException.none(); @@ -104,8 +107,12 @@ public class KsqlStructuredDataOutputNodeTest { @Mock private SchemaKStream resultStream; @Mock + private SchemaKStream sinkStream; + @Mock private SchemaKStream resultWithKeySelected; @Mock + private SchemaKStream sinkStreamWithKeySelected; + @Mock private KStream kstream; @Mock private KsqlTopic ksqlTopic; @@ -113,6 +120,8 @@ public class KsqlStructuredDataOutputNodeTest { private Serde rowSerde; @Captor private ArgumentCaptor queryContextCaptor; + @Captor + private ArgumentCaptor stackerCaptor; private final Set serdeOptions = SerdeOption.none(); @@ -135,18 +144,21 @@ public void before() { when(sourceStream.getKeyField()).thenReturn(KeyField.none()); - when(sourceStream.sink(any(), any())) + when(sourceStream.withKeyField(any())) .thenReturn(resultStream); - + when(resultStream.into(any(), any(), any(), any(), any(), any(), any())) + .thenReturn((SchemaKStream) sinkStream); when(resultStream.selectKey(any(), anyBoolean(), any())) .thenReturn((SchemaKStream) resultWithKeySelected); + when(resultWithKeySelected.into(any(), any(), any(), any(), any(), any(), any())) + .thenReturn((SchemaKStream) sinkStreamWithKeySelected); when(ksqlStreamBuilder.buildValueSerde(any(), any(), any())).thenReturn(rowSerde); when(ksqlStreamBuilder.buildNodeContext(any())).thenAnswer(inv -> new QueryContext.Stacker(QUERY_ID) .push(inv.getArgument(0).toString())); when(ksqlTopic.getKafkaTopicName()).thenReturn(SINK_KAFKA_TOPIC_NAME); - when(ksqlTopic.getValueFormat()).thenReturn(ValueFormat.of(FormatInfo.of(Format.JSON))); + when(ksqlTopic.getValueFormat()).thenReturn(JSON_FORMAT); buildNode(); } @@ -205,7 +217,7 @@ public void shouldBuildMapNodePriorToOutput() { inOrder.verify(sourceNode) .buildStream(any()); - inOrder.verify(sourceStream).sink(any(), any()); + inOrder.verify(sourceStream).withKeyField(any()); } @Test @@ -214,9 +226,8 @@ public void shouldBuildOutputNode() { outputNode.buildStream(ksqlStreamBuilder); // Then: - verify(sourceStream).sink( - KEY_FIELD.withName(Optional.empty()), - new QueryContext.Stacker(QUERY_ID).push(PLAN_NODE_ID.toString()) + verify(sourceStream).withKeyField( + KEY_FIELD.withName(Optional.empty()) ); } @@ -235,7 +246,7 @@ public void shouldPartitionByFieldNameInPartitionByProperty() { new QueryContext.Stacker(QUERY_ID).push(PLAN_NODE_ID.toString()) ); - assertThat(result, is(sameInstance(resultWithKeySelected))); + assertThat(result, is(sameInstance(sinkStreamWithKeySelected))); } @Test @@ -253,7 +264,7 @@ public void shouldPartitionByRowKey() { new QueryContext.Stacker(QUERY_ID).push(PLAN_NODE_ID.toString()) ); - assertThat(result, is(sameInstance(resultWithKeySelected))); + assertThat(result, is(sameInstance(sinkStreamWithKeySelected))); } @Test @@ -271,7 +282,7 @@ public void shouldPartitionByRowTime() { new QueryContext.Stacker(QUERY_ID).push(PLAN_NODE_ID.toString()) ); - assertThat(result, is(sameInstance(resultWithKeySelected))); + assertThat(result, is(sameInstance(sinkStreamWithKeySelected))); } @Test @@ -350,14 +361,23 @@ public void shouldBuildRowSerdeCorrectly() { @Test public void shouldCallInto() { // When: - outputNode.buildStream(ksqlStreamBuilder); + final SchemaKStream result = outputNode.buildStream(ksqlStreamBuilder); // Then: verify(resultStream).into( - SINK_KAFKA_TOPIC_NAME, - rowSerde, - ImmutableSet.of() + eq(SINK_KAFKA_TOPIC_NAME), + same(rowSerde), + eq(SCHEMA), + eq(JSON_FORMAT), + eq(SerdeOption.none()), + eq(ImmutableSet.of()), + stackerCaptor.capture() ); + assertThat( + stackerCaptor.getValue().getQueryContext().getContext(), + equalTo(ImmutableList.of("0")) + ); + assertThat(result, sameInstance(sinkStream)); } @Test @@ -371,9 +391,13 @@ public void shouldCallIntoWithIndexesToRemoveImplicitsAndRowKey() { // Then: verify(resultStream).into( - SINK_KAFKA_TOPIC_NAME, - rowSerde, - ImmutableSet.of(0, 1) + eq(SINK_KAFKA_TOPIC_NAME), + same(rowSerde), + any(), + any(), + any(), + eq(ImmutableSet.of(0, 1)), + any() ); } @@ -397,9 +421,13 @@ public void shouldCallIntoWithIndexesToRemoveImplicitsAndRowKeyRegardlessOfLocat // Then: verify(resultStream).into( - SINK_KAFKA_TOPIC_NAME, - rowSerde, - ImmutableSet.of(2, 5) + eq(SINK_KAFKA_TOPIC_NAME), + same(rowSerde), + any(), + any(), + any(), + eq(ImmutableSet.of(2, 5)), + any() ); } diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedStreamTest.java b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedStreamTest.java index 0433373bb986..3d17bf272e7f 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedStreamTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedStreamTest.java @@ -15,8 +15,10 @@ package io.confluent.ksql.structured; +import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.sameInstance; import static org.mockito.ArgumentMatchers.any; @@ -28,9 +30,14 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.plan.ExecutionStep; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.function.KsqlAggregateFunction; import io.confluent.ksql.metastore.model.KeyField; @@ -40,7 +47,12 @@ import io.confluent.ksql.query.QueryId; import io.confluent.ksql.schema.ksql.Field; import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.serde.Format; +import io.confluent.ksql.serde.FormatInfo; +import io.confluent.ksql.serde.KeyFormat; import io.confluent.ksql.serde.KeySerde; +import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.serde.ValueFormat; import io.confluent.ksql.serde.WindowInfo; import io.confluent.ksql.streams.MaterializedFactory; import io.confluent.ksql.streams.StreamsUtil; @@ -70,8 +82,6 @@ @SuppressWarnings("unchecked") @RunWith(MockitoJUnitRunner.class) public class SchemaKGroupedStreamTest { - @Mock - private LogicalSchema schema; @Mock private LogicalSchema aggregateSchema; @Mock @@ -95,6 +105,8 @@ public class SchemaKGroupedStreamTest { @Mock private KsqlAggregateFunction otherFunc; @Mock + private FunctionCall aggCall; + @Mock private KTable table; @Mock private KTable table2; @@ -112,6 +124,12 @@ public class SchemaKGroupedStreamTest { private KeySerde> windowedKeySerde; @Mock private Field field; + @Mock + private ExecutionStep sourceStep; + @Mock + private KeyFormat keyFormat; + @Mock + private ValueFormat valueFormat; private final QueryContext.Stacker queryContext = new QueryContext.Stacker(new QueryId("query")).push("node"); private SchemaKGroupedStream schemaGroupedStream; @@ -119,7 +137,14 @@ public class SchemaKGroupedStreamTest { @Before public void setUp() { schemaGroupedStream = new SchemaKGroupedStream( - groupedStream, schema, keySerde, keyField, sourceStreams, config, funcRegistry, + groupedStream, + sourceStep, + keyFormat, + keySerde, + keyField, + sourceStreams, + config, + funcRegistry, materializedFactory ); @@ -186,8 +211,10 @@ public void shouldSupportSessionWindowedKey() { aggregateSchema, initializer, 0, + emptyList(), emptyMap(), windowExp, + valueFormat, topicValueSerDe, queryContext ); @@ -210,8 +237,10 @@ public void shouldSupportHoppingWindowedKey() { aggregateSchema, initializer, 0, + emptyList(), emptyMap(), windowExp, + valueFormat, topicValueSerDe, queryContext ); @@ -234,8 +263,10 @@ public void shouldSupportTumblingWindowedKey() { aggregateSchema, initializer, 0, + emptyList(), emptyMap(), windowExp, + valueFormat, topicValueSerDe, queryContext ); @@ -256,8 +287,10 @@ public void shouldUseTimeWindowKeySerdeForWindowedIfLegacyConfig() { aggregateSchema, initializer, 0, + emptyList(), emptyMap(), windowExp, + valueFormat, topicValueSerDe, queryContext ); @@ -287,9 +320,11 @@ private void assertDoesNotInstallWindowSelectMapper( final SchemaKTable result = schemaGroupedStream.aggregate( aggregateSchema, initializer, - 0, funcMap, - + 0, + emptyList(), + funcMap, windowExp, + valueFormat, topicValueSerDe, queryContext ); @@ -315,9 +350,11 @@ private void assertDoesInstallWindowSelectMapper( final SchemaKTable result = schemaGroupedStream.aggregate( aggregateSchema, initializer, - 0, funcMap, - + 0, + emptyList(), + funcMap, windowExp, + valueFormat, topicValueSerDe, queryContext ); @@ -348,8 +385,10 @@ public void shouldUseMaterializedFactoryForStateStore() { aggregateSchema, () -> null, 0, + emptyList(), Collections.emptyMap(), null, + valueFormat, topicValueSerDe, queryContext ); @@ -376,8 +415,10 @@ public void shouldUseMaterializedFactoryWindowedStateStore() { aggregateSchema, () -> null, 0, + emptyList(), Collections.emptyMap(), windowExp, + valueFormat, topicValueSerDe, queryContext); @@ -397,8 +438,10 @@ public void shouldReturnKTableWithAggregateSchema() { aggregateSchema, initializer, 0, + emptyList(), emptyMap(), windowExp, + valueFormat, topicValueSerDe, queryContext ); @@ -407,6 +450,77 @@ public void shouldReturnKTableWithAggregateSchema() { assertThat(result.getSchema(), is(aggregateSchema)); } + @Test + public void shouldBuildStepForAggregate() { + // Given: + final Map functions = ImmutableMap.of(1, otherFunc); + when(aggregateSchema.valueFields()) + .thenReturn(ImmutableList.of(mock(Field.class), mock(Field.class))); + + // When: + final SchemaKTable result = schemaGroupedStream.aggregate( + aggregateSchema, + initializer, + 1, + ImmutableList.of(aggCall), + functions, + null, + valueFormat, + topicValueSerDe, + queryContext + ); + + // Then: + assertThat( + result.getSourceTableStep(), + equalTo( + ExecutionStepFactory.streamAggregate( + queryContext, + schemaGroupedStream.getSourceStep(), + aggregateSchema, + Formats.of(keyFormat, valueFormat, SerdeOption.none()), + 1, + ImmutableList.of(aggCall) + ) + ) + ); + } + + @Test + public void shouldBuildStepKeyFormatForWindowedAggregate() { + // When: + final SchemaKTable result = schemaGroupedStream.aggregate( + aggregateSchema, + initializer, + 0, + emptyList(), + Collections.emptyMap(), + windowExp, + valueFormat, + topicValueSerDe, + queryContext + ); + + // Then: + final KeyFormat expected = KeyFormat.windowed( + FormatInfo.of(Format.KAFKA), + WindowInfo.of(WindowType.SESSION, Optional.empty()) + ); + assertThat( + result.getSourceTableStep(), + equalTo( + ExecutionStepFactory.streamAggregate( + queryContext, + schemaGroupedStream.getSourceStep(), + aggregateSchema, + Formats.of(expected, valueFormat, SerdeOption.none()), + 0, + Collections.emptyList() + ) + ) + ); + } + @Test(expected = IllegalArgumentException.class) public void shouldThrowOnColumnCountMismatch() { // Given: @@ -421,8 +535,10 @@ public void shouldThrowOnColumnCountMismatch() { aggregateSchema, initializer, 2, + ImmutableList.of(aggCall), aggColumns, windowExp, + valueFormat, topicValueSerDe, queryContext ); diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java index d3c23984c83e..c36b6c91fd63 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java @@ -15,6 +15,7 @@ package io.confluent.ksql.structured; +import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; import static org.easymock.EasyMock.anyObject; import static org.easymock.EasyMock.eq; @@ -30,16 +31,23 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.confluent.kafka.schemaregistry.client.MockSchemaRegistryClient; import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.expression.tree.DereferenceExpression; import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.expression.tree.QualifiedName; import io.confluent.ksql.execution.expression.tree.QualifiedNameReference; +import io.confluent.ksql.execution.plan.DefaultExecutionStepProperties; +import io.confluent.ksql.execution.plan.ExecutionStep; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.function.InternalFunctionRegistry; import io.confluent.ksql.function.KsqlAggregateFunction; +import io.confluent.ksql.function.TableAggregationFunction; import io.confluent.ksql.function.udaf.KudafInitializer; import io.confluent.ksql.logging.processing.NoopProcessingLogContext; import io.confluent.ksql.logging.processing.ProcessingLogContext; @@ -56,7 +64,10 @@ import io.confluent.ksql.serde.Format; import io.confluent.ksql.serde.FormatInfo; import io.confluent.ksql.serde.GenericRowSerDe; +import io.confluent.ksql.serde.KeyFormat; import io.confluent.ksql.serde.KeySerde; +import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.serde.ValueFormat; import io.confluent.ksql.streams.MaterializedFactory; import io.confluent.ksql.streams.StreamsUtil; import io.confluent.ksql.testutils.AnalysisTestUtil; @@ -90,6 +101,7 @@ import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; @SuppressWarnings("unchecked") @@ -107,6 +119,8 @@ public class SchemaKGroupedTableTest { private final MetaStore metaStore = MetaStoreFixture.getNewMetaStore(new InternalFunctionRegistry()); private final QueryContext.Stacker queryContext = new QueryContext.Stacker(new QueryId("query")).push("node"); + private final ValueFormat valueFormat = ValueFormat.of(FormatInfo.of(Format.JSON)); + private final KeyFormat keyFormat = KeyFormat.nonWindowed(FormatInfo.of(Format.JSON)); @Rule public final ExpectedException expectedException = ExpectedException.none(); @@ -120,9 +134,15 @@ public class SchemaKGroupedTableTest { @Mock private Serde topicValueSerDe; @Mock + private FunctionCall aggCall1; + @Mock + private FunctionCall aggCall2; + @Mock private Field field; @Mock private KsqlAggregateFunction otherFunc; + @Mock + private TableAggregationFunction tableFunc; private KTable kTable; private KsqlTable ksqlTable; @@ -150,6 +170,14 @@ public void init() { .thenReturn(Optional.of(Field.of("GROUPING_COLUMN", SqlTypes.STRING))); } + private ExecutionStep buildSourceTableStep(final LogicalSchema schema) { + final ExecutionStep step = Mockito.mock(ExecutionStep.class); + when(step.getProperties()).thenReturn( + new DefaultExecutionStepProperties(schema, queryContext.getQueryContext()) + ); + return step; + } + private SchemaKGroupedTable buildSchemaKGroupedTableFromQuery( final String query, final String...groupByColumns @@ -158,16 +186,16 @@ private SchemaKGroupedTable buildSchemaKGroupedTableFromQuery( final PlanNode logicalPlan = AnalysisTestUtil.buildLogicalPlan(ksqlConfig, query, metaStore); - final SchemaKTable initialSchemaKTable = new SchemaKTable<>( + final SchemaKTable initialSchemaKTable = new SchemaKTable( kTable, - logicalPlan.getTheSourceNode().getSchema(), + buildSourceTableStep(logicalPlan.getTheSourceNode().getSchema()), + keyFormat, keySerde, logicalPlan.getTheSourceNode().getKeyField(), new ArrayList<>(), SchemaKStream.Type.SOURCE, ksqlConfig, - functionRegistry, - queryContext.getQueryContext()); + functionRegistry); final List groupByExpressions = Arrays.stream(groupByColumns) @@ -184,7 +212,7 @@ private SchemaKGroupedTable buildSchemaKGroupedTableFromQuery( ); final SchemaKGroupedStream groupedSchemaKTable = initialSchemaKTable.groupBy( - rowSerde, groupByExpressions, queryContext); + valueFormat, rowSerde, groupByExpressions, queryContext); Assert.assertThat(groupedSchemaKTable, instanceOf(SchemaKGroupedTable.class)); return (SchemaKGroupedTable)groupedSchemaKTable; } @@ -206,8 +234,10 @@ public void shouldFailWindowedTableAggregation() { aggregateSchema, initializer, 0, + emptyList(), emptyMap(), windowExp, + valueFormat, topicValueSerDe, queryContext ); @@ -231,8 +261,10 @@ public void shouldFailUnsupportedAggregateFunction() { aggregateSchema, new KudafInitializer(1), 1, + ImmutableList.of(aggCall1, aggCall2), aggValToFunctionMap, null, + valueFormat, GenericRowSerDe.from( FormatInfo.of(Format.JSON, Optional.empty()), PersistenceSchema.from(ksqlTable.getSchema().valueSchema(), false), @@ -257,7 +289,8 @@ private SchemaKGroupedTable buildSchemaKGroupedTable( ) { return new SchemaKGroupedTable( kGroupedTable, - schema, + buildSourceTableStep(schema), + keyFormat, keySerde, KeyField.of(schema.valueFields().get(0).name(), schema.valueFields().get(0)), Collections.emptyList(), @@ -297,8 +330,10 @@ public void shouldUseMaterializedFactoryForStateStore() { aggregateSchema, () -> null, 0, + emptyList(), Collections.emptyMap(), null, + valueFormat, valueSerde, queryContext); @@ -306,6 +341,44 @@ public void shouldUseMaterializedFactoryForStateStore() { verify(materializedFactory, mockKGroupedTable); } + @Test + public void shouldBuildStepForAggregate() { + // Given: + final Map functions = ImmutableMap.of(1, tableFunc); + final SchemaKGroupedTable groupedTable = + buildSchemaKGroupedTable(mockKGroupedTable, materializedFactory); + when(aggregateSchema.valueFields()).thenReturn( + ImmutableList.of(Mockito.mock(Field.class), Mockito.mock(Field.class))); + + // When: + final SchemaKTable result = groupedTable.aggregate( + aggregateSchema, + initializer, + 1, + ImmutableList.of(aggCall1), + functions, + null, + valueFormat, + topicValueSerDe, + queryContext + ); + + // Then: + assertThat( + result.getSourceTableStep(), + equalTo( + ExecutionStepFactory.tableAggregate( + queryContext, + groupedTable.getSourceTableStep(), + aggregateSchema, + Formats.of(keyFormat, valueFormat, SerdeOption.none()), + 1, + ImmutableList.of(aggCall1) + ) + ) + ); + } + @Test public void shouldReturnKTableWithAggregateSchema() { // Given: @@ -317,8 +390,10 @@ public void shouldReturnKTableWithAggregateSchema() { aggregateSchema, initializer, 0, + emptyList(), emptyMap(), null, + valueFormat, topicValueSerDe, queryContext ); @@ -344,8 +419,10 @@ public void shouldThrowOnColumnCountMismatch() { aggregateSchema, initializer, 2, + ImmutableList.of(aggCall1), aggColumns, null, + valueFormat, topicValueSerDe, queryContext ); diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKStreamTest.java b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKStreamTest.java index cf2de96fb816..e938b3971759 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKStreamTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKStreamTest.java @@ -30,6 +30,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -42,7 +43,13 @@ import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.expression.tree.QualifiedName; import io.confluent.ksql.execution.expression.tree.QualifiedNameReference; +import io.confluent.ksql.execution.plan.DefaultExecutionStepProperties; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.JoinType; import io.confluent.ksql.execution.plan.SelectExpression; +import io.confluent.ksql.execution.plan.ExecutionStep; +import io.confluent.ksql.execution.plan.ExecutionStepProperties; +import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.function.InternalFunctionRegistry; import io.confluent.ksql.logging.processing.ProcessingLogContext; import io.confluent.ksql.metastore.MetaStore; @@ -64,7 +71,10 @@ import io.confluent.ksql.serde.Format; import io.confluent.ksql.serde.FormatInfo; import io.confluent.ksql.serde.GenericRowSerDe; +import io.confluent.ksql.serde.KeyFormat; import io.confluent.ksql.serde.KeySerde; +import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.serde.ValueFormat; import io.confluent.ksql.streams.GroupedFactory; import io.confluent.ksql.streams.JoinedFactory; import io.confluent.ksql.streams.MaterializedFactory; @@ -74,6 +84,7 @@ import io.confluent.ksql.testutils.AnalysisTestUtil; import io.confluent.ksql.util.KsqlConfig; import io.confluent.ksql.util.MetaStoreFixture; +import io.confluent.ksql.util.Pair; import io.confluent.ksql.util.SchemaUtil; import java.time.Duration; import java.util.ArrayList; @@ -116,7 +127,6 @@ public class SchemaKStreamTest { private static final Expression COL1 = new DereferenceExpression( new QualifiedNameReference(QualifiedName.of("TEST1")), "COL1"); - private final MockSchemaRegistryClient schemaRegistryClient = new MockSchemaRegistryClient(); private SchemaKStream initialSchemaKStream; private final KsqlConfig ksqlConfig = new KsqlConfig(Collections.emptyMap()); @@ -126,7 +136,6 @@ public class SchemaKStreamTest { "group", Serdes.String(), Serdes.String()); private final Joined joined = Joined.with( Serdes.String(), Serdes.String(), Serdes.String(), "join"); - private final KeyField validJoinKeyField = KeyField.of( Optional.of("left.COL1"), metaStore.getSource("TEST1") @@ -143,6 +152,9 @@ public class SchemaKStreamTest { private Serde rightSerde; private LogicalSchema joinSchema; private Serde rowSerde; + private KeyFormat keyFormat = KeyFormat.nonWindowed(FormatInfo.of(Format.JSON)); + private ValueFormat valueFormat = ValueFormat.of(FormatInfo.of(Format.JSON)); + private ValueFormat rightFormat = ValueFormat.of(FormatInfo.of(Format.DELIMITED)); private final LogicalSchema simpleSchema = LogicalSchema.builder() .valueField("key", SqlTypes.STRING) .valueField("val", SqlTypes.BIGINT) @@ -167,6 +179,16 @@ public class SchemaKStreamTest { private KeySerde keySerde; @Mock private KeySerde reboundKeySerde; + @Mock + private KeySerde windowedKeySerde; + @Mock + private ExecutionStepProperties tableSourceProperties; + @Mock + private ExecutionStep tableSourceStep; + @Mock + private ExecutionStepProperties sourceProperties; + @Mock + private ExecutionStep sourceStep; @Before public void init() { @@ -199,20 +221,25 @@ public void init() { Serdes.String(), getRowSerde(ksqlTable.getKsqlTopic(), ksqlTable.getSchema().valueSchema()))); + when(tableSourceStep.getProperties()).thenReturn(tableSourceProperties); + when(tableSourceProperties.getSchema()).thenReturn(ksqlTable.getSchema()); + when(sourceStep.getProperties()).thenReturn(sourceProperties); + secondSchemaKStream = buildSchemaKStreamForJoin(secondKsqlStream, secondKStream); leftSerde = getRowSerde(ksqlStream.getKsqlTopic(), ksqlStream.getSchema().valueSchema()); rightSerde = getRowSerde(secondKsqlStream.getKsqlTopic(), secondKsqlStream.getSchema().valueSchema()); - schemaKTable = new SchemaKTable<>( - kTable, ksqlTable.getSchema(), + schemaKTable = new SchemaKTable( + kTable, + tableSourceStep, + keyFormat, keySerde, ksqlTable.getKeyField(), new ArrayList<>(), SchemaKStream.Type.SOURCE, ksqlConfig, - functionRegistry, - parentContext); + functionRegistry); joinSchema = getJoinSchema(ksqlStream.getSchema(), secondKsqlStream.getSchema()); @@ -257,6 +284,34 @@ public void testSelectSchemaKStream() { assertThat(projectedSchemaKStream.getSourceSchemaKStreams().get(0), is(initialSchemaKStream)); } + @Test + public void shouldBuildStepForSelect() { + // Given: + final PlanNode logicalPlan = givenInitialKStreamOf( + "SELECT col0, col2, col3 FROM test1 WHERE col0 > 100;"); + final ProjectNode projectNode = (ProjectNode) logicalPlan.getSources().get(0); + final List selectExpressions = projectNode.getProjectSelectExpressions(); + + // When: + final SchemaKStream projectedSchemaKStream = initialSchemaKStream.select( + selectExpressions, + childContextStacker, + processingLogContext); + + // Then: + assertThat( + projectedSchemaKStream.getSourceStep(), + equalTo( + ExecutionStepFactory.streamMapValues( + childContextStacker, + initialSchemaKStream.getSourceStep(), + selectExpressions, + projectedSchemaKStream.getSchema() + ) + ) + ); + } + @Test public void shouldUpdateKeyIfRenamed() { // Given: @@ -422,6 +477,32 @@ public void testFilter() { assertThat(filteredSchemaKStream.getSourceSchemaKStreams().get(0), is(initialSchemaKStream)); } + @Test + public void shouldBuildStepForFilter() { + // Given: + final PlanNode logicalPlan = givenInitialKStreamOf( + "SELECT col0, col2, col3 FROM test1 WHERE col0 > 100;"); + final FilterNode filterNode = (FilterNode) logicalPlan.getSources().get(0).getSources().get(0); + + // When: + final SchemaKStream filteredSchemaKStream = initialSchemaKStream.filter( + filterNode.getPredicate(), + childContextStacker, + processingLogContext); + + // Then: + assertThat( + filteredSchemaKStream.getSourceStep(), + equalTo( + ExecutionStepFactory.streamFilter( + childContextStacker, + initialSchemaKStream.getSourceStep(), + filterNode.getPredicate() + ) + ) + ); + } + @Test public void shouldSelectKey() { // Given: @@ -443,6 +524,31 @@ public void shouldSelectKey() { assertThat(rekeyedSchemaKStream.getKeySerde(), is(reboundKeySerde)); } + @Test + public void shouldBuildStepForSelectKey() { + // Given: + givenInitialKStreamOf("SELECT col0, col2, col3 FROM test1 WHERE col0 > 100;"); + + // When: + final SchemaKStream rekeyedSchemaKStream = initialSchemaKStream.selectKey( + "TEST1.COL1", + true, + childContextStacker); + + // Then: + assertThat( + rekeyedSchemaKStream.getSourceStep(), + equalTo( + ExecutionStepFactory.streamSelectKey( + childContextStacker, + initialSchemaKStream.getSourceStep(), + "TEST1.COL1", + true + ) + ) + ); + } + @Test(expected = IllegalArgumentException.class) public void shouldThrowOnSelectKeyIfKeyNotInSchema() { givenInitialKStreamOf("SELECT col0, col2, col3 FROM test1 WHERE col0 > 100;"); @@ -467,6 +573,7 @@ public void testGroupByKey() { // When: final SchemaKGroupedStream groupedSchemaKStream = initialSchemaKStream.groupBy( + valueFormat, rowSerde, groupBy, childContextStacker); @@ -476,6 +583,37 @@ public void testGroupByKey() { assertThat(groupedSchemaKStream.getKeyField().legacy(), OptionalMatchers.of(hasName("COL0"))); } + @Test + public void shouldBuildStepForGroupBy() { + // Given: + givenInitialKStreamOf("SELECT col0, col1 FROM test1 WHERE col0 > 100;"); + final List groupBy = Collections.singletonList( + new DereferenceExpression( + new QualifiedNameReference(QualifiedName.of("TEST1")), "COL0") + ); + + // When: + final SchemaKGroupedStream groupedSchemaKStream = initialSchemaKStream.groupBy( + valueFormat, + rowSerde, + groupBy, + childContextStacker); + + // Then: + final KeyFormat expectedKeyFormat = KeyFormat.nonWindowed(keyFormat.getFormatInfo()); + assertThat( + groupedSchemaKStream.getSourceStep(), + equalTo( + ExecutionStepFactory.streamGroupBy( + childContextStacker, + initialSchemaKStream.getSourceStep(), + Formats.of(expectedKeyFormat, valueFormat, SerdeOption.none()), + groupBy + ) + ) + ); + } + @Test public void testGroupByMultipleColumns() { // Given: @@ -490,6 +628,7 @@ public void testGroupByMultipleColumns() { // When: final SchemaKGroupedStream groupedSchemaKStream = initialSchemaKStream.groupBy( + valueFormat, rowSerde, groupBy, childContextStacker); @@ -508,6 +647,7 @@ public void testGroupByMoreComplexExpression() { // When: final SchemaKGroupedStream groupedSchemaKStream = initialSchemaKStream.groupBy( + valueFormat, rowSerde, ImmutableList.of(groupBy), childContextStacker); @@ -530,6 +670,7 @@ public void shouldUseFactoryForGroupedWithoutRekey() { // When: initialSchemaKStream.groupBy( + valueFormat, leftSerde, groupByExpressions, childContextStacker); @@ -559,6 +700,7 @@ public void shouldUseFactoryForGrouped() { // When: initialSchemaKStream.groupBy( + valueFormat, leftSerde, groupByExpressions, childContextStacker); @@ -571,6 +713,37 @@ public void shouldUseFactoryForGrouped() { verify(mockKStream).groupBy(any(KeyValueMapper.class), same(grouped)); } + @Test + public void shouldBuildStepForToTable() { + // Given: + givenInitialSchemaKStreamUsesMocks(); + when(mockKStream.mapValues(any(ValueMapper.class))).thenReturn(mockKStream); + final KGroupedStream groupedStream = mock(KGroupedStream.class); + final KTable table = mock(KTable.class); + when(mockKStream.groupByKey()).thenReturn(groupedStream); + when(groupedStream.aggregate(any(), any(), any())).thenReturn(table); + + // When: + final SchemaKTable result = initialSchemaKStream.toTable( + keyFormat, + valueFormat, + leftSerde, + childContextStacker + ); + + // Then: + assertThat( + result.getSourceTableStep(), + equalTo( + ExecutionStepFactory.streamToTable( + childContextStacker, + Formats.of(keyFormat, valueFormat, SerdeOption.none()), + sourceStep + ) + ) + ); + } + @Test public void shouldConvertToTableWithCorrectProperties() { // Given: @@ -582,7 +755,12 @@ public void shouldConvertToTableWithCorrectProperties() { when(groupedStream.aggregate(any(), any(), any())).thenReturn(table); // When: - final SchemaKTable result = initialSchemaKStream.toTable(leftSerde, childContextStacker); + final SchemaKTable result = initialSchemaKStream.toTable( + keyFormat, + valueFormat, + leftSerde, + childContextStacker + ); // Then: assertThat(result.getSchema(), is(initialSchemaKStream.getSchema())); @@ -602,7 +780,7 @@ public void shouldConvertToOptionalBeforeGroupingInToTable() { when(groupedStream.aggregate(any(), any(), any())).thenReturn(table); // When: - initialSchemaKStream.toTable(leftSerde, childContextStacker); + initialSchemaKStream.toTable(keyFormat, valueFormat, leftSerde, childContextStacker); // Then: InOrder inOrder = Mockito.inOrder(mockKStream); @@ -625,7 +803,7 @@ public void shouldComputeAggregateCorrectlyInToTable() { when(groupedStream.aggregate(any(), any(), any())).thenReturn(table); // When: - initialSchemaKStream.toTable(leftSerde, childContextStacker); + initialSchemaKStream.toTable(keyFormat, valueFormat, leftSerde, childContextStacker); // Then: final ArgumentCaptor initCaptor = ArgumentCaptor.forClass(Initializer.class); @@ -655,13 +833,17 @@ public void shouldPerformStreamToStreamLeftJoin() { // When: final SchemaKStream joinedKStream = initialSchemaKStream - .leftJoin(secondSchemaKStream, - joinSchema, + .leftJoin( + secondSchemaKStream, + joinSchema, validJoinKeyField, - joinWindow, - leftSerde, - rightSerde, - childContextStacker); + joinWindow, + valueFormat, + valueFormat, + leftSerde, + rightSerde, + childContextStacker + ); // Then: verifyCreateJoined(rightSerde); @@ -673,12 +855,89 @@ public void shouldPerformStreamToStreamLeftJoin() { ); assertThat(joinedKStream, instanceOf(SchemaKStream.class)); assertEquals(SchemaKStream.Type.JOIN, joinedKStream.type); - assertEquals(joinSchema, joinedKStream.schema); + assertEquals(joinSchema, joinedKStream.getSchema()); assertThat(joinedKStream.getKeyField(), is(validJoinKeyField)); assertEquals(Arrays.asList(initialSchemaKStream, secondSchemaKStream), joinedKStream.sourceSchemaKStreams); } + @FunctionalInterface + private interface StreamStreamJoin { + SchemaKStream join( + SchemaKStream otherSchemaKStream, + LogicalSchema joinSchema, + KeyField keyField, + JoinWindows joinWindows, + ValueFormat leftFormat, + ValueFormat rightFormat, + Serde leftSerde, + Serde rightSerde, + QueryContext.Stacker contextStacker); + } + + @Test + public void shouldBuildStepForStreamStreamJoin() { + // Given: + final SchemaKStream initialSchemaKStream = + buildSchemaKStreamForJoin(ksqlStream, mockKStream, mockGroupedFactory, mockJoinedFactory); + final JoinWindows joinWindow = JoinWindows.of(Duration.ofMillis(10L)); + when(mockKStream.leftJoin( + any(KStream.class), + any(SchemaKStream.KsqlValueJoiner.class), + any(JoinWindows.class), + any(Joined.class)) + ).thenReturn(mockKStream); + when(mockKStream.join( + any(KStream.class), + any(SchemaKStream.KsqlValueJoiner.class), + any(JoinWindows.class), + any(Joined.class)) + ).thenReturn(mockKStream); + when(mockKStream.outerJoin( + any(KStream.class), + any(SchemaKStream.KsqlValueJoiner.class), + any(JoinWindows.class), + any(Joined.class)) + ).thenReturn(mockKStream); + + final List> cases = ImmutableList.of( + Pair.of(JoinType.LEFT, initialSchemaKStream::leftJoin), + Pair.of(JoinType.INNER, initialSchemaKStream::join), + Pair.of(JoinType.OUTER, initialSchemaKStream::outerJoin) + ); + + for (final Pair testcase : cases) { + final SchemaKStream joinedKStream = testcase.right.join( + secondSchemaKStream, + joinSchema, + validJoinKeyField, + joinWindow, + valueFormat, + rightFormat, + leftSerde, + rightSerde, + childContextStacker + ); + + // Then: + assertThat( + joinedKStream.getSourceStep(), + equalTo( + ExecutionStepFactory.streamStreamJoin( + childContextStacker, + testcase.left, + Formats.of(keyFormat, valueFormat, SerdeOption.none()), + Formats.of(keyFormat, rightFormat, SerdeOption.none()), + initialSchemaKStream.getSourceStep(), + secondSchemaKStream.getSourceStep(), + joinSchema, + joinWindow + ) + ) + ); + } + } + @SuppressWarnings("unchecked") @Test public void shouldPerformStreamToStreamInnerJoin() { @@ -696,12 +955,15 @@ public void shouldPerformStreamToStreamInnerJoin() { // When: final SchemaKStream joinedKStream = initialSchemaKStream - .join(secondSchemaKStream, - joinSchema, + .join( + secondSchemaKStream, + joinSchema, validJoinKeyField, - joinWindow, - leftSerde, - rightSerde, + joinWindow, + valueFormat, + valueFormat, + leftSerde, + rightSerde, childContextStacker); // Then: @@ -715,7 +977,7 @@ public void shouldPerformStreamToStreamInnerJoin() { assertThat(joinedKStream, instanceOf(SchemaKStream.class)); assertEquals(SchemaKStream.Type.JOIN, joinedKStream.type); - assertEquals(joinSchema, joinedKStream.schema); + assertEquals(joinSchema, joinedKStream.getSchema()); Assert.assertThat(joinedKStream.getKeyField(), is(validJoinKeyField)); assertEquals(Arrays.asList(initialSchemaKStream, secondSchemaKStream), joinedKStream.sourceSchemaKStreams); @@ -738,13 +1000,17 @@ public void shouldPerformStreamToStreamOuterJoin() { // When: final SchemaKStream joinedKStream = initialSchemaKStream - .outerJoin(secondSchemaKStream, - joinSchema, + .outerJoin( + secondSchemaKStream, + joinSchema, validJoinKeyField, - joinWindow, - leftSerde, - rightSerde, - childContextStacker); + joinWindow, + valueFormat, + valueFormat, + leftSerde, + rightSerde, + childContextStacker + ); // Then: verifyCreateJoined(rightSerde); @@ -756,7 +1022,7 @@ public void shouldPerformStreamToStreamOuterJoin() { ); assertThat(joinedKStream, instanceOf(SchemaKStream.class)); assertEquals(SchemaKStream.Type.JOIN, joinedKStream.type); - assertEquals(joinSchema, joinedKStream.schema); + assertEquals(joinSchema, joinedKStream.getSchema()); assertThat(joinedKStream.getKeyField(), is(validJoinKeyField)); assertEquals(Arrays.asList(initialSchemaKStream, secondSchemaKStream), joinedKStream.sourceSchemaKStreams); @@ -781,6 +1047,7 @@ public void shouldPerformStreamToTableLeftJoin() { schemaKTable, joinSchema, validJoinKeyField, + valueFormat, leftSerde, childContextStacker); @@ -792,7 +1059,7 @@ public void shouldPerformStreamToTableLeftJoin() { same(joined)); assertThat(joinedKStream, instanceOf(SchemaKStream.class)); assertEquals(SchemaKStream.Type.JOIN, joinedKStream.type); - assertEquals(joinSchema, joinedKStream.schema); + assertEquals(joinSchema, joinedKStream.getSchema()); assertThat(joinedKStream.getKeyField(), is(validJoinKeyField)); assertEquals(Arrays.asList(initialSchemaKStream, schemaKTable), joinedKStream.sourceSchemaKStreams); @@ -817,6 +1084,7 @@ public void shouldPerformStreamToTableInnerJoin() { schemaKTable, joinSchema, validJoinKeyField, + valueFormat, leftSerde, childContextStacker); @@ -830,28 +1098,91 @@ public void shouldPerformStreamToTableInnerJoin() { assertThat(joinedKStream, instanceOf(SchemaKStream.class)); assertEquals(SchemaKStream.Type.JOIN, joinedKStream.type); - assertEquals(joinSchema, joinedKStream.schema); + assertEquals(joinSchema, joinedKStream.getSchema()); assertThat(joinedKStream.getKeyField(), is(validJoinKeyField)); assertEquals(Arrays.asList(initialSchemaKStream, schemaKTable), joinedKStream.sourceSchemaKStreams); } + @FunctionalInterface + private interface StreamTableJoin { + SchemaKStream join( + SchemaKTable other, + LogicalSchema joinSchema, + KeyField keyField, + ValueFormat leftFormat, + Serde leftSerde, + QueryContext.Stacker contextStacker); + } + + @Test + public void shouldBuildStepForStreamTableJoin() { + // Given: + final SchemaKStream initialSchemaKStream = + buildSchemaKStreamForJoin(ksqlStream, mockKStream, mockGroupedFactory, mockJoinedFactory); + when( + mockKStream.leftJoin( + any(KTable.class), + any(SchemaKStream.KsqlValueJoiner.class), + any(Joined.class)) + ).thenReturn(mockKStream); + when(mockKStream.join( + any(KTable.class), + any(SchemaKStream.KsqlValueJoiner.class), + any(Joined.class)) + ).thenReturn(mockKStream); + + final List> cases = ImmutableList.of( + Pair.of(JoinType.LEFT, initialSchemaKStream::leftJoin), + Pair.of(JoinType.INNER, initialSchemaKStream::join) + ); + + for (final Pair testcase : cases) { + final SchemaKStream joinedKStream = testcase.right.join( + schemaKTable, + joinSchema, + validJoinKeyField, + valueFormat, + leftSerde, + childContextStacker + ); + + // Then: + assertThat( + joinedKStream.getSourceStep(), + equalTo( + ExecutionStepFactory.streamTableJoin( + childContextStacker, + testcase.left, + Formats.of(keyFormat, valueFormat, SerdeOption.none()), + initialSchemaKStream.getSourceStep(), + schemaKTable.getSourceTableStep(), + joinSchema + ) + ) + ); + } + } @Test public void shouldSummarizeExecutionPlanCorrectly() { // Given: + when(sourceProperties.getSchema()).thenReturn(simpleSchema); final SchemaKStream parentSchemaKStream = mock(SchemaKStream.class); when(parentSchemaKStream.getExecutionPlan(anyString())) .thenReturn("parent plan"); + when(sourceProperties.getQueryContext()).thenReturn( + queryContext.push("source").getQueryContext()); final SchemaKStream schemaKtream = new SchemaKStream( mock(KStream.class), - simpleSchema, + sourceStep, + keyFormat, keySerde, KeyField.of("key", simpleSchema.findValueField("key").get()), ImmutableList.of(parentSchemaKStream), Type.SOURCE, ksqlConfig, - functionRegistry, - queryContext.push("source").getQueryContext()); + functionRegistry + ); // When/Then: assertThat(schemaKtream.getExecutionPlan(""), equalTo( @@ -863,16 +1194,20 @@ public void shouldSummarizeExecutionPlanCorrectly() { @Test public void shouldSummarizeExecutionPlanCorrectlyForRoot() { // Given: + when(sourceProperties.getSchema()).thenReturn(simpleSchema); + when(sourceProperties.getQueryContext()).thenReturn( + queryContext.push("source").getQueryContext()); final SchemaKStream schemaKtream = new SchemaKStream( mock(KStream.class), - simpleSchema, + sourceStep, + keyFormat, keySerde, KeyField.of("key", simpleSchema.findValueField("key").get()), Collections.emptyList(), Type.SOURCE, ksqlConfig, - functionRegistry, - queryContext.push("source").getQueryContext()); + functionRegistry + ); // When/Then: assertThat(schemaKtream.getExecutionPlan(""), equalTo( @@ -889,16 +1224,20 @@ public void shouldSummarizeExecutionPlanCorrectlyWhenMultipleParents() { final SchemaKStream parentSchemaKStream2 = mock(SchemaKStream.class); when(parentSchemaKStream2.getExecutionPlan(anyString())) .thenReturn("parent 2 plan"); + when(sourceProperties.getSchema()).thenReturn(simpleSchema); + when(sourceProperties.getQueryContext()).thenReturn( + queryContext.push("source").getQueryContext()); final SchemaKStream schemaKtream = new SchemaKStream( mock(KStream.class), - simpleSchema, + sourceStep, + keyFormat, keySerde, KeyField.of("key", simpleSchema.findValueField("key").get()), ImmutableList.of(parentSchemaKStream1, parentSchemaKStream2), Type.SOURCE, ksqlConfig, - functionRegistry, - queryContext.push("source").getQueryContext()); + functionRegistry + ); // When/Then: assertThat(schemaKtream.getExecutionPlan(""), equalTo( @@ -927,22 +1266,33 @@ private void verifyCreateJoined(final Serde rightSerde) { ); } + private void givenSourcePropertiesWithSchema(final LogicalSchema schema) { + reset(sourceProperties); + when(sourceProperties.getSchema()).thenReturn(schema); + when(sourceProperties.withQueryContext(any())).thenAnswer( + i -> new DefaultExecutionStepProperties(schema, (QueryContext) i.getArguments()[0]) + ); + } + private SchemaKStream buildSchemaKStream( final LogicalSchema schema, final KeyField keyField, final KStream kStream, final StreamsFactories streamsFactories) { + givenSourcePropertiesWithSchema(schema); return new SchemaKStream( kStream, - schema, + sourceStep, + sourceProperties, + keyFormat, keySerde, keyField, new ArrayList<>(), Type.SOURCE, ksqlConfig, functionRegistry, - streamsFactories, - parentContext); + streamsFactories + ); } private void givenInitialSchemaKStreamUsesMocks() { @@ -1007,16 +1357,18 @@ private PlanNode givenInitialKStreamOf(final String selectQuery) { metaStore ); + givenSourcePropertiesWithSchema(logicalPlan.getTheSourceNode().getSchema()); initialSchemaKStream = new SchemaKStream( kStream, - logicalPlan.getTheSourceNode().getSchema(), + sourceStep, + keyFormat, keySerde, logicalPlan.getTheSourceNode().getKeyField(), new ArrayList<>(), SchemaKStream.Type.SOURCE, ksqlConfig, - functionRegistry, - queryContext.push("source").getQueryContext()); + functionRegistry + ); rowSerde = GenericRowSerDe.from( FormatInfo.of(Format.JSON, Optional.empty()), diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKTableTest.java b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKTableTest.java index 83e6746a7b47..e479d2bac6c0 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKTableTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKTableTest.java @@ -37,12 +37,19 @@ import com.google.common.collect.ImmutableList; import io.confluent.kafka.schemaregistry.client.MockSchemaRegistryClient; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.expression.tree.DereferenceExpression; import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.QualifiedName; import io.confluent.ksql.execution.expression.tree.QualifiedNameReference; +import io.confluent.ksql.execution.plan.DefaultExecutionStepProperties; +import io.confluent.ksql.execution.plan.ExecutionStep; +import io.confluent.ksql.execution.plan.ExecutionStepProperties; +import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.JoinType; import io.confluent.ksql.execution.plan.SelectExpression; +import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.function.InternalFunctionRegistry; import io.confluent.ksql.logging.processing.ProcessingLogContext; import io.confluent.ksql.metastore.MetaStore; @@ -51,6 +58,7 @@ import io.confluent.ksql.metastore.model.KsqlTable; import io.confluent.ksql.metastore.model.KsqlTopic; import io.confluent.ksql.metastore.model.MetaStoreMatchers.KeyFieldMatchers; +import io.confluent.ksql.planner.LogicalPlanNode; import io.confluent.ksql.planner.plan.FilterNode; import io.confluent.ksql.planner.plan.PlanNode; import io.confluent.ksql.planner.plan.ProjectNode; @@ -62,7 +70,10 @@ import io.confluent.ksql.serde.Format; import io.confluent.ksql.serde.FormatInfo; import io.confluent.ksql.serde.GenericRowSerDe; +import io.confluent.ksql.serde.KeyFormat; import io.confluent.ksql.serde.KeySerde; +import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.serde.ValueFormat; import io.confluent.ksql.streams.GroupedFactory; import io.confluent.ksql.streams.JoinedFactory; import io.confluent.ksql.streams.MaterializedFactory; @@ -72,6 +83,7 @@ import io.confluent.ksql.testutils.AnalysisTestUtil; import io.confluent.ksql.util.KsqlConfig; import io.confluent.ksql.util.MetaStoreFixture; +import io.confluent.ksql.util.Pair; import io.confluent.ksql.util.SchemaUtil; import java.util.ArrayList; import java.util.Arrays; @@ -97,6 +109,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; @SuppressWarnings("unchecked") @@ -125,7 +138,6 @@ public class SchemaKTableTest { private LogicalSchema joinSchema; private final QueryContext.Stacker queryContext = new QueryContext.Stacker(new QueryId("query")).push("node"); - private final QueryContext parentContext = queryContext.push("parent").getQueryContext(); private final QueryContext.Stacker childContextStacker = queryContext.push("child"); private final ProcessingLogContext processingLogContext = ProcessingLogContext.create(); private Serde rowSerde; @@ -133,6 +145,8 @@ public class SchemaKTableTest { new QualifiedNameReference(QualifiedName.of("TEST2")), "COL1"); private static final Expression TEST_2_COL_2 = new DereferenceExpression( new QualifiedNameReference(QualifiedName.of("TEST2")), "COL2"); + private static final KeyFormat keyFormat = KeyFormat.nonWindowed(FormatInfo.of(Format.JSON)); + private static final ValueFormat valueFormat = ValueFormat.of(FormatInfo.of(Format.JSON)); @Mock private KeySerde keySerde; @@ -167,13 +181,22 @@ public void init() { when(keySerde.rebind(any(PersistenceSchema.class))).thenReturn(reboundKeySerde); } + private ExecutionStep buildSourceStep(final LogicalSchema schema) { + final ExecutionStep sourceStep = Mockito.mock(ExecutionStep.class); + when(sourceStep.getProperties()).thenReturn( + new DefaultExecutionStepProperties(schema, queryContext.getQueryContext())); + return sourceStep; + } + private SchemaKTable buildSchemaKTable( final LogicalSchema schema, final KeyField keyField, final KTable kTable, final GroupedFactory groupedFactory) { return new SchemaKTable( - kTable, schema, + kTable, + buildSourceStep(schema), + keyFormat, keySerde, keyField, new ArrayList<>(), @@ -183,8 +206,22 @@ private SchemaKTable buildSchemaKTable( new StreamsFactories( groupedFactory, JoinedFactory.create(ksqlConfig), - MaterializedFactory.create(ksqlConfig)), - parentContext); + MaterializedFactory.create(ksqlConfig)) + ); + } + + private SchemaKTable buildSchemaKTableFromPlan(final PlanNode logicalPlan) { + return new SchemaKTable( + kTable, + buildSourceStep(logicalPlan.getTheSourceNode().getSchema()), + keyFormat, + keySerde, + logicalPlan.getTheSourceNode().getKeyField(), + new ArrayList<>(), + SchemaKStream.Type.SOURCE, + ksqlConfig, + functionRegistry + ); } private SchemaKTable buildSchemaKTable( @@ -227,16 +264,7 @@ public void testSelectSchemaKStream() { final String selectQuery = "SELECT col0, col2, col3 FROM test2 WHERE col0 > 100;"; final PlanNode logicalPlan = buildLogicalPlan(selectQuery); final ProjectNode projectNode = (ProjectNode) logicalPlan.getSources().get(0); - - initialSchemaKTable = new SchemaKTable<>( - kTable, logicalPlan.getTheSourceNode().getSchema(), - keySerde, - logicalPlan.getTheSourceNode().getKeyField(), - new ArrayList<>(), - SchemaKStream.Type.SOURCE, - ksqlConfig, - functionRegistry, - parentContext); + initialSchemaKTable = buildSchemaKTableFromPlan(logicalPlan); // When: final SchemaKTable projectedSchemaKStream = initialSchemaKTable.select( @@ -255,21 +283,42 @@ public void testSelectSchemaKStream() { assertThat(projectedSchemaKStream.getSourceSchemaKStreams().get(0), is(initialSchemaKTable)); } + @Test + public void shouldBuildStepForSelect() { + // Given: + final String selectQuery = "SELECT col0, col2, col3 FROM test2 WHERE col0 > 100;"; + final PlanNode logicalPlan = buildLogicalPlan(selectQuery); + final ProjectNode projectNode = (ProjectNode) logicalPlan.getSources().get(0); + initialSchemaKTable = buildSchemaKTableFromPlan(logicalPlan); + + // When: + final SchemaKTable projectedSchemaKStream = initialSchemaKTable.select( + projectNode.getProjectSelectExpressions(), + childContextStacker, + processingLogContext + ); + + // Then: + assertThat( + projectedSchemaKStream.getSourceTableStep(), + equalTo( + ExecutionStepFactory.tableMapValues( + childContextStacker, + initialSchemaKTable.getSourceTableStep(), + projectedSchemaKStream.getSchema(), + projectNode.getProjectSelectExpressions() + ) + ) + ); + } + @Test public void testSelectWithExpression() { // Given: final String selectQuery = "SELECT col0, LEN(UCASE(col2)), col3*3+5 FROM test2 WHERE col0 > 100;"; final PlanNode logicalPlan = buildLogicalPlan(selectQuery); final ProjectNode projectNode = (ProjectNode) logicalPlan.getSources().get(0); - initialSchemaKTable = new SchemaKTable<>( - kTable, logicalPlan.getTheSourceNode().getSchema(), - keySerde, - logicalPlan.getTheSourceNode().getKeyField(), - new ArrayList<>(), - SchemaKStream.Type.SOURCE, - ksqlConfig, - functionRegistry, - parentContext); + initialSchemaKTable = buildSchemaKTableFromPlan(logicalPlan); // When: final SchemaKTable projectedSchemaKStream = initialSchemaKTable.select( @@ -294,16 +343,7 @@ public void testFilter() { final String selectQuery = "SELECT col0, col2, col3 FROM test2 WHERE col0 > 100;"; final PlanNode logicalPlan = buildLogicalPlan(selectQuery); final FilterNode filterNode = (FilterNode) logicalPlan.getSources().get(0).getSources().get(0); - - initialSchemaKTable = new SchemaKTable<>( - kTable, logicalPlan.getTheSourceNode().getSchema(), - keySerde, - logicalPlan.getTheSourceNode().getKeyField(), - new ArrayList<>(), - SchemaKStream.Type.SOURCE, - ksqlConfig, - functionRegistry, - parentContext); + initialSchemaKTable = buildSchemaKTableFromPlan(logicalPlan); // When: final SchemaKTable filteredSchemaKStream = initialSchemaKTable.filter( @@ -326,32 +366,46 @@ public void testFilter() { assertThat(filteredSchemaKStream.getSourceSchemaKStreams().get(0), is(initialSchemaKTable)); } + @Test + public void shouldBuildStepForFilter() { + // Given: + final String selectQuery = "SELECT col0, col2, col3 FROM test2 WHERE col0 > 100;"; + final PlanNode logicalPlan = buildLogicalPlan(selectQuery); + final FilterNode filterNode = (FilterNode) logicalPlan.getSources().get(0).getSources().get(0); + initialSchemaKTable = buildSchemaKTableFromPlan(logicalPlan); + + // When: + final SchemaKTable filteredSchemaKStream = initialSchemaKTable.filter( + filterNode.getPredicate(), + childContextStacker, + processingLogContext + ); + + // Then: + assertThat( + filteredSchemaKStream.getSourceTableStep(), + equalTo( + ExecutionStepFactory.tableFilter( + childContextStacker, + initialSchemaKTable.getSourceTableStep(), + filterNode.getPredicate() + ) + ) + ); + } + @Test public void testGroupBy() { // Given: final String selectQuery = "SELECT col0, col1, col2 FROM test2;"; final PlanNode logicalPlan = buildLogicalPlan(selectQuery); - initialSchemaKTable = new SchemaKTable<>( - kTable, logicalPlan.getTheSourceNode().getSchema(), - keySerde, - logicalPlan.getTheSourceNode().getKeyField(), - new ArrayList<>(), - SchemaKStream.Type.SOURCE, - ksqlConfig, - functionRegistry, - parentContext); - - final Serde rowSerde = GenericRowSerDe.from( - FormatInfo.of(Format.JSON, Optional.empty()), - PersistenceSchema.from(initialSchemaKTable.getSchema().valueSchema(), false), - null, - () -> null, - "test", - processingLogContext); + initialSchemaKTable = buildSchemaKTableFromPlan(logicalPlan); + final Serde rowSerde = mock(Serde.class); final List groupByExpressions = Arrays.asList(TEST_2_COL_2, TEST_2_COL_1); // When: final SchemaKGroupedStream groupedSchemaKTable = initialSchemaKTable.groupBy( + valueFormat, rowSerde, groupByExpressions, childContextStacker); @@ -363,6 +417,36 @@ public void testGroupBy() { is(Optional.of("TEST2.COL2|+|TEST2.COL1"))); } + @Test + public void shouldBuildStepForGroupBy() { + // Given: + final String selectQuery = "SELECT col0, col1, col2 FROM test2;"; + final PlanNode logicalPlan = buildLogicalPlan(selectQuery); + initialSchemaKTable = buildSchemaKTableFromPlan(logicalPlan); + final Serde rowSerde = mock(Serde.class); + final List groupByExpressions = Arrays.asList(TEST_2_COL_2, TEST_2_COL_1); + + // When: + final SchemaKGroupedStream groupedSchemaKTable = initialSchemaKTable.groupBy( + valueFormat, + rowSerde, + groupByExpressions, + childContextStacker); + + // Then: + assertThat( + ((SchemaKGroupedTable) groupedSchemaKTable).getSourceTableStep(), + equalTo( + ExecutionStepFactory.tableGroupBy( + childContextStacker, + initialSchemaKTable.getSourceTableStep(), + Formats.of(initialSchemaKTable.keyFormat, valueFormat, SerdeOption.none()), + groupByExpressions + ) + ) + ); + } + @Test public void shouldUseOpNameForGrouped() { // Given: @@ -383,7 +467,7 @@ public void shouldUseOpNameForGrouped() { final SchemaKTable schemaKTable = buildSchemaKTable(ksqlTable, mockKTable, groupedFactory); // When: - schemaKTable.groupBy(valSerde, groupByExpressions, childContextStacker); + schemaKTable.groupBy(valueFormat, valSerde, groupByExpressions, childContextStacker); // Then: verify(mockKTable, groupedFactory); @@ -405,15 +489,17 @@ public void shouldGroupKeysCorrectly() { // Build our test object from the mocks final String selectQuery = "SELECT col0, col1, col2 FROM test2;"; final PlanNode logicalPlan = buildLogicalPlan(selectQuery); - initialSchemaKTable = new SchemaKTable<>( - mockKTable, logicalPlan.getTheSourceNode().getSchema(), + initialSchemaKTable = new SchemaKTable( + mockKTable, + buildSourceStep(logicalPlan.getTheSourceNode().getSchema()), + keyFormat, keySerde, logicalPlan.getTheSourceNode().getKeyField(), new ArrayList<>(), SchemaKStream.Type.SOURCE, ksqlConfig, - functionRegistry, - parentContext); + functionRegistry + ); final List groupByExpressions = Arrays.asList(TEST_2_COL_2, TEST_2_COL_1); final Serde rowSerde = GenericRowSerDe.from( @@ -425,7 +511,7 @@ public void shouldGroupKeysCorrectly() { processingLogContext); // Call groupBy and extract the captured mapper - initialSchemaKTable.groupBy(rowSerde, groupByExpressions, childContextStacker); + initialSchemaKTable.groupBy(valueFormat, rowSerde, groupByExpressions, childContextStacker); verify(mockKTable, mockKGroupedTable); final KeyValueMapper keySelector = capturedKeySelector.getValue(); final GenericRow value = new GenericRow(Arrays.asList("key", 0, 100, "foo", "bar")); @@ -457,7 +543,7 @@ public void shouldPerformTableToTableLeftJoin() { assertThat(joinedKStream, instanceOf(SchemaKTable.class)); assertEquals(SchemaKStream.Type.JOIN, joinedKStream.type); - assertEquals(joinSchema, joinedKStream.schema); + assertEquals(joinSchema, joinedKStream.getSchema()); assertThat(joinedKStream.getKeyField(), is(validKeyField)); assertEquals(Arrays.asList(firstSchemaKTable, secondSchemaKTable), joinedKStream.sourceSchemaKStreams); @@ -481,7 +567,7 @@ public void shouldPerformTableToTableInnerJoin() { assertThat(joinedKStream, instanceOf(SchemaKTable.class)); assertEquals(SchemaKStream.Type.JOIN, joinedKStream.type); - assertEquals(joinSchema, joinedKStream.schema); + assertEquals(joinSchema, joinedKStream.getSchema()); assertThat(joinedKStream.getKeyField(), is(validKeyField)); assertEquals(Arrays.asList(firstSchemaKTable, secondSchemaKTable), joinedKStream.sourceSchemaKStreams); @@ -505,12 +591,61 @@ public void shouldPerformTableToTableOuterJoin() { assertThat(joinedKStream, instanceOf(SchemaKTable.class)); assertEquals(SchemaKStream.Type.JOIN, joinedKStream.type); - assertEquals(joinSchema, joinedKStream.schema); + assertEquals(joinSchema, joinedKStream.getSchema()); assertThat(joinedKStream.getKeyField(), is(validKeyField)); assertEquals(Arrays.asList(firstSchemaKTable, secondSchemaKTable), joinedKStream.sourceSchemaKStreams); } + interface Join { + SchemaKTable join( + SchemaKTable schemaKTable, + LogicalSchema joinSchema, + KeyField keyField, + QueryContext.Stacker contextStacker + ); + } + + @Test + public void shouldBuildStepForTableTableJoin() { + final KTable resultTable = EasyMock.niceMock(KTable.class); + expect(mockKTable.outerJoin( + eq(secondSchemaKTable.getKtable()), + anyObject(SchemaKStream.KsqlValueJoiner.class)) + ).andReturn(resultTable); + expect(mockKTable.join( + eq(secondSchemaKTable.getKtable()), + anyObject(SchemaKStream.KsqlValueJoiner.class)) + ).andReturn(resultTable); + expect(mockKTable.leftJoin( + eq(secondSchemaKTable.getKtable()), + anyObject(SchemaKStream.KsqlValueJoiner.class)) + ).andReturn(resultTable); + replay(mockKTable); + + final List> cases = ImmutableList.of( + Pair.of(JoinType.LEFT, firstSchemaKTable::leftJoin), + Pair.of(JoinType.INNER, firstSchemaKTable::join), + Pair.of(JoinType.OUTER, firstSchemaKTable::outerJoin) + ); + for (final Pair testCase : cases) { + final SchemaKTable result = + testCase.right.join(secondSchemaKTable, joinSchema, validKeyField, childContextStacker); + assertThat( + result.getSourceTableStep(), + equalTo( + ExecutionStepFactory.tableTableJoin( + childContextStacker, + testCase.left, + firstSchemaKTable.getSourceTableStep(), + secondSchemaKTable.getSourceTableStep(), + joinSchema + ) + ) + ); + } + } + @Test public void shouldUpdateKeyIfRenamed() { // Given: @@ -622,7 +757,7 @@ public void shouldSetKeyOnGroupBySingleExpressionThatIsInProjection() { // When: final SchemaKGroupedStream result = initialSchemaKTable - .groupBy(rowSerde, groupByExprs, childContextStacker); + .groupBy(valueFormat, rowSerde, groupByExprs, childContextStacker); // Then: assertThat(result.getKeyField(), @@ -653,15 +788,17 @@ private List givenInitialKTableOf(final String selectQuery) { metaStore ); - initialSchemaKTable = new SchemaKTable<>( - kTable, logicalPlan.getTheSourceNode().getSchema(), + initialSchemaKTable = new SchemaKTable( + kTable, + buildSourceStep(logicalPlan.getTheSourceNode().getSchema()), + keyFormat, keySerde, logicalPlan.getTheSourceNode().getKeyField(), new ArrayList<>(), SchemaKStream.Type.SOURCE, ksqlConfig, - functionRegistry, - parentContext); + functionRegistry + ); rowSerde = GenericRowSerDe.from( FormatInfo.of(Format.JSON, Optional.empty()), diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/DefaultExecutionStepProperties.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/DefaultExecutionStepProperties.java index 44d2d88fb848..e90698cfbdd0 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/DefaultExecutionStepProperties.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/DefaultExecutionStepProperties.java @@ -21,15 +21,12 @@ @Immutable public class DefaultExecutionStepProperties implements ExecutionStepProperties { - private final String id; private final QueryContext queryContext; private final LogicalSchema schema; public DefaultExecutionStepProperties( - final String id, final LogicalSchema schema, final QueryContext queryContext) { - this.id = Objects.requireNonNull(id, "id"); this.queryContext = Objects.requireNonNull(queryContext, "queryContext"); this.schema = Objects.requireNonNull(schema, "schema"); } @@ -41,7 +38,7 @@ public LogicalSchema getSchema() { @Override public String getId() { - return id; + return queryContext.toString(); } @Override @@ -49,6 +46,11 @@ public QueryContext getQueryContext() { return queryContext; } + @Override + public ExecutionStepProperties withQueryContext(final QueryContext queryContext) { + return new DefaultExecutionStepProperties(schema, queryContext); + } + @Override public boolean equals(final Object o) { if (this == o) { @@ -58,19 +60,19 @@ public boolean equals(final Object o) { return false; } final DefaultExecutionStepProperties that = (DefaultExecutionStepProperties) o; - return Objects.equals(id, that.id) + return Objects.equals(queryContext, that.queryContext) && Objects.equals(schema, that.schema); } @Override public int hashCode() { - return Objects.hash(id, schema); + return Objects.hash(queryContext, schema); } @Override public String toString() { return "ExecutionStepProperties{" - + "id='" + id + '\'' + + "queryContext='" + queryContext.toString() + '\'' + ", schema=" + schema + '}'; } diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/ExecutionStepProperties.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/ExecutionStepProperties.java index e900e8bcb712..e24c9f142026 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/ExecutionStepProperties.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/ExecutionStepProperties.java @@ -23,4 +23,6 @@ public interface ExecutionStepProperties { String getId(); QueryContext getQueryContext(); + + ExecutionStepProperties withQueryContext(QueryContext queryContext); } diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamAggregate.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamAggregate.java index 0adbd2c7dadc..1a47b2ffe9ab 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamAggregate.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamAggregate.java @@ -16,10 +16,9 @@ import com.google.errorprone.annotations.Immutable; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; -import io.confluent.ksql.function.KsqlAggregateFunction; +import io.confluent.ksql.execution.expression.tree.FunctionCall; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Objects; @Immutable @@ -27,20 +26,20 @@ public class StreamAggregate implements ExecutionStep { private final ExecutionStepProperties properties; private final ExecutionStep source; private final Formats formats; - private final Map aggValToFunctionMap; - private final Map aggValToValColumnMap; + private final int nonFuncColumnCount; + private final List aggregations; public StreamAggregate( final ExecutionStepProperties properties, final ExecutionStep source, final Formats formats, - final Map aggValToFunctionMap, - final Map aggValToValColumnMap) { + final int nonFuncColumnCount, + final List aggregations) { this.properties = Objects.requireNonNull(properties, "properties"); this.source = Objects.requireNonNull(source, "source"); this.formats = Objects.requireNonNull(formats, "formats"); - this.aggValToFunctionMap = Objects.requireNonNull(aggValToFunctionMap); - this.aggValToValColumnMap = Objects.requireNonNull(aggValToValColumnMap); + this.nonFuncColumnCount = nonFuncColumnCount; + this.aggregations = Objects.requireNonNull(aggregations); } @Override @@ -70,13 +69,13 @@ public boolean equals(final Object o) { return Objects.equals(properties, that.properties) && Objects.equals(source, that.source) && Objects.equals(formats, that.formats) - && Objects.equals(aggValToFunctionMap, that.aggValToFunctionMap) - && Objects.equals(aggValToValColumnMap, that.aggValToValColumnMap); + && Objects.equals(aggregations, that.aggregations) + && nonFuncColumnCount == that.nonFuncColumnCount; } @Override public int hashCode() { - return Objects.hash(properties, source, formats, aggValToFunctionMap, aggValToValColumnMap); + return Objects.hash(properties, source, formats, aggregations, nonFuncColumnCount); } } diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamSelectKey.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamSelectKey.java index e6d3db57b169..fb87dd91a581 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamSelectKey.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamSelectKey.java @@ -24,14 +24,17 @@ public class StreamSelectKey implements ExecutionStep { private final ExecutionStepProperties properties; private final ExecutionStep source; + private final String fieldName; private final boolean updateRowKey; public StreamSelectKey( final ExecutionStepProperties properties, final ExecutionStep source, + final String fieldName, final boolean updateRowKey) { this.properties = Objects.requireNonNull(properties, "properties"); this.source = Objects.requireNonNull(source, "source"); + this.fieldName = Objects.requireNonNull(fieldName, "fieldName"); this.updateRowKey = updateRowKey; } diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamStreamJoin.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamStreamJoin.java index 539cf6660c21..2222b9d47f1c 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamStreamJoin.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/StreamStreamJoin.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import java.time.Duration; import java.util.List; import java.util.Objects; @@ -25,21 +26,30 @@ public class StreamStreamJoin implements ExecutionStep { private final ExecutionStepProperties properties; private final JoinType joinType; - private final Formats formats; + private final Formats leftFormats; + private final Formats rightFormats; private final ExecutionStep left; private final ExecutionStep right; + private final Duration before; + private final Duration after; public StreamStreamJoin( final ExecutionStepProperties properties, final JoinType joinType, - final Formats formats, + final Formats leftFormats, + final Formats rightFormats, final ExecutionStep left, - final ExecutionStep right) { + final ExecutionStep right, + final Duration before, + final Duration after) { this.properties = Objects.requireNonNull(properties, "properties"); - this.formats = Objects.requireNonNull(formats, "formats"); + this.leftFormats = Objects.requireNonNull(leftFormats, "formats"); + this.rightFormats = Objects.requireNonNull(rightFormats, "rightFormats"); this.joinType = Objects.requireNonNull(joinType, "joinType"); this.left = Objects.requireNonNull(left, "left"); this.right = Objects.requireNonNull(right, "right"); + this.before = Objects.requireNonNull(before, "before"); + this.after = Objects.requireNonNull(after, "after"); } @Override @@ -57,6 +67,7 @@ public S build(final KsqlQueryBuilder streamsBuilder) { throw new UnsupportedOperationException(); } + // CHECKSTYLE_RULES.OFF: CyclomaticComplexity @Override public boolean equals(final Object o) { if (this == o) { @@ -68,14 +79,26 @@ public boolean equals(final Object o) { final StreamStreamJoin that = (StreamStreamJoin) o; return Objects.equals(properties, that.properties) && joinType == that.joinType - && Objects.equals(formats, that.formats) + && Objects.equals(leftFormats, that.leftFormats) + && Objects.equals(rightFormats, that.rightFormats) && Objects.equals(left, that.left) - && Objects.equals(right, that.right); + && Objects.equals(right, that.right) + && Objects.equals(before, that.before) + && Objects.equals(after, that.after); } + // CHECKSTYLE_RULES.ON: CyclomaticComplexity @Override public int hashCode() { - - return Objects.hash(properties, joinType, formats, left, right); + return Objects.hash( + properties, + joinType, + leftFormats, + rightFormats, + left, + right, + before, + after + ); } } diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableAggregate.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableAggregate.java index 6c241e3432b2..768d38d7948b 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableAggregate.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableAggregate.java @@ -16,10 +16,9 @@ import com.google.errorprone.annotations.Immutable; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; -import io.confluent.ksql.function.KsqlAggregateFunction; +import io.confluent.ksql.execution.expression.tree.FunctionCall; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Objects; @Immutable @@ -27,20 +26,20 @@ public class TableAggregate implements ExecutionStep { private final ExecutionStepProperties properties; private final ExecutionStep source; private final Formats formats; - private final Map indexToFunctionMap; - private final Map indexToValColumnMap; + private final int nonFuncColumnCount; + private final List aggregations; public TableAggregate( final ExecutionStepProperties properties, final ExecutionStep source, final Formats formats, - final Map indexToFunctionMap, - final Map indexToValColumnMap) { + final int nonFuncColumnCount, + final List aggregations) { this.properties = Objects.requireNonNull(properties, "properties"); this.source = Objects.requireNonNull(source, "source"); this.formats = Objects.requireNonNull(formats, "formats"); - this.indexToFunctionMap = Objects.requireNonNull(indexToFunctionMap, "indexToFunctionMap"); - this.indexToValColumnMap = Objects.requireNonNull(indexToValColumnMap, "indexToValColumnMap"); + this.nonFuncColumnCount = nonFuncColumnCount; + this.aggregations = Objects.requireNonNull(aggregations, "aggValToFunctionMap"); } @Override @@ -70,13 +69,13 @@ public boolean equals(final Object o) { return Objects.equals(properties, that.properties) && Objects.equals(source, that.source) && Objects.equals(formats, that.formats) - && Objects.equals(indexToFunctionMap, that.indexToFunctionMap) - && Objects.equals(indexToValColumnMap, that.indexToValColumnMap); + && nonFuncColumnCount == that.nonFuncColumnCount + && Objects.equals(aggregations, that.aggregations); } @Override public int hashCode() { - return Objects.hash(properties, source, formats, indexToFunctionMap, indexToValColumnMap); + return Objects.hash(properties, source, formats, nonFuncColumnCount, aggregations); } } diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableTableJoin.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableTableJoin.java index 25970036c513..d5de852dd14c 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableTableJoin.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/plan/TableTableJoin.java @@ -24,19 +24,16 @@ public class TableTableJoin implements ExecutionStep { private final ExecutionStepProperties properties; private final JoinType joinType; - private final Formats formats; private final ExecutionStep left; private final ExecutionStep right; public TableTableJoin( final ExecutionStepProperties properties, final JoinType joinType, - final Formats formats, final ExecutionStep left, final ExecutionStep right) { this.properties = Objects.requireNonNull(properties, "properties"); this.joinType = Objects.requireNonNull(joinType, "joinType"); - this.formats = Objects.requireNonNull(formats, "formats"); this.left = Objects.requireNonNull(left, "left"); this.right = Objects.requireNonNull(right, "right"); } @@ -67,7 +64,6 @@ public boolean equals(final Object o) { final TableTableJoin that = (TableTableJoin) o; return Objects.equals(properties, that.properties) && joinType == that.joinType - && Objects.equals(formats, that.formats) && Objects.equals(left, that.left) && Objects.equals(right, that.right); } @@ -75,6 +71,6 @@ public boolean equals(final Object o) { @Override public int hashCode() { - return Objects.hash(properties, joinType, formats, left, right); + return Objects.hash(properties, joinType, left, right); } } diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java index a3a708385a1d..5968418b03f5 100644 --- a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/ExecutionStepFactory.java @@ -16,23 +16,52 @@ import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.context.QueryContext; +import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.plan.DefaultExecutionStepProperties; +import io.confluent.ksql.execution.plan.ExecutionStep; import io.confluent.ksql.execution.plan.Formats; +import io.confluent.ksql.execution.plan.JoinType; import io.confluent.ksql.execution.plan.LogicalSchemaWithMetaAndKeyFields; +import io.confluent.ksql.execution.plan.SelectExpression; +import io.confluent.ksql.execution.plan.StreamAggregate; +import io.confluent.ksql.execution.plan.StreamFilter; +import io.confluent.ksql.execution.plan.StreamGroupBy; +import io.confluent.ksql.execution.plan.StreamMapValues; +import io.confluent.ksql.execution.plan.StreamSelectKey; +import io.confluent.ksql.execution.plan.StreamSink; import io.confluent.ksql.execution.plan.StreamSource; +import io.confluent.ksql.execution.plan.StreamStreamJoin; +import io.confluent.ksql.execution.plan.StreamTableJoin; +import io.confluent.ksql.execution.plan.StreamToTable; +import io.confluent.ksql.execution.plan.TableAggregate; +import io.confluent.ksql.execution.plan.TableFilter; +import io.confluent.ksql.execution.plan.TableGroupBy; +import io.confluent.ksql.execution.plan.TableMapValues; +import io.confluent.ksql.execution.plan.TableSink; +import io.confluent.ksql.execution.plan.TableTableJoin; +import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.util.timestamp.TimestampExtractionPolicy; +import java.time.Duration; +import java.util.List; import java.util.Optional; import org.apache.kafka.connect.data.Struct; import org.apache.kafka.streams.Topology.AutoOffsetReset; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KGroupedTable; import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; import org.apache.kafka.streams.kstream.Windowed; +// CHECKSTYLE_RULES.OFF: ClassDataAbstractionCoupling public final class ExecutionStepFactory { + // CHECKSTYLE_RULES.ON: ClassDataAbstractionCoupling private ExecutionStepFactory() { } public static StreamSource, GenericRow>> streamSourceWindowed( - final QueryContext queryContext, + final QueryContext.Stacker stacker, final LogicalSchemaWithMetaAndKeyFields schema, final String topicName, final Formats formats, @@ -40,9 +69,9 @@ public static StreamSource, GenericRow>> streamSourceWi final int timestampIndex, final Optional offsetReset ) { + final QueryContext queryContext = stacker.getQueryContext(); return new StreamSource<>( new DefaultExecutionStepProperties( - queryContext.toString(), schema.getSchema(), queryContext), topicName, @@ -56,7 +85,7 @@ public static StreamSource, GenericRow>> streamSourceWi } public static StreamSource> streamSource( - final QueryContext queryContext, + final QueryContext.Stacker stacker, final LogicalSchemaWithMetaAndKeyFields schema, final String topicName, final Formats formats, @@ -64,9 +93,9 @@ public static StreamSource> streamSource( final int timestampIndex, final Optional offsetReset ) { + final QueryContext queryContext = stacker.getQueryContext(); return new StreamSource<>( new DefaultExecutionStepProperties( - queryContext.toString(), schema.getSchema(), queryContext), topicName, @@ -78,4 +107,249 @@ public static StreamSource> streamSource( StreamSourceBuilder::buildUnwindowed ); } + + public static StreamToTable, KTable> streamToTable( + final QueryContext.Stacker stacker, + final Formats formats, + final ExecutionStep> source + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new StreamToTable<>( + source, + formats, + source.getProperties().withQueryContext(queryContext) + ); + } + + public static StreamSink> streamSink( + final QueryContext.Stacker stacker, + final LogicalSchema outputSchema, + final Formats formats, + final ExecutionStep> source, + final String topicName + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new StreamSink<>( + new DefaultExecutionStepProperties(outputSchema, queryContext), + source, + formats, + topicName + ); + } + + public static StreamFilter> streamFilter( + final QueryContext.Stacker stacker, + final ExecutionStep> source, + final Expression filterExpression + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new StreamFilter<>( + source.getProperties().withQueryContext(queryContext), + source, + filterExpression + ); + } + + public static StreamMapValues> streamMapValues( + final QueryContext.Stacker stacker, + final ExecutionStep> source, + final List selectExpressions, + final LogicalSchema resultSchema + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new StreamMapValues<>( + new DefaultExecutionStepProperties(resultSchema, queryContext), + source, + selectExpressions + ); + } + + public static StreamTableJoin, KTable> + streamTableJoin( + final QueryContext.Stacker stacker, + final JoinType joinType, + final Formats formats, + final ExecutionStep> left, + final ExecutionStep> right, + final LogicalSchema resultSchema + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new StreamTableJoin<>( + new DefaultExecutionStepProperties(resultSchema, queryContext), + joinType, + formats, + left, + right + ); + } + + public static StreamStreamJoin> streamStreamJoin( + final QueryContext.Stacker stacker, + final JoinType joinType, + final Formats leftFormats, + final Formats rightFormats, + final ExecutionStep> left, + final ExecutionStep> right, + final LogicalSchema resultSchema, + final JoinWindows joinWindows + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new StreamStreamJoin<>( + new DefaultExecutionStepProperties(resultSchema, queryContext), + joinType, + leftFormats, + rightFormats, + left, + right, + Duration.ofMillis(joinWindows.beforeMs), + Duration.ofMillis(joinWindows.afterMs) + ); + } + + public static StreamSelectKey> streamSelectKey( + final QueryContext.Stacker stacker, + final ExecutionStep> source, + final String fieldName, + final boolean updateRowKey + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new StreamSelectKey<>( + new DefaultExecutionStepProperties( + source.getProperties().getSchema(), + queryContext + ), + source, + fieldName, + updateRowKey + ); + } + + public static TableSink> tableSink( + final QueryContext.Stacker stacker, + final LogicalSchema outputSchema, + final ExecutionStep> source, + final Formats formats, + final String topicName + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new TableSink<>( + new DefaultExecutionStepProperties(outputSchema, queryContext), + source, + formats, + topicName + ); + } + + public static TableFilter> tableFilter( + final QueryContext.Stacker stacker, + final ExecutionStep> source, + final Expression filterExpression + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new TableFilter<>( + source.getProperties().withQueryContext(queryContext), + source, + filterExpression + ); + } + + public static TableMapValues> tableMapValues( + final QueryContext.Stacker stacker, + final ExecutionStep> source, + final LogicalSchema resultSchema, + final List selectExpressions + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new TableMapValues<>( + new DefaultExecutionStepProperties(resultSchema, queryContext), + source, + selectExpressions + ); + } + + public static TableTableJoin> tableTableJoin( + final QueryContext.Stacker stacker, + final JoinType joinType, + final ExecutionStep> left, + final ExecutionStep> right, + final LogicalSchema resultSchema + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new TableTableJoin<>( + new DefaultExecutionStepProperties(resultSchema, queryContext), + joinType, + left, + right + ); + } + + public static StreamAggregate, KGroupedStream> + streamAggregate( + final QueryContext.Stacker stacker, + final ExecutionStep> sourceStep, + final LogicalSchema resultSchema, + final Formats formats, + final int nonFuncColumnCount, + final List aggregations + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new StreamAggregate<>( + new DefaultExecutionStepProperties(resultSchema, queryContext), + sourceStep, + formats, + nonFuncColumnCount, + aggregations + ); + } + + public static StreamGroupBy, KGroupedStream> + streamGroupBy( + final QueryContext.Stacker stacker, + final ExecutionStep> sourceStep, + final Formats format, + final List groupingExpressions + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new StreamGroupBy<>( + sourceStep.getProperties().withQueryContext(queryContext), + sourceStep, + format, + groupingExpressions + ); + } + + public static TableAggregate, KGroupedTable> + tableAggregate( + final QueryContext.Stacker stacker, + final ExecutionStep> sourceStep, + final LogicalSchema resultSchema, + final Formats formats, + final int nonFuncColumnCount, + final List aggregations + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new TableAggregate<>( + new DefaultExecutionStepProperties(resultSchema, queryContext), + sourceStep, + formats, + nonFuncColumnCount, + aggregations + ); + } + + public static TableGroupBy, KGroupedTable> + tableGroupBy( + final QueryContext.Stacker stacker, + final ExecutionStep> sourceStep, + final Formats format, + final List groupingExpressions + ) { + final QueryContext queryContext = stacker.getQueryContext(); + return new TableGroupBy<>( + sourceStep.getProperties().withQueryContext(queryContext), + sourceStep, + format, + groupingExpressions + ); + } } diff --git a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamSourceBuilderTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamSourceBuilderTest.java index 3507badc202d..6ab1dcb08b4c 100644 --- a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamSourceBuilderTest.java +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/StreamSourceBuilderTest.java @@ -171,7 +171,7 @@ public void setup() { private void givenWindowedSource() { streamSource = new StreamSource<>( - new DefaultExecutionStepProperties("id", SCHEMA, ctx), + new DefaultExecutionStepProperties(SCHEMA, ctx), TOPIC_NAME, Formats.of(keyFormat, valueFormat, SERDE_OPTIONS), extractionPolicy, @@ -184,7 +184,7 @@ private void givenWindowedSource() { private void givenUnwindowedSource() { streamSource = new StreamSource<>( - new DefaultExecutionStepProperties("id", SCHEMA, ctx), + new DefaultExecutionStepProperties(SCHEMA, ctx), TOPIC_NAME, Formats.of(keyFormat, valueFormat, SERDE_OPTIONS), extractionPolicy, @@ -313,7 +313,7 @@ public void shouldAddNonWindowedKey() { public void shouldThrowOnMultiFieldKey() { // Given: final StreamSource> streamSource = new StreamSource<>( - new DefaultExecutionStepProperties("id", SCHEMA, ctx), + new DefaultExecutionStepProperties(SCHEMA, ctx), TOPIC_NAME, Formats.of(keyFormat, valueFormat, SERDE_OPTIONS), extractionPolicy,