Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: expression support for PARTITION BY #4032

Merged
merged 3 commits into from
Dec 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public class Analysis implements ImmutableAnalysis {
private final Set<ColumnRef> selectColumnRefs = new HashSet<>();
private final List<Expression> groupByExpressions = new ArrayList<>();
private Optional<WindowExpression> windowExpression = Optional.empty();
private Optional<ColumnRef> partitionBy = Optional.empty();
private Optional<Expression> partitionBy = Optional.empty();
private ImmutableSet<SerdeOption> serdeOptions = ImmutableSet.of();
private Optional<Expression> havingExpression = Optional.empty();
private OptionalInt limitClause = OptionalInt.empty();
Expand Down Expand Up @@ -134,11 +134,11 @@ void setHavingExpression(final Expression havingExpression) {
this.havingExpression = Optional.of(havingExpression);
}

public Optional<ColumnRef> getPartitionBy() {
public Optional<Expression> getPartitionBy() {
return partitionBy;
}

void setPartitionBy(final ColumnRef partitionBy) {
void setPartitionBy(final Expression partitionBy) {
this.partitionBy = Optional.of(partitionBy);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -509,13 +509,7 @@ private void analyzeGroupBy(final GroupBy groupBy) {
}

private void analyzePartitionBy(final Expression partitionBy) {
if (partitionBy instanceof ColumnReferenceExp) {
analysis.setPartitionBy(((ColumnReferenceExp) partitionBy).getReference());
return;
}

throw new KsqlException(
"Expected partition by to be a valid column but got " + partitionBy);
analysis.setPartitionBy(partitionBy);
}

private void analyzeWindowExpression(final WindowExpression windowExpression) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,8 @@ protected AstNode visitQuery(final Query node, final C context) {
final Optional<GroupBy> groupBy = node.getGroupBy()
.map(exp -> ((GroupBy) rewriter.apply(exp, context)));

// don't rewrite the partitionBy because we expect it to be
// exactly as it was (a single, un-aliased, column reference)
final Optional<Expression> partitionBy = node.getPartitionBy();
final Optional<Expression> partitionBy = node.getPartitionBy()
.map(exp -> processExpression(exp, context));

final Optional<Expression> having = node.getHaving()
.map(exp -> (processExpression(exp, context)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import io.confluent.ksql.schema.ksql.LogicalSchema.Builder;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -204,28 +203,34 @@ private static FilterNode buildFilterNode(

private static RepartitionNode buildRepartitionNode(
final PlanNode sourceNode,
final ColumnRef partitionBy
final Expression partitionBy
) {
if (!sourceNode.getSchema().withoutAlias().findValueColumn(partitionBy).isPresent()) {
throw new KsqlException("Invalid identifier for PARTITION BY clause: '" + partitionBy
+ "'. Only columns from the source schema can be referenced in the PARTITION BY clause.");
if (!(partitionBy instanceof ColumnReferenceExp)) {
return new RepartitionNode(
new PlanNodeId("PartitionBy"),
sourceNode,
partitionBy,
KeyField.none());
}

final KeyField keyField;
final ColumnRef partitionColumn = ((ColumnReferenceExp) partitionBy).getReference();
final LogicalSchema schema = sourceNode.getSchema();
if (schema.isMetaColumn(partitionBy.name())) {

final KeyField keyField;
if (schema.isMetaColumn(partitionColumn.name())) {
keyField = KeyField.none();
} else if (schema.isKeyColumn(partitionBy.name())) {
} else if (schema.isKeyColumn(partitionColumn.name())) {
keyField = sourceNode.getKeyField();
} else {
keyField = KeyField.of(partitionBy);
keyField = KeyField.of(partitionColumn);
}

return new RepartitionNode(
new PlanNodeId("PartitionBy"),
sourceNode,
partitionBy,
keyField);

}

private FlatMapNode buildFlatMapNode(final PlanNode sourcePlanNode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
import io.confluent.ksql.execution.context.QueryContext;
import io.confluent.ksql.execution.context.QueryContext.Stacker;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.streams.JoinParamsFactory;
import io.confluent.ksql.metastore.model.DataSource.DataSourceType;
Expand Down Expand Up @@ -285,7 +286,7 @@ static <K> SchemaKStream<K> maybeRePartitionByKey(
final ColumnRef joinFieldName,
final Stacker contextStacker
) {
return stream.selectKey(joinFieldName, contextStacker);
return stream.selectKey(new ColumnReferenceExp(joinFieldName), contextStacker);
}

static ValueFormat getFormatForSource(final DataSourceNode sourceNode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@

package io.confluent.ksql.planner.plan;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.Immutable;
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.metastore.model.KeyField;
import io.confluent.ksql.name.SourceName;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.services.KafkaTopicClient;
import io.confluent.ksql.structured.SchemaKStream;
Expand All @@ -32,14 +32,18 @@
public class RepartitionNode extends PlanNode {

private final PlanNode source;
private final ColumnRef partitionBy;
private final Expression partitionBy;
private final KeyField keyField;

public RepartitionNode(PlanNodeId id, PlanNode source, ColumnRef partitionBy, KeyField keyField) {
public RepartitionNode(
PlanNodeId id,
PlanNode source,
Expression partitionBy,
KeyField keyField
) {
super(id, source.getNodeOutputType());
final SourceName alias = source.getTheSourceNode().getAlias();
this.source = Objects.requireNonNull(source, "source");
this.partitionBy = Objects.requireNonNull(partitionBy, "partitionBy").withSource(alias);
this.partitionBy = Objects.requireNonNull(partitionBy, "partitionBy");
this.keyField = Objects.requireNonNull(keyField, "keyField");
}

Expand Down Expand Up @@ -73,4 +77,9 @@ public SchemaKStream<?> buildStream(KsqlQueryBuilder builder) {
return source.buildStream(builder)
.selectKey(partitionBy, builder.buildNodeContext(getId().toString()));
}

@VisibleForTesting
public Expression getPartitionBy() {
return partitionBy;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.schema.ksql.FormatOptions;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.serde.KeyFormat;
import io.confluent.ksql.serde.SerdeOption;
import io.confluent.ksql.serde.ValueFormat;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.List;
import java.util.Objects;
Expand Down Expand Up @@ -321,56 +323,65 @@ public SchemaKStream<K> outerJoin(

@SuppressWarnings("unchecked")
public SchemaKStream<Struct> selectKey(
final ColumnRef columnRef,
final Expression keyExpression,
final QueryContext.Stacker contextStacker
) {
if (keyFormat.isWindowed()) {
throw new UnsupportedOperationException("Can not selectKey of windowed stream");
}

final Optional<Column> existingKey = keyField.resolve(getSchema());

final Column proposedKey = getSchema().findValueColumn(columnRef)
.orElseThrow(IllegalArgumentException::new);

final KeyField resultantKeyField = isRowKey(columnRef)
? keyField
: KeyField.of(columnRef);

final boolean namesMatch = existingKey
.map(kf -> kf.ref().equals(proposedKey.ref()))
.orElse(false);

if (namesMatch || isRowKey(proposedKey.ref())) {
return (SchemaKStream<Struct>) new SchemaKStream<>(
sourceStep,
schema,
keyFormat,
resultantKeyField,
ksqlConfig,
functionRegistry
);
if (!needsRepartition(keyExpression)) {
return (SchemaKStream<Struct>) this;
}

final KeyField newKeyField = getSchema().isMetaColumn(columnRef.name())
? KeyField.none()
: resultantKeyField;

final StreamSelectKey step = ExecutionStepFactory.streamSelectKey(
contextStacker,
sourceStep,
columnRef
keyExpression
);

return new SchemaKStream<>(
step,
resolveSchema(step),
keyFormat,
newKeyField,
getNewKeyField(keyExpression),
ksqlConfig,
functionRegistry
);
}

private KeyField getNewKeyField(final Expression expression) {
if (!(expression instanceof ColumnReferenceExp)) {
return KeyField.none();
}

final ColumnRef columnRef = ((ColumnReferenceExp) expression).getReference();
final KeyField newKeyField = isRowKey(columnRef) ? keyField : KeyField.of(columnRef);
return getSchema().isMetaColumn(columnRef.name()) ? KeyField.none() : newKeyField;
}

private boolean needsRepartition(final Expression expression) {
if (!(expression instanceof ColumnReferenceExp)) {
return true;
}

final ColumnRef columnRef = ((ColumnReferenceExp) expression).getReference();
final Optional<Column> existingKey = keyField.resolve(getSchema());

final Column proposedKey = getSchema()
.findValueColumn(columnRef)
.orElseThrow(() -> new KsqlException("Invalid identifier for PARTITION BY clause: '"
+ columnRef.name().toString(FormatOptions.noEscape()) + "' Only columns from the "
+ "source schema can be referenced in the PARTITION BY clause."));


final boolean namesMatch = existingKey
.map(kf -> kf.ref().equals(proposedKey.ref()))
.orElse(false);

return !namesMatch && !isRowKey(columnRef);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think comparing with the ROWKEY is actually safe, e.g.:

CREATE STREAM LEFT (A STRING, B STRING) WITH (KEY='A');
CREATE STREAM RIGHT (C STRING, D STRING) WITH (KEY='C');
CREATE STREAM JOINED AS SELECT LEFT.*, RIGHT.* FROM LEFT JOIN RIGHT ON L.B=R.D PARTITION BY LEFT.ROWKEY;

In this case, when we hit the repartition, the stream will be partitioned on B/D, so we do want to repartition, even though the column ref is ROWKEY.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @big-andy-coates - I think @rodesai brings up a good point here, though I think it's an existing bug. I think we should fix it in the short-term by re-introducing the "unnecessary" repartition step in the case of isRowKey but before I do that I wanted to run it past you since you implemented the original optimization (#2735).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest adding this as a QTT test to ensure we're handling this correctly. Do what you need to do to get this functionally working.

Raise a github issue to track the outstanding piece.

I think we can clean this up once we have cleaner code around the handling of schemas and duplicating ROWKEY and ROWTIME into the value schema.

Specifically, once we clean up / remove the use of source-aliases on schema fields and have arbitrary key column names, then I think this will be easy to solve because the ROWKEY name won't be ambiguous.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f8bfefa - I will fix this in a future PR

Copy link
Contributor Author

@agavra agavra Dec 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

private static boolean isRowKey(final ColumnRef fieldName) {
return fieldName.name().equals(SchemaUtil.ROWKEY_NAME);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import com.google.common.collect.ImmutableList;
import io.confluent.ksql.analyzer.Analysis.Into;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.parser.tree.ResultMaterialization;
Expand Down Expand Up @@ -126,7 +127,7 @@ public void shouldThrowOnGroupBy() {
public void shouldThrowOnPartitionBy() {
// Given:
when(analysis.getPartitionBy())
.thenReturn(Optional.of(ColumnRef.withoutSource(ColumnName.of("Something"))));
.thenReturn(Optional.of(new ColumnReferenceExp(ColumnRef.withoutSource(ColumnName.of("Something")))));

// Then:
expectedException.expect(KsqlException.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,11 @@ public void shouldRewriteQueryWithGroupBy() {
}

@Test
public void shouldNotRewriteQueryWithPartitionBy() {
public void shouldRewriteQueryWithPartitionBy() {
// Given:
final Query query =
givenQuery(Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(expression), Optional.empty());
when(expressionRewriter.apply(expression, context)).thenReturn(rewrittenExpression);

// When:
final AstNode rewritten = rewriter.rewrite(query, context);
Expand All @@ -293,7 +294,7 @@ public void shouldNotRewriteQueryWithPartitionBy() {
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.of(expression),
Optional.of(rewrittenExpression),
Optional.empty(),
resultMaterialization,
false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
import io.confluent.ksql.execution.context.QueryContext;
import io.confluent.ksql.execution.ddl.commands.KsqlTopic;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.streams.KSPlanBuilder;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.function.InternalFunctionRegistry;
Expand Down Expand Up @@ -878,7 +879,7 @@ public void shouldSelectLeftKeyField() {

// Then:
verify(leftSchemaKStream).selectKey(
eq(LEFT_JOIN_FIELD_REF),
eq(new ColumnReferenceExp(LEFT_JOIN_FIELD_REF)),
any()
);
}
Expand Down
Loading