Skip to content

Commit

Permalink
fix: unify behavior for PARTITION BY and GROUP BY
Browse files Browse the repository at this point in the history
BREAKING CHANGE: this change makes it so that PARTITION BY statements
use the _source_ schema, not the value/projection schema, when selecting
the value to partition by. This is consistent with GROUP BY, and
standard SQL for GROUP by. Any statement that previously used PARTITION
BY may need to be reworked.
  • Loading branch information
agavra committed Nov 27, 2019
1 parent 8b4bf27 commit 808b04e
Show file tree
Hide file tree
Showing 39 changed files with 427 additions and 577 deletions.
2 changes: 1 addition & 1 deletion docs/developer-guide/syntax-reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ its corresponding topic.

If the PARTITION BY clause is present, then the resulting stream will
have the specified column as its key. The `column_name` must be present
in the `select_expr`. For more information, see :ref:`partition-data-to-enable-joins`.
in the `from_stream`. For more information, see :ref:`partition-data-to-enable-joins`.

For joins, the key of the resulting stream will be the value from the column
from the left stream that was used in the join criteria. This column will be
Expand Down
15 changes: 12 additions & 3 deletions ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,7 @@ private final class Visitor extends DefaultTraversalVisitor<AstNode, Void> {

private void analyzeNonStdOutSink(final Sink sink) {
analysis.setProperties(sink.getProperties());
sink.getPartitionBy()
.map(name -> ColumnRef.withoutSource(name.name()))
.ifPresent(analysis::setPartitionBy);


setSerdeOptions(sink);

Expand Down Expand Up @@ -317,6 +315,7 @@ protected AstNode visitQuery(

node.getWhere().ifPresent(this::analyzeWhere);
node.getGroupBy().ifPresent(this::analyzeGroupBy);
node.getPartitionBy().ifPresent(this::analyzePartitionBy);
node.getWindow().ifPresent(this::analyzeWindowExpression);
node.getHaving().ifPresent(this::analyzeHaving);
node.getLimit().ifPresent(analysis::setLimitClause);
Expand Down Expand Up @@ -543,6 +542,16 @@ private void analyzeGroupBy(final GroupBy groupBy) {
}
}

private void analyzePartitionBy(final Expression partitionBy) {
if (partitionBy instanceof ColumnReferenceExp) {
analysis.setPartitionBy(((ColumnReferenceExp) partitionBy).getReference());
return;
}

throw new KsqlException(
"Expected partition by to be a valid column but got " + partitionBy);
}

private void analyzeWindowExpression(final WindowExpression windowExpression) {
analysis.setWindowExpression(windowExpression);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ protected Optional<AstNode> visitCreateStreamAsSelect(
node.getName(),
(Query) ctx.process(node.getQuery()),
node.isNotExists(),
node.getProperties(),
node.getPartitionByColumn()
node.getProperties()
)
);
}
Expand Down Expand Up @@ -117,8 +116,7 @@ protected Optional<AstNode> visitInsertInto(
new InsertInto(
node.getLocation(),
node.getTarget(),
(Query) ctx.process(node.getQuery()),
node.getPartitionByColumn()
(Query) ctx.process(node.getQuery())
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ protected AstNode visitQuery(final Query node, final C context) {
final Optional<GroupBy> groupBy = node.getGroupBy()
.map(exp -> ((GroupBy) rewriter.apply(exp, context)));

// don't rewrite the partitionBy because we expect it to be
// exactly as it was (a single, un-aliased, column reference)
final Optional<Expression> partitionBy = node.getPartitionBy();

final Optional<Expression> having = node.getHaving()
.map(exp -> (processExpression(exp, context)));

Expand All @@ -198,6 +202,7 @@ protected AstNode visitQuery(final Query node, final C context) {
windowExpression,
where,
groupBy,
partitionBy,
having,
node.getResultMaterialization(),
node.isPullQuery(),
Expand Down Expand Up @@ -364,16 +369,12 @@ protected AstNode visitCreateStreamAsSelect(
return result.get();
}

final Optional<Expression> partitionBy = node.getPartitionByColumn()
.map(exp -> processExpression(exp, context));

return new CreateStreamAsSelect(
node.getLocation(),
node.getName(),
(Query) rewriter.apply(node.getQuery(), context),
node.isNotExists(),
node.getProperties(),
partitionBy
node.getProperties()
);
}

Expand Down Expand Up @@ -416,14 +417,11 @@ protected AstNode visitInsertInto(final InsertInto node, final C context) {
return result.get();
}

final Optional<Expression> rewrittenPartitionBy = node.getPartitionByColumn()
.map(exp -> processExpression(exp, context));

return new InsertInto(
node.getLocation(),
node.getTarget(),
(Query) rewriter.apply(node.getQuery(), context),
rewrittenPartitionBy);
(Query) rewriter.apply(node.getQuery(), context)
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
import io.confluent.ksql.planner.plan.PlanNode;
import io.confluent.ksql.planner.plan.PlanNodeId;
import io.confluent.ksql.planner.plan.ProjectNode;
import io.confluent.ksql.planner.plan.RepartitionNode;
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.schema.ksql.FormatOptions;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.LogicalSchema.Builder;
import io.confluent.ksql.schema.ksql.types.SqlType;
Expand Down Expand Up @@ -81,6 +81,10 @@ public OutputNode buildPlan() {
currentNode = buildFilterNode(currentNode, analysis.getWhereExpression().get());
}

if (analysis.getPartitionBy().isPresent()) {
currentNode = buildRepartitionNode(currentNode, analysis.getPartitionBy().get());
}

if (!analysis.getTableFunctions().isEmpty()) {
currentNode = buildFlatMapNode(currentNode);
}
Expand Down Expand Up @@ -111,56 +115,20 @@ private OutputNode buildOutputNode(final PlanNode sourcePlanNode) {

final Into intoDataSource = analysis.getInto().get();

final Optional<ColumnRef> partitionByField = analysis.getPartitionBy();

partitionByField.ifPresent(keyName ->
inputSchema.findValueColumn(keyName)
.orElseThrow(() -> new KsqlException(
"Column " + keyName.name().toString(FormatOptions.noEscape())
+ " does not exist in the result schema. Error in Partition By clause.")
));

final KeyField keyField = buildOutputKeyField(sourcePlanNode);

return new KsqlStructuredDataOutputNode(
new PlanNodeId(intoDataSource.getName().name()),
sourcePlanNode,
inputSchema,
extractionPolicy,
keyField,
sourcePlanNode.getKeyField(),
intoDataSource.getKsqlTopic(),
partitionByField,
analysis.getLimitClause(),
intoDataSource.isCreate(),
analysis.getSerdeOptions(),
intoDataSource.getName()
);
}

private KeyField buildOutputKeyField(
final PlanNode sourcePlanNode
) {
final KeyField sourceKeyField = sourcePlanNode.getKeyField();

final Optional<ColumnRef> partitionByField = analysis.getPartitionBy();
if (!partitionByField.isPresent()) {
return sourceKeyField;
}

final ColumnRef partitionBy = partitionByField.get();
final LogicalSchema schema = sourcePlanNode.getSchema();

if (schema.isMetaColumn(partitionBy.name())) {
return KeyField.none();
}

if (schema.isKeyColumn(partitionBy.name())) {
return sourceKeyField;
}

return KeyField.of(partitionBy);
}

private TimestampExtractionPolicy getTimestampExtractionPolicy(
final LogicalSchema inputSchema,
final Analysis analysis
Expand Down Expand Up @@ -229,6 +197,31 @@ private static FilterNode buildFilterNode(
return new FilterNode(new PlanNodeId("Filter"), sourcePlanNode, filterExpression);
}

private static RepartitionNode buildRepartitionNode(
final PlanNode sourceNode,
final ColumnRef partitionBy
) {
if (!sourceNode.getSchema().withoutAlias().findValueColumn(partitionBy).isPresent()) {
throw new KsqlException("Invalid identifier for PARTITION BY clause: " + partitionBy);
}

final KeyField keyField;
final LogicalSchema schema = sourceNode.getSchema();
if (schema.isMetaColumn(partitionBy.name())) {
keyField = KeyField.none();
} else if (schema.isKeyColumn(partitionBy.name())) {
keyField = sourceNode.getKeyField();
} else {
keyField = KeyField.of(partitionBy);
}

return new RepartitionNode(
new PlanNodeId("PartitionBy"),
sourceNode,
partitionBy,
keyField);
}

private FlatMapNode buildFlatMapNode(final PlanNode sourcePlanNode) {
return new FlatMapNode(new PlanNodeId("FlatMap"), sourcePlanNode, functionRegistry, analysis);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,17 @@
import io.confluent.ksql.name.SourceName;
import io.confluent.ksql.query.QueryId;
import io.confluent.ksql.query.id.QueryIdGenerator;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.serde.SerdeOption;
import io.confluent.ksql.structured.SchemaKStream;
import io.confluent.ksql.structured.SchemaKTable;
import io.confluent.ksql.util.timestamp.TimestampExtractionPolicy;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;

public class KsqlStructuredDataOutputNode extends OutputNode {

private final KsqlTopic ksqlTopic;
private final KeyField keyField;
private final Optional<ColumnRef> partitionByField;
private final boolean doCreateInto;
private final Set<SerdeOption> serdeOptions;
private final SourceName intoSourceName;
Expand All @@ -54,7 +49,6 @@ public KsqlStructuredDataOutputNode(
final TimestampExtractionPolicy timestampExtractionPolicy,
final KeyField keyField,
final KsqlTopic ksqlTopic,
final Optional<ColumnRef> partitionByField,
final OptionalInt limit,
final boolean doCreateInto,
final Set<SerdeOption> serdeOptions,
Expand All @@ -76,11 +70,8 @@ public KsqlStructuredDataOutputNode(
this.keyField = requireNonNull(keyField, "keyField")
.validateKeyExistsIn(schema);
this.ksqlTopic = requireNonNull(ksqlTopic, "ksqlTopic");
this.partitionByField = Objects.requireNonNull(partitionByField, "partitionByField");
this.doCreateInto = doCreateInto;
this.intoSourceName = requireNonNull(intoSourceName, "intoSourceName");

validatePartitionByField();
}

public boolean isDoCreateInto() {
Expand Down Expand Up @@ -119,52 +110,16 @@ public KeyField getKeyField() {
@Override
public SchemaKStream<?> buildStream(final KsqlQueryBuilder builder) {
final PlanNode source = getSource();
final SchemaKStream schemaKStream = source.buildStream(builder);
final SchemaKStream<?> schemaKStream = source.buildStream(builder);

final QueryContext.Stacker contextStacker = builder.buildNodeContext(getId().toString());

final SchemaKStream<?> result = createOutputStream(
schemaKStream,
contextStacker
);

return result.into(
return schemaKStream.into(
getKsqlTopic().getKafkaTopicName(),
getSchema(),
getKsqlTopic().getValueFormat(),
serdeOptions,
contextStacker
);
}

private SchemaKStream<?> createOutputStream(
final SchemaKStream schemaKStream,
final QueryContext.Stacker contextStacker
) {
if (schemaKStream instanceof SchemaKTable) {
return schemaKStream;
}

if (!partitionByField.isPresent()) {
return schemaKStream;
}

return schemaKStream.selectKey(partitionByField.get(), false, contextStacker);
}

private void validatePartitionByField() {
if (!partitionByField.isPresent()) {
return;
}

final ColumnRef fieldName = partitionByField.get();

if (getSchema().isMetaColumn(fieldName.name()) || getSchema().isKeyColumn(fieldName.name())) {
return;
}

if (!keyField.ref().equals(Optional.of(fieldName))) {
throw new IllegalArgumentException("keyField must match partition by field");
}
}
}
Loading

0 comments on commit 808b04e

Please sign in to comment.