-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: move aggregations to plan builder #3391
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,6 @@ | |
import static java.util.Objects.requireNonNull; | ||
|
||
import com.google.common.collect.ImmutableList; | ||
import io.confluent.ksql.GenericRow; | ||
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter; | ||
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context; | ||
import io.confluent.ksql.execution.builder.KsqlQueryBuilder; | ||
|
@@ -29,24 +28,20 @@ | |
import io.confluent.ksql.execution.expression.tree.FunctionCall; | ||
import io.confluent.ksql.execution.expression.tree.Literal; | ||
import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor; | ||
import io.confluent.ksql.execution.function.UdafUtil; | ||
import io.confluent.ksql.execution.plan.SelectExpression; | ||
import io.confluent.ksql.execution.util.ExpressionTypeManager; | ||
import io.confluent.ksql.function.AggregateFunctionArguments; | ||
import io.confluent.ksql.function.FunctionRegistry; | ||
import io.confluent.ksql.function.KsqlAggregateFunction; | ||
import io.confluent.ksql.function.udaf.KudafInitializer; | ||
import io.confluent.ksql.materialization.MaterializationInfo; | ||
import io.confluent.ksql.metastore.model.KeyField; | ||
import io.confluent.ksql.name.ColumnName; | ||
import io.confluent.ksql.parser.tree.WindowExpression; | ||
import io.confluent.ksql.schema.ksql.Column; | ||
import io.confluent.ksql.schema.ksql.ColumnRef; | ||
import io.confluent.ksql.schema.ksql.LogicalSchema; | ||
import io.confluent.ksql.schema.ksql.PhysicalSchema; | ||
import io.confluent.ksql.schema.ksql.SchemaConverters; | ||
import io.confluent.ksql.schema.ksql.SchemaConverters.ConnectToSqlTypeConverter; | ||
import io.confluent.ksql.schema.ksql.types.SqlType; | ||
import io.confluent.ksql.serde.SerdeOption; | ||
import io.confluent.ksql.serde.ValueFormat; | ||
import io.confluent.ksql.services.KafkaTopicClient; | ||
import io.confluent.ksql.structured.SchemaKGroupedStream; | ||
|
@@ -66,8 +61,6 @@ | |
import java.util.Set; | ||
import java.util.function.Function; | ||
import java.util.stream.Collectors; | ||
import org.apache.kafka.common.serialization.Serde; | ||
import org.apache.kafka.connect.data.Schema; | ||
|
||
|
||
public class AggregateNode extends PlanNode { | ||
|
@@ -234,54 +227,43 @@ public SchemaKStream<?> buildStream(final KsqlQueryBuilder builder) { | |
builder | ||
); | ||
|
||
// Aggregate computations | ||
final KudafInitializer initializer = new KudafInitializer(requiredColumns.size()); | ||
|
||
final Map<Integer, KsqlAggregateFunction> aggValToFunctionMap = createAggValToFunctionMap( | ||
aggregateArgExpanded, | ||
initializer, | ||
requiredColumns.size(), | ||
builder.getFunctionRegistry(), | ||
internalSchema | ||
); | ||
final List<FunctionCall> functionsWithInternalIdentifiers = functionList.stream() | ||
.map( | ||
fc -> new FunctionCall( | ||
fc.getName(), | ||
internalSchema.getInternalArgsExpressionList(fc.getArguments()) | ||
) | ||
) | ||
.collect(Collectors.toList()); | ||
|
||
// This is the schema of the aggregation change log topic and associated state store. | ||
// It contains all columns from prepareSchema and columns for any aggregating functions | ||
// It uses internal column names, e.g. KSQL_INTERNAL_COL_0 and KSQL_AGG_VARIABLE_0 | ||
final LogicalSchema aggregationSchema = buildLogicalSchema( | ||
prepareSchema, | ||
aggValToFunctionMap, | ||
functionsWithInternalIdentifiers, | ||
builder.getFunctionRegistry(), | ||
true | ||
); | ||
|
||
final QueryContext.Stacker aggregationContext = contextStacker.push(AGGREGATION_OP_NAME); | ||
|
||
final Serde<GenericRow> aggValueGenericRowSerde = builder.buildValueSerde( | ||
valueFormat.getFormatInfo(), | ||
PhysicalSchema.from(aggregationSchema, SerdeOption.none()), | ||
aggregationContext.getQueryContext() | ||
); | ||
|
||
final List<FunctionCall> functionsWithInternalIdentifiers = functionList.stream() | ||
.map(internalSchema::resolveToInternal) | ||
.map(FunctionCall.class::cast) | ||
.collect(Collectors.toList()); | ||
|
||
final LogicalSchema outputSchema = buildLogicalSchema( | ||
prepareSchema, | ||
aggValToFunctionMap, | ||
false); | ||
functionsWithInternalIdentifiers, | ||
builder.getFunctionRegistry(), | ||
false | ||
); | ||
|
||
SchemaKTable<?> aggregated = schemaKGroupedStream.aggregate( | ||
aggregationSchema, | ||
outputSchema, | ||
initializer, | ||
requiredColumns.size(), | ||
functionsWithInternalIdentifiers, | ||
aggValToFunctionMap, | ||
windowExpression, | ||
valueFormat, | ||
aggValueGenericRowSerde, | ||
aggregationContext | ||
aggregationContext, | ||
builder | ||
); | ||
|
||
final Optional<Expression> havingExpression = Optional.ofNullable(havingExpressions) | ||
|
@@ -316,61 +298,12 @@ protected int getPartitions(final KafkaTopicClient kafkaTopicClient) { | |
return source.getPartitions(kafkaTopicClient); | ||
} | ||
|
||
private Map<Integer, KsqlAggregateFunction> createAggValToFunctionMap( | ||
final SchemaKStream aggregateArgExpanded, | ||
final KudafInitializer initializer, | ||
final int initialUdafIndex, | ||
final FunctionRegistry functionRegistry, | ||
final InternalSchema internalSchema | ||
) { | ||
int udafIndexInAggSchema = initialUdafIndex; | ||
final Map<Integer, KsqlAggregateFunction> aggValToAggFunctionMap = new HashMap<>(); | ||
for (final FunctionCall functionCall : functionList) { | ||
final KsqlAggregateFunction aggregateFunction = getAggregateFunction( | ||
functionRegistry, | ||
internalSchema, | ||
functionCall, aggregateArgExpanded.getSchema()); | ||
|
||
aggValToAggFunctionMap.put(udafIndexInAggSchema++, aggregateFunction); | ||
initializer.addAggregateIntializer(aggregateFunction.getInitialValueSupplier()); | ||
} | ||
return aggValToAggFunctionMap; | ||
} | ||
|
||
@SuppressWarnings("deprecation") // Need to migrate away from Connect Schema use. | ||
private static KsqlAggregateFunction getAggregateFunction( | ||
final FunctionRegistry functionRegistry, | ||
final InternalSchema internalSchema, | ||
final FunctionCall functionCall, | ||
final LogicalSchema schema | ||
) { | ||
try { | ||
final ExpressionTypeManager expressionTypeManager = | ||
new ExpressionTypeManager(schema, functionRegistry); | ||
final List<Expression> functionArgs = internalSchema.getInternalArgsExpressionList( | ||
functionCall.getArguments()); | ||
final Schema expressionType = expressionTypeManager.getExpressionSchema(functionArgs.get(0)); | ||
final KsqlAggregateFunction aggregateFunctionInfo = functionRegistry | ||
.getAggregate(functionCall.getName().name(), expressionType); | ||
|
||
final List<String> args = functionArgs.stream() | ||
.map(Expression::toString) | ||
.collect(Collectors.toList()); | ||
|
||
final int udafIndex = Integer | ||
.parseInt(args.get(0).substring(INTERNAL_COLUMN_NAME_PREFIX.length())); | ||
|
||
return aggregateFunctionInfo.getInstance(new AggregateFunctionArguments(udafIndex, args)); | ||
} catch (final Exception e) { | ||
throw new KsqlException("Failed to create aggregate function: " + functionCall, e); | ||
} | ||
} | ||
|
||
private LogicalSchema buildLogicalSchema( | ||
final LogicalSchema inputSchema, | ||
final Map<Integer, KsqlAggregateFunction> aggregateFunctions, | ||
final boolean useAggregate) { | ||
|
||
final List<FunctionCall> aggregations, | ||
final FunctionRegistry functionRegistry, | ||
final boolean useAggregate | ||
) { | ||
final LogicalSchema.Builder schemaBuilder = LogicalSchema.builder(); | ||
final List<Column> cols = inputSchema.value(); | ||
|
||
|
@@ -382,18 +315,13 @@ private LogicalSchema buildLogicalSchema( | |
|
||
final ConnectToSqlTypeConverter converter = SchemaConverters.connectToSqlConverter(); | ||
|
||
for (int idx = 0; idx < aggregateFunctions.size(); idx++) { | ||
|
||
final KsqlAggregateFunction aggregateFunction = aggregateFunctions | ||
.get(requiredColumns.size() + idx); | ||
|
||
final ColumnName colName = ColumnName.aggregate(idx); | ||
SqlType fieldType = null; | ||
if (useAggregate) { | ||
fieldType = converter.toSqlType(aggregateFunction.getAggregateType()); | ||
} else { | ||
fieldType = converter.toSqlType(aggregateFunction.getReturnType()); | ||
} | ||
for (int i = 0; i < aggregations.size(); i++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we could move this building of schema, and other places we build schemas, into the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that makes sense to me for a follow-up. |
||
final KsqlAggregateFunction aggregateFunction = | ||
UdafUtil.resolveAggregateFunction(functionRegistry, aggregations.get(i), inputSchema); | ||
final ColumnName colName = ColumnName.aggregate(i); | ||
final SqlType fieldType = converter.toSqlType( | ||
useAggregate ? aggregateFunction.getAggregateType() : aggregateFunction.getReturnType() | ||
); | ||
schemaBuilder.valueColumn(colName, fieldType); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally.... we should be looking to move the whole internal schema thing into the physical layer. The logical layer shouldn't need to know about such things. Though such internal names should be part of the serialized form of the physical.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it shouldn't be in the logical plan node (like the other code that builds the execution plan). But the code that converts from the physical plan to the streams app shouldn't have to worry about it either. I think when we have a visitor that traverses the logical plan to build the physical plan, that would be the right place for it.