Skip to content

Commit

Permalink
KafkaConsumer should continue to poll while waiting for buffer (#4023)
Browse files Browse the repository at this point in the history
* KafkaConsumer should continue to poll while waiting for buffer

Signed-off-by: Krishna Kondaka <[email protected]>

* Modified to call pause() whenever parititon assignment changes

Signed-off-by: Krishna Kondaka <[email protected]>

* Addressed review comments

Signed-off-by: Krishna Kondaka <[email protected]>

* Addressed review comments

Signed-off-by: Krishna Kondaka <[email protected]>

---------

Signed-off-by: Krishna Kondaka <[email protected]>
Co-authored-by: Krishna Kondaka <[email protected]>
  • Loading branch information
kkondaka and Krishna Kondaka authored Feb 16, 2024
1 parent 8bf0daa commit c0776ef
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public class KafkaCustomConsumer implements Runnable, ConsumerRebalanceListener
private static final Logger LOG = LoggerFactory.getLogger(KafkaCustomConsumer.class);
private static final Long COMMIT_OFFSET_INTERVAL_MS = 300000L;
private static final int DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE = 1;
private static final int RETRY_ON_EXCEPTION_SLEEP_MS = 1000;
static final String DEFAULT_KEY = "message";

private volatile long lastCommitTime;
Expand All @@ -75,6 +76,7 @@ public class KafkaCustomConsumer implements Runnable, ConsumerRebalanceListener
private final String topicName;
private final TopicConsumerConfig topicConfig;
private MessageFormat schema;
private boolean paused;
private final BufferAccumulator<Record<Event>> bufferAccumulator;
private final Buffer<Record<Event>> buffer;
private static final ObjectMapper objectMapper = new ObjectMapper();
Expand All @@ -94,6 +96,7 @@ public class KafkaCustomConsumer implements Runnable, ConsumerRebalanceListener
private long numRecordsCommitted = 0;
private final LogRateLimiter errLogRateLimiter;
private final ByteDecoder byteDecoder;
private final long maxRetriesOnException;

public KafkaCustomConsumer(final KafkaConsumer consumer,
final AtomicBoolean shutdownInProgress,
Expand All @@ -110,8 +113,10 @@ public KafkaCustomConsumer(final KafkaConsumer consumer,
this.shutdownInProgress = shutdownInProgress;
this.consumer = consumer;
this.buffer = buffer;
this.paused = false;
this.byteDecoder = byteDecoder;
this.topicMetrics = topicMetrics;
this.maxRetriesOnException = topicConfig.getMaxPollInterval().toMillis() / (2 * RETRY_ON_EXCEPTION_SLEEP_MS);
this.pauseConsumePredicate = pauseConsumePredicate;
this.topicMetrics.register(consumer);
this.offsetsToCommit = new HashMap<>();
Expand Down Expand Up @@ -170,10 +175,15 @@ private AcknowledgementSet createAcknowledgementSet(Map<TopicPartition, CommitOf
return acknowledgementSet;
}

<T> void consumeRecords() throws Exception {
try {
<T> ConsumerRecords<String, T> doPoll() throws Exception {
ConsumerRecords<String, T> records =
consumer.poll(Duration.ofMillis(topicConfig.getThreadWaitingTime().toMillis()/2));
return records;
}

<T> void consumeRecords() throws Exception {
try {
ConsumerRecords<String, T> records = doPoll();
if (Objects.nonNull(records) && !records.isEmpty() && records.count() > 0) {
Map<TopicPartition, CommitOffsetRange> offsets = new HashMap<>();
AcknowledgementSet acknowledgementSet = null;
Expand Down Expand Up @@ -419,21 +429,45 @@ private void processRecord(final AcknowledgementSet acknowledgementSet, final Re
if (acknowledgementSet != null) {
acknowledgementSet.add(record.getData());
}
long numRetries = 0;
while (true) {
try {
bufferAccumulator.add(record);
if (numRetries == 0) {
bufferAccumulator.add(record);
} else {
bufferAccumulator.flush();
}
break;
} catch (Exception e) {
if (!paused && numRetries++ > maxRetriesOnException) {
paused = true;
consumer.pause(consumer.assignment());
}
if (e instanceof SizeOverflowException) {
topicMetrics.getNumberOfBufferSizeOverflows().increment();
} else {
LOG.debug("Error while adding record to buffer, retrying ", e);
}
try {
Thread.sleep(100);
Thread.sleep(RETRY_ON_EXCEPTION_SLEEP_MS);
if (paused) {
ConsumerRecords<String, ?> records = doPoll();
if (records.count() > 0) {
LOG.warn("Unexpected records received while the consumer is paused. Resetting the paritions to retry from last read pointer");
synchronized(this) {
partitionsToReset.addAll(consumer.assignment());
};
break;
}
}
} catch (Exception ex) {} // ignore the exception because it only means the thread slept for shorter time
}
}

if (paused) {
consumer.resume(consumer.assignment());
paused = false;
}
}

private <T> void iterateRecordPartitions(ConsumerRecords<String, T> records, final AcknowledgementSet acknowledgementSet,
Expand Down Expand Up @@ -503,6 +537,9 @@ public void onPartitionsAssigned(Collection<TopicPartition> partitions) {
LOG.info("Assigned partition {}", topicPartition);
ownedPartitionsEpoch.put(topicPartition, epoch);
}
if (paused) {
consumer.pause(consumer.assignment());
}
}
dumpTopicPartitionOffsets(partitions);
}
Expand All @@ -520,6 +557,9 @@ public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
ownedPartitionsEpoch.remove(topicPartition);
partitionCommitTrackerMap.remove(topicPartition.partition());
}
if (paused) {
consumer.pause(consumer.assignment());
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.dataprepper.model.CheckpointState;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.buffer.SizeOverflowException;
import org.opensearch.dataprepper.model.configuration.PluginSetting;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
Expand All @@ -54,6 +55,7 @@
import static org.hamcrest.Matchers.hasEntry;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand All @@ -70,6 +72,9 @@ public class KafkaCustomConsumerTest {

private Buffer<Record<Event>> buffer;

@Mock
private Buffer<Record<Event>> mockBuffer;

@Mock
private KafkaConsumerConfig sourceConfig;

Expand Down Expand Up @@ -106,28 +111,48 @@ public class KafkaCustomConsumerTest {
private Counter posCounter;
@Mock
private Counter negCounter;
@Mock
private Counter overflowCounter;
private Duration delayTime;
private double posCount;
private double negCount;
private double overflowCount;
private boolean paused;
private boolean resumed;

@BeforeEach
public void setUp() {
delayTime = Duration.ofMillis(10);
paused = false;
resumed = false;
kafkaConsumer = mock(KafkaConsumer.class);
topicMetrics = mock(KafkaTopicConsumerMetrics.class);
counter = mock(Counter.class);
posCounter = mock(Counter.class);
mockBuffer = mock(Buffer.class);
negCounter = mock(Counter.class);
overflowCounter = mock(Counter.class);
topicConfig = mock(TopicConsumerConfig.class);
when(topicMetrics.getNumberOfPositiveAcknowledgements()).thenReturn(posCounter);
when(topicMetrics.getNumberOfNegativeAcknowledgements()).thenReturn(negCounter);
when(topicMetrics.getNumberOfBufferSizeOverflows()).thenReturn(overflowCounter);
when(topicMetrics.getNumberOfRecordsCommitted()).thenReturn(counter);
when(topicMetrics.getNumberOfDeserializationErrors()).thenReturn(counter);
when(topicConfig.getThreadWaitingTime()).thenReturn(Duration.ofSeconds(1));
when(topicConfig.getSerdeFormat()).thenReturn(MessageFormat.PLAINTEXT);
when(topicConfig.getAutoCommit()).thenReturn(false);
when(kafkaConsumer.committed(any(TopicPartition.class))).thenReturn(null);

doAnswer((i)-> {
paused = true;
return null;
}).when(kafkaConsumer).pause(any());

doAnswer((i)-> {
resumed = true;
return null;
}).when(kafkaConsumer).resume(any());

doAnswer((i)-> {
posCount += 1.0;
return null;
Expand All @@ -136,6 +161,10 @@ public void setUp() {
negCount += 1.0;
return null;
}).when(negCounter).increment();
doAnswer((i)-> {
overflowCount += 1.0;
return null;
}).when(overflowCounter).increment();
doAnswer((i)-> {return posCount;}).when(posCounter).count();
doAnswer((i)-> {return negCount;}).when(negCounter).count();
callbackExecutor = Executors.newScheduledThreadPool(2);
Expand All @@ -147,6 +176,11 @@ public void setUp() {
when(topicConfig.getName()).thenReturn(TOPIC_NAME);
}

public KafkaCustomConsumer createObjectUnderTestWithMockBuffer(String schemaType) {
return new KafkaCustomConsumer(kafkaConsumer, shutdownInProgress, mockBuffer, sourceConfig, topicConfig, schemaType,
acknowledgementSetManager, null, topicMetrics, pauseConsumePredicate);
}

public KafkaCustomConsumer createObjectUnderTest(String schemaType, boolean acknowledgementsEnabled) {
when(sourceConfig.getAcknowledgementsEnabled()).thenReturn(acknowledgementsEnabled);
return new KafkaCustomConsumer(kafkaConsumer, shutdownInProgress, buffer, sourceConfig, topicConfig, schemaType,
Expand All @@ -162,6 +196,56 @@ private BlockingBuffer<Record<Event>> getBuffer() {
return new BlockingBuffer<>(pluginSetting);
}

@Test
public void testBufferOverflowPauseResume() throws InterruptedException, Exception {
when(topicConfig.getMaxPollInterval()).thenReturn(Duration.ofMillis(4000));
String topic = topicConfig.getName();
consumerRecords = createPlainTextRecords(topic, 0L);
doAnswer((i)-> {
if (!paused && !resumed)
throw new SizeOverflowException("size overflow");
buffer.writeAll(i.getArgument(0), i.getArgument(1));
return null;
}).when(mockBuffer).writeAll(any(), anyInt());

doAnswer((i) -> {
if (paused && !resumed)
return List.of();
return consumerRecords;
}).when(kafkaConsumer).poll(any(Duration.class));
consumer = createObjectUnderTestWithMockBuffer("plaintext");
try {
consumer.onPartitionsAssigned(List.of(new TopicPartition(topic, testPartition)));
consumer.consumeRecords();
} catch (Exception e){}
assertTrue(paused);
assertTrue(resumed);

final Map.Entry<Collection<Record<Event>>, CheckpointState> bufferRecords = buffer.read(1000);
ArrayList<Record<Event>> bufferedRecords = new ArrayList<>(bufferRecords.getKey());
Assertions.assertEquals(consumerRecords.count(), bufferedRecords.size());
Map<TopicPartition, OffsetAndMetadata> offsetsToCommit = consumer.getOffsetsToCommit();
Assertions.assertEquals(offsetsToCommit.size(), 1);
offsetsToCommit.forEach((topicPartition, offsetAndMetadata) -> {
Assertions.assertEquals(topicPartition.partition(), testPartition);
Assertions.assertEquals(topicPartition.topic(), topic);
Assertions.assertEquals(offsetAndMetadata.offset(), 2L);
});
Assertions.assertEquals(consumer.getNumRecordsCommitted(), 2L);

for (Record<Event> record: bufferedRecords) {
Event event = record.getData();
String value1 = event.get(testKey1, String.class);
String value2 = event.get(testKey2, String.class);
assertTrue(value1 != null || value2 != null);
if (value1 != null) {
Assertions.assertEquals(value1, testValue1);
}
if (value2 != null) {
Assertions.assertEquals(value2, testValue2);
}
}
}
@Test
public void testPlainTextConsumeRecords() throws InterruptedException {
String topic = topicConfig.getName();
Expand Down

0 comments on commit c0776ef

Please sign in to comment.