diff --git a/ksql-cli/src/main/java/io/confluent/ksql/cli/Cli.java b/ksql-cli/src/main/java/io/confluent/ksql/cli/Cli.java index 2c9fa7437189..03b3ac959112 100644 --- a/ksql-cli/src/main/java/io/confluent/ksql/cli/Cli.java +++ b/ksql-cli/src/main/java/io/confluent/ksql/cli/Cli.java @@ -37,10 +37,8 @@ import io.confluent.ksql.rest.client.RestResponse; import io.confluent.ksql.rest.entity.CommandStatus; import io.confluent.ksql.rest.entity.CommandStatusEntity; -import io.confluent.ksql.rest.entity.FieldInfo; import io.confluent.ksql.rest.entity.KsqlEntity; import io.confluent.ksql.rest.entity.KsqlEntityList; -import io.confluent.ksql.rest.entity.QueryDescriptionEntity; import io.confluent.ksql.rest.entity.StreamedRow; import io.confluent.ksql.util.ErrorMessageUtil; import io.confluent.ksql.util.HandlerMaps; @@ -316,18 +314,6 @@ private void handleStreamedQuery( final String query, final SqlBaseParser.QueryStatementContext ignored ) { - final RestResponse explainResponse = restClient - .makeKsqlRequest("EXPLAIN " + query); - if (!explainResponse.isSuccessful()) { - terminal.printErrorMessage(explainResponse.getErrorMessage()); - return; - } - - final QueryDescriptionEntity description = - (QueryDescriptionEntity) explainResponse.getResponse().get(0); - final List fields = description.getQueryDescription().getFields(); - terminal.printRowHeader(fields); - final RestResponse queryResponse = makeKsqlRequest(query, restClient::makeQueryRequest); @@ -338,22 +324,24 @@ private void handleStreamedQuery( } else { try (QueryStream queryStream = queryResponse.getResponse(); StatusClosable toClose = terminal.setStatusMessage("Press CTRL-C to interrupt")) { - streamResults(queryStream, fields); + streamResults(queryStream); } } } - private void streamResults( - final QueryStream queryStream, - final List fields - ) { + private void streamResults(final QueryStream queryStream) { final Future queryStreamFuture = queryStreamExecutorService.submit(() -> { - for (long rowsRead = 0; limitNotReached(rowsRead) && queryStream.hasNext(); rowsRead++) { + for (long rowsRead = 0; limitNotReached(rowsRead) && queryStream.hasNext(); ) { final StreamedRow row = queryStream.next(); - terminal.printStreamedRow(row, fields); + + terminal.printStreamedRow(row); if (row.isTerminal()) { break; } + + if (row.getRow().isPresent()) { + rowsRead++; + } } }); diff --git a/ksql-cli/src/main/java/io/confluent/ksql/cli/console/Console.java b/ksql-cli/src/main/java/io/confluent/ksql/cli/console/Console.java index 5b4bbae4e509..a6a0ce7e3fd2 100644 --- a/ksql-cli/src/main/java/io/confluent/ksql/cli/console/Console.java +++ b/ksql-cli/src/main/java/io/confluent/ksql/cli/console/Console.java @@ -78,6 +78,7 @@ import io.confluent.ksql.rest.entity.TablesList; import io.confluent.ksql.rest.entity.TopicDescription; import io.confluent.ksql.rest.entity.TypeList; +import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.util.CmdLineUtil; import io.confluent.ksql.util.HandlerMaps; import io.confluent.ksql.util.HandlerMaps.ClassHandlerMap1; @@ -324,32 +325,27 @@ public void printError(final String shortMsg, final String fullMsg) { writer().println(shortMsg); } - public void printStreamedRow( - final StreamedRow row, - final List fields - ) { - if (row.getErrorMessage() != null) { - printErrorMessage(row.getErrorMessage()); - return; - } + public void printStreamedRow(final StreamedRow row) { + row.getErrorMessage().ifPresent(this::printErrorMessage); - if (row.getFinalMessage() != null) { - writer().println(row.getFinalMessage()); - return; - } + row.getFinalMessage().ifPresent(finalMsg -> writer().println(finalMsg)); - switch (outputFormat) { - case JSON: - printAsJson(row.getRow().getColumns()); - break; - case TABULAR: - printAsTable(row.getRow(), fields); - break; - default: - throw new RuntimeException(String.format( - "Unexpected output format: '%s'", - outputFormat.name() - )); + row.getHeader().ifPresent(header -> printRowHeader(header.getSchema())); + + if (row.getRow().isPresent()) { + switch (outputFormat) { + case JSON: + printAsJson(row.getRow().get().getColumns()); + break; + case TABULAR: + printAsTable(row.getRow().get()); + break; + default: + throw new RuntimeException(String.format( + "Unexpected output format: '%s'", + outputFormat.name() + )); + } } } @@ -376,12 +372,12 @@ public void printKsqlEntityList(final List entityList) { } } - public void printRowHeader(final List fields) { + private void printRowHeader(final LogicalSchema schema) { switch (outputFormat) { case JSON: break; case TABULAR: - writer().println(TabularRow.createHeader(getWidth(), fields)); + writer().println(TabularRow.createHeader(getWidth(), schema)); break; default: throw new RuntimeException(String.format( @@ -426,12 +422,9 @@ private Optional getCliCommand(final String line) { .findFirst(); } - private void printAsTable( - final GenericRow row, - final List fields - ) { + private void printAsTable(final GenericRow row) { rowCaptor.addRow(row); - writer().println(TabularRow.createRow(getWidth(), fields, row, config)); + writer().println(TabularRow.createRow(getWidth(), row, config)); flush(); } diff --git a/ksql-cli/src/main/java/io/confluent/ksql/util/TabularRow.java b/ksql-cli/src/main/java/io/confluent/ksql/util/TabularRow.java index 1019060eac20..31ea22070ca7 100644 --- a/ksql-cli/src/main/java/io/confluent/ksql/util/TabularRow.java +++ b/ksql-cli/src/main/java/io/confluent/ksql/util/TabularRow.java @@ -15,69 +15,71 @@ package io.confluent.ksql.util; -import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Splitter; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; import io.confluent.ksql.GenericRow; import io.confluent.ksql.cli.console.CliConfig; import io.confluent.ksql.cli.console.CliConfig.OnOff; -import io.confluent.ksql.rest.entity.FieldInfo; +import io.confluent.ksql.name.ColumnName; +import io.confluent.ksql.schema.ksql.Column; +import io.confluent.ksql.schema.ksql.LogicalSchema; import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; -public class TabularRow { +public final class TabularRow { private static final String CLIPPED = "..."; private static final int MIN_CELL_WIDTH = 5; private final int width; - private final List value; - private final List header; + private final List columns; private final boolean isHeader; private final boolean shouldWrap; - public static TabularRow createHeader(final int width, final List header) { + public static TabularRow createHeader(final int width, final LogicalSchema schema) { + final List headings = schema.columns().stream() + .map(Column::name) + .map(ColumnName::name) + .collect(Collectors.toList()); + return new TabularRow( width, - header.stream().map(FieldInfo::getName).collect(Collectors.toList()), - null, - true); + headings, + true, + true + ); } public static TabularRow createRow( final int width, - final List header, final GenericRow value, final CliConfig config ) { return new TabularRow( width, - header.stream().map(FieldInfo::getName).collect(Collectors.toList()), value.getColumns().stream().map(Objects::toString).collect(Collectors.toList()), + false, config.getString(CliConfig.WRAP_CONFIG).equalsIgnoreCase(OnOff.ON.toString()) ); } - @VisibleForTesting - TabularRow( + private TabularRow( final int width, - final List header, - final List value, + final List columns, + final boolean isHeader, final boolean shouldWrap ) { - this.header = Objects.requireNonNull(header, "header"); + this.columns = ImmutableList.copyOf(Objects.requireNonNull(columns, "columns")); this.width = width; - this.value = value; - this.isHeader = value == null; + this.isHeader = isHeader; this.shouldWrap = shouldWrap; } @Override public String toString() { - final List columns = isHeader ? header : value; - if (columns.isEmpty()) { return ""; } diff --git a/ksql-cli/src/test/java/io/confluent/ksql/cli/CliTest.java b/ksql-cli/src/test/java/io/confluent/ksql/cli/CliTest.java index fc835b41fda9..0a42b25099f9 100644 --- a/ksql-cli/src/test/java/io/confluent/ksql/cli/CliTest.java +++ b/ksql-cli/src/test/java/io/confluent/ksql/cli/CliTest.java @@ -22,7 +22,6 @@ import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.either; -import static org.hamcrest.Matchers.emptyIterable; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; @@ -603,7 +602,7 @@ public void testTransientSelect() { } @Test - public void testTransientStaticSelectStar() { + public void shouldHandlePullQuery() { // Given: run("CREATE TABLE X AS SELECT COUNT(1) AS COUNT " + "FROM " + orderDataProvider.kstreamName() @@ -611,14 +610,25 @@ public void testTransientStaticSelectStar() { localCli ); + // When: + final Supplier runner = () -> { + // It's possible that the state store is not warm on the first invocation, hence the retry + run("SELECT * FROM X WHERE ROWKEY='ITEM_1';", localCli); + return terminal.getOutputString(); + }; + + // Wait for warm store: + assertThatEventually(runner, containsString("ROWKEY")); assertRunCommand( - "SELECT * FROM X WHERE ROWKEY='unknowwn';", - is(emptyIterable()) + "SELECT * FROM X WHERE ROWKEY='ITEM_1';", + containsRows( + row("ITEM_1", "1") + ) ); } @Test - public void testTransientStaticHeader() { + public void shouldOutputPullQueryHeader() { // Given: run("CREATE TABLE Y AS SELECT COUNT(1) AS COUNT " + "FROM " + orderDataProvider.kstreamName() @@ -655,7 +665,7 @@ public void testTransientContinuousSelectStar() { } @Test - public void testTransientContinuousHeader() { + public void shouldOutputPushQueryHeader() { // When: run("SELECT * FROM " + orderDataProvider.kstreamName() + " EMIT CHANGES LIMIT 1", localCli); diff --git a/ksql-cli/src/test/java/io/confluent/ksql/cli/console/ConsoleTest.java b/ksql-cli/src/test/java/io/confluent/ksql/cli/console/ConsoleTest.java index 1d5d3a37e612..be9c2cd66935 100644 --- a/ksql-cli/src/test/java/io/confluent/ksql/cli/console/ConsoleTest.java +++ b/ksql-cli/src/test/java/io/confluent/ksql/cli/console/ConsoleTest.java @@ -97,10 +97,12 @@ public class ConsoleTest { private static final String CLI_CMD_NAME = "some command"; private static final String WHITE_SPACE = " \t "; - private static final List HEADER = - ImmutableList.of( - new FieldInfo("foo", new SchemaInfo(SqlBaseType.STRING, null, null)), - new FieldInfo("bar", new SchemaInfo(SqlBaseType.STRING, null, null))); + + private static final LogicalSchema SCHEMA = LogicalSchema.builder() + .noImplicitColumns() + .keyColumn(ColumnName.of("foo"), SqlTypes.INTEGER) + .valueColumn(ColumnName.of("bar"), SqlTypes.STRING) + .build(); private final TestTerminal terminal; private final Console console; @@ -147,12 +149,12 @@ public void after() { } @Test - public void testPrintGenericStreamedRow() throws IOException { + public void testPrintGenericStreamedRow() { // Given: final StreamedRow row = StreamedRow.row(new GenericRow(ImmutableList.of("col_1", "col_2"))); // When: - console.printStreamedRow(row, HEADER); + console.printStreamedRow(row); // Then: if (console.getOutputFormat() == OutputFormat.TABULAR) { @@ -162,9 +164,12 @@ public void testPrintGenericStreamedRow() throws IOException { } @Test - public void testPrintHeader() throws IOException { + public void shouldPrintHeader() { + // Given: + final StreamedRow header = StreamedRow.header(new QueryId("id"), SCHEMA); + // When: - console.printRowHeader(HEADER); + console.printStreamedRow(header); // Then: if (console.getOutputFormat() == OutputFormat.TABULAR) { @@ -174,22 +179,22 @@ public void testPrintHeader() throws IOException { } @Test - public void testPrintErrorStreamedRow() throws IOException { + public void testPrintErrorStreamedRow() { final FakeException exception = new FakeException(); - console.printStreamedRow(StreamedRow.error(exception, Errors.ERROR_CODE_SERVER_ERROR), HEADER); + console.printStreamedRow(StreamedRow.error(exception, Errors.ERROR_CODE_SERVER_ERROR)); assertThat(terminal.getOutputString(), is(exception.getMessage() + "\n")); } @Test - public void testPrintFinalMessageStreamedRow() throws IOException { - console.printStreamedRow(StreamedRow.finalMessage("Some message"), HEADER); + public void testPrintFinalMessageStreamedRow() { + console.printStreamedRow(StreamedRow.finalMessage("Some message")); assertThat(terminal.getOutputString(), is("Some message\n")); } @Test - public void testPrintCommandStatus() throws IOException { + public void testPrintCommandStatus() { // Given: final KsqlEntityList entityList = new KsqlEntityList(ImmutableList.of( new CommandStatusEntity( @@ -226,7 +231,7 @@ public void testPrintCommandStatus() throws IOException { } @Test - public void testPrintPropertyList() throws IOException { + public void testPrintPropertyList() { // Given: final Map properties = new HashMap<>(); properties.put("k1", 1); @@ -267,7 +272,7 @@ public void testPrintPropertyList() throws IOException { } @Test - public void testPrintQueries() throws IOException { + public void testPrintQueries() { // Given: final List queries = new ArrayList<>(); queries.add( @@ -305,7 +310,7 @@ public void testPrintQueries() throws IOException { } @Test - public void testPrintSourceDescription() throws IOException { + public void testPrintSourceDescription() { // Given: final List fields = buildTestSchema( SqlTypes.BOOLEAN, @@ -491,7 +496,7 @@ public void testPrintSourceDescription() throws IOException { } @Test - public void testPrintTopicDescription() throws IOException { + public void testPrintTopicDescription() { // Given: final KsqlEntityList entityList = new KsqlEntityList(ImmutableList.of( new TopicDescription("e", "TestTopic", "TestKafkaTopic", "AVRO", "schemaString") @@ -522,7 +527,7 @@ public void testPrintTopicDescription() throws IOException { } @Test - public void testPrintConnectorDescription() throws IOException { + public void testPrintConnectorDescription() { // Given: final KsqlEntityList entityList = new KsqlEntityList(ImmutableList.of( new ConnectorDescription( @@ -640,7 +645,7 @@ public void testPrintConnectorDescription() throws IOException { } @Test - public void testPrintStreamsList() throws IOException { + public void testPrintStreamsList() { // Given: final KsqlEntityList entityList = new KsqlEntityList(ImmutableList.of( new StreamsList("e", @@ -674,7 +679,7 @@ public void testPrintStreamsList() throws IOException { } @Test - public void testSortedPrintStreamsList() throws IOException { + public void testSortedPrintStreamsList() { // Given: final KsqlEntityList entityList = new KsqlEntityList(ImmutableList.of( new StreamsList("e", @@ -731,7 +736,7 @@ public void testSortedPrintStreamsList() throws IOException { } @Test - public void testPrintTablesList() throws IOException { + public void testPrintTablesList() { // Given: final KsqlEntityList entityList = new KsqlEntityList(ImmutableList.of( new TablesList("e", @@ -766,7 +771,7 @@ public void testPrintTablesList() throws IOException { } @Test - public void testSortedPrintTablesList() throws IOException { + public void testSortedPrintTablesList() { // Given: final KsqlEntityList entityList = new KsqlEntityList(ImmutableList.of( new TablesList("e", @@ -828,7 +833,7 @@ public void testSortedPrintTablesList() throws IOException { } @Test - public void shouldPrintConnectorsList() throws IOException { + public void shouldPrintConnectorsList() { // Given: final KsqlEntityList entities = new KsqlEntityList(ImmutableList.of( new ConnectorList( @@ -870,7 +875,7 @@ public void shouldPrintConnectorsList() throws IOException { } @Test - public void shouldPrintTypesList() throws IOException { + public void shouldPrintTypesList() { // Given: final KsqlEntityList entities = new KsqlEntityList(ImmutableList.of( new TypeList("statement", ImmutableMap.of( @@ -931,7 +936,7 @@ public void shouldPrintTypesList() throws IOException { } @Test - public void testPrintExecuptionPlan() throws IOException { + public void testPrintExecuptionPlan() { // Given: final KsqlEntityList entityList = new KsqlEntityList(ImmutableList.of( new ExecutionPlan("Test Execution Plan") @@ -959,7 +964,7 @@ public void testPrintExecuptionPlan() throws IOException { } @Test - public void shouldPrintTopicDescribeExtended() throws IOException { + public void shouldPrintTopicDescribeExtended() { // Given: final List readQueries = ImmutableList.of( new RunningQuery("read query", ImmutableSet.of("sink1"), new QueryId("readId")) @@ -1084,7 +1089,7 @@ public void shouldPrintTopicDescribeExtended() throws IOException { } @Test - public void shouldPrintWarnings() throws IOException { + public void shouldPrintWarnings() { // Given: final KsqlEntity entity = new SourceDescriptionEntity( "e", @@ -1109,7 +1114,7 @@ public void shouldPrintWarnings() throws IOException { } @Test - public void shouldPrintDropConnector() throws IOException { + public void shouldPrintDropConnector() { // Given: final KsqlEntity entity = new DropConnectorEntity("statementText", "connectorName"); @@ -1140,7 +1145,7 @@ public void shouldPrintDropConnector() throws IOException { } @Test - public void shouldPrintErrorEntityLongNonJson() throws IOException { + public void shouldPrintErrorEntityLongNonJson() { // Given: final KsqlEntity entity = new ErrorEntity( "statementText", @@ -1195,7 +1200,7 @@ public void shouldPrintErrorEntityLongJson() throws IOException { } @Test - public void shouldPrintFunctionDescription() throws IOException { + public void shouldPrintFunctionDescription() { final KsqlEntityList entityList = new KsqlEntityList(ImmutableList.of( new FunctionDescriptionList( "DESCRIBE FUNCTION foo;", diff --git a/ksql-cli/src/test/java/io/confluent/ksql/util/TabularRowTest.java b/ksql-cli/src/test/java/io/confluent/ksql/util/TabularRowTest.java index dd60d1169984..555d0daee2e1 100644 --- a/ksql-cli/src/test/java/io/confluent/ksql/util/TabularRowTest.java +++ b/ksql-cli/src/test/java/io/confluent/ksql/util/TabularRowTest.java @@ -18,20 +18,36 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.isEmptyString; - -import com.google.common.collect.ImmutableList; -import java.util.List; +import static org.mockito.Mockito.when; + +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.cli.console.CliConfig; +import io.confluent.ksql.cli.console.CliConfig.OnOff; +import io.confluent.ksql.name.ColumnName; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.types.SqlTypes; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +@RunWith(MockitoJUnitRunner.class) public class TabularRowTest { + @Mock + private CliConfig config; + @Test public void shouldFormatHeader() { // Given: - final List header = ImmutableList.of("foo", "bar"); + final LogicalSchema schema = LogicalSchema.builder() + .noImplicitColumns() + .keyColumn(ColumnName.of("foo"), SqlTypes.BIGINT) + .valueColumn(ColumnName.of("bar"), SqlTypes.STRING) + .build(); // When: - final String formatted = new TabularRow(20, header, null, true).toString(); + final String formatted = TabularRow.createHeader(20, schema).toString(); // Then: assertThat(formatted, is("" @@ -43,10 +59,14 @@ public void shouldFormatHeader() { @Test public void shouldMultilineFormatHeader() { // Given: - final List header = ImmutableList.of("foo", "bar is a long string"); + final LogicalSchema schema = LogicalSchema.builder() + .noImplicitColumns() + .keyColumn(ColumnName.of("foo"), SqlTypes.BIGINT) + .valueColumn(ColumnName.of("bar is a long string"), SqlTypes.STRING) + .build(); // When: - final String formatted = new TabularRow(20, header, null, true).toString(); + final String formatted = TabularRow.createHeader(20, schema).toString(); // Then: assertThat(formatted, is("" @@ -60,10 +80,12 @@ public void shouldMultilineFormatHeader() { @Test public void shouldFormatRow() { // Given: - final List header = ImmutableList.of("foo", "bar"); + givenWrappingEnabled(); + + final GenericRow value = new GenericRow("foo", "bar"); // When: - final String formatted = new TabularRow(20, header, header, true).toString(); + final String formatted = TabularRow.createRow(20, value, config).toString(); // Then: assertThat(formatted, is("|foo |bar |")); @@ -72,10 +94,12 @@ public void shouldFormatRow() { @Test public void shouldMultilineFormatRow() { // Given: - final List header = ImmutableList.of("foo", "bar is a long string"); + givenWrappingEnabled(); + + final GenericRow value = new GenericRow("foo", "bar is a long string"); // When: - final String formatted = new TabularRow(20, header, header, true).toString(); + final String formatted = TabularRow.createRow(20, value, config).toString(); // Then: assertThat(formatted, is("" @@ -87,10 +111,12 @@ public void shouldMultilineFormatRow() { @Test public void shouldClipMultilineFormatRow() { // Given: - final List header = ImmutableList.of("foo", "bar is a long string"); + givenWrappingDisabled(); + + final GenericRow value = new GenericRow("foo", "bar is a long string"); // When: - final String formatted = new TabularRow(20, header, header, false).toString(); + final String formatted = TabularRow.createRow(20, value, config).toString(); // Then: assertThat(formatted, is("" @@ -100,12 +126,15 @@ public void shouldClipMultilineFormatRow() { @Test public void shouldClipMultilineFormatRowWithLotsOfWhitespace() { // Given: - final List header = ImmutableList.of( + givenWrappingDisabled(); + + final GenericRow value = new GenericRow( "foo", - "bar foo"); + "bar foo" + ); // When: - final String formatted = new TabularRow(20, header, header, false).toString(); + final String formatted = TabularRow.createRow(20, value, config).toString(); // Then: assertThat(formatted, is("" @@ -115,12 +144,15 @@ public void shouldClipMultilineFormatRowWithLotsOfWhitespace() { @Test public void shouldNotAddEllipsesMultilineFormatRowWithLotsOfWhitespace() { // Given: - final List header = ImmutableList.of( + givenWrappingDisabled(); + + final GenericRow value = new GenericRow( "foo", - "bar "); + "bar " + ); // When: - final String formatted = new TabularRow(20, header, header, false).toString(); + final String formatted = TabularRow.createRow(20, value, config).toString(); // Then: assertThat(formatted, is("" @@ -129,12 +161,14 @@ public void shouldNotAddEllipsesMultilineFormatRowWithLotsOfWhitespace() { @Test - public void shouldFormatNoColumns() { + public void shouldFormatNoColumnsHeader() { // Given: - final List header = ImmutableList.of(); + final LogicalSchema schema = LogicalSchema.builder() + .noImplicitColumns() + .build(); // When: - final String formatted = new TabularRow(20, header, null, true).toString(); + final String formatted = TabularRow.createHeader(20, schema).toString(); // Then: assertThat(formatted, isEmptyString()); @@ -143,10 +177,15 @@ public void shouldFormatNoColumns() { @Test public void shouldFormatMoreColumnsThanWidth() { // Given: - final List header = ImmutableList.of("foo", "bar", "baz"); + final LogicalSchema schema = LogicalSchema.builder() + .noImplicitColumns() + .keyColumn(ColumnName.of("foo"), SqlTypes.BIGINT) + .valueColumn(ColumnName.of("bar"), SqlTypes.STRING) + .valueColumn(ColumnName.of("baz"), SqlTypes.DOUBLE) + .build(); // When: - final String formatted = new TabularRow(3, header, null, true).toString(); + final String formatted = TabularRow.createHeader(3, schema).toString(); // Then: assertThat(formatted, @@ -156,4 +195,11 @@ public void shouldFormatMoreColumnsThanWidth() { + "+-----+-----+-----+")); } + private void givenWrappingEnabled() { + when(config.getString(CliConfig.WRAP_CONFIG)).thenReturn(OnOff.ON.toString()); + } + + private void givenWrappingDisabled() { + when(config.getString(CliConfig.WRAP_CONFIG)).thenReturn("Not ON"); + } } \ No newline at end of file diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/StaticQueryExecutor.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/StaticQueryExecutor.java index 9a6cf053e311..af556878ce92 100644 --- a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/StaticQueryExecutor.java +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/StaticQueryExecutor.java @@ -132,20 +132,6 @@ public static void validate( if (!queryStmt.isStatic()) { throw new KsqlRestException(Errors.queryEndpoint(statement.getStatementText())); } - - try { - final Analysis analysis = analyze(statement, executionContext); - - final PersistentQueryMetadata query = findMaterializingQuery(executionContext, analysis); - - extractWhereInfo(analysis, query); - } catch (final Exception e) { - throw new KsqlStatementException( - e.getMessage(), - statement.getStatementText(), - e - ); - } } public static Optional execute( @@ -153,6 +139,14 @@ public static Optional execute( final Map sessionProperties, final KsqlExecutionContext executionContext, final ServiceContext serviceContext + ) { + return Optional.of(execute(statement, executionContext, serviceContext)); + } + + public static TableRowsEntity execute( + final ConfiguredStatement statement, + final KsqlExecutionContext executionContext, + final ServiceContext serviceContext ) { try { final Analysis analysis = analyze(statement, executionContext); @@ -172,7 +166,7 @@ public static Optional execute( final KsqlNode owner = getOwner(rowKey, mat); if (!owner.isLocal()) { - return Optional.of(proxyTo(owner, statement, serviceContext)); + return proxyTo(owner, statement, serviceContext); } final Result result; @@ -206,14 +200,12 @@ public static Optional execute( rows = handleSelects(result, statement, executionContext, analysis, outputSchema); } - final TableRowsEntity entity = new TableRowsEntity( + return new TableRowsEntity( statement.getStatementText(), queryId, outputSchema, rows ); - - return Optional.of(entity); } catch (final Exception e) { throw new KsqlStatementException( e.getMessage() == null ? "Server Error" : e.getMessage(), @@ -637,7 +629,7 @@ private static PersistentQueryMetadata findMaterializingQuery( } if (queries.size() > 1) { throw new KsqlException("Multiple queries currently materialize '" + sourceName + "'." - + " KSQL currently only supports static queries when the table has only been" + + " KSQL currently only supports pull queries when the table has only been" + " materialized once."); } @@ -668,7 +660,7 @@ private static KsqlNode getOwner(final Struct rowKey, final Materialization mat) ); } - private static KsqlEntity proxyTo( + private static TableRowsEntity proxyTo( final KsqlNode owner, final ConfiguredStatement statement, final ServiceContext serviceContext @@ -686,7 +678,7 @@ private static KsqlEntity proxyTo( throw new RuntimeException("Boom - expected 1 entity, got: " + entities.size()); } - return entities.get(0); + return (TableRowsEntity) entities.get(0); } private static KsqlException notMaterializedException(final SourceName sourceTable) { diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/Flow.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/Flow.java index be32c24fc640..4d4ce3b0591c 100644 --- a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/Flow.java +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/Flow.java @@ -19,7 +19,7 @@ /** * Flow constructs borrowed from Java 9 - * https://docs.oracle.com/javase/9/docs/api/java/util/concurrent/Flow.Subscription.html + * @see Flow Java 9 Docs */ public class Flow { diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryPublisher.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryPublisher.java new file mode 100644 index 000000000000..83256ba431c3 --- /dev/null +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryPublisher.java @@ -0,0 +1,134 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.rest.server.resources.streaming; + +import static java.util.Objects.requireNonNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.KsqlExecutionContext; +import io.confluent.ksql.engine.KsqlEngine; +import io.confluent.ksql.parser.tree.Query; +import io.confluent.ksql.rest.entity.StreamedRow; +import io.confluent.ksql.rest.entity.TableRowsEntity; +import io.confluent.ksql.rest.server.execution.StaticQueryExecutor; +import io.confluent.ksql.rest.server.resources.streaming.Flow.Subscriber; +import io.confluent.ksql.services.ServiceContext; +import io.confluent.ksql.statement.ConfiguredStatement; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.stream.Collectors; + +class PullQueryPublisher implements Flow.Publisher> { + + private final KsqlEngine ksqlEngine; + private final ServiceContext serviceContext; + private final ConfiguredStatement query; + private final PullQueryExecutor pullQueryExecutor; + + PullQueryPublisher( + final KsqlEngine ksqlEngine, + final ServiceContext serviceContext, + final ConfiguredStatement query + ) { + this(ksqlEngine, serviceContext, query, StaticQueryExecutor::execute); + } + + @VisibleForTesting + PullQueryPublisher( + final KsqlEngine ksqlEngine, + final ServiceContext serviceContext, + final ConfiguredStatement query, + final PullQueryExecutor pullQueryExecutor + ) { + this.ksqlEngine = requireNonNull(ksqlEngine, "ksqlEngine"); + this.serviceContext = requireNonNull(serviceContext, "serviceContext"); + this.query = requireNonNull(query, "query"); + this.pullQueryExecutor = requireNonNull(pullQueryExecutor, "pullQueryExecutor"); + } + + @Override + public synchronized void subscribe(final Subscriber> subscriber) { + final PullQuerySubscription subscription = new PullQuerySubscription( + subscriber, + () -> pullQueryExecutor.execute(query, ksqlEngine, serviceContext) + ); + + subscriber.onSubscribe(subscription); + } + + private static final class PullQuerySubscription implements Flow.Subscription { + + private final Subscriber> subscriber; + private final Callable executor; + private boolean done = false; + + private PullQuerySubscription( + final Subscriber> subscriber, + final Callable executor + ) { + this.subscriber = requireNonNull(subscriber, "subscriber"); + this.executor = requireNonNull(executor, "executor"); + } + + @Override + public void request(final long n) { + Preconditions.checkArgument(n == 1, "number of requested items must be 1"); + + if (done) { + return; + } + + done = true; + + try { + final TableRowsEntity entity = executor.call(); + + subscriber.onSchema(entity.getSchema()); + + final List rows = entity.getRows().stream() + .map(PullQuerySubscription::toGenericRow) + .map(StreamedRow::row) + .collect(Collectors.toList()); + + subscriber.onNext(rows); + subscriber.onComplete(); + } catch (final Exception e) { + subscriber.onError(e); + } + } + + @Override + public void cancel() { + } + + @SuppressWarnings("unchecked") + private static GenericRow toGenericRow(final List values) { + return new GenericRow((List)values); + } + } + + interface PullQueryExecutor { + + TableRowsEntity execute( + ConfiguredStatement statement, + KsqlExecutionContext executionContext, + ServiceContext serviceContext + ); + } +} diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamPublisher.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/PushQueryPublisher.java similarity index 87% rename from ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamPublisher.java rename to ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/PushQueryPublisher.java index aa9a557abd0c..f1163bf86447 100644 --- a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamPublisher.java +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/PushQueryPublisher.java @@ -33,16 +33,17 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -class StreamPublisher implements Flow.Publisher> { +@SuppressWarnings("UnstableApiUsage") +class PushQueryPublisher implements Flow.Publisher> { - private static final Logger log = LoggerFactory.getLogger(StreamPublisher.class); + private static final Logger log = LoggerFactory.getLogger(PushQueryPublisher.class); private final KsqlEngine ksqlEngine; private final ServiceContext serviceContext; private final ConfiguredStatement query; private final ListeningScheduledExecutorService exec; - StreamPublisher( + PushQueryPublisher( final KsqlEngine ksqlEngine, final ServiceContext serviceContext, final ListeningScheduledExecutorService exec, @@ -54,7 +55,7 @@ class StreamPublisher implements Flow.Publisher> { this.query = Objects.requireNonNull(query, "query"); } - @SuppressWarnings("ConstantConditions") + @SuppressWarnings("OptionalGetWithoutIsPresent") @Override public synchronized void subscribe(final Flow.Subscriber> subscriber) { final TransientQueryMetadata queryMetadata = @@ -62,7 +63,7 @@ public synchronized void subscribe(final Flow.Subscriber .getQuery() .get(); - final StreamSubscription subscription = new StreamSubscription(subscriber, queryMetadata); + final PushQuerySubscription subscription = new PushQuerySubscription(subscriber, queryMetadata); log.info("Running query {}", queryMetadata.getQueryApplicationId()); queryMetadata.start(); @@ -70,12 +71,12 @@ public synchronized void subscribe(final Flow.Subscriber subscriber.onSubscribe(subscription); } - class StreamSubscription extends PollingSubscription> { + class PushQuerySubscription extends PollingSubscription> { private final TransientQueryMetadata queryMetadata; private boolean closed = false; - StreamSubscription( + PushQuerySubscription( final Subscriber> subscriber, final TransientQueryMetadata queryMetadata ) { diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/QueryStreamWriter.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/QueryStreamWriter.java index 096ddd671845..0c9654cb029e 100644 --- a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/QueryStreamWriter.java +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/QueryStreamWriter.java @@ -18,8 +18,11 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.Lists; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.query.QueryId; import io.confluent.ksql.rest.Errors; import io.confluent.ksql.rest.entity.StreamedRow; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.LogicalSchema.Builder; import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.TransientQueryMetadata; import java.io.EOFException; @@ -37,6 +40,7 @@ class QueryStreamWriter implements StreamingOutput { private static final Logger log = LoggerFactory.getLogger(QueryStreamWriter.class); + private static final QueryId NO_QUERY_ID = new QueryId("none"); private final TransientQueryMetadata queryMetadata; private final long disconnectCheckInterval; @@ -60,13 +64,15 @@ class QueryStreamWriter implements StreamingOutput { @Override public void write(final OutputStream out) { try { + write(out, buildHeader()); + while (queryMetadata.isRunning() && !limitReached) { final KeyValue value = queryMetadata.getRowQueue().poll( disconnectCheckInterval, TimeUnit.MILLISECONDS ); if (value != null) { - write(out, value.value); + write(out, StreamedRow.row(value.value)); } else { // If no new rows have been written, the user may have terminated the connection without // us knowing. Check by trying to write a single newline. @@ -98,12 +104,24 @@ public void write(final OutputStream out) { } } - private void write(final OutputStream output, final GenericRow row) throws IOException { - objectMapper.writeValue(output, StreamedRow.row(row)); + private void write(final OutputStream output, final StreamedRow row) throws IOException { + objectMapper.writeValue(output, row); output.write("\n".getBytes(StandardCharsets.UTF_8)); output.flush(); } + private StreamedRow buildHeader() { + // Push queries only return value columns, but query metadata schema includes key and meta: + final LogicalSchema storedSchema = queryMetadata.getLogicalSchema(); + + final Builder actualSchemaBuilder = LogicalSchema.builder() + .noImplicitColumns(); + + storedSchema.value().forEach(actualSchemaBuilder::valueColumn); + + return StreamedRow.header(NO_QUERY_ID, actualSchemaBuilder.build()); + } + private void outputException(final OutputStream out, final Throwable exception) { try { out.write("\n".getBytes(StandardCharsets.UTF_8)); @@ -133,7 +151,7 @@ private void drain(final OutputStream out) throws IOException { queryMetadata.getRowQueue().drainTo(rows); for (final KeyValue row : rows) { - write(out, row.value); + write(out, StreamedRow.row(row.value)); } } diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResource.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResource.java index b7801a373773..6a9f12b34234 100644 --- a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResource.java +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResource.java @@ -15,8 +15,10 @@ package io.confluent.ksql.rest.server.resources.streaming; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; +import io.confluent.ksql.GenericRow; import io.confluent.ksql.engine.KsqlEngine; import io.confluent.ksql.json.JsonMapper; import io.confluent.ksql.parser.KsqlParser.PreparedStatement; @@ -24,9 +26,12 @@ import io.confluent.ksql.parser.tree.Query; import io.confluent.ksql.rest.Errors; import io.confluent.ksql.rest.entity.KsqlRequest; +import io.confluent.ksql.rest.entity.StreamedRow; +import io.confluent.ksql.rest.entity.TableRowsEntity; import io.confluent.ksql.rest.entity.Versions; import io.confluent.ksql.rest.server.StatementParser; import io.confluent.ksql.rest.server.computation.CommandQueue; +import io.confluent.ksql.rest.server.execution.StaticQueryExecutor; import io.confluent.ksql.rest.server.resources.KsqlConfigurable; import io.confluent.ksql.rest.server.resources.KsqlRestException; import io.confluent.ksql.rest.util.CommandStoreUtil; @@ -42,6 +47,7 @@ import java.time.Duration; import java.util.Collection; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; @@ -175,9 +181,19 @@ private Response handleStatement( ); if (statement.getStatement() instanceof Query) { - return handleQuery( + final PreparedStatement queryStmt = (PreparedStatement) statement; + + if (queryStmt.getStatement().isStatic()) { + return handlePullQuery( + serviceContext, + queryStmt, + request.getStreamsProperties() + ); + } + + return handlePushQuery( serviceContext, - (PreparedStatement) statement, + queryStmt, request.getStreamsProperties() ); } @@ -200,7 +216,34 @@ private Response handleStatement( } } - private Response handleQuery( + private Response handlePullQuery( + final ServiceContext serviceContext, + final PreparedStatement statement, + final Map streamsProperties + ) { + final ConfiguredStatement configured = + ConfiguredStatement.of(statement, streamsProperties, ksqlConfig); + + final TableRowsEntity entity = StaticQueryExecutor + .execute(configured, ksqlEngine, serviceContext); + + final StreamedRow header = StreamedRow.header(entity.getQueryId(), entity.getSchema()); + + final List rows = entity.getRows().stream() + .map(GenericRow::new) + .map(StreamedRow::row) + .collect(Collectors.toList()); + + rows.add(0, header); + + final String data = rows.stream() + .map(this::writeValueAsString) + .collect(Collectors.joining("," + System.lineSeparator(), "[", "]")); + + return Response.ok().entity(data).build(); + } + + private Response handlePushQuery( final ServiceContext serviceContext, final PreparedStatement statement, final Map streamsProperties @@ -228,6 +271,14 @@ private Response handleQuery( return Response.ok().entity(queryStreamWriter).build(); } + private String writeValueAsString(final Object object) { + try { + return objectMapper.writeValueAsString(object); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + private Response handlePrintTopic( final ServiceContext serviceContext, final Map streamProperties, diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpoint.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpoint.java index 7e830c275d73..9d44798b9c55 100644 --- a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpoint.java +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpoint.java @@ -84,7 +84,8 @@ public class WSQueryEndpoint { private final CommandQueue commandQueue; private final ListeningScheduledExecutorService exec; private final ActivenessRegistrar activenessRegistrar; - private final QueryPublisher queryPublisher; + private final QueryPublisher pushQueryPublisher; + private final QueryPublisher pullQueryPublisher; private final PrintTopicPublisher topicPublisher; private final Duration commandQueueCatchupTimeout; private final KsqlAuthorizationValidator authorizationValidator; @@ -117,7 +118,8 @@ public WSQueryEndpoint( ksqlEngine, commandQueue, exec, - WSQueryEndpoint::startQueryPublisher, + WSQueryEndpoint::startPushQueryPublisher, + WSQueryEndpoint::startPullQueryPublisher, WSQueryEndpoint::startPrintPublisher, activenessRegistrar, commandQueueCatchupTimeout, @@ -137,7 +139,8 @@ public WSQueryEndpoint( final KsqlEngine ksqlEngine, final CommandQueue commandQueue, final ListeningScheduledExecutorService exec, - final QueryPublisher queryPublisher, + final QueryPublisher pushQueryPublisher, + final QueryPublisher pullQueryPublisher, final PrintTopicPublisher topicPublisher, final ActivenessRegistrar activenessRegistrar, final Duration commandQueueCatchupTimeout, @@ -154,7 +157,8 @@ public WSQueryEndpoint( this.commandQueue = Objects.requireNonNull(commandQueue, "commandQueue"); this.exec = Objects.requireNonNull(exec, "exec"); - this.queryPublisher = Objects.requireNonNull(queryPublisher, "queryPublisher"); + this.pushQueryPublisher = Objects.requireNonNull(pushQueryPublisher, "pushQueryPublisher"); + this.pullQueryPublisher = Objects.requireNonNull(pullQueryPublisher, "pullQueryPublisher"); this.topicPublisher = Objects.requireNonNull(topicPublisher, "topicPublisher"); this.activenessRegistrar = Objects.requireNonNull(activenessRegistrar, "activenessRegistrar"); @@ -352,10 +356,21 @@ private void handleQuery(final RequestContext info, final Query query) { final PreparedStatement statement = PreparedStatement.of(info.request.getKsql(), query); + final ConfiguredStatement configured = ConfiguredStatement.of(statement, clientLocalProperties, ksqlConfig); - queryPublisher.start(ksqlEngine, info.serviceContext, exec, configured, streamSubscriber); + final QueryPublisher queryPublisher = query.isStatic() + ? pullQueryPublisher + : pushQueryPublisher; + + queryPublisher.start( + ksqlEngine, + info.serviceContext, + exec, + configured, + streamSubscriber + ); } private void handlePrintTopic(final RequestContext info, final PrintTopic printTopic) { @@ -390,14 +405,25 @@ private void handleUnsupportedStatement( )); } - private static void startQueryPublisher( + private static void startPushQueryPublisher( final KsqlEngine ksqlEngine, final ServiceContext serviceContext, final ListeningScheduledExecutorService exec, final ConfiguredStatement query, final WebSocketSubscriber streamSubscriber ) { - new StreamPublisher(ksqlEngine, serviceContext, exec, query) + new PushQueryPublisher(ksqlEngine, serviceContext, exec, query) + .subscribe(streamSubscriber); + } + + private static void startPullQueryPublisher( + final KsqlEngine ksqlEngine, + final ServiceContext serviceContext, + final ListeningScheduledExecutorService ignored, + final ConfiguredStatement query, + final WebSocketSubscriber streamSubscriber + ) { + new PullQueryPublisher(ksqlEngine, serviceContext, query) .subscribe(streamSubscriber); } @@ -413,6 +439,7 @@ private static void startPrintPublisher( } interface QueryPublisher { + void start( KsqlEngine ksqlEngine, ServiceContext serviceContext, diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestApiTest.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestApiTest.java index 289d3a99fe56..77526251d71d 100644 --- a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestApiTest.java +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestApiTest.java @@ -15,22 +15,31 @@ package io.confluent.ksql.rest.integration; +import static io.confluent.ksql.test.util.AssertEventually.assertThatEventually; import static io.confluent.ksql.test.util.EmbeddedSingleNodeKafkaCluster.VALID_USER1; import static io.confluent.ksql.test.util.EmbeddedSingleNodeKafkaCluster.VALID_USER2; import static io.confluent.ksql.test.util.EmbeddedSingleNodeKafkaCluster.ops; import static io.confluent.ksql.test.util.EmbeddedSingleNodeKafkaCluster.prefixedResource; import static io.confluent.ksql.test.util.EmbeddedSingleNodeKafkaCluster.resource; +import static junit.framework.TestCase.fail; import static org.apache.kafka.common.acl.AclOperation.ALL; +import static org.apache.kafka.common.acl.AclOperation.CREATE; +import static org.apache.kafka.common.acl.AclOperation.DESCRIBE; import static org.apache.kafka.common.acl.AclOperation.DESCRIBE_CONFIGS; import static org.apache.kafka.common.resource.ResourceType.CLUSTER; import static org.apache.kafka.common.resource.ResourceType.GROUP; import static org.apache.kafka.common.resource.ResourceType.TOPIC; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import com.fasterxml.jackson.core.type.TypeReference; import io.confluent.common.utils.IntegrationTest; import io.confluent.ksql.integration.IntegrationTestHarness; +import io.confluent.ksql.json.JsonMapper; import io.confluent.ksql.rest.entity.Versions; import io.confluent.ksql.rest.server.TestKsqlRestApp; import io.confluent.ksql.serde.Format; @@ -40,12 +49,15 @@ import io.confluent.ksql.test.util.secure.Credentials; import io.confluent.ksql.test.util.secure.SecureKafkaHelper; import io.confluent.ksql.util.PageViewDataProvider; +import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; import javax.websocket.CloseReason.CloseCodes; import javax.ws.rs.core.MediaType; import org.eclipse.jetty.websocket.api.Session; @@ -54,6 +66,7 @@ import org.eclipse.jetty.websocket.api.annotations.OnWebSocketMessage; import org.eclipse.jetty.websocket.api.annotations.WebSocket; import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.junit.After; import org.junit.BeforeClass; import org.junit.ClassRule; import org.junit.Test; @@ -64,11 +77,14 @@ public class RestApiTest { private static final int HEADER = 1; // <-- some responses include a header as the first message. + private static final int FOOTER = 1; // <-- some responses include a footer as the last message. private static final int LIMIT = 2; private static final String PAGE_VIEW_TOPIC = "pageviews"; private static final String PAGE_VIEW_STREAM = "pageviews_original"; + private static final String AGG_TABLE = "AGG_TABLE"; private static final Credentials SUPER_USER = VALID_USER1; private static final Credentials NORMAL_USER = VALID_USER2; + private static final String AN_AGG_KEY = "USER_1"; private static final IntegrationTestHarness TEST_HARNESS = IntegrationTestHarness.builder() .withKafkaCluster( @@ -79,11 +95,11 @@ public class RestApiTest { .withAcl( NORMAL_USER, resource(CLUSTER, "kafka-cluster"), - ops(DESCRIBE_CONFIGS) + ops(DESCRIBE_CONFIGS, CREATE) ) .withAcl( NORMAL_USER, - resource(TOPIC, "_confluent-ksql-default__command_topic"), + prefixedResource(TOPIC, "_confluent-ksql-default_"), ops(ALL) ) .withAcl( @@ -106,6 +122,16 @@ public class RestApiTest { resource(TOPIC, "X"), ops(ALL) ) + .withAcl( + NORMAL_USER, + resource(TOPIC, "AGG_TABLE"), + ops(ALL) + ) + .withAcl( + NORMAL_USER, + resource(TOPIC, "__consumer_offsets"), + ops(DESCRIBE) + ) ) .build(); @@ -120,6 +146,8 @@ public class RestApiTest { @ClassRule public static final RuleChain CHAIN = RuleChain.outerRule(TEST_HARNESS).around(REST_APP); + private ServiceContext serviceContext; + @BeforeClass public static void setUpClass() { TEST_HARNESS.ensureTopics(PAGE_VIEW_TOPIC); @@ -127,70 +155,231 @@ public static void setUpClass() { TEST_HARNESS.produceRows(PAGE_VIEW_TOPIC, new PageViewDataProvider(), Format.JSON); RestIntegrationTestUtil.createStreams(REST_APP, PAGE_VIEW_STREAM, PAGE_VIEW_TOPIC); + + makeKsqlRequest("CREATE TABLE " + AGG_TABLE + " AS " + + "SELECT COUNT(1) AS COUNT FROM " + PAGE_VIEW_STREAM + " GROUP BY USERID;"); + } + + @After + public void tearDown() { + if (serviceContext != null) { + serviceContext.close(); + } } @Test - public void shouldExecuteStreamingQueryWithV1ContentType() throws Exception { + public void shouldExecutePushQueryOverWebSocketWithV1ContentType() { // When: - final List messages = makeStreamingRequest( + final List messages = makeWebSocketRequest( "SELECT * from " + PAGE_VIEW_STREAM + " EMIT CHANGES LIMIT " + LIMIT + ";", Versions.KSQL_V1_JSON_TYPE, Versions.KSQL_V1_JSON_TYPE ); // Then: - assertThat(messages, hasSize(is(HEADER + LIMIT))); + assertThat(messages, hasSize(HEADER + LIMIT)); + assertValidJsonMessages(messages); } @Test - public void shouldExecuteStreamingQueryWithJsonContentType() throws Exception { + public void shouldExecutePushQueryOverWebSocketWithJsonContentType() { // When: - final List messages = makeStreamingRequest( + final List messages = makeWebSocketRequest( "SELECT * from " + PAGE_VIEW_STREAM + " EMIT CHANGES LIMIT " + LIMIT + ";", MediaType.APPLICATION_JSON_TYPE, MediaType.APPLICATION_JSON_TYPE ); // Then: - assertThat(messages, hasSize(is(HEADER + LIMIT))); + assertThat(messages, hasSize(HEADER + LIMIT)); + assertValidJsonMessages(messages); + } + + @Test + public void shouldExecutePushQueryOverRest() { + // When: + final String response = rawRestQueryRequest( + "SELECT USERID, PAGEID, VIEWTIME, ROWKEY from " + PAGE_VIEW_STREAM + " EMIT CHANGES LIMIT " + + LIMIT + ";" + ); + + // Then: + final String[] messages = response.split(System.lineSeparator()); + assertThat(messages.length, is(HEADER + LIMIT + FOOTER)); + assertThat(messages[0], + is("{\"header\":{\"queryId\":\"none\",\"schema\":\"`USERID` STRING, `PAGEID` STRING, `VIEWTIME` BIGINT, `ROWKEY` STRING\"}}")); + assertThat(messages[1], is("{\"row\":{\"columns\":[\"USER_1\",\"PAGE_1\",1,\"1\"]}}")); + assertThat(messages[2], is("{\"row\":{\"columns\":[\"USER_2\",\"PAGE_2\",2,\"2\"]}}")); + assertThat(messages[3], is("{\"finalMessage\":\"Limit Reached\"}")); + } + + @Test + public void shouldExecutePullQueryOverWebSocketWithV1ContentType() { + // When: + final Supplier> call = () -> makeWebSocketRequest( + "SELECT * from " + AGG_TABLE + " WHERE ROWKEY='" + AN_AGG_KEY + "';", + Versions.KSQL_V1_JSON_TYPE, + Versions.KSQL_V1_JSON_TYPE + ); + + // Then: + final List messages = assertThatEventually(call, hasSize(HEADER + 1)); + assertValidJsonMessages(messages); + assertThat(messages.get(0), + is("[{\"name\":\"COUNT\",\"schema\":{\"type\":\"BIGINT\",\"fields\":null,\"memberSchema\":null}}]")); + assertThat(messages.get(1), + is("{\"row\":{\"columns\":[\"USER_1\",1]}}")); + } + + @Test + public void shouldExecutePullQueryOverWebSocketWithJsonContentType() { + // When: + final Supplier> call = () -> makeWebSocketRequest( + "SELECT * from " + AGG_TABLE + " WHERE ROWKEY='" + AN_AGG_KEY + "';", + MediaType.APPLICATION_JSON_TYPE, + MediaType.APPLICATION_JSON_TYPE + ); + + // Then: + final List messages = assertThatEventually(call, hasSize(HEADER + 1)); + assertValidJsonMessages(messages); + assertThat(messages.get(0), + is("[{\"name\":\"COUNT\",\"schema\":{\"type\":\"BIGINT\",\"fields\":null,\"memberSchema\":null}}]")); + assertThat(messages.get(1), + is("{\"row\":{\"columns\":[\"USER_1\",1]}}")); + } + + @Test + public void shouldReturnCorrectSchemaForPullQueryWithOnlyKeyInSelect() { + // When: + final Supplier> call = () -> makeWebSocketRequest( + "SELECT * from " + AGG_TABLE + " WHERE ROWKEY='" + AN_AGG_KEY + "';", + MediaType.APPLICATION_JSON_TYPE, + MediaType.APPLICATION_JSON_TYPE + ); + + // Then: + final List messages = assertThatEventually(call, hasSize(HEADER + 1)); + assertValidJsonMessages(messages); + assertThat(messages.get(0), + is("[{\"name\":\"COUNT\",\"schema\":{\"type\":\"BIGINT\",\"fields\":null,\"memberSchema\":null}}]")); + assertThat(messages.get(1), + is("{\"row\":{\"columns\":[\"USER_1\",1]}}")); + } + + @Test + public void shouldReturnCorrectSchemaForPullQueryWithOnlyValueColumnInSelect() { + // When: + final Supplier> call = () -> makeWebSocketRequest( + "SELECT COUNT from " + AGG_TABLE + " WHERE ROWKEY='" + AN_AGG_KEY + "';", + MediaType.APPLICATION_JSON_TYPE, + MediaType.APPLICATION_JSON_TYPE + ); + + // Then: + final List messages = assertThatEventually(call, hasSize(HEADER + 1)); + assertValidJsonMessages(messages); + assertThat(messages.get(0), + is("[{\"name\":\"COUNT\",\"schema\":{\"type\":\"BIGINT\",\"fields\":null,\"memberSchema\":null}}]")); + assertThat(messages.get(1), + is("{\"row\":{\"columns\":[1]}}")); + } + + @Test + public void shouldExecutePullQueryOverRest() { + // When: + final Supplier> call = () -> { + final String response = rawRestQueryRequest( + "SELECT * from " + AGG_TABLE + " WHERE ROWKEY='" + AN_AGG_KEY + "';" + ); + return Arrays.asList(response.split(System.lineSeparator())); + }; + + // Then: + final List messages = assertThatEventually(call, hasSize(HEADER + 1)); + final List> parsed = parseRawRestQueryResponse(String.join("", messages)); + assertThat(parsed, hasSize(HEADER + 1)); + assertThat(parsed.get(0).get("header"), instanceOf(Map.class)); + assertThat(((Map) parsed.get(0).get("header")).get("queryId"), is(notNullValue())); + assertThat(((Map) parsed.get(0).get("header")).get("schema"), + is("`ROWKEY` STRING KEY, `COUNT` BIGINT")); + assertThat(messages.get(1), is("{\"row\":{\"columns\":[[\"USER_1\",1]]}}]")); + } + + @Test + public void shouldReportErrorOnInvalidPullQueryOverRest() { + // When: + final String response = rawRestQueryRequest( + "SELECT * from " + AGG_TABLE + ";" + ); + + // Then: + assertThat(response, containsString("Missing WHERE clause")); } @Test - public void shouldPrintTopic() throws Exception { + public void shouldPrintTopicOverWebSocket() { // When: - final List messages = makeStreamingRequest( + final List messages = makeWebSocketRequest( "PRINT '" + PAGE_VIEW_TOPIC + "' FROM BEGINNING LIMIT " + LIMIT + ";", MediaType.APPLICATION_JSON_TYPE, MediaType.APPLICATION_JSON_TYPE); // Then: - assertThat(messages, hasSize(is(LIMIT))); + assertThat(messages, hasSize(LIMIT)); } @Test public void shouldDeleteTopic() { - try (final ServiceContext serviceContext = REST_APP.getServiceContext()) { - // Given: - RestIntegrationTestUtil.makeKsqlRequest( - REST_APP, - "CREATE STREAM X AS SELECT * FROM " + PAGE_VIEW_STREAM + ";"); - assertThat("Expected topic X to be created", serviceContext.getTopicClient().isTopicExists("X")); - - // When: - RestIntegrationTestUtil.makeKsqlRequest( - REST_APP, - "TERMINATE QUERY CSAS_X_1; DROP STREAM X DELETE TOPIC;"); - - // Then: - assertThat("Expected topic X to be deleted", !serviceContext.getTopicClient().isTopicExists("X")); + // Given: + makeKsqlRequest("CREATE STREAM X AS SELECT * FROM " + PAGE_VIEW_STREAM + ";" + + "TERMINATE QUERY CSAS_X_2; "); + + assertThat("Expected topic X to be created", topicExists("X")); + + // When: + makeKsqlRequest("DROP STREAM X DELETE TOPIC;"); + + // Then: + assertThat("Expected topic X to be deleted", !topicExists("X")); + } + + private boolean topicExists(final String topicName) { + return getServiceContext().getTopicClient().isTopicExists(topicName); + } + + private ServiceContext getServiceContext() { + if (serviceContext == null) { + serviceContext = REST_APP.getServiceContext(); } + return serviceContext; } - private static List makeStreamingRequest( + private static void makeKsqlRequest(final String sql) { + RestIntegrationTestUtil.makeKsqlRequest(REST_APP, sql); + } + + private static String rawRestQueryRequest(final String sql) { + return RestIntegrationTestUtil.rawRestQueryRequest(REST_APP, sql, Optional.empty()); + } + + private static List> parseRawRestQueryResponse(final String response) { + try { + return JsonMapper.INSTANCE.mapper.readValue( + response, + new TypeReference>>() { + } + ); + } catch (final Exception e) { + throw new AssertionError("Invalid JSON received: " + response); + } + } + + private static List makeWebSocketRequest( final String sql, final MediaType mediaType, final MediaType contentType - ) throws Exception { + ) { final WebSocketListener listener = new WebSocketListener(); final WebSocketClient wsClient = RestIntegrationTestUtil.makeWsRequest( @@ -205,7 +394,21 @@ private static List makeStreamingRequest( try { return listener.awaitMessages(); } finally { - wsClient.stop(); + try { + wsClient.stop(); + } catch (final Exception e) { + fail("Failed to close ws"); + } + } + } + + private static void assertValidJsonMessages(final Iterable messages) { + for (final String msg : messages) { + try { + JsonMapper.INSTANCE.mapper.readValue(msg, Object.class); + } catch (final Exception e) { + throw new AssertionError("Invalid JSON message received: " + msg, e); + } } } diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestIntegrationTestUtil.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestIntegrationTestUtil.java index 5ab5245b1ead..389e29185afe 100644 --- a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestIntegrationTestUtil.java +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestIntegrationTestUtil.java @@ -17,7 +17,9 @@ import static org.junit.Assert.assertEquals; +import com.google.common.collect.ImmutableMap; import com.google.common.net.UrlEscapers; +import io.confluent.ksql.json.JsonMapper; import io.confluent.ksql.rest.client.BasicCredentials; import io.confluent.ksql.rest.client.KsqlRestClient; import io.confluent.ksql.rest.client.RestResponse; @@ -29,12 +31,14 @@ import io.confluent.ksql.rest.entity.KsqlRequest; import io.confluent.ksql.rest.server.TestKsqlRestApp; import io.confluent.ksql.test.util.secure.Credentials; +import io.confluent.rest.validation.JacksonMessageBodyProvider; import java.net.URI; import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; import javax.ws.rs.client.Client; +import javax.ws.rs.client.ClientBuilder; import javax.ws.rs.client.Entity; import javax.ws.rs.core.HttpHeaders; import javax.ws.rs.core.MediaType; @@ -63,7 +67,45 @@ static List makeKsqlRequest( throwOnError(res); - return awaitResults(restApp, res.getResponse()); + return awaitResults(restClient, res.getResponse()); + } + } + + /** + * Make a query request using a basic Http client. + * + * @param restApp the test app instance to issue the request against + * @param sql the sql payload + * @param cmdSeqNum optional sequence number of previous command + * @return the response payload + */ + static String rawRestQueryRequest( + final TestKsqlRestApp restApp, + final String sql, + final Optional cmdSeqNum + ) { + final KsqlRequest request = new KsqlRequest( + sql, + ImmutableMap.of(), + cmdSeqNum.orElse(null) + ); + + final URI listener = restApp.getHttpListener(); + + final Client httpClient = ClientBuilder.newBuilder() + .register(new JacksonMessageBodyProvider(JsonMapper.INSTANCE.mapper)) + .build(); + + try { + final Response response = httpClient + .target(listener) + .path("/query") + .request(MediaType.APPLICATION_JSON_TYPE) + .post(Entity.json(request)); + + return response.readEntity(String.class); + } finally { + httpClient.close(); } } @@ -85,19 +127,17 @@ static void createStreams(final TestKsqlRestApp restApp, final String streamName } } - static Entity ksqlRequest(final String sql) { + private static Entity ksqlRequest(final String sql) { return Entity.json(new KsqlRequest(sql, Collections.emptyMap(), null)); } private static List awaitResults( - final TestKsqlRestApp restApp, + final KsqlRestClient ksqlRestClient, final List pending ) { - try (final KsqlRestClient ksqlRestClient = restApp.buildKsqlClient()) { - return pending.stream() - .map(e -> awaitResult(e, ksqlRestClient)) - .collect(Collectors.toList()); - } + return pending.stream() + .map(e -> awaitResult(e, ksqlRestClient)) + .collect(Collectors.toList()); } private static KsqlEntity awaitResult( @@ -136,31 +176,34 @@ private static void throwOnError(final RestResponse res) { } } - public static WebSocketClient makeWsRequest( + static WebSocketClient makeWsRequest( final URI baseUri, final String sql, final Object listener, final Optional mediaType, final Optional contentType, final Optional credentials - ) throws Exception { - - final WebSocketClient wsClient = new WebSocketClient(); - wsClient.start(); + ) { + try { + final WebSocketClient wsClient = new WebSocketClient(); + wsClient.start(); - final ClientUpgradeRequest request = new ClientUpgradeRequest(); + final ClientUpgradeRequest request = new ClientUpgradeRequest(); - credentials.ifPresent(creds -> request - .setHeader(HttpHeaders.AUTHORIZATION, "Basic " + buildBasicAuthHeader(creds))); + credentials.ifPresent(creds -> request + .setHeader(HttpHeaders.AUTHORIZATION, "Basic " + buildBasicAuthHeader(creds))); - mediaType.ifPresent(mt -> request.setHeader(HttpHeaders.ACCEPT, mt.toString())); - contentType.ifPresent(ct -> request.setHeader(HttpHeaders.CONTENT_TYPE, ct.toString())); + mediaType.ifPresent(mt -> request.setHeader(HttpHeaders.ACCEPT, mt.toString())); + contentType.ifPresent(ct -> request.setHeader(HttpHeaders.CONTENT_TYPE, ct.toString())); - final URI wsUri = baseUri.resolve("/ws/query?request=" + buildStreamingRequest(sql)); + final URI wsUri = baseUri.resolve("/ws/query?request=" + buildStreamingRequest(sql)); - wsClient.connect(listener, wsUri, request); + wsClient.connect(listener, wsUri, request); - return wsClient; + return wsClient; + } catch (final Exception e) { + throw new RuntimeException("failed to create ws client", e); + } } private static String buildBasicAuthHeader(final Credentials credentials) { diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryPublisherTest.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryPublisherTest.java new file mode 100644 index 000000000000..64e62a2d5b0b --- /dev/null +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryPublisherTest.java @@ -0,0 +1,181 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.rest.server.resources.streaming; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.engine.KsqlEngine; +import io.confluent.ksql.name.ColumnName; +import io.confluent.ksql.parser.tree.Query; +import io.confluent.ksql.rest.entity.StreamedRow; +import io.confluent.ksql.rest.entity.TableRowsEntity; +import io.confluent.ksql.rest.server.resources.streaming.Flow.Subscriber; +import io.confluent.ksql.rest.server.resources.streaming.Flow.Subscription; +import io.confluent.ksql.rest.server.resources.streaming.PullQueryPublisher.PullQueryExecutor; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.services.ServiceContext; +import io.confluent.ksql.statement.ConfiguredStatement; +import java.util.Collection; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; + +@RunWith(MockitoJUnitRunner.class) +public class PullQueryPublisherTest { + + private static final LogicalSchema SCHEMA = LogicalSchema.builder() + .valueColumn(ColumnName.of("bob"), SqlTypes.BIGINT) + .build(); + + @Mock + private KsqlEngine engine; + @Mock + private ServiceContext serviceContext; + @Mock + private ConfiguredStatement statement; + @Mock + private Subscriber> subscriber; + @Mock + private PullQueryExecutor pullQueryExecutor; + @Mock + private TableRowsEntity entity; + @Captor + private ArgumentCaptor subscriptionCaptor; + + private Subscription subscription; + private PullQueryPublisher publisher; + + @Before + public void setUp() { + publisher = new PullQueryPublisher(engine, serviceContext, statement, pullQueryExecutor); + + when(pullQueryExecutor.execute(any(), any(), any())).thenReturn(entity); + + when(entity.getSchema()).thenReturn(SCHEMA); + + doAnswer(callRequestAgain()).when(subscriber).onNext(any()); + } + + @Test + public void shouldSubscribe() { + // When: + publisher.subscribe(subscriber); + + // Then: + verify(subscriber).onSubscribe(any()); + } + + @Test + public void shouldRunQueryWithCorrectParams() { + // Given: + givenSubscribed(); + + // When: + subscription.request(1); + + // Then: + verify(pullQueryExecutor).execute(statement, engine, serviceContext); + } + + @Test + public void shouldOnlyExecuteOnce() { + // Given: + givenSubscribed(); + + // When: + subscription.request(1); + + // Then: + verify(subscriber).onNext(any()); + verify(pullQueryExecutor).execute(statement, engine, serviceContext); + } + + @Test + public void shouldCallOnSchemaThenOnNextThenOnCompleteOnSuccess() { + // Given: + givenSubscribed(); + + // When: + subscription.request(1); + + // Then: + final InOrder inOrder = inOrder(subscriber); + inOrder.verify(subscriber).onSchema(SCHEMA); + inOrder.verify(subscriber).onNext(ImmutableList.of()); + inOrder.verify(subscriber).onComplete(); + } + + @Test + public void shouldCallOnErrorOnFailure() { + // Given: + givenSubscribed(); + + final Throwable e = new RuntimeException("Boom!"); + when(pullQueryExecutor.execute(any(), any(), any())).thenThrow(e); + + // When: + subscription.request(1); + + // Then: + verify(subscriber).onError(e); + } + + @Test + public void shouldBuildStreamingRows() { + // Given: + givenSubscribed(); + + when(entity.getRows()).thenReturn(ImmutableList.of( + ImmutableList.of("a", 1, 2L, 3.0f), + ImmutableList.of("b", 1, 2L, 3.0f) + )); + + // When: + subscription.request(1); + + // Then: + verify(subscriber).onNext(ImmutableList.of( + StreamedRow.row(new GenericRow("a", 1, 2L, 3.0f)), + StreamedRow.row(new GenericRow("b", 1, 2L, 3.0f)) + )); + } + + private Answer callRequestAgain() { + return inv -> { + subscription.request(1); + return null; + }; + } + + private void givenSubscribed() { + publisher.subscribe(subscriber); + verify(subscriber).onSubscribe(subscriptionCaptor.capture()); + subscription = subscriptionCaptor.getValue(); + } +} \ No newline at end of file diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResourceTest.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResourceTest.java index 5b6fdf1cc79e..cd3cd1a2f299 100644 --- a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResourceTest.java +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResourceTest.java @@ -15,11 +15,11 @@ package io.confluent.ksql.rest.server.resources.streaming; -import static io.confluent.ksql.metastore.model.DataSource.DataSourceType; import static io.confluent.ksql.rest.entity.KsqlErrorMessageMatchers.errorCode; import static io.confluent.ksql.rest.entity.KsqlErrorMessageMatchers.errorMessage; import static io.confluent.ksql.rest.server.resources.KsqlRestExceptionMatchers.exceptionErrorMessage; import static io.confluent.ksql.rest.server.resources.KsqlRestExceptionMatchers.exceptionStatusCode; +import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertEquals; @@ -339,16 +339,29 @@ public void shouldStreamRowsCorrectly() throws Throwable { throw new Exception("Response input stream failed to have expected line available"); } final String responseLine = responseScanner.nextLine(); - if (responseLine.trim().isEmpty()) { + + if (responseLine.isEmpty()) { i--; - } else { - final GenericRow expectedRow; - synchronized (writtenRows) { - expectedRow = writtenRows.poll(); - } - final GenericRow testRow = objectMapper.readValue(responseLine, StreamedRow.class).getRow(); - assertEquals(expectedRow, testRow); + continue; + } + + if (i == 0) { + // Header: + assertThat(responseLine, is("{\"header\":{\"queryId\":\"none\",\"schema\":\"`f1` INTEGER\"}}")); + continue; } + + final GenericRow expectedRow; + synchronized (writtenRows) { + expectedRow = writtenRows.poll(); + } + + final GenericRow testRow = objectMapper + .readValue(responseLine, StreamedRow.class) + .getRow() + .orElse(null); + + assertEquals(expectedRow, testRow); } responseOutputStream.close(); diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpointTest.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpointTest.java index 552f106cf9bd..fbdc4f1d57b9 100644 --- a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpointTest.java +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpointTest.java @@ -130,7 +130,9 @@ public class WSQueryEndpointTest { @Mock private CommandQueue commandQueue; @Mock - private QueryPublisher queryPublisher; + private QueryPublisher pushQueryPublisher; + @Mock + private QueryPublisher pullQueryPublisher; @Mock private PrintTopicPublisher topicPublisher; @Mock @@ -158,9 +160,7 @@ public class WSQueryEndpointTest { @Before public void setUp() { - query = new Query(Optional.empty(), mock(Select.class), mock(Relation.class), Optional.empty(), - Optional.empty(), Optional.empty(), Optional.empty(), ResultMaterialization.CHANGES, false, - OptionalInt.empty()); + givenQueryIs(QueryType.PUSH); when(session.getId()).thenReturn("session-id"); when(session.getUserPrincipal()).thenReturn(principal); when(statementParser.parseSingleStatement(anyString())) @@ -179,10 +179,23 @@ public void setUp() { givenRequest(VALID_REQUEST); wsQueryEndpoint = new WSQueryEndpoint( - ksqlConfig, OBJECT_MAPPER, statementParser, ksqlEngine, commandQueue, exec, - queryPublisher, topicPublisher, activenessRegistrar, COMMAND_QUEUE_CATCHUP_TIMEOUT, - authorizationValidator, securityExtension, serviceContextFactory, - defaultServiceContextProvider, serverState); + ksqlConfig, + OBJECT_MAPPER, + statementParser, + ksqlEngine, + commandQueue, + exec, + pushQueryPublisher, + pullQueryPublisher, + topicPublisher, + activenessRegistrar, + COMMAND_QUEUE_CATCHUP_TIMEOUT, + authorizationValidator, + securityExtension, + serviceContextFactory, + defaultServiceContextProvider, + serverState + ); } @Test @@ -348,8 +361,31 @@ public void shouldReturnErrorOnFailedStateCheck() throws Exception { } @Test - public void shouldHandleQuery() { + public void shouldHandlePushQuery() { + // Given: + givenRequestIs(query); + + // When: + wsQueryEndpoint.onOpen(session, null); + + // Then: + final ConfiguredStatement configuredStatement = ConfiguredStatement.of( + PreparedStatement.of(VALID_REQUEST.getKsql(), query), + VALID_REQUEST.getStreamsProperties(), + ksqlConfig); + + verify(pushQueryPublisher).start( + eq(ksqlEngine), + eq(serviceContext), + eq(exec), + eq(configuredStatement), + any()); + } + + @Test + public void shouldHandlePullQuery() { // Given: + givenQueryIs(QueryType.PULL); givenRequestIs(query); // When: @@ -361,7 +397,7 @@ public void shouldHandleQuery() { VALID_REQUEST.getStreamsProperties(), ksqlConfig); - verify(queryPublisher).start( + verify(pullQueryPublisher).start( eq(ksqlEngine), eq(serviceContext), eq(exec), @@ -514,6 +550,15 @@ public void shouldCloseSessionOnError() throws Exception { verifyClosedWithReason("ahh scary", CloseCodes.UNEXPECTED_CONDITION); } + @Test + public void shouldUpdateTheLastRequestTime() { + // When: + wsQueryEndpoint.onOpen(session, null); + + // Then: + verify(activenessRegistrar).updateLastRequestTime(); + } + private void givenVersions(final String... versions) { givenRequestAndVersions( @@ -578,15 +623,20 @@ private void verifyClosedWithReason(final String reason, final CloseCodes code) assertThat(closeReason.getCloseCode(), is(code)); } - @Test - public void shouldUpdateTheLastRequestTime() { - // Given: + private enum QueryType {PUSH, PULL}; - // When: - wsQueryEndpoint.onOpen(session, null); - - // Then: - verify(activenessRegistrar).updateLastRequestTime(); + private void givenQueryIs(final QueryType type) { + query = new Query( + Optional.empty(), + mock(Select.class), + mock(Relation.class), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ResultMaterialization.CHANGES, + type == QueryType.PULL, + OptionalInt.empty() + ); } - } diff --git a/ksql-rest-model/src/main/java/io/confluent/ksql/rest/entity/StreamedRow.java b/ksql-rest-model/src/main/java/io/confluent/ksql/rest/entity/StreamedRow.java index 5a1dea51ef19..36b65f6594b7 100644 --- a/ksql-rest-model/src/main/java/io/confluent/ksql/rest/entity/StreamedRow.java +++ b/ksql-rest-model/src/main/java/io/confluent/ksql/rest/entity/StreamedRow.java @@ -15,64 +15,103 @@ package io.confluent.ksql.rest.entity; +import static java.util.Objects.requireNonNull; + import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.google.errorprone.annotations.Immutable; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.query.QueryId; +import io.confluent.ksql.schema.ksql.LogicalSchema; import java.util.Arrays; -import java.util.List; import java.util.Objects; +import java.util.Optional; @JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(Include.NON_ABSENT) @JsonSubTypes({}) -public class StreamedRow { +public final class StreamedRow { - private final GenericRow row; - private final KsqlErrorMessage errorMessage; - private final String finalMessage; + private final Optional
header; + private final Optional row; + private final Optional errorMessage; + private final Optional finalMessage; + + public static StreamedRow header(final QueryId queryId, final LogicalSchema schema) { + return new StreamedRow( + Optional.of(Header.of(queryId, schema)), + Optional.empty(), + Optional.empty(), + Optional.empty() + ); + } public static StreamedRow row(final GenericRow row) { - return new StreamedRow(row, null, null); + return new StreamedRow( + Optional.empty(), + Optional.of(row), + Optional.empty(), + Optional.empty() + ); } public static StreamedRow error(final Throwable exception, final int errorCode) { return new StreamedRow( - null, - new KsqlErrorMessage(errorCode, exception), - null); + Optional.empty(), + Optional.empty(), + Optional.of(new KsqlErrorMessage(errorCode, exception)), + Optional.empty() + ); } public static StreamedRow finalMessage(final String finalMessage) { - return new StreamedRow(null, null, finalMessage); + return new StreamedRow( + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.of(finalMessage) + ); } @JsonCreator - public StreamedRow( - @JsonProperty("row") final GenericRow row, - @JsonProperty("errorMessage") final KsqlErrorMessage errorMessage, - @JsonProperty("finalMessage") final String finalMessage + private StreamedRow( + @JsonProperty("header") final Optional
header, + @JsonProperty("row") final Optional row, + @JsonProperty("errorMessage") final Optional errorMessage, + @JsonProperty("finalMessage") final Optional finalMessage ) { - checkUnion(row, errorMessage, finalMessage); - this.row = row; - this.errorMessage = errorMessage; - this.finalMessage = finalMessage; + this.header = requireNonNull(header, "header"); + this.row = requireNonNull(row, "row"); + this.errorMessage = requireNonNull(errorMessage, "errorMessage"); + this.finalMessage = requireNonNull(finalMessage, "finalMessage"); + + checkUnion(header, row, errorMessage, finalMessage); } - public GenericRow getRow() { + public Optional
getHeader() { + return header; + } + + public Optional getRow() { return row; } - public KsqlErrorMessage getErrorMessage() { + public Optional getErrorMessage() { return errorMessage; } - public String getFinalMessage() { + public Optional getFinalMessage() { return finalMessage; } + @JsonIgnore public boolean isTerminal() { - return finalMessage != null || errorMessage != null; + return finalMessage.isPresent() || errorMessage.isPresent(); } @Override @@ -84,24 +123,53 @@ public boolean equals(final Object o) { return false; } final StreamedRow that = (StreamedRow) o; - return Objects.equals(row, that.row) - && Objects.equals(errorMessage, that.errorMessage) - && Objects.equals(finalMessage, that.finalMessage); + return Objects.equals(header, that.header) + && Objects.equals(row, that.row) + && Objects.equals(errorMessage, that.errorMessage) + && Objects.equals(finalMessage, that.finalMessage); } @Override public int hashCode() { - return Objects.hash(row, errorMessage, finalMessage); + return Objects.hash(header, row, errorMessage, finalMessage); } - private static void checkUnion(final Object... fields) { - final List fs = Arrays.asList(fields); - final long count = fs.stream() - .filter(Objects::nonNull) + private static void checkUnion(final Optional... fs) { + final long count = Arrays.stream(fs) + .filter(Optional::isPresent) .count(); if (count != 1) { - throw new IllegalArgumentException("Exactly one parameter should be non-null. got: " + fs); + throw new IllegalArgumentException("Exactly one parameter should be non-null. got: " + count); + } + } + + @Immutable + @JsonIgnoreProperties(ignoreUnknown = true) + public static final class Header { + + private final QueryId queryId; + private final LogicalSchema schema; + + @JsonCreator + public static Header of( + @JsonProperty(value = "queryId", required = true) final QueryId queryId, + @JsonProperty(value = "schema", required = true) final LogicalSchema schema + ) { + return new Header(queryId, schema); + } + + public QueryId getQueryId() { + return queryId; + } + + public LogicalSchema getSchema() { + return schema; + } + + private Header(final QueryId queryId, final LogicalSchema schema) { + this.queryId = requireNonNull(queryId, "queryId"); + this.schema = requireNonNull(schema, "schema"); } } }