Skip to content

Commit

Permalink
refactor: build table with correct aggregate schema (MINOR) (#3259)
Browse files Browse the repository at this point in the history
* refactor: build table with correct aggregate schema (MINOR)

Current code calls `SchemaKGroupedSchema.aggregate` but the return value has the wrong schema, so the caller has to build a new `SchemaKTable` with the correct schema.   This change removes the need for this by passing in the correct schema and sanity checking it.
  • Loading branch information
big-andy-coates authored Aug 28, 2019
1 parent 6a1a69f commit d3f8075
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,18 +184,17 @@ public <C, R> R accept(final PlanVisitor<C, R> visitor, final C context) {
return visitor.visitAggregate(this, context);
}

@SuppressWarnings("unchecked") // needs investigating
@Override
public SchemaKStream<?> buildStream(final KsqlQueryBuilder builder) {
final QueryContext.Stacker contextStacker = builder.buildNodeContext(getId().toString());
final DataSourceNode streamSourceNode = getTheSourceNode();
final SchemaKStream sourceSchemaKStream = getSource().buildStream(builder);
final SchemaKStream<?> sourceSchemaKStream = getSource().buildStream(builder);

// Pre aggregate computations
final InternalSchema internalSchema = new InternalSchema(getRequiredColumns(),
getAggregateFunctionArguments());

final SchemaKStream aggregateArgExpanded =
final SchemaKStream<?> aggregateArgExpanded =
sourceSchemaKStream.select(
internalSchema.getAggArgExpansionList(),
contextStacker.push(PREPARE_OP_NAME),
Expand Down Expand Up @@ -247,34 +246,25 @@ public SchemaKStream<?> buildStream(final KsqlQueryBuilder builder) {
aggregationContext.getQueryContext()
);

final SchemaKTable<?> schemaKTable = schemaKGroupedStream.aggregate(
SchemaKTable<?> aggregated = schemaKGroupedStream.aggregate(
aggStageSchema,
initializer,
requiredColumns.size(),
aggValToFunctionMap,
getWindowExpression(),
aggValueGenericRowSerde,
aggregationContext
);

SchemaKTable<?> result = new SchemaKTable<>(
schemaKTable.getKtable(), aggStageSchema,
schemaKTable.getKeySerde(),
schemaKTable.getKeyField(),
schemaKTable.getSourceSchemaKStreams(),
SchemaKStream.Type.AGGREGATE,
builder.getKsqlConfig(),
builder.getFunctionRegistry(),
aggregationContext.getQueryContext()
);

if (havingExpressions != null) {
result = result.filter(
aggregated = aggregated.filter(
internalSchema.resolveToInternal(havingExpressions),
contextStacker.push(FILTER_OP_NAME),
builder.getProcessingLogContext());
}

return result.select(
return aggregated.select(
internalSchema.updateFinalSelectExpressions(getFinalSelectExpressions()),
contextStacker.push(PROJECT_OP_NAME),
builder.getProcessingLogContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,24 +102,29 @@ public KeyField getKeyField() {

@SuppressWarnings("unchecked")
public SchemaKTable<?> aggregate(
final LogicalSchema aggregateSchema,
final Initializer initializer,
final int nonFuncColumnCount,
final Map<Integer, KsqlAggregateFunction> aggValToFunctionMap,
final WindowExpression windowExpression,
final Serde<GenericRow> topicValueSerDe,
final QueryContext.Stacker contextStacker
) {
throwOnValueFieldCountMismatch(aggregateSchema, nonFuncColumnCount, aggValToFunctionMap);

final KTable table;
final KeySerde<?> newKeySerde;
if (windowExpression != null) {
newKeySerde = getKeySerde(windowExpression);

table = aggregateWindowed(
initializer,
nonFuncColumnCount,
aggValToFunctionMap,
windowExpression,
topicValueSerDe,
contextStacker);
contextStacker
);
} else {
newKeySerde = keySerde;

Expand All @@ -128,12 +133,13 @@ public SchemaKTable<?> aggregate(
nonFuncColumnCount,
aggValToFunctionMap,
topicValueSerDe,
contextStacker);
contextStacker
);
}

return new SchemaKTable(
table,
schema,
aggregateSchema,
newKeySerde,
keyField,
sourceSchemaKStreams,
Expand Down Expand Up @@ -202,4 +208,23 @@ private KeySerde<Windowed<Struct>> getKeySerde(final WindowExpression windowExpr

return keySerde.rebind(windowExpression.getKsqlWindowExpression().getWindowInfo());
}

static void throwOnValueFieldCountMismatch(
final LogicalSchema aggregateSchema,
final int nonFuncColumnCount,
final Map<Integer, KsqlAggregateFunction> aggValToFunctionMap
) {
final int nonAggColumnCount = aggValToFunctionMap.size();
final int totalColumnCount = nonAggColumnCount + nonFuncColumnCount;

final int valueColumnCount = aggregateSchema.valueFields().size();
if (valueColumnCount != totalColumnCount) {
throw new IllegalArgumentException(
"Aggregate schema value field count does not match expected."
+ " expected: " + totalColumnCount
+ ", actual: " + valueColumnCount
+ ", schema: " + aggregateSchema
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ public class SchemaKGroupedTable extends SchemaKGroupedStream {
@SuppressWarnings("unchecked")
@Override
public SchemaKTable<Struct> aggregate(
final LogicalSchema aggregateSchema,
final Initializer initializer,
final int nonFuncColumnCount,
final Map<Integer, KsqlAggregateFunction> aggValToFunctionMap,
Expand All @@ -94,6 +95,8 @@ public SchemaKTable<Struct> aggregate(
throw new KsqlException("Windowing not supported for table aggregations.");
}

throwOnValueFieldCountMismatch(aggregateSchema, nonFuncColumnCount, aggValToFunctionMap);

final List<String> unsupportedFunctionNames = aggValToFunctionMap.values()
.stream()
.filter(function -> !(function instanceof TableAggregationFunction))
Expand Down Expand Up @@ -134,7 +137,7 @@ public SchemaKTable<Struct> aggregate(

return new SchemaKTable<>(
aggKtable,
schema,
aggregateSchema,
keySerde,
keyField,
sourceSchemaKStreams,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import io.confluent.ksql.parser.tree.KsqlWindowExpression;
import io.confluent.ksql.parser.tree.WindowExpression;
import io.confluent.ksql.query.QueryId;
import io.confluent.ksql.schema.ksql.Field;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.serde.KeySerde;
import io.confluent.ksql.serde.WindowInfo;
Expand All @@ -49,6 +50,8 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.connect.data.Struct;
import org.apache.kafka.streams.kstream.Initializer;
Expand All @@ -70,6 +73,8 @@ public class SchemaKGroupedStreamTest {
@Mock
private LogicalSchema schema;
@Mock
private LogicalSchema aggregateSchema;
@Mock
private KGroupedStream groupedStream;
@Mock
private KeyField keyField;
Expand Down Expand Up @@ -105,6 +110,8 @@ public class SchemaKGroupedStreamTest {
private KeySerde<Struct> keySerde;
@Mock
private KeySerde<Windowed<Struct>> windowedKeySerde;
@Mock
private Field field;
private final QueryContext.Stacker queryContext
= new QueryContext.Stacker(new QueryId("query")).push("node");
private SchemaKGroupedStream schemaGroupedStream;
Expand Down Expand Up @@ -175,8 +182,15 @@ public void shouldSupportSessionWindowedKey() {
when(ksqlWindowExp.getWindowInfo()).thenReturn(windowInfo);

// When:
final SchemaKTable result = schemaGroupedStream
.aggregate(initializer, 0, emptyMap(), windowExp, topicValueSerDe, queryContext);
final SchemaKTable result = schemaGroupedStream.aggregate(
aggregateSchema,
initializer,
0,
emptyMap(),
windowExp,
topicValueSerDe,
queryContext
);

// Then:
verify(keySerde).rebind(windowInfo);
Expand All @@ -192,8 +206,15 @@ public void shouldSupportHoppingWindowedKey() {
when(ksqlWindowExp.getWindowInfo()).thenReturn(windowInfo);

// When:
final SchemaKTable result = schemaGroupedStream
.aggregate(initializer, 0, emptyMap(), windowExp, topicValueSerDe, queryContext);
final SchemaKTable result = schemaGroupedStream.aggregate(
aggregateSchema,
initializer,
0,
emptyMap(),
windowExp,
topicValueSerDe,
queryContext
);

// Then:
verify(keySerde).rebind(windowInfo);
Expand All @@ -209,8 +230,15 @@ public void shouldSupportTumblingWindowedKey() {
when(ksqlWindowExp.getWindowInfo()).thenReturn(windowInfo);

// When:
final SchemaKTable result = schemaGroupedStream
.aggregate(initializer, 0, emptyMap(), windowExp, topicValueSerDe, queryContext);
final SchemaKTable result = schemaGroupedStream.aggregate(
aggregateSchema,
initializer,
0,
emptyMap(),
windowExp,
topicValueSerDe,
queryContext
);

// Then:
verify(keySerde).rebind(windowInfo);
Expand All @@ -224,8 +252,15 @@ public void shouldUseTimeWindowKeySerdeForWindowedIfLegacyConfig() {
.thenReturn(true);

// When:
final SchemaKTable result = schemaGroupedStream
.aggregate(initializer, 0, emptyMap(), windowExp, topicValueSerDe, queryContext);
final SchemaKTable result = schemaGroupedStream.aggregate(
aggregateSchema,
initializer,
0,
emptyMap(),
windowExp,
topicValueSerDe,
queryContext
);

// Then:
verify(keySerde)
Expand All @@ -246,9 +281,18 @@ private void assertDoesNotInstallWindowSelectMapper(
.thenReturn(table);
}

givenAggregateSchemaFieldCount(funcMap.size());

// When:
final SchemaKTable result = schemaGroupedStream
.aggregate(initializer, 0, funcMap, windowExp, topicValueSerDe, queryContext);
final SchemaKTable result = schemaGroupedStream.aggregate(
aggregateSchema,
initializer,
0, funcMap,

windowExp,
topicValueSerDe,
queryContext
);

// Then:
assertThat(result.getKtable(), is(sameInstance(table)));
Expand All @@ -265,10 +309,18 @@ private void assertDoesInstallWindowSelectMapper(
when(table.mapValues(any(ValueMapperWithKey.class)))
.thenReturn(table2);

givenAggregateSchemaFieldCount(funcMap.size());

// When:
final SchemaKTable result = schemaGroupedStream
.aggregate(initializer, 0, funcMap, windowExp, topicValueSerDe, queryContext);
final SchemaKTable result = schemaGroupedStream.aggregate(
aggregateSchema,
initializer,
0, funcMap,

windowExp,
topicValueSerDe,
queryContext
);

// Then:
assertThat(result.getKtable(), is(sameInstance(table2)));
Expand All @@ -293,6 +345,7 @@ public void shouldUseMaterializedFactoryForStateStore() {

// When:
schemaGroupedStream.aggregate(
aggregateSchema,
() -> null,
0,
Collections.emptyMap(),
Expand Down Expand Up @@ -320,6 +373,7 @@ public void shouldUseMaterializedFactoryWindowedStateStore() {

// When:
schemaGroupedStream.aggregate(
aggregateSchema,
() -> null,
0,
Collections.emptyMap(),
Expand All @@ -335,4 +389,51 @@ public void shouldUseMaterializedFactoryWindowedStateStore() {
eq(StreamsUtil.buildOpName(queryContext.getQueryContext())));
verify(ksqlWindowExp, times(1)).applyAggregate(any(), any(), any(), same(materialized));
}

@Test
public void shouldReturnKTableWithAggregateSchema() {
// When:
final SchemaKTable result = schemaGroupedStream.aggregate(
aggregateSchema,
initializer,
0,
emptyMap(),
windowExp,
topicValueSerDe,
queryContext
);

// Then:
assertThat(result.getSchema(), is(aggregateSchema));
}

@Test(expected = IllegalArgumentException.class)
public void shouldThrowOnColumnCountMismatch() {
// Given:
// Agg schema has 2 fields:
givenAggregateSchemaFieldCount(2);

// Where as params have 1 nonAgg and 2 agg fields:
final Map<Integer, KsqlAggregateFunction> aggColumns = ImmutableMap.of(2, otherFunc);

// When:
schemaGroupedStream.aggregate(
aggregateSchema,
initializer,
2,
aggColumns,
windowExp,
topicValueSerDe,
queryContext
);
}

private void givenAggregateSchemaFieldCount(final int count) {
final List<Field> valueFields = IntStream
.range(0, count)
.mapToObj(i -> field)
.collect(Collectors.toList());

when(aggregateSchema.valueFields()).thenReturn(valueFields);
}
}
Loading

0 comments on commit d3f8075

Please sign in to comment.