From 296e1cfba9a5451acb5bf1ca197fbbb17cf53ea1 Mon Sep 17 00:00:00 2001 From: Krishna Kondaka Date: Sat, 10 Feb 2024 00:45:56 +0000 Subject: [PATCH] Addressed review comments Signed-off-by: Krishna Kondaka --- .../kafka/consumer/KafkaCustomConsumer.java | 23 +++-- .../consumer/KafkaCustomConsumerTest.java | 84 +++++++++++++++++++ 2 files changed, 98 insertions(+), 9 deletions(-) diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumer.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumer.java index efdf7325a0..38a4bf125d 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumer.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumer.java @@ -67,6 +67,7 @@ public class KafkaCustomConsumer implements Runnable, ConsumerRebalanceListener private static final Logger LOG = LoggerFactory.getLogger(KafkaCustomConsumer.class); private static final Long COMMIT_OFFSET_INTERVAL_MS = 300000L; private static final int DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE = 1; + private static final int RETRY_ON_EXCEPTION_SLEEP_MS = 1000; static final String DEFAULT_KEY = "message"; private volatile long lastCommitTime; @@ -95,6 +96,7 @@ public class KafkaCustomConsumer implements Runnable, ConsumerRebalanceListener private long numRecordsCommitted = 0; private final LogRateLimiter errLogRateLimiter; private final ByteDecoder byteDecoder; + private final long maxRetriesOnException; public KafkaCustomConsumer(final KafkaConsumer consumer, final AtomicBoolean shutdownInProgress, @@ -114,6 +116,7 @@ public KafkaCustomConsumer(final KafkaConsumer consumer, this.paused = false; this.byteDecoder = byteDecoder; this.topicMetrics = topicMetrics; + this.maxRetriesOnException = topicConfig.getMaxPollInterval().toMillis() / (2 * RETRY_ON_EXCEPTION_SLEEP_MS); this.pauseConsumePredicate = pauseConsumePredicate; this.topicMetrics.register(consumer); this.offsetsToCommit = new HashMap<>(); @@ -418,7 +421,7 @@ private Record getRecord(ConsumerRecord consumerRecord, in return new Record(event); } - private void processRecord(final AcknowledgementSet acknowledgementSet, final Record record) { + private void processRecord(final AcknowledgementSet acknowledgementSet, final Record record) { // Always add record to acknowledgementSet before adding to // buffer because another thread may take and process // buffer contents before the event record is added @@ -427,15 +430,16 @@ private void processRecord(final AcknowledgementSet acknowledgementSet, fina acknowledgementSet.add(record.getData()); } long numRetries = 0; - final int retrySleepTimeMs = 100; - // Donot pause until half the poll interval time has expired - final long maxRetries = topicConfig.getMaxPollInterval().toMillis() / (2 * retrySleepTimeMs); while (true) { try { - bufferAccumulator.add(record); + if (numRetries == 0) { + bufferAccumulator.add(record); + } else { + bufferAccumulator.flush(); + } break; } catch (Exception e) { - if (!paused && numRetries++ > maxRetries) { + if (!paused && numRetries++ > maxRetriesOnException) { paused = true; consumer.pause(consumer.assignment()); } @@ -445,11 +449,11 @@ private void processRecord(final AcknowledgementSet acknowledgementSet, fina LOG.debug("Error while adding record to buffer, retrying ", e); } try { - Thread.sleep(retrySleepTimeMs); + Thread.sleep(RETRY_ON_EXCEPTION_SLEEP_MS); if (paused) { - ConsumerRecords records = doPoll(); + ConsumerRecords records = doPoll(); if (records.count() > 0) { - LOG.debug("Unexpected records received while the consumer is paused. Resetting the paritions to retry from last read pointer"); + LOG.warn("Unexpected records received while the consumer is paused. Resetting the paritions to retry from last read pointer"); synchronized(this) { partitionsToReset.addAll(consumer.assignment()); }; @@ -459,6 +463,7 @@ private void processRecord(final AcknowledgementSet acknowledgementSet, fina } catch (Exception ex) {} // ignore the exception because it only means the thread slept for shorter time } } + if (paused) { consumer.resume(consumer.assignment()); paused = false; diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerTest.java index 7d3a0f3fb9..968639f674 100644 --- a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerTest.java +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerTest.java @@ -28,6 +28,7 @@ import org.opensearch.dataprepper.model.CheckpointState; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.buffer.SizeOverflowException; import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; @@ -54,6 +55,7 @@ import static org.hamcrest.Matchers.hasEntry; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -70,6 +72,9 @@ public class KafkaCustomConsumerTest { private Buffer> buffer; + @Mock + private Buffer> mockBuffer; + @Mock private KafkaConsumerConfig sourceConfig; @@ -106,21 +111,31 @@ public class KafkaCustomConsumerTest { private Counter posCounter; @Mock private Counter negCounter; + @Mock + private Counter overflowCounter; private Duration delayTime; private double posCount; private double negCount; + private double overflowCount; + private boolean paused; + private boolean resumed; @BeforeEach public void setUp() { delayTime = Duration.ofMillis(10); + paused = false; + resumed = false; kafkaConsumer = mock(KafkaConsumer.class); topicMetrics = mock(KafkaTopicConsumerMetrics.class); counter = mock(Counter.class); posCounter = mock(Counter.class); + mockBuffer = mock(Buffer.class); negCounter = mock(Counter.class); + overflowCounter = mock(Counter.class); topicConfig = mock(TopicConsumerConfig.class); when(topicMetrics.getNumberOfPositiveAcknowledgements()).thenReturn(posCounter); when(topicMetrics.getNumberOfNegativeAcknowledgements()).thenReturn(negCounter); + when(topicMetrics.getNumberOfBufferSizeOverflows()).thenReturn(overflowCounter); when(topicMetrics.getNumberOfRecordsCommitted()).thenReturn(counter); when(topicMetrics.getNumberOfDeserializationErrors()).thenReturn(counter); when(topicConfig.getThreadWaitingTime()).thenReturn(Duration.ofSeconds(1)); @@ -128,6 +143,16 @@ public void setUp() { when(topicConfig.getAutoCommit()).thenReturn(false); when(kafkaConsumer.committed(any(TopicPartition.class))).thenReturn(null); + doAnswer((i)-> { + paused = true; + return null; + }).when(kafkaConsumer).pause(any()); + + doAnswer((i)-> { + resumed = true; + return null; + }).when(kafkaConsumer).resume(any()); + doAnswer((i)-> { posCount += 1.0; return null; @@ -136,6 +161,10 @@ public void setUp() { negCount += 1.0; return null; }).when(negCounter).increment(); + doAnswer((i)-> { + overflowCount += 1.0; + return null; + }).when(overflowCounter).increment(); doAnswer((i)-> {return posCount;}).when(posCounter).count(); doAnswer((i)-> {return negCount;}).when(negCounter).count(); callbackExecutor = Executors.newScheduledThreadPool(2); @@ -147,6 +176,11 @@ public void setUp() { when(topicConfig.getName()).thenReturn(TOPIC_NAME); } + public KafkaCustomConsumer createObjectUnderTestWithMockBuffer(String schemaType) { + return new KafkaCustomConsumer(kafkaConsumer, shutdownInProgress, mockBuffer, sourceConfig, topicConfig, schemaType, + acknowledgementSetManager, null, topicMetrics, pauseConsumePredicate); + } + public KafkaCustomConsumer createObjectUnderTest(String schemaType, boolean acknowledgementsEnabled) { when(sourceConfig.getAcknowledgementsEnabled()).thenReturn(acknowledgementsEnabled); return new KafkaCustomConsumer(kafkaConsumer, shutdownInProgress, buffer, sourceConfig, topicConfig, schemaType, @@ -162,6 +196,56 @@ private BlockingBuffer> getBuffer() { return new BlockingBuffer<>(pluginSetting); } + @Test + public void testBufferOverflowPauseResume() throws InterruptedException, Exception { + when(topicConfig.getMaxPollInterval()).thenReturn(Duration.ofMillis(4000)); + String topic = topicConfig.getName(); + consumerRecords = createPlainTextRecords(topic, 0L); + doAnswer((i)-> { + if (!paused && !resumed) + throw new SizeOverflowException("size overflow"); + buffer.writeAll(i.getArgument(0), i.getArgument(1)); + return null; + }).when(mockBuffer).writeAll(any(), anyInt()); + + doAnswer((i) -> { + if (paused && !resumed) + return List.of(); + return consumerRecords; + }).when(kafkaConsumer).poll(any(Duration.class)); + consumer = createObjectUnderTestWithMockBuffer("plaintext"); + try { + consumer.onPartitionsAssigned(List.of(new TopicPartition(topic, testPartition))); + consumer.consumeRecords(); + } catch (Exception e){} + assertTrue(paused); + assertTrue(resumed); + + final Map.Entry>, CheckpointState> bufferRecords = buffer.read(1000); + ArrayList> bufferedRecords = new ArrayList<>(bufferRecords.getKey()); + Assertions.assertEquals(consumerRecords.count(), bufferedRecords.size()); + Map offsetsToCommit = consumer.getOffsetsToCommit(); + Assertions.assertEquals(offsetsToCommit.size(), 1); + offsetsToCommit.forEach((topicPartition, offsetAndMetadata) -> { + Assertions.assertEquals(topicPartition.partition(), testPartition); + Assertions.assertEquals(topicPartition.topic(), topic); + Assertions.assertEquals(offsetAndMetadata.offset(), 2L); + }); + Assertions.assertEquals(consumer.getNumRecordsCommitted(), 2L); + + for (Record record: bufferedRecords) { + Event event = record.getData(); + String value1 = event.get(testKey1, String.class); + String value2 = event.get(testKey2, String.class); + assertTrue(value1 != null || value2 != null); + if (value1 != null) { + Assertions.assertEquals(value1, testValue1); + } + if (value2 != null) { + Assertions.assertEquals(value2, testValue2); + } + } + } @Test public void testPlainTextConsumeRecords() throws InterruptedException { String topic = topicConfig.getName();