Skip to content

Commit

Permalink
Addressed review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Krishna Kondaka <[email protected]>
  • Loading branch information
Krishna Kondaka committed Feb 10, 2024
1 parent a373071 commit 296e1cf
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public class KafkaCustomConsumer implements Runnable, ConsumerRebalanceListener
private static final Logger LOG = LoggerFactory.getLogger(KafkaCustomConsumer.class);
private static final Long COMMIT_OFFSET_INTERVAL_MS = 300000L;
private static final int DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE = 1;
private static final int RETRY_ON_EXCEPTION_SLEEP_MS = 1000;
static final String DEFAULT_KEY = "message";

private volatile long lastCommitTime;
Expand Down Expand Up @@ -95,6 +96,7 @@ public class KafkaCustomConsumer implements Runnable, ConsumerRebalanceListener
private long numRecordsCommitted = 0;
private final LogRateLimiter errLogRateLimiter;
private final ByteDecoder byteDecoder;
private final long maxRetriesOnException;

public KafkaCustomConsumer(final KafkaConsumer consumer,
final AtomicBoolean shutdownInProgress,
Expand All @@ -114,6 +116,7 @@ public KafkaCustomConsumer(final KafkaConsumer consumer,
this.paused = false;
this.byteDecoder = byteDecoder;
this.topicMetrics = topicMetrics;
this.maxRetriesOnException = topicConfig.getMaxPollInterval().toMillis() / (2 * RETRY_ON_EXCEPTION_SLEEP_MS);
this.pauseConsumePredicate = pauseConsumePredicate;
this.topicMetrics.register(consumer);
this.offsetsToCommit = new HashMap<>();
Expand Down Expand Up @@ -418,7 +421,7 @@ private <T> Record<Event> getRecord(ConsumerRecord<String, T> consumerRecord, in
return new Record<Event>(event);
}

private <T> void processRecord(final AcknowledgementSet acknowledgementSet, final Record<Event> record) {
private void processRecord(final AcknowledgementSet acknowledgementSet, final Record<Event> record) {
// Always add record to acknowledgementSet before adding to
// buffer because another thread may take and process
// buffer contents before the event record is added
Expand All @@ -427,15 +430,16 @@ private <T> void processRecord(final AcknowledgementSet acknowledgementSet, fina
acknowledgementSet.add(record.getData());
}
long numRetries = 0;
final int retrySleepTimeMs = 100;
// Donot pause until half the poll interval time has expired
final long maxRetries = topicConfig.getMaxPollInterval().toMillis() / (2 * retrySleepTimeMs);
while (true) {
try {
bufferAccumulator.add(record);
if (numRetries == 0) {
bufferAccumulator.add(record);
} else {
bufferAccumulator.flush();
}
break;
} catch (Exception e) {
if (!paused && numRetries++ > maxRetries) {
if (!paused && numRetries++ > maxRetriesOnException) {
paused = true;
consumer.pause(consumer.assignment());
}
Expand All @@ -445,11 +449,11 @@ private <T> void processRecord(final AcknowledgementSet acknowledgementSet, fina
LOG.debug("Error while adding record to buffer, retrying ", e);
}
try {
Thread.sleep(retrySleepTimeMs);
Thread.sleep(RETRY_ON_EXCEPTION_SLEEP_MS);
if (paused) {
ConsumerRecords<String, T> records = doPoll();
ConsumerRecords<String, ?> records = doPoll();
if (records.count() > 0) {
LOG.debug("Unexpected records received while the consumer is paused. Resetting the paritions to retry from last read pointer");
LOG.warn("Unexpected records received while the consumer is paused. Resetting the paritions to retry from last read pointer");
synchronized(this) {
partitionsToReset.addAll(consumer.assignment());
};
Expand All @@ -459,6 +463,7 @@ private <T> void processRecord(final AcknowledgementSet acknowledgementSet, fina
} catch (Exception ex) {} // ignore the exception because it only means the thread slept for shorter time
}
}

if (paused) {
consumer.resume(consumer.assignment());
paused = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.dataprepper.model.CheckpointState;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.buffer.SizeOverflowException;
import org.opensearch.dataprepper.model.configuration.PluginSetting;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
Expand All @@ -54,6 +55,7 @@
import static org.hamcrest.Matchers.hasEntry;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand All @@ -70,6 +72,9 @@ public class KafkaCustomConsumerTest {

private Buffer<Record<Event>> buffer;

@Mock
private Buffer<Record<Event>> mockBuffer;

@Mock
private KafkaConsumerConfig sourceConfig;

Expand Down Expand Up @@ -106,28 +111,48 @@ public class KafkaCustomConsumerTest {
private Counter posCounter;
@Mock
private Counter negCounter;
@Mock
private Counter overflowCounter;
private Duration delayTime;
private double posCount;
private double negCount;
private double overflowCount;
private boolean paused;
private boolean resumed;

@BeforeEach
public void setUp() {
delayTime = Duration.ofMillis(10);
paused = false;
resumed = false;
kafkaConsumer = mock(KafkaConsumer.class);
topicMetrics = mock(KafkaTopicConsumerMetrics.class);
counter = mock(Counter.class);
posCounter = mock(Counter.class);
mockBuffer = mock(Buffer.class);
negCounter = mock(Counter.class);
overflowCounter = mock(Counter.class);
topicConfig = mock(TopicConsumerConfig.class);
when(topicMetrics.getNumberOfPositiveAcknowledgements()).thenReturn(posCounter);
when(topicMetrics.getNumberOfNegativeAcknowledgements()).thenReturn(negCounter);
when(topicMetrics.getNumberOfBufferSizeOverflows()).thenReturn(overflowCounter);
when(topicMetrics.getNumberOfRecordsCommitted()).thenReturn(counter);
when(topicMetrics.getNumberOfDeserializationErrors()).thenReturn(counter);
when(topicConfig.getThreadWaitingTime()).thenReturn(Duration.ofSeconds(1));
when(topicConfig.getSerdeFormat()).thenReturn(MessageFormat.PLAINTEXT);
when(topicConfig.getAutoCommit()).thenReturn(false);
when(kafkaConsumer.committed(any(TopicPartition.class))).thenReturn(null);

doAnswer((i)-> {
paused = true;
return null;
}).when(kafkaConsumer).pause(any());

doAnswer((i)-> {
resumed = true;
return null;
}).when(kafkaConsumer).resume(any());

doAnswer((i)-> {
posCount += 1.0;
return null;
Expand All @@ -136,6 +161,10 @@ public void setUp() {
negCount += 1.0;
return null;
}).when(negCounter).increment();
doAnswer((i)-> {
overflowCount += 1.0;
return null;
}).when(overflowCounter).increment();
doAnswer((i)-> {return posCount;}).when(posCounter).count();
doAnswer((i)-> {return negCount;}).when(negCounter).count();
callbackExecutor = Executors.newScheduledThreadPool(2);
Expand All @@ -147,6 +176,11 @@ public void setUp() {
when(topicConfig.getName()).thenReturn(TOPIC_NAME);
}

public KafkaCustomConsumer createObjectUnderTestWithMockBuffer(String schemaType) {
return new KafkaCustomConsumer(kafkaConsumer, shutdownInProgress, mockBuffer, sourceConfig, topicConfig, schemaType,
acknowledgementSetManager, null, topicMetrics, pauseConsumePredicate);
}

public KafkaCustomConsumer createObjectUnderTest(String schemaType, boolean acknowledgementsEnabled) {
when(sourceConfig.getAcknowledgementsEnabled()).thenReturn(acknowledgementsEnabled);
return new KafkaCustomConsumer(kafkaConsumer, shutdownInProgress, buffer, sourceConfig, topicConfig, schemaType,
Expand All @@ -162,6 +196,56 @@ private BlockingBuffer<Record<Event>> getBuffer() {
return new BlockingBuffer<>(pluginSetting);
}

@Test
public void testBufferOverflowPauseResume() throws InterruptedException, Exception {
when(topicConfig.getMaxPollInterval()).thenReturn(Duration.ofMillis(4000));
String topic = topicConfig.getName();
consumerRecords = createPlainTextRecords(topic, 0L);
doAnswer((i)-> {
if (!paused && !resumed)
throw new SizeOverflowException("size overflow");
buffer.writeAll(i.getArgument(0), i.getArgument(1));
return null;
}).when(mockBuffer).writeAll(any(), anyInt());

doAnswer((i) -> {
if (paused && !resumed)
return List.of();
return consumerRecords;
}).when(kafkaConsumer).poll(any(Duration.class));
consumer = createObjectUnderTestWithMockBuffer("plaintext");
try {
consumer.onPartitionsAssigned(List.of(new TopicPartition(topic, testPartition)));
consumer.consumeRecords();
} catch (Exception e){}
assertTrue(paused);
assertTrue(resumed);

final Map.Entry<Collection<Record<Event>>, CheckpointState> bufferRecords = buffer.read(1000);
ArrayList<Record<Event>> bufferedRecords = new ArrayList<>(bufferRecords.getKey());
Assertions.assertEquals(consumerRecords.count(), bufferedRecords.size());
Map<TopicPartition, OffsetAndMetadata> offsetsToCommit = consumer.getOffsetsToCommit();
Assertions.assertEquals(offsetsToCommit.size(), 1);
offsetsToCommit.forEach((topicPartition, offsetAndMetadata) -> {
Assertions.assertEquals(topicPartition.partition(), testPartition);
Assertions.assertEquals(topicPartition.topic(), topic);
Assertions.assertEquals(offsetAndMetadata.offset(), 2L);
});
Assertions.assertEquals(consumer.getNumRecordsCommitted(), 2L);

for (Record<Event> record: bufferedRecords) {
Event event = record.getData();
String value1 = event.get(testKey1, String.class);
String value2 = event.get(testKey2, String.class);
assertTrue(value1 != null || value2 != null);
if (value1 != null) {
Assertions.assertEquals(value1, testValue1);
}
if (value2 != null) {
Assertions.assertEquals(value2, testValue2);
}
}
}
@Test
public void testPlainTextConsumeRecords() throws InterruptedException {
String topic = topicConfig.getName();
Expand Down

0 comments on commit 296e1cf

Please sign in to comment.