Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: move aggregations to plan builder #3391

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.windows.KsqlWindowExpression;
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.metastore.model.DataSource;
import io.confluent.ksql.name.ColumnName;
Expand All @@ -40,7 +41,6 @@
import io.confluent.ksql.parser.tree.GroupingElement;
import io.confluent.ksql.parser.tree.Join;
import io.confluent.ksql.parser.tree.JoinOn;
import io.confluent.ksql.parser.tree.KsqlWindowExpression;
import io.confluent.ksql.parser.tree.Query;
import io.confluent.ksql.parser.tree.Select;
import io.confluent.ksql.parser.tree.SelectItem;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import io.confluent.ksql.parser.tree.GroupingElement;
import io.confluent.ksql.parser.tree.InsertInto;
import io.confluent.ksql.parser.tree.Join;
import io.confluent.ksql.parser.tree.KsqlWindowExpression;
import io.confluent.ksql.parser.tree.Query;
import io.confluent.ksql.parser.tree.RegisterType;
import io.confluent.ksql.parser.tree.Relation;
Expand Down Expand Up @@ -236,14 +235,8 @@ protected AstNode visitWindowExpression(final WindowExpression node, final C con
return new WindowExpression(
node.getLocation(),
node.getWindowName(),
(KsqlWindowExpression) rewriter.apply(node.getKsqlWindowExpression(), context));
}

