From 6b5ce0c1a6f34891dcac88ad7454a06f3ed08f37 Mon Sep 17 00:00:00 2001 From: Andy Coates <8012398+big-andy-coates@users.noreply.github.com> Date: Tue, 14 Jan 2020 12:52:37 +0000 Subject: [PATCH] fix: deadlock when closing transient push query (#4297) * fix: deadlock when closing transient push query fixes: https://github.com/confluentinc/ksql/issues/4296 The produce side not calls `offer` in a loop, with a short timeout, to try and put the row into the blocking queue. When the consume side closes the query, e.g. on an `EOFException` if the user has closed the connection, the query first closes the queue; setting a flag the producers are checking on each loop; causing any producers to exit the loop. Then it can safely close the KS topology. --- .../ksql/query/BlockingRowQueue.java | 68 ++++++++++++ .../confluent/ksql/query/QueryExecutor.java | 8 +- .../ksql/query/TransientQueryQueue.java | 76 ++++++++----- .../ksql/util/PersistentQueryMetadata.java | 2 +- .../io/confluent/ksql/util/QueryMetadata.java | 5 +- .../ksql/util/TransientQueryMetadata.java | 25 ++--- .../integration/EndToEndIntegrationTest.java | 6 +- .../ksql/query/TransientQueryQueueTest.java | 102 ++++++++++++++---- .../ksql/util/TransientQueryMetadataTest.java | 90 ++++++++++++++++ .../entity/QueryDescriptionFactoryTest.java | 14 +-- .../streaming/QueryStreamWriterTest.java | 21 ++-- .../streaming/StreamedQueryResourceTest.java | 46 +++++++- 12 files changed, 376 insertions(+), 87 deletions(-) create mode 100644 ksql-engine/src/main/java/io/confluent/ksql/query/BlockingRowQueue.java create mode 100644 ksql-engine/src/test/java/io/confluent/ksql/util/TransientQueryMetadataTest.java diff --git a/ksql-engine/src/main/java/io/confluent/ksql/query/BlockingRowQueue.java b/ksql-engine/src/main/java/io/confluent/ksql/query/BlockingRowQueue.java new file mode 100644 index 000000000000..6512ec9e15ce --- /dev/null +++ b/ksql-engine/src/main/java/io/confluent/ksql/query/BlockingRowQueue.java @@ -0,0 +1,68 @@ +/* + * Copyright 2020 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.query; + +import io.confluent.ksql.GenericRow; +import java.util.Collection; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.TimeUnit; +import org.apache.kafka.streams.KeyValue; + +/** + * The queue between the Kafka-streams topology and the client connection. + * + *

The KS topology writes to the queue from its {@code StreamThread}, while the KSQL server + * thread that is servicing the client request reads from the queue and writes to the client + * socket. + */ +public interface BlockingRowQueue { + + /** + * Sets the limit handler that will be called when any row limit is reached. + * + *

Replaces any previous handler. + * + * @param limitHandler the handler. + */ + void setLimitHandler(LimitHandler limitHandler); + + /** + * Poll the queue for a single row + * + * @see BlockingQueue#poll(long, TimeUnit) + */ + KeyValue poll(long timeout, TimeUnit unit) + throws InterruptedException; + + /** + * Drain the queue to the supplied {@code collection}. + * + * @see BlockingQueue#drainTo(Collection) + */ + void drainTo(Collection> collection); + + /** + * The size of the queue. + * + * @see BlockingQueue#size() + */ + int size(); + + /** + * Close the queue. + */ + void close(); +} diff --git a/ksql-engine/src/main/java/io/confluent/ksql/query/QueryExecutor.java b/ksql-engine/src/main/java/io/confluent/ksql/query/QueryExecutor.java index 34301d205afd..199f009108b3 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/query/QueryExecutor.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/query/QueryExecutor.java @@ -151,7 +151,7 @@ public TransientQueryMetadata buildTransientQuery( final LogicalSchema schema, final OptionalInt limit ) { - final TransientQueryQueue queue = buildTransientQueryQueue(queryId, physicalPlan, limit); + final BlockingRowQueue queue = buildTransientQueryQueue(queryId, physicalPlan, limit); final String applicationId = addTimeSuffix(getQueryApplicationId( getServiceId(), @@ -171,15 +171,15 @@ public TransientQueryMetadata buildTransientQuery( built.kafkaStreams, transientSchema, sources, - queue::setLimitHandler, planSummary, - queue.getQueue(), + queue, applicationId, built.topology, streamsProperties, overrides, queryCloseCallback, - ksqlConfig.getLong(KSQL_SHUTDOWN_TIMEOUT_MS_CONFIG)); + ksqlConfig.getLong(KSQL_SHUTDOWN_TIMEOUT_MS_CONFIG) + ); } private static Optional getMaterializationInfo(final Object result) { diff --git a/ksql-engine/src/main/java/io/confluent/ksql/query/TransientQueryQueue.java b/ksql-engine/src/main/java/io/confluent/ksql/query/TransientQueryQueue.java index 2d308e374a9a..eb983e560dcc 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/query/TransientQueryQueue.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/query/TransientQueryQueue.java @@ -15,12 +15,15 @@ package io.confluent.ksql.query; +import com.google.common.annotations.VisibleForTesting; import io.confluent.ksql.GenericRow; import io.confluent.ksql.util.KsqlException; +import java.util.Collection; import java.util.Objects; import java.util.OptionalInt; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; import org.apache.kafka.streams.KeyValue; import org.apache.kafka.streams.kstream.ForeachAction; import org.apache.kafka.streams.kstream.KStream; @@ -29,41 +32,62 @@ /** * A queue of rows for transient queries. */ -class TransientQueryQueue { +class TransientQueryQueue implements BlockingRowQueue { private final LimitQueueCallback callback; - private final BlockingQueue> rowQueue = - new LinkedBlockingQueue<>(100); + private final BlockingQueue> rowQueue; + private final int offerTimeoutMs; + private volatile boolean closed = false; TransientQueryQueue(final KStream kstream, final OptionalInt limit) { + this(kstream, limit, 100, 100); + } + + @VisibleForTesting + TransientQueryQueue( + final KStream kstream, + final OptionalInt limit, + final int queueSizeLimit, + final int offerTimeoutMs + ) { this.callback = limit.isPresent() ? new LimitedQueueCallback(limit.getAsInt()) : new UnlimitedQueueCallback(); + this.rowQueue = new LinkedBlockingQueue<>(queueSizeLimit); + this.offerTimeoutMs = offerTimeoutMs; - kstream.foreach(new TransientQueryQueue.QueuePopulator<>(rowQueue, callback)); + kstream.foreach(new QueuePopulator<>()); } - BlockingQueue> getQueue() { - return rowQueue; + @Override + public void setLimitHandler(final LimitHandler limitHandler) { + callback.setLimitHandler(limitHandler); } - void setLimitHandler(final LimitHandler limitHandler) { - callback.setLimitHandler(limitHandler); + @Override + public KeyValue poll(final long timeout, final TimeUnit unit) + throws InterruptedException { + return rowQueue.poll(timeout, unit); } - @SuppressWarnings("OptionalUsedAsFieldOrParameterType") - static final class QueuePopulator implements ForeachAction { + @Override + public void drainTo(final Collection> collection) { + rowQueue.drainTo(collection); + } - private final BlockingQueue> queue; - private final QueueCallback callback; + @Override + public int size() { + return rowQueue.size(); + } - QueuePopulator( - final BlockingQueue> queue, - final QueueCallback callback - ) { - this.queue = Objects.requireNonNull(queue, "queue"); - this.callback = Objects.requireNonNull(callback, "callback"); - } + @Override + public void close() { + closed = true; + } + + @VisibleForTesting + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + final class QueuePopulator implements ForeachAction { @Override public void apply(final K key, final GenericRow row) { @@ -76,18 +100,22 @@ public void apply(final K key, final GenericRow row) { return; } - final String keyString = getStringKey(key); - queue.put(new KeyValue<>(keyString, row)); + final KeyValue kv = new KeyValue<>(getStringKey(key), row); - callback.onQueued(); - } catch (final InterruptedException exception) { + while (!closed) { + if (rowQueue.offer(kv, offerTimeoutMs, TimeUnit.MILLISECONDS)) { + callback.onQueued(); + break; + } + } + } catch (final InterruptedException e) { throw new KsqlException("InterruptedException while enqueueing:" + key); } } private String getStringKey(final K key) { if (key instanceof Windowed) { - final Windowed windowedKey = (Windowed) key; + final Windowed windowedKey = (Windowed) key; return String.format("%s : %s", windowedKey.key(), windowedKey.window()); } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/util/PersistentQueryMetadata.java b/ksql-engine/src/main/java/io/confluent/ksql/util/PersistentQueryMetadata.java index 2dea992ffafd..6036ec1e78f3 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/util/PersistentQueryMetadata.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/util/PersistentQueryMetadata.java @@ -64,7 +64,7 @@ public PersistentQueryMetadata( final Map streamsProperties, final Map overriddenProperties, final Consumer closeCallback, - final Long closeTimeout) { + final long closeTimeout) { // CHECKSTYLE_RULES.ON: ParameterNumberCheck super( statementString, diff --git a/ksql-engine/src/main/java/io/confluent/ksql/util/QueryMetadata.java b/ksql-engine/src/main/java/io/confluent/ksql/util/QueryMetadata.java index 47414605ab6f..3d24c3c53c40 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/util/QueryMetadata.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/util/QueryMetadata.java @@ -62,7 +62,8 @@ protected QueryMetadata( final Map streamsProperties, final Map overriddenProperties, final Consumer closeCallback, - final Long closeTimeout) { + final long closeTimeout + ) { // CHECKSTYLE_RULES.ON: ParameterNumberCheck this.statementString = Objects.requireNonNull(statementString, "statementString"); this.kafkaStreams = Objects.requireNonNull(kafkaStreams, "kafkaStreams"); @@ -78,7 +79,7 @@ protected QueryMetadata( this.closeCallback = Objects.requireNonNull(closeCallback, "closeCallback"); this.sourceNames = Objects.requireNonNull(sourceNames, "sourceNames"); this.logicalSchema = Objects.requireNonNull(logicalSchema, "logicalSchema"); - this.closeTimeout = Objects.requireNonNull(closeTimeout, "closeTimeout"); + this.closeTimeout = closeTimeout; } protected QueryMetadata(final QueryMetadata other, final Consumer closeCallback) { diff --git a/ksql-engine/src/main/java/io/confluent/ksql/util/TransientQueryMetadata.java b/ksql-engine/src/main/java/io/confluent/ksql/util/TransientQueryMetadata.java index 8f3f33be693f..3c83876d073e 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/util/TransientQueryMetadata.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/util/TransientQueryMetadata.java @@ -15,18 +15,16 @@ package io.confluent.ksql.util; -import io.confluent.ksql.GenericRow; import io.confluent.ksql.name.SourceName; +import io.confluent.ksql.query.BlockingRowQueue; import io.confluent.ksql.query.LimitHandler; import io.confluent.ksql.schema.ksql.LogicalSchema; import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import org.apache.kafka.streams.KafkaStreams; -import org.apache.kafka.streams.KeyValue; import org.apache.kafka.streams.Topology; /** @@ -34,9 +32,8 @@ */ public class TransientQueryMetadata extends QueryMetadata { - private final BlockingQueue> rowQueue; + private final BlockingRowQueue rowQueue; private final AtomicBoolean isRunning = new AtomicBoolean(true); - private final Consumer limitHandlerSetter; // CHECKSTYLE_RULES.OFF: ParameterNumberCheck public TransientQueryMetadata( @@ -44,15 +41,14 @@ public TransientQueryMetadata( final KafkaStreams kafkaStreams, final LogicalSchema logicalSchema, final Set sourceNames, - final Consumer limitHandlerSetter, final String executionPlan, - final BlockingQueue> rowQueue, + final BlockingRowQueue rowQueue, final String queryApplicationId, final Topology topology, final Map streamsProperties, final Map overriddenProperties, final Consumer closeCallback, - final Long closeTimeout) { + final long closeTimeout) { // CHECKSTYLE_RULES.ON: ParameterNumberCheck super( statementString, @@ -65,8 +61,8 @@ public TransientQueryMetadata( streamsProperties, overriddenProperties, closeCallback, - closeTimeout); - this.limitHandlerSetter = Objects.requireNonNull(limitHandlerSetter, "limitHandlerSetter"); + closeTimeout + ); this.rowQueue = Objects.requireNonNull(rowQueue, "rowQueue"); if (!logicalSchema.metadata().isEmpty() || !logicalSchema.key().isEmpty()) { @@ -78,7 +74,7 @@ public boolean isRunning() { return isRunning.get(); } - public BlockingQueue> getRowQueue() { + public BlockingRowQueue getRowQueue() { return rowQueue; } @@ -99,11 +95,16 @@ public int hashCode() { } public void setLimitHandler(final LimitHandler limitHandler) { - limitHandlerSetter.accept(limitHandler); + rowQueue.setLimitHandler(limitHandler); } @Override public void close() { + // To avoid deadlock, close the queue first to ensure producer side isn't blocked trying to + // write to the blocking queue, otherwise super.close call can deadlock: + rowQueue.close(); + + // Now safe to close: super.close(); isRunning.set(false); } diff --git a/ksql-engine/src/test/java/io/confluent/ksql/integration/EndToEndIntegrationTest.java b/ksql-engine/src/test/java/io/confluent/ksql/integration/EndToEndIntegrationTest.java index ae3edbcfb317..580ccdf3b16b 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/integration/EndToEndIntegrationTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/integration/EndToEndIntegrationTest.java @@ -30,6 +30,7 @@ import io.confluent.ksql.GenericRow; import io.confluent.ksql.function.udf.Udf; import io.confluent.ksql.function.udf.UdfDescription; +import io.confluent.ksql.query.BlockingRowQueue; import io.confluent.ksql.query.QueryId; import io.confluent.ksql.serde.Format; import io.confluent.ksql.util.KsqlConstants; @@ -44,7 +45,6 @@ import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; @@ -216,7 +216,7 @@ public void shouldSelectAllFromDerivedStream() throws Exception { "SELECT * from pageviews_female EMIT CHANGES;"); final List> results = new ArrayList<>(); - final BlockingQueue> rowQueue = queryMetadata.getRowQueue(); + final BlockingRowQueue rowQueue = queryMetadata.getRowQueue(); // From the mock data, we expect exactly 3 page views from female users. final List expectedPages = ImmutableList.of("PAGE_2", "PAGE_5", "PAGE_5"); @@ -402,7 +402,7 @@ private static List verifyAvailableRows( final TransientQueryMetadata queryMetadata, final int expectedRows ) throws Exception { - final BlockingQueue> rowQueue = queryMetadata.getRowQueue(); + final BlockingRowQueue rowQueue = queryMetadata.getRowQueue(); TestUtils.waitForCondition( () -> rowQueue.size() >= expectedRows, diff --git a/ksql-engine/src/test/java/io/confluent/ksql/query/TransientQueryQueueTest.java b/ksql-engine/src/test/java/io/confluent/ksql/query/TransientQueryQueueTest.java index aea22e2c4e53..e59167da08ed 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/query/TransientQueryQueueTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/query/TransientQueryQueueTest.java @@ -16,9 +16,9 @@ package io.confluent.ksql.query; import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.empty; -import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -26,47 +26,58 @@ import io.confluent.ksql.GenericRow; import io.confluent.ksql.query.TransientQueryQueue.QueuePopulator; +import java.util.ArrayList; +import java.util.List; import java.util.OptionalInt; -import java.util.Queue; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; import java.util.stream.IntStream; import org.apache.kafka.streams.KeyValue; import org.apache.kafka.streams.kstream.KStream; +import org.junit.After; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -@SuppressWarnings("ConstantConditions") +@SuppressWarnings("unchecked") @RunWith(MockitoJUnitRunner.class) public class TransientQueryQueueTest { private static final int SOME_LIMIT = 4; + private static final int MAX_LIMIT = SOME_LIMIT * 2; private static final GenericRow ROW_ONE = mock(GenericRow.class); private static final GenericRow ROW_TWO = mock(GenericRow.class); + @Rule + public final Timeout timeout = Timeout.seconds(10); + @Mock private LimitHandler limitHandler; @Mock private KStream kStreamsApp; @Captor private ArgumentCaptor> queuePopulatorCaptor; - private Queue> queue; private QueuePopulator queuePopulator; + private TransientQueryQueue queue; + private ScheduledExecutorService executorService; @Before public void setUp() { - final TransientQueryQueue queuer = - new TransientQueryQueue(kStreamsApp, OptionalInt.of(SOME_LIMIT)); - - queuer.setLimitHandler(limitHandler); - - queue = queuer.getQueue(); + givenQueue(OptionalInt.of(SOME_LIMIT)); + } - verify(kStreamsApp).foreach(queuePopulatorCaptor.capture()); - queuePopulator = queuePopulatorCaptor.getValue(); + @After + public void tearDown() { + if (executorService != null) { + executorService.shutdownNow(); + } } @Test @@ -76,11 +87,10 @@ public void shouldQueue() { queuePopulator.apply("key2", ROW_TWO); // Then: - assertThat(queue, hasSize(2)); - assertThat(queue.peek().key, is("key1")); - assertThat(queue.remove().value, is(ROW_ONE)); - assertThat(queue.peek().key, is("key2")); - assertThat(queue.remove().value, is(ROW_TWO)); + assertThat(drainValues(), contains( + new KeyValue<>("key1", ROW_ONE), + new KeyValue<>("key2", ROW_TWO) + )); } @Test @@ -89,7 +99,7 @@ public void shouldNotQueueNullValues() { queuePopulator.apply("key1", null); // Then: - assertThat(queue, is(empty())); + assertThat(queue.size(), is(0)); } @Test @@ -99,7 +109,21 @@ public void shouldQueueUntilLimitReached() { .forEach(idx -> queuePopulator.apply("key1", ROW_ONE)); // Then: - assertThat(queue, hasSize(SOME_LIMIT)); + assertThat(queue.size(), is(SOME_LIMIT)); + } + + @Test + public void shouldPoll() throws Exception { + // Given: + queuePopulator.apply("key1", ROW_ONE); + queuePopulator.apply("key2", ROW_TWO); + + // When: + final KeyValue result = queue.poll(1, TimeUnit.SECONDS); + + // Then: + assertThat(result, is(new KeyValue<>("key1", ROW_ONE))); + assertThat(drainValues(), contains(new KeyValue<>("key2", ROW_TWO))); } @Test @@ -131,4 +155,42 @@ public void shouldCallLimitHandlerOnlyOnce() { // Then: verify(limitHandler, times(1)).limitReached(); } + + @Test + public void shouldBlockOnProduceOnceQueueLimitReachedAndUnblockOnClose() { + // Given: + givenQueue(OptionalInt.empty()); + + IntStream.range(0, MAX_LIMIT) + .forEach(idx -> queuePopulator.apply("key1", ROW_ONE)); + + givenWillCloseQueueAsync(); + + // When: + queuePopulator.apply("should not be queued", ROW_TWO); + + // Then: did not block and: + assertThat(queue.size(), is(MAX_LIMIT)); + } + + private void givenWillCloseQueueAsync() { + executorService = Executors.newSingleThreadScheduledExecutor(); + executorService.schedule(queue::close, 200, TimeUnit.MILLISECONDS); + } + + private void givenQueue(final OptionalInt limit) { + clearInvocations(kStreamsApp); + queue = new TransientQueryQueue(kStreamsApp, limit, MAX_LIMIT, 1); + + queue.setLimitHandler(limitHandler); + + verify(kStreamsApp).foreach(queuePopulatorCaptor.capture()); + queuePopulator = queuePopulatorCaptor.getValue(); + } + + private List> drainValues() { + final List> entries = new ArrayList<>(); + queue.drainTo(entries); + return entries; + } } \ No newline at end of file diff --git a/ksql-engine/src/test/java/io/confluent/ksql/util/TransientQueryMetadataTest.java b/ksql-engine/src/test/java/io/confluent/ksql/util/TransientQueryMetadataTest.java new file mode 100644 index 000000000000..198529e24342 --- /dev/null +++ b/ksql-engine/src/test/java/io/confluent/ksql/util/TransientQueryMetadataTest.java @@ -0,0 +1,90 @@ +/* + * Copyright 2020 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.util; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.inOrder; + +import io.confluent.ksql.name.SourceName; +import io.confluent.ksql.query.BlockingRowQueue; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.Topology; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class TransientQueryMetadataTest { + + private static final String QUERY_ID = "queryId"; + private static final String EXECUTION_PLAN = "execution plan"; + private static final String SQL = "sql"; + private static final long CLOSE_TIMEOUT = 10L; + + @Mock + private KafkaStreams kafkaStreams; + @Mock + private LogicalSchema logicalSchema; + @Mock + private Set sourceNames; + @Mock + private BlockingRowQueue rowQueue; + @Mock + private Topology topology; + @Mock + private Map props; + @Mock + private Map overrides; + @Mock + private Consumer closeCallback; + private TransientQueryMetadata query; + + @Before + public void setUp() { + query = new TransientQueryMetadata( + SQL, + kafkaStreams, + logicalSchema, + sourceNames, + EXECUTION_PLAN, + rowQueue, + QUERY_ID, + topology, + props, + overrides, + closeCallback, + CLOSE_TIMEOUT + ); + } + + @Test + public void shouldCloseQueueBeforeTopologyToAvoidDeadLock() { + // When: + query.close(); + + // Then: + final InOrder inOrder = inOrder(rowQueue, kafkaStreams); + inOrder.verify(rowQueue).close(); + inOrder.verify(kafkaStreams).close(any()); + } +} \ No newline at end of file diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/entity/QueryDescriptionFactoryTest.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/entity/QueryDescriptionFactoryTest.java index a77244684401..79c11ad3f9ff 100644 --- a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/entity/QueryDescriptionFactoryTest.java +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/entity/QueryDescriptionFactoryTest.java @@ -26,7 +26,7 @@ import io.confluent.ksql.metastore.model.DataSource.DataSourceType; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.SourceName; -import io.confluent.ksql.query.LimitHandler; +import io.confluent.ksql.query.BlockingRowQueue; import io.confluent.ksql.query.QueryId; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.PhysicalSchema; @@ -42,7 +42,6 @@ import java.util.LinkedHashMap; import java.util.Map; import java.util.Optional; -import java.util.concurrent.LinkedBlockingQueue; import java.util.function.Consumer; import java.util.stream.Collectors; import org.apache.kafka.streams.KafkaStreams; @@ -86,7 +85,7 @@ public class QueryDescriptionFactoryTest { @Mock(name = TOPOLOGY_TEXT) private TopologyDescription topologyDescription; @Mock - private Consumer limitHandler; + private BlockingRowQueue queryQueue; @Mock private KsqlTopic sinkTopic; private QueryMetadata transientQuery; @@ -103,9 +102,8 @@ public void setUp() { queryStreams, TRANSIENT_SCHEMA, SOURCE_NAMES, - limitHandler, "execution plan", - new LinkedBlockingQueue<>(), + queryQueue, "app id", topology, STREAMS_PROPS, @@ -218,9 +216,8 @@ public void shouldHandleRowTimeInValueSchemaForTransientQuery() { queryStreams, schema, SOURCE_NAMES, - limitHandler, "execution plan", - new LinkedBlockingQueue<>(), + queryQueue, "app id", topology, STREAMS_PROPS, @@ -253,9 +250,8 @@ public void shouldHandleRowKeyInValueSchemaForTransientQuery() { queryStreams, schema, SOURCE_NAMES, - limitHandler, "execution plan", - new LinkedBlockingQueue<>(), + queryQueue, "app id", topology, STREAMS_PROPS, diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/QueryStreamWriterTest.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/QueryStreamWriterTest.java index 1ac874e28b9a..1a2f73cf58c9 100644 --- a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/QueryStreamWriterTest.java +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/QueryStreamWriterTest.java @@ -32,6 +32,7 @@ import io.confluent.ksql.engine.KsqlEngine; import io.confluent.ksql.json.JsonMapper; import io.confluent.ksql.name.ColumnName; +import io.confluent.ksql.query.BlockingRowQueue; import io.confluent.ksql.query.LimitHandler; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; @@ -42,7 +43,6 @@ import java.util.Arrays; import java.util.Collection; import java.util.List; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import org.apache.kafka.streams.KafkaStreams; @@ -59,7 +59,7 @@ import org.junit.rules.Timeout; import org.junit.runner.RunWith; -@SuppressWarnings({"unchecked", "ConstantConditions"}) +@SuppressWarnings("unchecked") @RunWith(EasyMockRunner.class) public class QueryStreamWriterTest { @@ -74,7 +74,7 @@ public class QueryStreamWriterTest { @Mock(MockType.NICE) private TransientQueryMetadata queryMetadata; @Mock(MockType.NICE) - private BlockingQueue> rowQueue; + private BlockingRowQueue rowQueue; private Capture ehCapture; private Capture>> drainCapture; private Capture limitHandlerCapture; @@ -115,10 +115,11 @@ public void setUp() { } @Test - public void shouldWriteAnyPendingRowsBeforeReportingException() throws Exception { + public void shouldWriteAnyPendingRowsBeforeReportingException() { // Given: expect(queryMetadata.isRunning()).andReturn(true).anyTimes(); - expect(rowQueue.drainTo(capture(drainCapture))).andAnswer(rows("Row1", "Row2", "Row3")); + rowQueue.drainTo(capture(drainCapture)); + expectLastCall().andAnswer(rows("Row1", "Row2", "Row3")); createWriter(); @@ -136,10 +137,11 @@ public void shouldWriteAnyPendingRowsBeforeReportingException() throws Exception } @Test - public void shouldExitAndDrainIfQueryStopsRunning() throws Exception { + public void shouldExitAndDrainIfQueryStopsRunning() { // Given: expect(queryMetadata.isRunning()).andReturn(true).andReturn(false); - expect(rowQueue.drainTo(capture(drainCapture))).andAnswer(rows("Row1", "Row2", "Row3")); + rowQueue.drainTo(capture(drainCapture)); + expectLastCall().andAnswer(rows("Row1", "Row2", "Row3")); createWriter(); @@ -155,10 +157,11 @@ public void shouldExitAndDrainIfQueryStopsRunning() throws Exception { } @Test - public void shouldExitAndDrainIfLimitReached() throws Exception { + public void shouldExitAndDrainIfLimitReached() { // Given: expect(queryMetadata.isRunning()).andReturn(true).anyTimes(); - expect(rowQueue.drainTo(capture(drainCapture))).andAnswer(rows("Row1", "Row2", "Row3")); + rowQueue.drainTo(capture(drainCapture)); + expectLastCall().andAnswer(rows("Row1", "Row2", "Row3")); createWriter(); diff --git a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResourceTest.java b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResourceTest.java index 7f41dd116f26..d0935b7bdafd 100644 --- a/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResourceTest.java +++ b/ksql-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResourceTest.java @@ -24,7 +24,6 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; @@ -47,6 +46,8 @@ import io.confluent.ksql.parser.tree.PrintTopic; import io.confluent.ksql.parser.tree.Query; import io.confluent.ksql.parser.tree.Statement; +import io.confluent.ksql.query.BlockingRowQueue; +import io.confluent.ksql.query.LimitHandler; import io.confluent.ksql.rest.Errors; import io.confluent.ksql.rest.entity.KsqlErrorMessage; import io.confluent.ksql.rest.entity.KsqlRequest; @@ -70,12 +71,15 @@ import java.io.PipedInputStream; import java.io.PipedOutputStream; import java.time.Duration; +import java.util.Collection; import java.util.Collections; import java.util.LinkedList; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Scanner; import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -365,9 +369,8 @@ public void shouldStreamRowsCorrectly() throws Throwable { mockKafkaStreams, SOME_SCHEMA, Collections.emptySet(), - limitHandler -> {}, "", - rowQueue, + new TestRowQueue(rowQueue), "", mock(Topology.class), Collections.emptyMap(), @@ -612,4 +615,41 @@ public void shouldSuggestAlternativesIfPrintTopicDoesNotExist() { new KsqlRequest(PRINT_TOPIC, Collections.emptyMap(), null) ); } + + private static class TestRowQueue implements BlockingRowQueue { + + private final SynchronousQueue> rowQueue; + + TestRowQueue( + final SynchronousQueue> rowQueue + ) { + this.rowQueue = Objects.requireNonNull(rowQueue, "rowQueue"); + } + + @Override + public void setLimitHandler(final LimitHandler limitHandler) { + + } + + @Override + public KeyValue poll(final long timeout, final TimeUnit unit) + throws InterruptedException { + return rowQueue.poll(timeout, unit); + } + + @Override + public void drainTo(final Collection> collection) { + rowQueue.drainTo(collection); + } + + @Override + public int size() { + return rowQueue.size(); + } + + @Override + public void close() { + + } + } }