From 9218766775744e894df0a9ea5c2af7d37117f238 Mon Sep 17 00:00:00 2001 From: Andy Coates <8012398+big-andy-coates@users.noreply.github.com> Date: Wed, 2 Oct 2019 10:02:34 -0700 Subject: [PATCH] feat(static): support ROWKEY in the projection of static queries (#3439) * feat(static): support ROWKEY in the projection of static queries --- .../execution/util/ExpressionTypeManager.java | 2 +- .../codegen/SqlToJavaVisitorTest.java | 7 +- .../util/ExpressionTypeManagerTest.java | 46 +++-- ...materialized-aggregate-static-queries.json | 175 ++++++++++++++---- .../server/execution/StaticQueryExecutor.java | 125 +++++++++---- .../ksql/rest/entity/TableRowsEntity.java | 2 +- 6 files changed, 259 insertions(+), 98 deletions(-) diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java index 24b8ec92f5eb..617ebb4069bd 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java @@ -210,7 +210,7 @@ public Void visitColumnReference( final ColumnReferenceExp node, final ExpressionTypeContext expressionTypeContext ) { - final Column schemaColumn = schema.findValueColumn(node.getReference().toString()) + final Column schemaColumn = schema.findColumn(node.getReference().name()) .orElseThrow(() -> new KsqlException(String.format("Invalid Expression %s.", node.toString()))); diff --git a/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java b/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java index 3961b46e448d..46d7bf988a24 100644 --- a/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java +++ b/ksql-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java @@ -36,6 +36,7 @@ import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression; import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression.Sign; import io.confluent.ksql.execution.expression.tree.Cast; +import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.ComparisonExpression; import io.confluent.ksql.execution.expression.tree.DoubleLiteral; import io.confluent.ksql.execution.expression.tree.Expression; @@ -44,9 +45,6 @@ import io.confluent.ksql.execution.expression.tree.InPredicate; import io.confluent.ksql.execution.expression.tree.IntegerLiteral; import io.confluent.ksql.execution.expression.tree.LikePredicate; -import io.confluent.ksql.name.FunctionName; -import io.confluent.ksql.schema.ksql.ColumnRef; -import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.SearchedCaseExpression; import io.confluent.ksql.execution.expression.tree.SimpleCaseExpression; import io.confluent.ksql.execution.expression.tree.StringLiteral; @@ -58,12 +56,13 @@ import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.function.KsqlFunction; import io.confluent.ksql.function.UdfFactory; +import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.schema.Operator; +import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.schema.ksql.types.SqlDecimal; import io.confluent.ksql.schema.ksql.types.SqlPrimitiveType; import io.confluent.ksql.schema.ksql.types.SqlTypes; import java.util.Optional; -import java.util.function.Function; import org.apache.kafka.connect.data.Schema; import org.junit.Before; import org.junit.Rule; diff --git a/ksql-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java b/ksql-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java index 7cc80ef1555f..665a763bc35e 100644 --- a/ksql-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java +++ b/ksql-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java @@ -16,7 +16,6 @@ package io.confluent.ksql.execution.util; import static io.confluent.ksql.execution.testutil.TestExpressions.ADDRESS; -import static io.confluent.ksql.execution.testutil.TestExpressions.COL0; import static io.confluent.ksql.execution.testutil.TestExpressions.COL1; import static io.confluent.ksql.execution.testutil.TestExpressions.COL2; import static io.confluent.ksql.execution.testutil.TestExpressions.COL3; @@ -54,6 +53,7 @@ import io.confluent.ksql.execution.expression.tree.TimestampLiteral; import io.confluent.ksql.execution.expression.tree.WhenClause; import io.confluent.ksql.execution.function.udf.structfieldextractor.FetchFieldFromStruct; +import io.confluent.ksql.execution.testutil.TestExpressions; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.function.KsqlFunction; import io.confluent.ksql.function.UdfFactory; @@ -79,6 +79,10 @@ @SuppressWarnings("OptionalGetWithoutIsPresent") public class ExpressionTypeManagerTest { + + private static final SourceName TEST1 = SourceName.of("TEST1"); + private static final ColumnName COL0 = ColumnName.of("COL0"); + @Mock private FunctionRegistry functionRegistry; @Mock @@ -116,7 +120,8 @@ private void givenUdfWithNameAndReturnType( @Test public void shouldResolveTypeForAddBigIntDouble() { - final Expression expression = new ArithmeticBinaryExpression(Operator.ADD, COL0, COL3); + final Expression expression = new ArithmeticBinaryExpression(Operator.ADD, TestExpressions.COL0, + COL3); final SqlType type = expressionTypeManager.getExpressionSqlType(expression); @@ -134,7 +139,8 @@ public void shouldResolveTypeForAddDoubleIntegerLiteral() { @Test public void shouldResolveTypeForAddBigintIntegerLiteral() { - final Expression expression = new ArithmeticBinaryExpression(Operator.ADD, COL0, literal(10)); + final Expression expression = new ArithmeticBinaryExpression(Operator.ADD, TestExpressions.COL0, + literal(10)); final SqlType type = expressionTypeManager.getExpressionSqlType(expression); @@ -144,7 +150,7 @@ public void shouldResolveTypeForAddBigintIntegerLiteral() { @Test public void shouldResolveTypeForMultiplyBigintIntegerLiteral() { final Expression expression = - new ArithmeticBinaryExpression(Operator.MULTIPLY, COL0, literal(10)); + new ArithmeticBinaryExpression(Operator.MULTIPLY, TestExpressions.COL0, literal(10)); final SqlType type = expressionTypeManager.getExpressionSqlType(expression); @@ -153,7 +159,8 @@ public void shouldResolveTypeForMultiplyBigintIntegerLiteral() { @Test public void testComparisonExpr() { - final Expression expression = new ComparisonExpression(Type.GREATER_THAN, COL0, COL3); + final Expression expression = new ComparisonExpression(Type.GREATER_THAN, TestExpressions.COL0, + COL3); final SqlType exprType = expressionTypeManager.getExpressionSqlType(expression); @@ -163,7 +170,8 @@ public void testComparisonExpr() { @Test public void shouldFailIfComparisonOperandsAreIncompatible() { // Given: - final ComparisonExpression expr = new ComparisonExpression(Type.GREATER_THAN, COL0, COL1); + final ComparisonExpression expr = new ComparisonExpression(Type.GREATER_THAN, + TestExpressions.COL0, COL1); expectedException.expect(KsqlException.class); expectedException.expectMessage("Operator GREATER_THAN cannot be used to compare BIGINT and STRING"); @@ -283,7 +291,7 @@ public void shouldThrowOnStructFieldDereference() { // Given: final Expression expression = new DereferenceExpression( Optional.empty(), - new ColumnReferenceExp(ColumnRef.of(SourceName.of("TEST1"), ColumnName.of("COL6"))), + new ColumnReferenceExp(ColumnRef.of(TEST1, ColumnName.of("COL6"))), "STREET" ); @@ -322,10 +330,11 @@ public void shouldEvaluateTypeForStructDereferenceInArray() { // Given: final SqlStruct inner = SqlTypes.struct().field("IN0", SqlTypes.INTEGER).build(); final LogicalSchema schema = LogicalSchema.builder() - .valueColumn(ColumnName.of("TEST1.COL0"), SqlTypes.array(inner)) + .valueColumn(TEST1, COL0, SqlTypes.array(inner)) .build(); expressionTypeManager = new ExpressionTypeManager(schema, functionRegistry); - final Expression arrayRef = new SubscriptExpression(COL0, new IntegerLiteral(1)); + final Expression arrayRef = new SubscriptExpression(TestExpressions.COL0, + new IntegerLiteral(1)); final Expression expression = new FunctionCall( FunctionName.of(FetchFieldFromStruct.FUNCTION_NAME), ImmutableList.of(arrayRef, new StringLiteral("IN0")) @@ -343,12 +352,12 @@ public void shouldEvaluateTypeForArrayReferenceInStruct() { .field("IN0", SqlTypes.array(SqlTypes.INTEGER)) .build(); final LogicalSchema schema = LogicalSchema.builder() - .valueColumn(ColumnName.of("TEST1.COL0"), inner) + .valueColumn(TEST1, COL0, inner) .build(); expressionTypeManager = new ExpressionTypeManager(schema, functionRegistry); final Expression structRef = new FunctionCall( FunctionName.of(FetchFieldFromStruct.FUNCTION_NAME), - ImmutableList.of(COL0, new StringLiteral("IN0")) + ImmutableList.of(TestExpressions.COL0, new StringLiteral("IN0")) ); final Expression expression = new SubscriptExpression(structRef, new IntegerLiteral(1)); @@ -387,7 +396,7 @@ public void shouldGetCorrectSchemaForSearchedCaseWhenStruct() { final Expression expression = new SearchedCaseExpression( ImmutableList.of( new WhenClause( - new ComparisonExpression(Type.EQUAL, COL0, new IntegerLiteral(10)), + new ComparisonExpression(Type.EQUAL, TestExpressions.COL0, new IntegerLiteral(10)), ADDRESS) ), Optional.empty() @@ -407,7 +416,8 @@ public void shouldFailIfWhenIsNotBoolean() { final Expression expression = new SearchedCaseExpression( ImmutableList.of( new WhenClause( - new ArithmeticBinaryExpression(Operator.ADD, COL0, new IntegerLiteral(10)), + new ArithmeticBinaryExpression(Operator.ADD, TestExpressions.COL0, + new IntegerLiteral(10)), new StringLiteral("foo")) ), Optional.empty() @@ -425,10 +435,10 @@ public void shouldFailOnInconsistentWhenResultType() { final Expression expression = new SearchedCaseExpression( ImmutableList.of( new WhenClause( - new ComparisonExpression(Type.EQUAL, COL0, new IntegerLiteral(100)), + new ComparisonExpression(Type.EQUAL, TestExpressions.COL0, new IntegerLiteral(100)), new StringLiteral("one-hundred")), new WhenClause( - new ComparisonExpression(Type.EQUAL, COL0, new IntegerLiteral(10)), + new ComparisonExpression(Type.EQUAL, TestExpressions.COL0, new IntegerLiteral(10)), new IntegerLiteral(10)) ), Optional.empty() @@ -447,7 +457,7 @@ public void shouldFailIfDefaultHasDifferentTypeToWhen() { final Expression expression = new SearchedCaseExpression( ImmutableList.of( new WhenClause( - new ComparisonExpression(Type.EQUAL, COL0, new IntegerLiteral(10)), + new ComparisonExpression(Type.EQUAL, TestExpressions.COL0, new IntegerLiteral(10)), new StringLiteral("good")) ), Optional.of(new BooleanLiteral("true")) @@ -481,7 +491,7 @@ public void shouldThrowOnTimestampLiteral() { public void shouldThrowOnIn() { // Given: final Expression expression = new InPredicate( - COL0, + TestExpressions.COL0, new InListExpression(ImmutableList.of(new IntegerLiteral(1), new IntegerLiteral(2))) ); @@ -495,7 +505,7 @@ public void shouldThrowOnIn() { @Test public void shouldThrowOnSimpleCase() { final Expression expression = new SimpleCaseExpression( - COL0, + TestExpressions.COL0, ImmutableList.of(new WhenClause(new IntegerLiteral(10), new StringLiteral("ten"))), Optional.empty() ); diff --git a/ksql-functional-tests/src/test/resources/rest-query-validation-tests/materialized-aggregate-static-queries.json b/ksql-functional-tests/src/test/resources/rest-query-validation-tests/materialized-aggregate-static-queries.json index 3bcf1b3fcdd0..9b5708c88e3e 100644 --- a/ksql-functional-tests/src/test/resources/rest-query-validation-tests/materialized-aggregate-static-queries.json +++ b/ksql-functional-tests/src/test/resources/rest-query-validation-tests/materialized-aggregate-static-queries.json @@ -26,45 +26,6 @@ {"@type": "rows", "rows": []} ] }, - { - "name": "non-windowed with projection", - "statements": [ - "CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", - "CREATE TABLE AGGREGATE AS SELECT ROWKEY AS ID, COUNT(1) AS COUNT FROM INPUT GROUP BY ROWKEY;", - "SELECT COUNT, CONCAT(ID, 'x') AS ID, COUNT * 2 FROM AGGREGATE WHERE ROWKEY='10';" - ], - "inputs": [ - {"topic": "test_topic", "key": "11", "value": {}}, - {"topic": "test_topic", "key": "10", "value": {}} - ], - "responses": [ - {"@type": "currentStatus"}, - {"@type": "currentStatus"}, - {"@type": "rows", "rows": [ - ["10", 1, "10x", 2] - ]} - ] - }, - { - "name": "windowed with projection", - "statements": [ - "CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", - "CREATE TABLE AGGREGATE AS SELECT ROWKEY AS ID, COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ROWKEY;", - "SELECT COUNT, CONCAT(ID, 'x') AS ID, COUNT * 2 FROM AGGREGATE WHERE ROWKEY='10' AND WindowStart=12000;" - ], - "inputs": [ - {"topic": "test_topic", "timestamp": 12345, "key": "11", "value": {}}, - {"topic": "test_topic", "timestamp": 11345, "key": "10", "value": {}}, - {"topic": "test_topic", "timestamp": 12345, "key": "10", "value": {}} - ], - "responses": [ - {"@type": "currentStatus"}, - {"@type": "currentStatus"}, - {"@type": "rows", "rows": [ - ["10", 12000, 1, "10x", 2] - ]} - ] - }, { "name": "tumbling windowed single key lookup with exact window start", "statements": [ @@ -342,6 +303,92 @@ } ] }, + { + "name": "non-windowed with projection", + "statements": [ + "CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE AGGREGATE AS SELECT ROWKEY AS ID, COUNT(1) AS COUNT FROM INPUT GROUP BY ROWKEY;", + "SELECT COUNT, CONCAT(ID, 'x') AS ID, COUNT * 2 FROM AGGREGATE WHERE ROWKEY='10';" + ], + "inputs": [ + {"topic": "test_topic", "key": "11", "value": {}}, + {"topic": "test_topic", "key": "10", "value": {}} + ], + "responses": [ + {"@type": "currentStatus"}, + {"@type": "currentStatus"}, + { + "@type": "rows", + "schema": "`COUNT` BIGINT, `ID` STRING, `KSQL_COL_2` BIGINT", + "rows": [[1, "10x", 2]] + } + ] + }, + { + "name": "windowed with projection", + "statements": [ + "CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE AGGREGATE AS SELECT ROWKEY AS ID, COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ROWKEY;", + "SELECT COUNT, CONCAT(ID, 'x') AS ID, COUNT * 2 FROM AGGREGATE WHERE ROWKEY='10' AND WindowStart=12000;" + ], + "inputs": [ + {"topic": "test_topic", "timestamp": 12345, "key": "11", "value": {}}, + {"topic": "test_topic", "timestamp": 11345, "key": "10", "value": {}}, + {"topic": "test_topic", "timestamp": 12345, "key": "10", "value": {}} + ], + "responses": [ + {"@type": "currentStatus"}, + {"@type": "currentStatus"}, + { + "@type": "rows", + "schema": "`COUNT` BIGINT, `ID` STRING, `KSQL_COL_2` BIGINT", + "rows": [[1, "10x", 2]] + } + ] + }, + { + "name": "non-windowed projection WITH ROWKEY", + "statements": [ + "CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT FROM INPUT GROUP BY ROWKEY;", + "SELECT ROWKEY, COUNT FROM AGGREGATE WHERE ROWKEY='10';" + ], + "inputs": [ + {"topic": "test_topic", "key": "11", "value": {}}, + {"topic": "test_topic", "key": "10", "value": {}} + ], + "responses": [ + {"@type": "currentStatus"}, + {"@type": "currentStatus"}, + { + "@type": "rows", + "schema": "`ROWKEY` STRING KEY, `COUNT` BIGINT", + "rows": [["10", 1]] + } + ] + }, + { + "name": "windowed with projection with ROWKEY", + "statements": [ + "CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ROWKEY;", + "SELECT COUNT, ROWKEY FROM AGGREGATE WHERE ROWKEY='10' AND WindowStart=12000;" + ], + "inputs": [ + {"topic": "test_topic", "timestamp": 12345, "key": "11", "value": {}}, + {"topic": "test_topic", "timestamp": 11345, "key": "10", "value": {}}, + {"topic": "test_topic", "timestamp": 12345, "key": "10", "value": {}} + ], + "responses": [ + {"@type": "currentStatus"}, + {"@type": "currentStatus"}, + { + "@type": "rows", + "schema": "`COUNT` BIGINT, `ROWKEY` STRING KEY", + "rows": [[1, "10"]] + } + ] + }, { "name": "non-windowed projection WITH ROWTIME", "statements": [ @@ -368,12 +415,56 @@ "status": 400 } }, + { + "name": "non-windowed projection with ROWMEY and more columns in aggregate", + "statements": [ + "CREATE STREAM INPUT (VAL INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT, SUM(VAL) AS SUM, MIN(VAL) AS MIN FROM INPUT GROUP BY ROWKEY;", + "SELECT ROWKEY, COUNT FROM AGGREGATE WHERE ROWKEY='10';" + ], + "inputs": [ + {"topic": "test_topic", "key": "11", "value": {"val": 1}}, + {"topic": "test_topic", "key": "10", "value": {"val": 2}} + ], + "responses": [ + {"@type": "currentStatus"}, + {"@type": "currentStatus"}, + { + "@type": "rows", + "schema": "`ROWKEY` STRING KEY, `COUNT` BIGINT", + "rows": [["10", 1]] + } + ] + }, + { + "name": "non-windowed projection with ROWMEY and more columns in lookup", + "statements": [ + "CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT FROM INPUT GROUP BY ROWKEY;", + "SELECT COUNT, ROWKEY, COUNT AS COUNT2 FROM AGGREGATE WHERE ROWKEY='10';" + ], + "inputs": [ + {"topic": "test_topic", "timestamp": 12345, "key": "11", "value": {}}, + {"topic": "test_topic", "timestamp": 11345, "key": "10", "value": {}}, + {"topic": "test_topic", "timestamp": 12345, "key": "10", "value": {}} + ], + "responses": [ + {"@type": "currentStatus"}, + {"@type": "currentStatus"}, + { + "@type": "rows", + "schema": "`COUNT` BIGINT, `ROWKEY` STRING KEY, `COUNT2` BIGINT", + "rows": [[2,"10",2]] + } + ] + }, { "name": "text datetime window bounds", + "comment": "Note: this test must specify a timezone in the exact lookup so that it works when run from any TZ.", "statements": [ "CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", "CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ROWKEY;", - "SELECT * FROM AGGREGATE WHERE ROWKEY='10' AND WindowStart='2020-02-23T23:45:12.000';" + "SELECT * FROM AGGREGATE WHERE ROWKEY='10' AND WindowStart='2020-02-23T23:45:12.000+0000';" ], "inputs": [ {"topic": "test_topic", "timestamp": 1582501512456, "key": "11", "value": {}}, @@ -408,13 +499,15 @@ }, { "name": "partial text datetime window bounds", + "comment": "Note: this test has side enough range on dates to ensure running in different timezones do not cause it to fail", "statements": [ "CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", "CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ROWKEY;", - "SELECT * FROM AGGREGATE WHERE '2020-02-23T23:45' <= WindowStart AND WindowStart < '2020-02-23T24' AND ROWKEY='10';" + "SELECT * FROM AGGREGATE WHERE '2020-02-22T23:45' <= WindowStart AND WindowStart < '2020-02-24T24' AND ROWKEY='10';" ], "inputs": [ {"topic": "test_topic", "timestamp": 1582501512456, "key": "11", "value": {}}, + {"topic": "test_topic", "timestamp": 1582401512456, "key": "10", "value": {}}, {"topic": "test_topic", "timestamp": 1582501512456, "key": "10", "value": {}}, {"topic": "test_topic", "timestamp": 1582501552456, "key": "10", "value": {}} ], diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/StaticQueryExecutor.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/StaticQueryExecutor.java index dea53c607163..285b21e50823 100644 --- a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/StaticQueryExecutor.java +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/StaticQueryExecutor.java @@ -53,6 +53,7 @@ import io.confluent.ksql.name.SourceName; import io.confluent.ksql.parser.tree.AllColumns; import io.confluent.ksql.parser.tree.Query; +import io.confluent.ksql.parser.tree.Select; import io.confluent.ksql.parser.tree.SelectItem; import io.confluent.ksql.query.QueryId; import io.confluent.ksql.rest.Errors; @@ -65,6 +66,7 @@ import io.confluent.ksql.rest.server.resources.KsqlRestException; import io.confluent.ksql.schema.ksql.FormatOptions; import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.LogicalSchema.Builder; import io.confluent.ksql.schema.ksql.PhysicalSchema; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.serde.SerdeOption; @@ -86,6 +88,7 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.function.BiFunction; import java.util.stream.Collectors; import java.util.stream.Stream; import org.apache.kafka.connect.data.Struct; @@ -160,7 +163,7 @@ public static Optional execute( return Optional.of(proxyTo(owner, statement)); } - Result result; + final Result result; if (whereInfo.windowStartBounds.isPresent()) { final Range windowStart = whereInfo.windowStartBounds.get(); @@ -177,12 +180,24 @@ public static Optional execute( result = new Result(mat.schema(), rows); } - result = handleSelects(result, statement, executionContext, analysis); + final LogicalSchema outputSchema; + final List> rows; + if (isSelectStar(statement.getStatement().getSelect())) { + outputSchema = TableRowsEntityFactory.buildSchema(result.schema, mat.windowType()); + rows = TableRowsEntityFactory.createRows(result.rows); + } else { + final LogicalSchema.Builder schemaBuilder = + selectSchemaBuilder(result, executionContext, analysis); + + outputSchema = schemaBuilder.build(); + + rows = handleSelects(result, statement, executionContext, analysis, outputSchema); + } final TableRowsEntity entity = new TableRowsEntity( statement.getStatementText(), - TableRowsEntityFactory.buildSchema(result.schema, mat.windowType()), - TableRowsEntityFactory.createRows(result.rows) + outputSchema, + rows ); return Optional.of(entity); @@ -493,58 +508,102 @@ private static ComparisonTarget extractWhereClauseTarget(final ComparisonExpress } } - private static boolean isSelectStar(final List selects) { + private static boolean isSelectStar(final Select select) { + final List selects = select.getSelectItems(); return selects.size() == 1 && selects.get(0) instanceof AllColumns; } - private static Result handleSelects( + private static List> handleSelects( final Result input, final ConfiguredStatement statement, final KsqlExecutionContext executionContext, - final Analysis analysis + final Analysis analysis, + final LogicalSchema outputSchema ) { - final List selectItems = statement.getStatement().getSelect().getSelectItems(); - if (input.rows.isEmpty() || isSelectStar(selectItems)) { - return input; - } - - final LogicalSchema.Builder schemaBuilder = LogicalSchema.builder(); - schemaBuilder.keyColumns(input.schema.key()); - - final ExpressionTypeManager expressionTypeManager = new ExpressionTypeManager( - input.schema, - executionContext.getMetaStore() - ); - - for (int idx = 0; idx < analysis.getSelectExpressions().size(); idx++) { - final SelectExpression select = analysis.getSelectExpressions().get(idx); - final SqlType type = expressionTypeManager.getExpressionSqlType(select.getExpression()); - schemaBuilder.valueColumn(select.getName(), type); + final LogicalSchema intermediateSchema; + final BiFunction preSelectTransform; + if (outputSchema.key().isEmpty()) { + intermediateSchema = input.schema; + preSelectTransform = (key, value) -> value; + } else { + // SelectValueMapper requires the key fields in the value schema :( + intermediateSchema = LogicalSchema.builder() + .keyColumns(input.schema.key()) + .valueColumns(input.schema.value()) + .valueColumns(input.schema.key()) + .build(); + + preSelectTransform = (key, value) -> { + key.schema().fields().forEach(f -> { + final Object keyField = key.get(f); + value.getColumns().add(keyField); + }); + return value; + }; } - final LogicalSchema schema = schemaBuilder.build(); - final SourceName sourceName = getSourceName(analysis); final KsqlConfig ksqlConfig = statement.getConfig() .cloneWithPropertyOverwrite(statement.getOverrides()); - final SelectValueMapper mapper = SelectValueMapperFactory.create( + final SelectValueMapper select = SelectValueMapperFactory.create( analysis.getSelectExpressions(), - input.schema.withAlias(sourceName), + intermediateSchema.withAlias(sourceName), ksqlConfig, executionContext.getMetaStore(), NoopProcessingLogContext.INSTANCE.getLoggerFactory().getLogger("any") ); - final ImmutableList.Builder output = ImmutableList.builder(); + final ImmutableList.Builder> output = ImmutableList.builder(); input.rows.forEach(r -> { - final GenericRow mapped = mapper.apply(r.value()); - final TableRow tableRow = r.withValue(mapped, schema); - output.add(tableRow); + final GenericRow intermediate = preSelectTransform.apply(r.key(), r.value()); + final GenericRow mapped = select.apply(intermediate); + validateProjection(mapped, outputSchema); + output.add(mapped.getColumns()); }); - return new Result(schema, output.build()); + return output.build(); + } + + private static void validateProjection( + final GenericRow fullRow, + final LogicalSchema schema + ) { + final int actual = fullRow.getColumns().size(); + final int expected = schema.columns().size(); + if (actual != expected) { + throw new IllegalStateException("Row column count mismatch." + + " expected:" + expected + + ", got:" + actual + ); + } + } + + private static LogicalSchema.Builder selectSchemaBuilder( + final Result input, + final KsqlExecutionContext executionContext, + final Analysis analysis + ) { + final Builder schemaBuilder = LogicalSchema.builder() + .noImplicitColumns(); + + final ExpressionTypeManager expressionTypeManager = new ExpressionTypeManager( + input.schema, + executionContext.getMetaStore() + ); + + for (int idx = 0; idx < analysis.getSelectExpressions().size(); idx++) { + final SelectExpression select = analysis.getSelectExpressions().get(idx); + final SqlType type = expressionTypeManager.getExpressionSqlType(select.getExpression()); + + if (input.schema.isKeyColumn(select.getName())) { + schemaBuilder.keyColumn(select.getName(), type); + } else { + schemaBuilder.valueColumn(select.getName(), type); + } + } + return schemaBuilder; } private static PersistentQueryMetadata findMaterializingQuery( diff --git a/ksql-rest-model/src/main/java/io/confluent/ksql/rest/entity/TableRowsEntity.java b/ksql-rest-model/src/main/java/io/confluent/ksql/rest/entity/TableRowsEntity.java index e4c2e47643f0..d3db574830b4 100644 --- a/ksql-rest-model/src/main/java/io/confluent/ksql/rest/entity/TableRowsEntity.java +++ b/ksql-rest-model/src/main/java/io/confluent/ksql/rest/entity/TableRowsEntity.java @@ -76,7 +76,7 @@ private void validate(final List row) { final int actualSize = row.size(); if (expectedSize != actualSize) { - throw new IllegalArgumentException("field count mismatch." + throw new IllegalArgumentException("column count mismatch." + " expected: " + expectedSize + ", got: " + actualSize );