Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KAFKA-9623: Keep polling until the task manager is no longer rebalancing in progress #8190

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,9 @@ public void run() {
private void runLoop() {
subscribeConsumer();

while (isRunning()) {
// if the thread is still in the middle of a rebalance, we should keep polling
// until the rebalance is completed before we close and commit the tasks
while (isRunning() || taskManager.isRebalanceInProgress()) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the actual fix.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to avoid adding these new records to the PartitionGroup down on line 825 or else they'll be included in the offsets committed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we need to poll more than once to complete the rebalance for example

try {
runOnce();
if (assignmentErrorCode.get() == AssignorError.VERSION_PROBING.code()) {
Expand Down Expand Up @@ -806,6 +808,10 @@ void runOnce() {
// try to fetch some records with normal poll time
// in order to get long polling
records = pollRequests(pollTime);
} else if (state == State.PENDING_SHUTDOWN) {
// we are only here because there's rebalance in progress,
// just poll with zero to complete it
records = pollRequests(Duration.ZERO);
} else {
// any other state should not happen
log.error("Unexpected state {} during normal iteration", state);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ public void onPartitionsAssigned(final Collection<TopicPartition> partitions) {
log.error("Received error code {} - shutdown", streamThread.getAssignmentErrorCode());
streamThread.shutdown();
} else {
taskManager.handleRebalanceComplete();

streamThread.setState(State.PARTITIONS_ASSIGNED);
}

taskManager.handleRebalanceComplete();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to put this out of the loop since if we get an error code, we should still set the flag so that thread can complete shutdown.

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ InternalTopologyBuilder builder() {
return builder;
}

boolean isRebalanceInProgress() {
return rebalanceInProgress;
}

void handleRebalanceStart(final Set<String> subscribedTopics) {
builder.addSubscribedTopicsFromMetadata(subscribedTopics, logPrefix);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,51 @@ public void shouldInjectProducerPerTaskUsingClientSupplierOnCreateIfEosEnable()
assertSame(clientSupplier.restoreConsumer, thread.restoreConsumer);
}

@Test
public void shouldOnlyCompleteShutdownAfterRebalanceNotInProgress() throws InterruptedException {
internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1);

final StreamThread thread = createStreamThread(CLIENT_ID, new StreamsConfig(configProps(true)), true);

thread.start();
TestUtils.waitForCondition(
() -> thread.state() == StreamThread.State.STARTING,
10 * 1000,
"Thread never started.");

thread.rebalanceListener.onPartitionsRevoked(Collections.emptyList());
thread.taskManager().handleRebalanceStart(Collections.singleton(topic1));

final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
final List<TopicPartition> assignedPartitions = new ArrayList<>();

// assign single partition
assignedPartitions.add(t1p1);
assignedPartitions.add(t1p2);
activeTasks.put(task1, Collections.singleton(t1p1));
activeTasks.put(task2, Collections.singleton(t1p2));

thread.taskManager().handleAssignment(activeTasks, Collections.emptyMap());

thread.shutdown();

// even if thread is no longer running, it should still be polling
// as long as the rebalance is still ongoing
assertFalse(thread.isRunning());

Thread.sleep(1000);
assertEquals(Utils.mkSet(task1, task2), thread.taskManager().activeTaskIds());
assertEquals(StreamThread.State.PENDING_SHUTDOWN, thread.state());

thread.rebalanceListener.onPartitionsAssigned(assignedPartitions);

TestUtils.waitForCondition(
() -> thread.state() == StreamThread.State.DEAD,
10 * 1000,
"Thread never shut down.");
assertEquals(Collections.emptySet(), thread.taskManager().activeTaskIds());
}

@Test
public void shouldCloseAllTaskProducersOnCloseIfEosEnabled() throws InterruptedException {
internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,20 +485,20 @@ private static boolean verify(final PrintStream resultStream,
if (!expected.equals(actual)) {
resultStream.printf("%s fail: key=%s actual=%s expected=%s%n", topic, key, actual, expected);

if (printResults) {
resultStream.printf("\t inputEvents=%n%s%n\t" +
"echoEvents=%n%s%n\tmaxEvents=%n%s%n\tminEvents=%n%s%n\tdifEvents=%n%s%n\tcntEvents=%n%s%n\ttaggEvents=%n%s%n",
indent("\t\t", observedInputEvents.get(key)),
indent("\t\t", events.getOrDefault("echo", emptyMap()).getOrDefault(key, new LinkedList<>())),
indent("\t\t", events.getOrDefault("max", emptyMap()).getOrDefault(key, new LinkedList<>())),
indent("\t\t", events.getOrDefault("min", emptyMap()).getOrDefault(key, new LinkedList<>())),
indent("\t\t", events.getOrDefault("dif", emptyMap()).getOrDefault(key, new LinkedList<>())),
indent("\t\t", events.getOrDefault("cnt", emptyMap()).getOrDefault(key, new LinkedList<>())),
indent("\t\t", events.getOrDefault("tagg", emptyMap()).getOrDefault(key, new LinkedList<>())));

if (!Utils.mkSet("echo", "max", "min", "dif", "cnt", "tagg").contains(topic))
resultStream.printf("%sEvents=%n%s%n", topic, indent("\t\t", entry.getValue()));
}
// if (printResults) {
// resultStream.printf("\t inputEvents=%n%s%n\t" +
// "echoEvents=%n%s%n\tmaxEvents=%n%s%n\tminEvents=%n%s%n\tdifEvents=%n%s%n\tcntEvents=%n%s%n\ttaggEvents=%n%s%n",
// indent("\t\t", observedInputEvents.get(key)),
// indent("\t\t", events.getOrDefault("echo", emptyMap()).getOrDefault(key, new LinkedList<>())),
// indent("\t\t", events.getOrDefault("max", emptyMap()).getOrDefault(key, new LinkedList<>())),
// indent("\t\t", events.getOrDefault("min", emptyMap()).getOrDefault(key, new LinkedList<>())),
// indent("\t\t", events.getOrDefault("dif", emptyMap()).getOrDefault(key, new LinkedList<>())),
// indent("\t\t", events.getOrDefault("cnt", emptyMap()).getOrDefault(key, new LinkedList<>())),
// indent("\t\t", events.getOrDefault("tagg", emptyMap()).getOrDefault(key, new LinkedList<>())));
//
// if (!Utils.mkSet("echo", "max", "min", "dif", "cnt", "tagg").contains(topic))
// resultStream.printf("%sEvents=%n%s%n", topic, indent("\t\t", entry.getValue()));
// }

return false;
}
Expand Down