diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBuffer.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBuffer.java index 4359e6d0c0..611e7acaf7 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBuffer.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBuffer.java @@ -24,8 +24,10 @@ import org.apache.kafka.common.errors.RecordBatchTooLargeException; import org.opensearch.dataprepper.plugins.kafka.admin.KafkaAdminAccessor; import org.opensearch.dataprepper.plugins.kafka.buffer.serialization.BufferSerializationFactory; +import org.opensearch.dataprepper.plugins.kafka.common.KafkaMdc; import org.opensearch.dataprepper.plugins.kafka.common.serialization.CommonSerializationFactory; import org.opensearch.dataprepper.plugins.kafka.common.serialization.SerializationFactory; +import org.opensearch.dataprepper.plugins.kafka.common.thread.KafkaPluginThreadFactory; import org.opensearch.dataprepper.plugins.kafka.consumer.KafkaCustomConsumer; import org.opensearch.dataprepper.plugins.kafka.consumer.KafkaCustomConsumerFactory; import org.opensearch.dataprepper.plugins.kafka.producer.KafkaCustomProducer; @@ -33,6 +35,7 @@ import org.opensearch.dataprepper.plugins.kafka.service.TopicServiceFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.slf4j.MDC; import java.time.Duration; import java.util.Collection; @@ -53,6 +56,7 @@ public class KafkaBuffer extends AbstractBuffer> { public static final int INNER_BUFFER_BATCH_SIZE = 250000; static final String WRITE = "Write"; static final String READ = "Read"; + static final String MDC_KAFKA_PLUGIN_VALUE = "buffer"; private final KafkaCustomProducer producer; private final KafkaAdminAccessor kafkaAdminAccessor; private final AbstractBuffer> innerBuffer; @@ -80,7 +84,7 @@ public KafkaBuffer(final PluginSetting pluginSetting, final KafkaBufferConfig ka final List consumers = kafkaCustomConsumerFactory.createConsumersForTopic(kafkaBufferConfig, kafkaBufferConfig.getTopic(), innerBuffer, consumerMetrics, acknowledgementSetManager, byteDecoder, shutdownInProgress, false, circuitBreaker); this.kafkaAdminAccessor = new KafkaAdminAccessor(kafkaBufferConfig, List.of(kafkaBufferConfig.getTopic().getGroupId())); - this.executorService = Executors.newFixedThreadPool(consumers.size()); + this.executorService = Executors.newFixedThreadPool(consumers.size(), KafkaPluginThreadFactory.defaultExecutorThreadFactory(MDC_KAFKA_PLUGIN_VALUE)); consumers.forEach(this.executorService::submit); this.drainTimeout = kafkaBufferConfig.getDrainTimeout(); @@ -89,6 +93,7 @@ public KafkaBuffer(final PluginSetting pluginSetting, final KafkaBufferConfig ka @Override public void writeBytes(final byte[] bytes, final String key, int timeoutInMillis) throws Exception { try { + setMdc(); producer.produceRawData(bytes, key); } catch (final Exception e) { LOG.error(e.getMessage(), e); @@ -102,15 +107,21 @@ public void writeBytes(final byte[] bytes, final String key, int timeoutInMillis throw new RuntimeException(e); } } + finally { + resetMdc(); + } } @Override public void doWrite(Record record, int timeoutInMillis) throws TimeoutException { try { + setMdc(); producer.produceRecords(record); } catch (final Exception e) { LOG.error(e.getMessage(), e); throw new RuntimeException(e); + } finally { + resetMdc(); } } @@ -121,29 +132,50 @@ public boolean isByteBuffer() { @Override public void doWriteAll(Collection> records, int timeoutInMillis) throws Exception { - for ( Record record: records ) { + for (Record record : records) { doWrite(record, timeoutInMillis); } } @Override public Map.Entry>, CheckpointState> doRead(int timeoutInMillis) { - return innerBuffer.read(timeoutInMillis); + try { + setMdc(); + return innerBuffer.read(timeoutInMillis); + } finally { + resetMdc(); + } } @Override public void postProcess(final Long recordsInBuffer) { - innerBuffer.postProcess(recordsInBuffer); + try { + setMdc(); + + innerBuffer.postProcess(recordsInBuffer); + } finally { + resetMdc(); + } } @Override public void doCheckpoint(CheckpointState checkpointState) { - innerBuffer.doCheckpoint(checkpointState); + try { + setMdc(); + innerBuffer.doCheckpoint(checkpointState); + } finally { + resetMdc(); + } } @Override public boolean isEmpty() { - return kafkaAdminAccessor.areTopicsEmpty() && innerBuffer.isEmpty(); + try { + setMdc(); + return kafkaAdminAccessor.areTopicsEmpty() && innerBuffer.isEmpty(); + } finally { + resetMdc(); + } } @Override @@ -158,21 +190,35 @@ public boolean isWrittenOffHeapOnly() { @Override public void shutdown() { - shutdownInProgress.set(true); - executorService.shutdown(); - try { - if (executorService.awaitTermination(EXECUTOR_SERVICE_SHUTDOWN_TIMEOUT, TimeUnit.SECONDS)) { - LOG.info("Successfully waited for consumer task to terminate"); - } else { - LOG.warn("Consumer task did not terminate in time, forcing termination"); + setMdc(); + + shutdownInProgress.set(true); + executorService.shutdown(); + + try { + if (executorService.awaitTermination(EXECUTOR_SERVICE_SHUTDOWN_TIMEOUT, TimeUnit.SECONDS)) { + LOG.info("Successfully waited for consumer task to terminate"); + } else { + LOG.warn("Consumer task did not terminate in time, forcing termination"); + executorService.shutdownNow(); + } + } catch (final InterruptedException e) { + LOG.error("Interrupted while waiting for consumer task to terminate", e); executorService.shutdownNow(); } - } catch (final InterruptedException e) { - LOG.error("Interrupted while waiting for consumer task to terminate", e); - executorService.shutdownNow(); + + innerBuffer.shutdown(); + } finally { + resetMdc(); } + } + + private static void setMdc() { + MDC.put(KafkaMdc.MDC_KAFKA_PLUGIN_KEY, MDC_KAFKA_PLUGIN_VALUE); + } - innerBuffer.shutdown(); + private static void resetMdc() { + MDC.remove(KafkaMdc.MDC_KAFKA_PLUGIN_KEY); } } diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/KafkaMdc.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/KafkaMdc.java new file mode 100644 index 0000000000..9ae8985908 --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/KafkaMdc.java @@ -0,0 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.kafka.common;public class KafkaMdc { + public static final String MDC_KAFKA_PLUGIN_KEY = "kafkaPluginType"; +} diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/thread/KafkaPluginThreadFactory.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/thread/KafkaPluginThreadFactory.java new file mode 100644 index 0000000000..a05540c320 --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/thread/KafkaPluginThreadFactory.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.kafka.common.thread; + +import org.opensearch.dataprepper.plugins.kafka.common.KafkaMdc; +import org.slf4j.MDC; + +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * An implementation of {@link ThreadFactory} for Kafka plugin threads. + */ +public class KafkaPluginThreadFactory implements ThreadFactory { + private final ThreadFactory delegateThreadFactory; + private final String threadPrefix; + private final String kafkaPluginType; + private final AtomicInteger threadNumber = new AtomicInteger(1); + + KafkaPluginThreadFactory( + final ThreadFactory delegateThreadFactory, + final String kafkaPluginType) { + this.delegateThreadFactory = delegateThreadFactory; + this.threadPrefix = "kafka-" + kafkaPluginType + "-"; + this.kafkaPluginType = kafkaPluginType; + } + + /** + * Creates an instance specifically for use with {@link Executors}. + * + * @param kafkaPluginType The name of the plugin type. e.g. sink, source, buffer + * @return An instance of the {@link KafkaPluginThreadFactory}. + */ + public static KafkaPluginThreadFactory defaultExecutorThreadFactory(final String kafkaPluginType) { + return new KafkaPluginThreadFactory(Executors.defaultThreadFactory(), kafkaPluginType); + } + + @Override + public Thread newThread(final Runnable runnable) { + final Thread thread = delegateThreadFactory.newThread(() -> { + MDC.put(KafkaMdc.MDC_KAFKA_PLUGIN_KEY, kafkaPluginType); + try { + runnable.run(); + } finally { + MDC.clear(); + } + }); + + thread.setName(threadPrefix + threadNumber.getAndIncrement()); + + return thread; + } +} diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBufferTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBufferTest.java index ddef03577a..362bcd580c 100644 --- a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBufferTest.java +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/buffer/KafkaBufferTest.java @@ -5,7 +5,9 @@ package org.opensearch.dataprepper.plugins.kafka.buffer; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -25,6 +27,8 @@ import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.plugins.buffer.blockingbuffer.BlockingBuffer; import org.opensearch.dataprepper.plugins.kafka.admin.KafkaAdminAccessor; +import org.opensearch.dataprepper.plugins.kafka.common.KafkaMdc; +import org.opensearch.dataprepper.plugins.kafka.common.thread.KafkaPluginThreadFactory; import org.opensearch.dataprepper.plugins.kafka.configuration.AuthConfig; import org.opensearch.dataprepper.plugins.kafka.configuration.EncryptionConfig; import org.opensearch.dataprepper.plugins.kafka.configuration.EncryptionType; @@ -35,6 +39,7 @@ import org.opensearch.dataprepper.plugins.kafka.producer.KafkaCustomProducerFactory; import org.opensearch.dataprepper.plugins.kafka.producer.ProducerWorker; import org.opensearch.dataprepper.plugins.kafka.util.MessageFormat; +import org.slf4j.MDC; import java.time.Duration; import java.util.Arrays; @@ -151,7 +156,7 @@ public KafkaBuffer createObjectUnderTest(final List consume blockingBuffer = mock; })) { - executorsMockedStatic.when(() -> Executors.newFixedThreadPool(anyInt())).thenReturn(executorService); + executorsMockedStatic.when(() -> Executors.newFixedThreadPool(anyInt(), any(KafkaPluginThreadFactory.class))).thenReturn(executorService); return new KafkaBuffer(pluginSetting, bufferConfig, acknowledgementSetManager, null, awsCredentialsSupplier, circuitBreaker); } } @@ -353,4 +358,84 @@ public void testShutdown_InterruptedException() throws InterruptedException { verify(executorService).awaitTermination(eq(EXECUTOR_SERVICE_SHUTDOWN_TIMEOUT), eq(TimeUnit.SECONDS)); verify(executorService).shutdownNow(); } + + @Nested + class MdcTests { + private MockedStatic mdcMockedStatic; + + @BeforeEach + void setUp() { + mdcMockedStatic = mockStatic(MDC.class); + } + + @AfterEach + void tearDown() { + mdcMockedStatic.close(); + } + + @Test + void writeBytes_sets_and_clears_MDC() throws Exception { + createObjectUnderTest().writeBytes(new byte[] {}, UUID.randomUUID().toString(), 100); + + mdcMockedStatic.verify(() -> MDC.put(KafkaMdc.MDC_KAFKA_PLUGIN_KEY, "buffer")); + mdcMockedStatic.verify(() -> MDC.remove(KafkaMdc.MDC_KAFKA_PLUGIN_KEY)); + } + + @Test + void doWrite_sets_and_clears_MDC() throws Exception { + createObjectUnderTest().doWrite(mock(Record.class), 100); + + mdcMockedStatic.verify(() -> MDC.put(KafkaMdc.MDC_KAFKA_PLUGIN_KEY, "buffer")); + mdcMockedStatic.verify(() -> MDC.remove(KafkaMdc.MDC_KAFKA_PLUGIN_KEY)); + } + + @Test + void doWriteAll_sets_and_clears_MDC() throws Exception { + final List> records = Collections.singletonList(mock(Record.class)); + createObjectUnderTest().doWriteAll(records, 100); + + mdcMockedStatic.verify(() -> MDC.put(KafkaMdc.MDC_KAFKA_PLUGIN_KEY, "buffer")); + mdcMockedStatic.verify(() -> MDC.remove(KafkaMdc.MDC_KAFKA_PLUGIN_KEY)); + } + + @Test + void doRead_sets_and_clears_MDC() throws Exception { + createObjectUnderTest().doRead(100); + + mdcMockedStatic.verify(() -> MDC.put(KafkaMdc.MDC_KAFKA_PLUGIN_KEY, "buffer")); + mdcMockedStatic.verify(() -> MDC.remove(KafkaMdc.MDC_KAFKA_PLUGIN_KEY)); + } + + @Test + void doCheckpoint_sets_and_clears_MDC() throws Exception { + createObjectUnderTest().doCheckpoint(mock(CheckpointState.class)); + + mdcMockedStatic.verify(() -> MDC.put(KafkaMdc.MDC_KAFKA_PLUGIN_KEY, "buffer")); + mdcMockedStatic.verify(() -> MDC.remove(KafkaMdc.MDC_KAFKA_PLUGIN_KEY)); + } + + @Test + void postProcess_sets_and_clears_MDC() throws Exception { + createObjectUnderTest().postProcess(100L); + + mdcMockedStatic.verify(() -> MDC.put(KafkaMdc.MDC_KAFKA_PLUGIN_KEY, "buffer")); + mdcMockedStatic.verify(() -> MDC.remove(KafkaMdc.MDC_KAFKA_PLUGIN_KEY)); + } + + @Test + void isEmpty_sets_and_clears_MDC() throws Exception { + createObjectUnderTest().isEmpty(); + + mdcMockedStatic.verify(() -> MDC.put(KafkaMdc.MDC_KAFKA_PLUGIN_KEY, "buffer")); + mdcMockedStatic.verify(() -> MDC.remove(KafkaMdc.MDC_KAFKA_PLUGIN_KEY)); + } + + @Test + void shutdown_sets_and_clears_MDC() throws Exception { + createObjectUnderTest().shutdown(); + + mdcMockedStatic.verify(() -> MDC.put(KafkaMdc.MDC_KAFKA_PLUGIN_KEY, "buffer")); + mdcMockedStatic.verify(() -> MDC.remove(KafkaMdc.MDC_KAFKA_PLUGIN_KEY)); + } + } } diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/thread/KafkaPluginThreadFactoryTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/thread/KafkaPluginThreadFactoryTest.java new file mode 100644 index 0000000000..589f81a74c --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/thread/KafkaPluginThreadFactoryTest.java @@ -0,0 +1,107 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.kafka.common.thread; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.plugins.kafka.common.KafkaMdc; +import org.slf4j.MDC; + +import java.util.UUID; +import java.util.concurrent.ThreadFactory; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.not; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class KafkaPluginThreadFactoryTest { + + @Mock + private ThreadFactory delegateThreadFactory; + @Mock + private Thread innerThread; + @Mock + private Runnable runnable; + private String pluginType; + + @BeforeEach + void setUp() { + pluginType = UUID.randomUUID().toString(); + + when(delegateThreadFactory.newThread(any(Runnable.class))).thenReturn(innerThread); + } + + + private KafkaPluginThreadFactory createObjectUnderTest() { + return new KafkaPluginThreadFactory(delegateThreadFactory, pluginType); + } + + @Test + void newThread_creates_thread_from_delegate() { + assertThat(createObjectUnderTest().newThread(runnable), equalTo(innerThread)); + } + + @Test + void newThread_creates_thread_with_name() { + final KafkaPluginThreadFactory objectUnderTest = createObjectUnderTest(); + + + final Thread thread1 = objectUnderTest.newThread(runnable); + assertThat(thread1, notNullValue()); + verify(thread1).setName(String.format("kafka-%s-1", pluginType)); + + final Thread thread2 = objectUnderTest.newThread(runnable); + assertThat(thread2, notNullValue()); + verify(thread2).setName(String.format("kafka-%s-2", pluginType)); + } + + @Test + void newThread_creates_thread_with_wrapping_runnable() { + createObjectUnderTest().newThread(runnable); + + final ArgumentCaptor actualRunnableCaptor = ArgumentCaptor.forClass(Runnable.class); + verify(delegateThreadFactory).newThread(actualRunnableCaptor.capture()); + + final Runnable actualRunnable = actualRunnableCaptor.getValue(); + + assertThat(actualRunnable, not(equalTo(runnable))); + + verifyNoInteractions(runnable); + actualRunnable.run(); + verify(runnable).run(); + } + + @Test + void newThread_creates_thread_that_calls_MDC_on_run() { + createObjectUnderTest().newThread(runnable); + + final ArgumentCaptor actualRunnableCaptor = ArgumentCaptor.forClass(Runnable.class); + verify(delegateThreadFactory).newThread(actualRunnableCaptor.capture()); + + final Runnable actualRunnable = actualRunnableCaptor.getValue(); + + final String[] actualKafkaPluginType = new String[1]; + doAnswer(a -> { + actualKafkaPluginType[0] = MDC.get(KafkaMdc.MDC_KAFKA_PLUGIN_KEY); + return null; + }).when(runnable).run(); + + actualRunnable.run(); + + assertThat(actualKafkaPluginType[0], equalTo(pluginType)); + } +} \ No newline at end of file