diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java index 4dcad01c3143d..1ec6e7b297964 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java @@ -744,7 +744,9 @@ public void run() { private void runLoop() { subscribeConsumer(); - while (isRunning()) { + // if the thread is still in the middle of a rebalance, we should keep polling + // until the rebalance is completed before we close and commit the tasks + while (isRunning() || taskManager.isRebalanceInProgress()) { try { runOnce(); if (assignmentErrorCode.get() == AssignorError.VERSION_PROBING.code()) { @@ -806,6 +808,10 @@ void runOnce() { // try to fetch some records with normal poll time // in order to get long polling records = pollRequests(pollTime); + } else if (state == State.PENDING_SHUTDOWN) { + // we are only here because there's rebalance in progress, + // just poll with zero to complete it + records = pollRequests(Duration.ZERO); } else { // any other state should not happen log.error("Unexpected state {} during normal iteration", state); diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsRebalanceListener.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsRebalanceListener.java index 56b0dc698a301..2c85eaa89820d 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsRebalanceListener.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsRebalanceListener.java @@ -50,10 +50,10 @@ public void onPartitionsAssigned(final Collection partitions) { log.error("Received error code {} - shutdown", streamThread.getAssignmentErrorCode()); streamThread.shutdown(); } else { - taskManager.handleRebalanceComplete(); - streamThread.setState(State.PARTITIONS_ASSIGNED); } + + taskManager.handleRebalanceComplete(); } @Override diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java index b36fb36ffec17..84facc5bacae2 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java @@ -112,6 +112,10 @@ InternalTopologyBuilder builder() { return builder; } + boolean isRebalanceInProgress() { + return rebalanceInProgress; + } + void handleRebalanceStart(final Set subscribedTopics) { builder.addSubscribedTopicsFromMetadata(subscribedTopics, logPrefix); diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java index 765e4c196b9fa..5651c4e4909c9 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java @@ -690,6 +690,51 @@ public void shouldInjectProducerPerTaskUsingClientSupplierOnCreateIfEosEnable() assertSame(clientSupplier.restoreConsumer, thread.restoreConsumer); } + @Test + public void shouldOnlyCompleteShutdownAfterRebalanceNotInProgress() throws InterruptedException { + internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1); + + final StreamThread thread = createStreamThread(CLIENT_ID, new StreamsConfig(configProps(true)), true); + + thread.start(); + TestUtils.waitForCondition( + () -> thread.state() == StreamThread.State.STARTING, + 10 * 1000, + "Thread never started."); + + thread.rebalanceListener.onPartitionsRevoked(Collections.emptyList()); + thread.taskManager().handleRebalanceStart(Collections.singleton(topic1)); + + final Map> activeTasks = new HashMap<>(); + final List assignedPartitions = new ArrayList<>(); + + // assign single partition + assignedPartitions.add(t1p1); + assignedPartitions.add(t1p2); + activeTasks.put(task1, Collections.singleton(t1p1)); + activeTasks.put(task2, Collections.singleton(t1p2)); + + thread.taskManager().handleAssignment(activeTasks, Collections.emptyMap()); + + thread.shutdown(); + + // even if thread is no longer running, it should still be polling + // as long as the rebalance is still ongoing + assertFalse(thread.isRunning()); + + Thread.sleep(1000); + assertEquals(Utils.mkSet(task1, task2), thread.taskManager().activeTaskIds()); + assertEquals(StreamThread.State.PENDING_SHUTDOWN, thread.state()); + + thread.rebalanceListener.onPartitionsAssigned(assignedPartitions); + + TestUtils.waitForCondition( + () -> thread.state() == StreamThread.State.DEAD, + 10 * 1000, + "Thread never shut down."); + assertEquals(Collections.emptySet(), thread.taskManager().activeTaskIds()); + } + @Test public void shouldCloseAllTaskProducersOnCloseIfEosEnabled() throws InterruptedException { internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1);