@Override
protected AstNode visitKsqlWindowExpression(
final KsqlWindowExpression node,
final C context) {
return node;
node.getKsqlWindowExpression()
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package io.confluent.ksql.function;

import io.confluent.ksql.execution.function.TableAggregationFunction;
import io.confluent.ksql.function.udaf.TableUdaf;
import java.util.List;
import java.util.Optional;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

package io.confluent.ksql.function.udaf.count;

import io.confluent.ksql.execution.function.TableAggregationFunction;
import io.confluent.ksql.function.AggregateFunctionArguments;
import io.confluent.ksql.function.BaseAggregateFunction;
import io.confluent.ksql.function.KsqlAggregateFunction;
import io.confluent.ksql.function.TableAggregationFunction;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

package io.confluent.ksql.function.udaf.sum;

import io.confluent.ksql.execution.function.TableAggregationFunction;
import io.confluent.ksql.function.AggregateFunctionArguments;
import io.confluent.ksql.function.BaseAggregateFunction;
import io.confluent.ksql.function.KsqlAggregateFunction;
import io.confluent.ksql.function.TableAggregationFunction;
import io.confluent.ksql.util.DecimalUtil;
import java.math.BigDecimal;
import java.math.MathContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

package io.confluent.ksql.function.udaf.sum;

import io.confluent.ksql.execution.function.TableAggregationFunction;
import io.confluent.ksql.function.AggregateFunctionArguments;
import io.confluent.ksql.function.BaseAggregateFunction;
import io.confluent.ksql.function.KsqlAggregateFunction;
import io.confluent.ksql.function.TableAggregationFunction;
import java.util.Collections;
import java.util.function.Function;
import org.apache.kafka.connect.data.Schema;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

package io.confluent.ksql.function.udaf.sum;

import io.confluent.ksql.execution.function.TableAggregationFunction;
import io.confluent.ksql.function.AggregateFunctionArguments;
import io.confluent.ksql.function.BaseAggregateFunction;
import io.confluent.ksql.function.KsqlAggregateFunction;
import io.confluent.ksql.function.TableAggregationFunction;
import java.util.Collections;
import java.util.function.Function;
import org.apache.kafka.connect.data.Schema;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

package io.confluent.ksql.function.udaf.sum;

import io.confluent.ksql.execution.function.TableAggregationFunction;
import io.confluent.ksql.function.AggregateFunctionArguments;
import io.confluent.ksql.function.BaseAggregateFunction;
import io.confluent.ksql.function.KsqlAggregateFunction;
import io.confluent.ksql.function.TableAggregationFunction;
import java.util.Collections;
import java.util.function.Function;
import org.apache.kafka.connect.data.Schema;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package io.confluent.ksql.function.udaf.window;

import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper;
import io.confluent.ksql.function.udaf.TableUdaf;
import io.confluent.ksql.function.udaf.UdafDescription;
import io.confluent.ksql.function.udaf.UdafFactory;
Expand All @@ -37,7 +38,7 @@ private WindowEndKudaf() {
}

static String getFunctionName() {
return "WindowEnd";
return WindowSelectMapper.WINDOW_END_NAME;
}

@UdafFactory(description = "Extracts the window end time")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package io.confluent.ksql.function.udaf.window;

import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper;
import io.confluent.ksql.function.udaf.TableUdaf;
import io.confluent.ksql.function.udaf.UdafDescription;
import io.confluent.ksql.function.udaf.UdafFactory;
Expand All @@ -37,7 +38,7 @@ private WindowStartKudaf() {
}

static String getFunctionName() {
return "WindowStart";
return WindowSelectMapper.WINDOW_START_NAME;
}

@UdafFactory(description = "Extracts the window start time")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import static java.util.Objects.requireNonNull;

import com.google.common.collect.ImmutableList;
import io.confluent.ksql.GenericRow;
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter;
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context;
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
Expand All @@ -29,24 +28,20 @@
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.Literal;
import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor;
import io.confluent.ksql.execution.function.UdafUtil;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.util.ExpressionTypeManager;
import io.confluent.ksql.function.AggregateFunctionArguments;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.function.KsqlAggregateFunction;
import io.confluent.ksql.function.udaf.KudafInitializer;
import io.confluent.ksql.materialization.MaterializationInfo;
import io.confluent.ksql.metastore.model.KeyField;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.parser.tree.WindowExpression;
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.PhysicalSchema;
import io.confluent.ksql.schema.ksql.SchemaConverters;
import io.confluent.ksql.schema.ksql.SchemaConverters.ConnectToSqlTypeConverter;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.serde.SerdeOption;
import io.confluent.ksql.serde.ValueFormat;
import io.confluent.ksql.services.KafkaTopicClient;
import io.confluent.ksql.structured.SchemaKGroupedStream;
Expand All @@ -66,8 +61,6 @@
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.connect.data.Schema;


public class AggregateNode extends PlanNode {
Expand Down Expand Up @@ -234,54 +227,43 @@ public SchemaKStream<?> buildStream(final KsqlQueryBuilder builder) {
builder
);

// Aggregate computations
final KudafInitializer initializer = new KudafInitializer(requiredColumns.size());

final Map<Integer, KsqlAggregateFunction> aggValToFunctionMap = createAggValToFunctionMap(
aggregateArgExpanded,
initializer,
requiredColumns.size(),
builder.getFunctionRegistry(),
internalSchema
);
final List<FunctionCall> functionsWithInternalIdentifiers = functionList.stream()
.map(
fc -> new FunctionCall(
fc.getName(),
internalSchema.getInternalArgsExpressionList(fc.getArguments())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally.... we should be looking to move the whole internal schema thing into the physical layer. The logical layer shouldn't need to know about such things. Though such internal names should be part of the serialized form of the physical.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it shouldn't be in the logical plan node (like the other code that builds the execution plan). But the code that converts from the physical plan to the streams app shouldn't have to worry about it either. I think when we have a visitor that traverses the logical plan to build the physical plan, that would be the right place for it.

)
)
.collect(Collectors.toList());

// This is the schema of the aggregation change log topic and associated state store.
// It contains all columns from prepareSchema and columns for any aggregating functions
// It uses internal column names, e.g. KSQL_INTERNAL_COL_0 and KSQL_AGG_VARIABLE_0
final LogicalSchema aggregationSchema = buildLogicalSchema(
prepareSchema,
aggValToFunctionMap,
functionsWithInternalIdentifiers,
builder.getFunctionRegistry(),
true
);

final QueryContext.Stacker aggregationContext = contextStacker.push(AGGREGATION_OP_NAME);

final Serde<GenericRow> aggValueGenericRowSerde = builder.buildValueSerde(
valueFormat.getFormatInfo(),
PhysicalSchema.from(aggregationSchema, SerdeOption.none()),
aggregationContext.getQueryContext()
);

final List<FunctionCall> functionsWithInternalIdentifiers = functionList.stream()
.map(internalSchema::resolveToInternal)
.map(FunctionCall.class::cast)
.collect(Collectors.toList());

final LogicalSchema outputSchema = buildLogicalSchema(
prepareSchema,
aggValToFunctionMap,
false);
functionsWithInternalIdentifiers,
builder.getFunctionRegistry(),
false
);

SchemaKTable<?> aggregated = schemaKGroupedStream.aggregate(
aggregationSchema,
outputSchema,
initializer,
requiredColumns.size(),
functionsWithInternalIdentifiers,
aggValToFunctionMap,
windowExpression,
valueFormat,
aggValueGenericRowSerde,
aggregationContext
aggregationContext,
builder
);

final Optional<Expression> havingExpression = Optional.ofNullable(havingExpressions)
Expand Down Expand Up @@ -316,61 +298,12 @@ protected int getPartitions(final KafkaTopicClient kafkaTopicClient) {
return source.getPartitions(kafkaTopicClient);
}

private Map<Integer, KsqlAggregateFunction> createAggValToFunctionMap(
final SchemaKStream aggregateArgExpanded,
final KudafInitializer initializer,
final int initialUdafIndex,
final FunctionRegistry functionRegistry,
final InternalSchema internalSchema
) {
int udafIndexInAggSchema = initialUdafIndex;
final Map<Integer, KsqlAggregateFunction> aggValToAggFunctionMap = new HashMap<>();
for (final FunctionCall functionCall : functionList) {
final KsqlAggregateFunction aggregateFunction = getAggregateFunction(
functionRegistry,
internalSchema,
functionCall, aggregateArgExpanded.getSchema());

aggValToAggFunctionMap.put(udafIndexInAggSchema++, aggregateFunction);
initializer.addAggregateIntializer(aggregateFunction.getInitialValueSupplier());
}
return aggValToAggFunctionMap;
}

@SuppressWarnings("deprecation") // Need to migrate away from Connect Schema use.
private static KsqlAggregateFunction getAggregateFunction(
final FunctionRegistry functionRegistry,
final InternalSchema internalSchema,
final FunctionCall functionCall,
final LogicalSchema schema
) {
try {
final ExpressionTypeManager expressionTypeManager =
new ExpressionTypeManager(schema, functionRegistry);
final List<Expression> functionArgs = internalSchema.getInternalArgsExpressionList(
functionCall.getArguments());
final Schema expressionType = expressionTypeManager.getExpressionSchema(functionArgs.get(0));
final KsqlAggregateFunction aggregateFunctionInfo = functionRegistry
.getAggregate(functionCall.getName().name(), expressionType);

final List<String> args = functionArgs.stream()
.map(Expression::toString)
.collect(Collectors.toList());

final int udafIndex = Integer
.parseInt(args.get(0).substring(INTERNAL_COLUMN_NAME_PREFIX.length()));

return aggregateFunctionInfo.getInstance(new AggregateFunctionArguments(udafIndex, args));
} catch (final Exception e) {
throw new KsqlException("Failed to create aggregate function: " + functionCall, e);
}
}

private LogicalSchema buildLogicalSchema(
final LogicalSchema inputSchema,
final Map<Integer, KsqlAggregateFunction> aggregateFunctions,
final boolean useAggregate) {

final List<FunctionCall> aggregations,
final FunctionRegistry functionRegistry,
final boolean useAggregate
) {
final LogicalSchema.Builder schemaBuilder = LogicalSchema.builder();
final List<Column> cols = inputSchema.value();

Expand All @@ -382,18 +315,13 @@ private LogicalSchema buildLogicalSchema(

final ConnectToSqlTypeConverter converter = SchemaConverters.connectToSqlConverter();

for (int idx = 0; idx < aggregateFunctions.size(); idx++) {

final KsqlAggregateFunction aggregateFunction = aggregateFunctions
.get(requiredColumns.size() + idx);

final ColumnName colName = ColumnName.aggregate(idx);
SqlType fieldType = null;
if (useAggregate) {
fieldType = converter.toSqlType(aggregateFunction.getAggregateType());
} else {
fieldType = converter.toSqlType(aggregateFunction.getReturnType());
}
for (int i = 0; i < aggregations.size(); i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could move this building of schema, and other places we build schemas, into the QueryAnalyzer / Analyzer. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that makes sense to me for a follow-up.

final KsqlAggregateFunction aggregateFunction =
UdafUtil.resolveAggregateFunction(functionRegistry, aggregations.get(i), inputSchema);
final ColumnName colName = ColumnName.aggregate(i);
final SqlType fieldType = converter.toSqlType(
useAggregate ? aggregateFunction.getAggregateType() : aggregateFunction.getReturnType()
);
schemaBuilder.valueColumn(colName, fieldType);
}

Expand Down
Loading