Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: move groupBy into plan builders #3359

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,6 @@ public SchemaKStream<?> buildStream(final KsqlQueryBuilder builder) {
.getKsqlTopic()
.getValueFormat();

final Serde<GenericRow> genericRowSerde = builder.buildValueSerde(
valueFormat.getFormatInfo(),
PhysicalSchema.from(prepareSchema, SerdeOption.none()),
groupByContext.getQueryContext()
);

final List<Expression> internalGroupByColumns = internalSchema.resolveGroupByExpressions(
getGroupByExpressions(),
aggregateArgExpanded,
Expand All @@ -235,9 +229,9 @@ public SchemaKStream<?> buildStream(final KsqlQueryBuilder builder) {

final SchemaKGroupedStream schemaKGroupedStream = aggregateArgExpanded.groupBy(
valueFormat,
genericRowSerde,
internalGroupByColumns,
groupByContext
groupByContext,
builder
);

// Aggregate computations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package io.confluent.ksql.streams;

import io.confluent.ksql.execution.streams.GroupedFactory;
import io.confluent.ksql.execution.streams.MaterializedFactory;
import io.confluent.ksql.util.KsqlConfig;
import java.util.Objects;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import com.google.common.collect.ImmutableList;
import io.confluent.ksql.GenericRow;
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
import io.confluent.ksql.execution.codegen.CodeGenRunner;
import io.confluent.ksql.execution.codegen.ExpressionMetadata;
import io.confluent.ksql.execution.context.QueryContext;
import io.confluent.ksql.execution.context.QueryLoggerUtil;
import io.confluent.ksql.execution.ddl.commands.KsqlTopic;
Expand All @@ -37,12 +35,15 @@
import io.confluent.ksql.execution.plan.LogicalSchemaWithMetaAndKeyFields;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.plan.StreamFilter;
import io.confluent.ksql.execution.plan.StreamGroupBy;
import io.confluent.ksql.execution.plan.StreamGroupByKey;
import io.confluent.ksql.execution.plan.StreamMapValues;
import io.confluent.ksql.execution.plan.StreamSelectKey;
import io.confluent.ksql.execution.plan.StreamSource;
import io.confluent.ksql.execution.plan.StreamToTable;
import io.confluent.ksql.execution.streams.ExecutionStepFactory;
import io.confluent.ksql.execution.streams.StreamFilterBuilder;
import io.confluent.ksql.execution.streams.StreamGroupByBuilder;
import io.confluent.ksql.execution.streams.StreamMapValuesBuilder;
import io.confluent.ksql.execution.streams.StreamSelectKeyBuilder;
import io.confluent.ksql.execution.streams.StreamSourceBuilder;
Expand Down Expand Up @@ -72,12 +73,11 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.connect.data.Struct;
import org.apache.kafka.streams.Topology.AutoOffsetReset;
import org.apache.kafka.streams.kstream.Grouped;
import org.apache.kafka.streams.kstream.JoinWindows;
import org.apache.kafka.streams.kstream.KGroupedStream;
import org.apache.kafka.streams.kstream.KStream;
import org.apache.kafka.streams.kstream.KTable;
import org.apache.kafka.streams.kstream.Produced;
Expand All @@ -92,6 +92,8 @@ public class SchemaKStream<K> {
private static final FormatOptions FORMAT_OPTIONS =
FormatOptions.of(IdentifierUtil::needsQuotes);

static final String GROUP_BY_COLUMN_SEPARATOR = "|+|";

public enum Type { SOURCE, PROJECT, FILTER, AGGREGATE, SINK, REKEY, JOIN }

final KStream<K, GenericRow> kstream;
Expand Down Expand Up @@ -777,73 +779,38 @@ private boolean rekeyRequired(final List<Expression> groupByExpressions) {

public SchemaKGroupedStream groupBy(
final ValueFormat valueFormat,
final Serde<GenericRow> valSerde,
final List<Expression> groupByExpressions,
final QueryContext.Stacker contextStacker
final QueryContext.Stacker contextStacker,
final KsqlQueryBuilder queryBuilder
) {
final boolean rekey = rekeyRequired(groupByExpressions);
final KeyFormat rekeyedKeyFormat = KeyFormat.nonWindowed(keyFormat.getFormatInfo());
if (!rekey) {
final Grouped<K, GenericRow> grouped = streamsFactories.getGroupedFactory()
.create(
StreamsUtil.buildOpName(contextStacker.getQueryContext()),
keySerde,
valSerde
);

final KGroupedStream kgroupedStream = kstream.groupByKey(grouped);

final KeySerde<Struct> structKeySerde = getGroupByKeyKeySerde();

final ExecutionStep<KGroupedStream<Struct, GenericRow>> step =
ExecutionStepFactory.streamGroupBy(
contextStacker,
sourceStep,
Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none()),
groupByExpressions
);
return new SchemaKGroupedStream(
kgroupedStream,
step,
keyFormat,
structKeySerde,
keyField,
Collections.singletonList(this),
ksqlConfig,
functionRegistry
);
return groupByKey(rekeyedKeyFormat, valueFormat, contextStacker, queryBuilder);
}

final GroupBy groupBy = new GroupBy(groupByExpressions);

final KeySerde<Struct> groupedKeySerde = keySerde
.rebind(StructKeyUtil.ROWKEY_SERIALIZED_SCHEMA);

final Grouped<Struct, GenericRow> grouped = streamsFactories.getGroupedFactory()
.create(
StreamsUtil.buildOpName(contextStacker.getQueryContext()),
groupedKeySerde,
valSerde
);

final KGroupedStream kgroupedStream = kstream
.filter((key, value) -> value != null)
.groupBy(groupBy.mapper, grouped);

final String aggregateKeyName = groupedKeyNameFor(groupByExpressions);
final LegacyField legacyKeyField = LegacyField
.notInSchema(groupBy.aggregateKeyName, SqlTypes.STRING);

final Optional<String> newKeyCol = getSchema().findValueColumn(groupBy.aggregateKeyName)
.notInSchema(aggregateKeyName, SqlTypes.STRING);
final Optional<String> newKeyCol = getSchema().findValueColumn(aggregateKeyName)
.map(Column::name);
final ExecutionStep<KGroupedStream<Struct, GenericRow>> source =
ExecutionStepFactory.streamGroupBy(
contextStacker,
sourceStep,
Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none()),
groupByExpressions
);

final StreamGroupBy<K> source = ExecutionStepFactory.streamGroupBy(
contextStacker,
sourceStep,
Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none()),
groupByExpressions
);
return new SchemaKGroupedStream(
kgroupedStream,
StreamGroupByBuilder.build(
kstream,
source,
queryBuilder,
streamsFactories.getGroupedFactory()
),
source,
rekeyedKeyFormat,
groupedKeySerde,
Expand All @@ -854,6 +821,37 @@ public SchemaKGroupedStream groupBy(
);
}

@SuppressWarnings("unchecked")
private SchemaKGroupedStream groupByKey(
final KeyFormat rekeyedKeyFormat,
final ValueFormat valueFormat,
final QueryContext.Stacker contextStacker,
final KsqlQueryBuilder queryBuilder
) {
final KeySerde<Struct> structKeySerde = getGroupByKeyKeySerde();
final StreamGroupByKey step =
ExecutionStepFactory.streamGroupByKey(
contextStacker,
(ExecutionStep) sourceStep,
Formats.of(rekeyedKeyFormat, valueFormat, SerdeOption.none())
);
return new SchemaKGroupedStream(
StreamGroupByBuilder.build(
(KStream) kstream,
step,
queryBuilder,
streamsFactories.getGroupedFactory()
),
step,
keyFormat,
structKeySerde,
keyField,
Collections.singletonList(this),
ksqlConfig,
functionRegistry
);
}

@SuppressWarnings("unchecked")
private KeySerde<Struct> getGroupByKeyKeySerde() {
if (keySerde.isWindowed()) {
Expand Down Expand Up @@ -920,18 +918,10 @@ public FunctionRegistry getFunctionRegistry() {
return functionRegistry;
}

class GroupBy {

final String aggregateKeyName;
final GroupByMapper mapper;

GroupBy(final List<Expression> expressions) {
final List<ExpressionMetadata> groupBy = CodeGenRunner.compileExpressions(
expressions.stream(), "Group By", getSchema(), ksqlConfig, functionRegistry);

this.mapper = new GroupByMapper(groupBy);
this.aggregateKeyName = GroupByMapper.keyNameFor(expressions);
}
String groupedKeyNameFor(final List<Expression> groupByExpressions) {
return groupByExpressions.stream()
.map(Expression::toString)
.collect(Collectors.joining(GROUP_BY_COLUMN_SEPARATOR));
}

protected static class KsqlValueJoiner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
import io.confluent.ksql.execution.plan.JoinType;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.plan.TableFilter;
import io.confluent.ksql.execution.plan.TableGroupBy;
import io.confluent.ksql.execution.plan.TableMapValues;
import io.confluent.ksql.execution.streams.ExecutionStepFactory;
import io.confluent.ksql.execution.streams.StreamsUtil;
import io.confluent.ksql.execution.streams.TableFilterBuilder;
import io.confluent.ksql.execution.streams.TableGroupByBuilder;
import io.confluent.ksql.execution.streams.TableMapValuesBuilder;
import io.confluent.ksql.execution.util.StructKeyUtil;
import io.confluent.ksql.function.FunctionRegistry;
Expand All @@ -51,9 +52,6 @@
import java.util.Set;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.connect.data.Struct;
import org.apache.kafka.streams.KeyValue;
import org.apache.kafka.streams.kstream.Grouped;
import org.apache.kafka.streams.kstream.KGroupedTable;
import org.apache.kafka.streams.kstream.KStream;
import org.apache.kafka.streams.kstream.KTable;
import org.apache.kafka.streams.kstream.Produced;
Expand Down Expand Up @@ -231,48 +229,37 @@ public ExecutionStep<KTable<K, GenericRow>> getSourceTableStep() {
}

@Override
@SuppressWarnings("unchecked")
public SchemaKGroupedStream groupBy(
final ValueFormat valueFormat,
final Serde<GenericRow> valSerde,
final List<Expression> groupByExpressions,
final QueryContext.Stacker contextStacker
final QueryContext.Stacker contextStacker,
final KsqlQueryBuilder queryBuilder
) {

final GroupBy groupBy = new GroupBy(groupByExpressions);
final KeyFormat groupedKeyFormat = KeyFormat.nonWindowed(keyFormat.getFormatInfo());

final KeySerde<Struct> groupedKeySerde = keySerde
.rebind(StructKeyUtil.ROWKEY_SERIALIZED_SCHEMA);

final Grouped<Struct, GenericRow> grouped = streamsFactories.getGroupedFactory()
.create(
StreamsUtil.buildOpName(contextStacker.getQueryContext()),
groupedKeySerde,
valSerde
);

final KGroupedTable kgroupedTable = ktable
.filter((key, value) -> value != null)
.groupBy(
(key, value) -> new KeyValue<>(groupBy.mapper.apply(key, value), value),
grouped
);
final String aggregateKeyName = groupedKeyNameFor(groupByExpressions);
final LegacyField legacyKeyField = LegacyField.notInSchema(aggregateKeyName, SqlTypes.STRING);
final Optional<String> newKeyField =
getSchema().findValueColumn(aggregateKeyName).map(Column::fullName);

final LegacyField legacyKeyField = LegacyField
.notInSchema(groupBy.aggregateKeyName, SqlTypes.STRING);

final Optional<String> newKeyField = getSchema().findValueColumn(groupBy.aggregateKeyName)
.map(Column::fullName);

final ExecutionStep<KGroupedTable<Struct, GenericRow>> step =
ExecutionStepFactory.tableGroupBy(
contextStacker,
sourceTableStep,
Formats.of(groupedKeyFormat, valueFormat, SerdeOption.none()),
groupByExpressions
);
final TableGroupBy<K> step = ExecutionStepFactory.tableGroupBy(
contextStacker,
sourceTableStep,
Formats.of(groupedKeyFormat, valueFormat, SerdeOption.none()),
groupByExpressions
);
return new SchemaKGroupedTable(
kgroupedTable,
TableGroupByBuilder.build(
ktable,
step,
queryBuilder,
streamsFactories.getGroupedFactory()
),
step,
groupedKeyFormat,
groupedKeySerde,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.PersistenceSchema;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.serde.FormatInfo;
import io.confluent.ksql.serde.KeySerde;
import io.confluent.ksql.serde.WindowInfo;
import io.confluent.ksql.structured.SchemaKStream;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.serde.FormatInfo;
import io.confluent.ksql.serde.KeySerde;
import io.confluent.ksql.structured.SchemaKStream;
import io.confluent.ksql.testutils.AnalysisTestUtil;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.GenericRow;
import io.confluent.ksql.execution.streams.GroupedFactory;
import io.confluent.ksql.util.KsqlConfig;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.streams.StreamsConfig;
Expand Down
Loading