diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractMembershipManager.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractMembershipManager.java index c6aa70d805e0c..90c2b3a647d7e 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractMembershipManager.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractMembershipManager.java @@ -1175,10 +1175,16 @@ private CompletableFuture assignPartitions( // Invoke user call back. CompletableFuture result = signalPartitionsAssigned(addedPartitions); + // Enable newly added partitions to start fetching and updating positions for them. result.whenComplete((__, exception) -> { if (exception == null) { - // Enable newly added partitions to start fetching and updating positions for them. - subscriptions.enablePartitionsAwaitingCallback(addedPartitions); + // Enable assigned partitions to start fetching and updating positions for them. + // We use assignedPartitions here instead of addedPartitions because there's a chance that the callback + // might throw an exception, leaving addedPartitions empty. This would result in the poll operation + // returning no records, as no topic partitions are marked as fetchable. In contrast, with the classic consumer, + // if the first callback fails but the next one succeeds, polling can still retrieve data. To align with + // this behavior, we rely on assignedPartitions to avoid such scenarios. + subscriptions.enablePartitionsAwaitingCallback(toTopicPartitionSet(assignedPartitions)); } else { // Keeping newly added partitions as non-fetchable after the callback failure. // They will be retried on the next reconciliation loop, until it succeeds or the diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java index ff679e5542d11..f5e12407be52e 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java @@ -84,6 +84,7 @@ import org.apache.kafka.common.errors.InterruptException; import org.apache.kafka.common.errors.InvalidGroupIdException; import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.WakeupException; import org.apache.kafka.common.internals.ClusterResourceListeners; import org.apache.kafka.common.metrics.KafkaMetric; import org.apache.kafka.common.metrics.Metrics; @@ -2072,23 +2073,27 @@ static ConsumerRebalanceListenerCallbackCompletedEvent invokeRebalanceCallbacks( ConsumerRebalanceListenerMethodName methodName, SortedSet partitions, CompletableFuture future) { - final Exception e; + Exception e; - switch (methodName) { - case ON_PARTITIONS_REVOKED: - e = rebalanceListenerInvoker.invokePartitionsRevoked(partitions); - break; + try { + switch (methodName) { + case ON_PARTITIONS_REVOKED: + e = rebalanceListenerInvoker.invokePartitionsRevoked(partitions); + break; - case ON_PARTITIONS_ASSIGNED: - e = rebalanceListenerInvoker.invokePartitionsAssigned(partitions); - break; + case ON_PARTITIONS_ASSIGNED: + e = rebalanceListenerInvoker.invokePartitionsAssigned(partitions); + break; - case ON_PARTITIONS_LOST: - e = rebalanceListenerInvoker.invokePartitionsLost(partitions); - break; + case ON_PARTITIONS_LOST: + e = rebalanceListenerInvoker.invokePartitionsLost(partitions); + break; - default: - throw new IllegalArgumentException("The method " + methodName.fullyQualifiedMethodName() + " to invoke was not expected"); + default: + throw new IllegalArgumentException("The method " + methodName.fullyQualifiedMethodName() + " to invoke was not expected"); + } + } catch (WakeupException | InterruptException ex) { + e = ex; } final Optional error; diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java index f700b8706ca60..bd45e71c884d9 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java @@ -898,8 +898,8 @@ public synchronized void assignFromSubscribedAwaitingCallback(Collection partitions) { diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMembershipManagerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMembershipManagerTest.java index 9517e04e05456..d42d81d7ce427 100644 --- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMembershipManagerTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMembershipManagerTest.java @@ -29,6 +29,8 @@ import org.apache.kafka.common.TopicIdPartition; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.errors.InterruptException; +import org.apache.kafka.common.errors.WakeupException; import org.apache.kafka.common.message.ConsumerGroupHeartbeatResponseData; import org.apache.kafka.common.message.ConsumerGroupHeartbeatResponseData.Assignment; import org.apache.kafka.common.message.ConsumerGroupHeartbeatResponseData.TopicPartitions; @@ -96,6 +98,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +@SuppressWarnings("ClassDataAbstractionCoupling") public class ConsumerMembershipManagerTest { private static final String GROUP_ID = "test-group"; @@ -1738,6 +1741,12 @@ public void testListenerCallbacksBasic() { @Test public void testListenerCallbacksThrowsErrorOnPartitionsRevoked() { + testErrorsOnPartitionsRevoked(new WakeupException()); + testErrorsOnPartitionsRevoked(new InterruptException("Intentional onPartitionsRevoked() error")); + testErrorsOnPartitionsRevoked(new IllegalArgumentException("Intentional onPartitionsRevoked() error")); + } + + private void testErrorsOnPartitionsRevoked(RuntimeException error) { // Step 1: set up mocks String topicName = "topic1"; Uuid topicId = Uuid.randomUuid(); @@ -1745,7 +1754,7 @@ public void testListenerCallbacksThrowsErrorOnPartitionsRevoked() { ConsumerMembershipManager membershipManager = createMemberInStableState(); mockOwnedPartition(membershipManager, topicId, topicName); CounterConsumerRebalanceListener listener = new CounterConsumerRebalanceListener( - Optional.of(new IllegalArgumentException("Intentional onPartitionsRevoked() error")), + Optional.ofNullable(error), Optional.empty(), Optional.empty() ); @@ -1792,6 +1801,12 @@ public void testListenerCallbacksThrowsErrorOnPartitionsRevoked() { @Test public void testListenerCallbacksThrowsErrorOnPartitionsAssigned() { + testErrorsOnPartitionsAssigned(new WakeupException()); + testErrorsOnPartitionsAssigned(new InterruptException("Intentional error")); + testErrorsOnPartitionsAssigned(new IllegalArgumentException("Intentional error")); + } + + private void testErrorsOnPartitionsAssigned(RuntimeException error) { // Step 1: set up mocks ConsumerMembershipManager membershipManager = createMemberInStableState(); String topicName = "topic1"; @@ -1799,7 +1814,7 @@ public void testListenerCallbacksThrowsErrorOnPartitionsAssigned() { mockOwnedPartition(membershipManager, topicId, topicName); CounterConsumerRebalanceListener listener = new CounterConsumerRebalanceListener( Optional.empty(), - Optional.of(new IllegalArgumentException("Intentional onPartitionsAssigned() error")), + Optional.ofNullable(error), Optional.empty() ); ConsumerRebalanceListenerInvoker invoker = consumerRebalanceListenerInvoker(); @@ -1879,7 +1894,7 @@ public void testAddedPartitionsTemporarilyDisabledAwaitingOnPartitionsAssignedCa true ); - verify(subscriptionState).enablePartitionsAwaitingCallback(addedPartitions); + verify(subscriptionState).enablePartitionsAwaitingCallback(assignedPartitions); } @Test @@ -1915,12 +1930,14 @@ public void testAddedPartitionsNotEnabledAfterFailedOnPartitionsAssignedCallback @Test public void testOnPartitionsLostNoError() { - testOnPartitionsLost(Optional.empty()); + testOnPartitionsLost(null); } @Test public void testOnPartitionsLostError() { - testOnPartitionsLost(Optional.of(new KafkaException("Intentional error for test"))); + testOnPartitionsLost(new KafkaException("Intentional error for test")); + testOnPartitionsLost(new WakeupException()); + testOnPartitionsLost(new InterruptException("Intentional error for test")); } private void assertLeaveGroupDueToExpiredPollAndTransitionToStale(ConsumerMembershipManager membershipManager) { @@ -2054,7 +2071,7 @@ private void mockPartitionOwnedAndNewPartitionAdded(String topicName, receiveAssignment(topicId, Arrays.asList(partitionOwned, partitionAdded), membershipManager); } - private void testOnPartitionsLost(Optional lostError) { + private void testOnPartitionsLost(RuntimeException lostError) { // Step 1: set up mocks ConsumerMembershipManager membershipManager = createMemberInStableState(); String topicName = "topic1"; @@ -2063,7 +2080,7 @@ private void testOnPartitionsLost(Optional lostError) { CounterConsumerRebalanceListener listener = new CounterConsumerRebalanceListener( Optional.empty(), Optional.empty(), - lostError + Optional.ofNullable(lostError) ); ConsumerRebalanceListenerInvoker invoker = consumerRebalanceListenerInvoker(); diff --git a/core/src/test/java/kafka/clients/consumer/AsyncKafkaConsumerIntegrationTest.java b/core/src/test/java/kafka/clients/consumer/AsyncKafkaConsumerIntegrationTest.java deleted file mode 100644 index 7e1c062c0fcf7..0000000000000 --- a/core/src/test/java/kafka/clients/consumer/AsyncKafkaConsumerIntegrationTest.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package kafka.clients.consumer; - -import org.apache.kafka.clients.consumer.ConsumerConfig; -import org.apache.kafka.clients.consumer.GroupProtocol; -import org.apache.kafka.clients.consumer.KafkaConsumer; -import org.apache.kafka.clients.consumer.internals.AbstractHeartbeatRequestManager; -import org.apache.kafka.common.errors.UnsupportedVersionException; -import org.apache.kafka.common.serialization.StringDeserializer; -import org.apache.kafka.common.test.TestUtils; -import org.apache.kafka.common.test.api.ClusterConfigProperty; -import org.apache.kafka.common.test.api.ClusterInstance; -import org.apache.kafka.common.test.api.ClusterTest; -import org.apache.kafka.common.test.api.ClusterTestExtensions; -import org.apache.kafka.common.test.api.ClusterTests; - -import org.junit.jupiter.api.extension.ExtendWith; - -import java.time.Duration; -import java.util.Collections; -import java.util.Map; - -@ExtendWith(ClusterTestExtensions.class) -public class AsyncKafkaConsumerIntegrationTest { - - @ClusterTests({ - @ClusterTest(serverProperties = { - @ClusterConfigProperty(key = "offsets.topic.num.partitions", value = "1"), - @ClusterConfigProperty(key = "offsets.topic.replication.factor", value = "1"), - @ClusterConfigProperty(key = "group.coordinator.new.enable", value = "false") - }), - @ClusterTest(serverProperties = { - @ClusterConfigProperty(key = "offsets.topic.num.partitions", value = "1"), - @ClusterConfigProperty(key = "offsets.topic.replication.factor", value = "1"), - @ClusterConfigProperty(key = "group.coordinator.rebalance.protocols", value = "classic") - }) - }) - public void testAsyncConsumerWithOldGroupCoordinator(ClusterInstance clusterInstance) throws Exception { - String topic = "test-topic"; - clusterInstance.createTopic(topic, 1, (short) 1); - try (KafkaConsumer consumer = new KafkaConsumer<>(Map.of( - ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, clusterInstance.bootstrapServers(), - ConsumerConfig.GROUP_ID_CONFIG, "test-group", - ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName(), - ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName(), - ConsumerConfig.GROUP_PROTOCOL_CONFIG, GroupProtocol.CONSUMER.name()))) { - consumer.subscribe(Collections.singletonList(topic)); - TestUtils.waitForCondition(() -> { - try { - consumer.poll(Duration.ofMillis(1000)); - return false; - } catch (UnsupportedVersionException e) { - return e.getMessage().equals(AbstractHeartbeatRequestManager.CONSUMER_PROTOCOL_NOT_SUPPORTED_MSG); - } - }, "Should get UnsupportedVersionException and how to revert to classic protocol"); - } - } -} diff --git a/core/src/test/java/kafka/clients/consumer/ConsumerIntegrationTest.java b/core/src/test/java/kafka/clients/consumer/ConsumerIntegrationTest.java new file mode 100644 index 0000000000000..00ba1fbfcad5b --- /dev/null +++ b/core/src/test/java/kafka/clients/consumer/ConsumerIntegrationTest.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.clients.consumer; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.GroupProtocol; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.consumer.internals.AbstractHeartbeatRequestManager; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.test.TestUtils; +import org.apache.kafka.common.test.api.ClusterConfigProperty; +import org.apache.kafka.common.test.api.ClusterInstance; +import org.apache.kafka.common.test.api.ClusterTest; +import org.apache.kafka.common.test.api.ClusterTestExtensions; +import org.apache.kafka.common.test.api.ClusterTests; + +import org.junit.jupiter.api.extension.ExtendWith; + +import java.time.Duration; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +@ExtendWith(ClusterTestExtensions.class) +public class ConsumerIntegrationTest { + + @ClusterTests({ + @ClusterTest(serverProperties = { + @ClusterConfigProperty(key = "offsets.topic.num.partitions", value = "1"), + @ClusterConfigProperty(key = "offsets.topic.replication.factor", value = "1"), + @ClusterConfigProperty(key = "group.coordinator.new.enable", value = "false") + }), + @ClusterTest(serverProperties = { + @ClusterConfigProperty(key = "offsets.topic.num.partitions", value = "1"), + @ClusterConfigProperty(key = "offsets.topic.replication.factor", value = "1"), + @ClusterConfigProperty(key = "group.coordinator.rebalance.protocols", value = "classic") + }) + }) + public void testAsyncConsumerWithOldGroupCoordinator(ClusterInstance clusterInstance) throws Exception { + String topic = "test-topic"; + clusterInstance.createTopic(topic, 1, (short) 1); + try (KafkaConsumer consumer = new KafkaConsumer<>(Map.of( + ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, clusterInstance.bootstrapServers(), + ConsumerConfig.GROUP_ID_CONFIG, "test-group", + ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName(), + ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName(), + ConsumerConfig.GROUP_PROTOCOL_CONFIG, GroupProtocol.CONSUMER.name()))) { + consumer.subscribe(Collections.singletonList(topic)); + TestUtils.waitForCondition(() -> { + try { + consumer.poll(Duration.ofMillis(1000)); + return false; + } catch (UnsupportedVersionException e) { + return e.getMessage().equals(AbstractHeartbeatRequestManager.CONSUMER_PROTOCOL_NOT_SUPPORTED_MSG); + } + }, "Should get UnsupportedVersionException and how to revert to classic protocol"); + } + } + + @ClusterTest(serverProperties = { + @ClusterConfigProperty(key = "offsets.topic.num.partitions", value = "1"), + @ClusterConfigProperty(key = "offsets.topic.replication.factor", value = "1"), + }) + public void testFetchPartitionsAfterFailedListenerWithGroupProtocolClassic(ClusterInstance clusterInstance) + throws InterruptedException { + testFetchPartitionsAfterFailedListener(clusterInstance, GroupProtocol.CLASSIC); + } + + @ClusterTest(serverProperties = { + @ClusterConfigProperty(key = "offsets.topic.num.partitions", value = "1"), + @ClusterConfigProperty(key = "offsets.topic.replication.factor", value = "1"), + }) + public void testFetchPartitionsAfterFailedListenerWithGroupProtocolConsumer(ClusterInstance clusterInstance) + throws InterruptedException { + testFetchPartitionsAfterFailedListener(clusterInstance, GroupProtocol.CONSUMER); + } + + private static void testFetchPartitionsAfterFailedListener(ClusterInstance clusterInstance, GroupProtocol groupProtocol) + throws InterruptedException { + var topic = "topic"; + try (var producer = clusterInstance.producer(Map.of( + ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class))) { + producer.send(new ProducerRecord<>(topic, "key".getBytes(), "value".getBytes())); + } + + try (var consumer = clusterInstance.consumer(Map.of( + ConsumerConfig.GROUP_PROTOCOL_CONFIG, groupProtocol.name()))) { + consumer.subscribe(List.of(topic), new ConsumerRebalanceListener() { + private int count = 0; + @Override + public void onPartitionsRevoked(Collection partitions) { + } + + @Override + public void onPartitionsAssigned(Collection partitions) { + count++; + if (count == 1) throw new IllegalArgumentException("temporary error"); + } + }); + + TestUtils.waitForCondition(() -> consumer.poll(Duration.ofSeconds(1)).count() == 1, + 5000, + "failed to poll data"); + } + } + + @ClusterTest(serverProperties = { + @ClusterConfigProperty(key = "offsets.topic.num.partitions", value = "1"), + @ClusterConfigProperty(key = "offsets.topic.replication.factor", value = "1"), + }) + public void testFetchPartitionsWithAlwaysFailedListenerWithGroupProtocolClassic(ClusterInstance clusterInstance) + throws InterruptedException { + testFetchPartitionsWithAlwaysFailedListener(clusterInstance, GroupProtocol.CLASSIC); + } + + @ClusterTest(serverProperties = { + @ClusterConfigProperty(key = "offsets.topic.num.partitions", value = "1"), + @ClusterConfigProperty(key = "offsets.topic.replication.factor", value = "1"), + }) + public void testFetchPartitionsWithAlwaysFailedListenerWithGroupProtocolConsumer(ClusterInstance clusterInstance) + throws InterruptedException { + testFetchPartitionsWithAlwaysFailedListener(clusterInstance, GroupProtocol.CONSUMER); + } + + private static void testFetchPartitionsWithAlwaysFailedListener(ClusterInstance clusterInstance, GroupProtocol groupProtocol) + throws InterruptedException { + var topic = "topic"; + try (var producer = clusterInstance.producer(Map.of( + ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class))) { + producer.send(new ProducerRecord<>(topic, "key".getBytes(), "value".getBytes())); + } + + try (var consumer = clusterInstance.consumer(Map.of( + ConsumerConfig.GROUP_PROTOCOL_CONFIG, groupProtocol.name()))) { + consumer.subscribe(List.of(topic), new ConsumerRebalanceListener() { + @Override + public void onPartitionsRevoked(Collection partitions) { + } + + @Override + public void onPartitionsAssigned(Collection partitions) { + throw new IllegalArgumentException("always failed"); + } + }); + + long startTimeMillis = System.currentTimeMillis(); + long currentTimeMillis = System.currentTimeMillis(); + while (currentTimeMillis < startTimeMillis + 3000) { + currentTimeMillis = System.currentTimeMillis(); + try { + // In the async consumer, there is a possibility that the ConsumerRebalanceListenerCallbackCompletedEvent + // has not yet reached the application thread. And a poll operation might still succeed, but it + // should not return any records since none of the assigned topic partitions are marked as fetchable. + assertEquals(0, consumer.poll(Duration.ofSeconds(1)).count()); + } catch (KafkaException ex) { + assertEquals("User rebalance callback throws an error", ex.getMessage()); + } + Thread.sleep(300); + } + } + } +}