diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/key/KmsKeyProvider.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/key/KmsKeyProvider.java index 3abe62ffb2..c89e069542 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/key/KmsKeyProvider.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/common/key/KmsKeyProvider.java @@ -1,6 +1,7 @@ package org.opensearch.dataprepper.plugins.kafka.common.key; import org.opensearch.dataprepper.plugins.kafka.common.aws.AwsContext; +import org.opensearch.dataprepper.plugins.kafka.configuration.KmsConfig; import org.opensearch.dataprepper.plugins.kafka.configuration.TopicConfig; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.core.SdkBytes; @@ -18,12 +19,13 @@ public KmsKeyProvider(AwsContext awsContext) { @Override public boolean supportsConfiguration(TopicConfig topicConfig) { - return topicConfig.getKmsKeyId() != null; + return topicConfig.getKmsConfig() != null && topicConfig.getKmsConfig().getKeyId() != null; } @Override public byte[] apply(TopicConfig topicConfig) { - String kmsKeyId = topicConfig.getKmsKeyId(); + KmsConfig kmsConfig = topicConfig.getKmsConfig(); + String kmsKeyId = kmsConfig.getKeyId(); AwsCredentialsProvider awsCredentialsProvider = awsContext.get(); @@ -36,6 +38,7 @@ public byte[] apply(TopicConfig topicConfig) { DecryptResponse decryptResponse = kmsClient.decrypt(builder -> builder .keyId(kmsKeyId) .ciphertextBlob(SdkBytes.fromByteArray(decodedEncryptionKey)) + .encryptionContext(kmsConfig.getEncryptionContext()) ); return decryptResponse.plaintext().asByteArray(); diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/KmsConfig.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/KmsConfig.java new file mode 100644 index 0000000000..11294db13f --- /dev/null +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/KmsConfig.java @@ -0,0 +1,21 @@ +package org.opensearch.dataprepper.plugins.kafka.configuration; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Map; + +public class KmsConfig { + @JsonProperty("key_id") + private String keyId; + + @JsonProperty("encryption_context") + private Map encryptionContext; + + public String getKeyId() { + return keyId; + } + + public Map getEncryptionContext() { + return encryptionContext; + } +} diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/TopicConfig.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/TopicConfig.java index f0441d7114..546587a15a 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/TopicConfig.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/configuration/TopicConfig.java @@ -128,8 +128,8 @@ public class TopicConfig { @JsonProperty("encryption_key") private String encryptionKey; - @JsonProperty("kms_key_id") - private String kmsKeyId; + @JsonProperty("kms") + private KmsConfig kmsConfig; public Long getRetentionPeriod() { return retentionPeriod; @@ -151,8 +151,8 @@ public String getEncryptionKey() { return encryptionKey; } - public String getKmsKeyId() { - return kmsKeyId; + public KmsConfig getKmsConfig() { + return kmsConfig; } public Boolean getAutoCommit() { diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/key/KmsKeyProviderTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/key/KmsKeyProviderTest.java index 4fd241f65f..ed8caa2c19 100644 --- a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/key/KmsKeyProviderTest.java +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/common/key/KmsKeyProviderTest.java @@ -9,6 +9,7 @@ import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.plugins.kafka.common.aws.AwsContext; +import org.opensearch.dataprepper.plugins.kafka.configuration.KmsConfig; import org.opensearch.dataprepper.plugins.kafka.configuration.TopicConfig; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.core.SdkBytes; @@ -19,6 +20,7 @@ import software.amazon.awssdk.services.kms.model.DecryptResponse; import java.util.Base64; +import java.util.Map; import java.util.UUID; import java.util.function.Consumer; @@ -26,6 +28,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.verify; @@ -39,19 +42,28 @@ class KmsKeyProviderTest { private AwsCredentialsProvider awsCredentialsProvider; @Mock private TopicConfig topicConfig; + @Mock + private KmsConfig kmsConfig; private KmsKeyProvider createObjectUnderTest() { return new KmsKeyProvider(awsContext); } @Test - void supportsConfiguration_returns_false_if_kmsKeyId_is_null() { + void supportsConfiguration_returns_false_if_kms_config_is_null() { + assertThat(createObjectUnderTest().supportsConfiguration(topicConfig), equalTo(false)); + } + + @Test + void supportsConfiguration_returns_false_if_kms_keyId_is_null() { + when(topicConfig.getKmsConfig()).thenReturn(kmsConfig); assertThat(createObjectUnderTest().supportsConfiguration(topicConfig), equalTo(false)); } @Test - void supportsConfiguration_returns_true_if_kmsKeyId_is_present() { - when(topicConfig.getKmsKeyId()).thenReturn(UUID.randomUUID().toString()); + void supportsConfiguration_returns_true_if_kms_keyId_is_present() { + when(topicConfig.getKmsConfig()).thenReturn(kmsConfig); + when(kmsConfig.getKeyId()).thenReturn(UUID.randomUUID().toString()); assertThat(createObjectUnderTest().supportsConfiguration(topicConfig), equalTo(true)); } @@ -77,7 +89,8 @@ void setUp() { encryptionKey = UUID.randomUUID().toString(); String base64EncryptionKey = Base64.getEncoder().encodeToString(encryptionKey.getBytes()); when(topicConfig.getEncryptionKey()).thenReturn(base64EncryptionKey); - when(topicConfig.getKmsKeyId()).thenReturn(kmsKeyId); + when(topicConfig.getKmsConfig()).thenReturn(kmsConfig); + when(kmsConfig.getKeyId()).thenReturn(kmsKeyId); kmsClient = mock(KmsClient.class); DecryptResponse decryptResponse = mock(DecryptResponse.class); @@ -104,8 +117,48 @@ void apply_returns_plaintext_from_decrypt_request() { } @Test - void apply_calls_decrypt_with_correct_values() { + void apply_calls_decrypt_with_correct_values_when_encryption_context_is_null() { + KmsKeyProvider objectUnderTest = createObjectUnderTest(); + + when(kmsConfig.getEncryptionContext()).thenReturn(null); + + try (MockedStatic kmsClientMockedStatic = mockStatic(KmsClient.class)) { + kmsClientMockedStatic.when(() -> KmsClient.builder()).thenReturn(kmsClientBuilder); + objectUnderTest.apply(topicConfig); + } + + ArgumentCaptor> consumerArgumentCaptor = ArgumentCaptor.forClass(Consumer.class); + verify(kmsClient).decrypt(consumerArgumentCaptor.capture()); + + Consumer actualConsumer = consumerArgumentCaptor.getValue(); + + DecryptRequest.Builder builder = mock(DecryptRequest.Builder.class); + when(builder.keyId(anyString())).thenReturn(builder); + when(builder.ciphertextBlob(any())).thenReturn(builder); + when(builder.encryptionContext(any())).thenReturn(builder); + actualConsumer.accept(builder); + + verify(builder).keyId(kmsKeyId); + ArgumentCaptor actualBytesCaptor = ArgumentCaptor.forClass(SdkBytes.class); + verify(builder).ciphertextBlob(actualBytesCaptor.capture()); + + SdkBytes actualSdkBytes = actualBytesCaptor.getValue(); + assertThat(actualSdkBytes.asByteArray(), equalTo(encryptionKey.getBytes())); + + verify(builder).encryptionContext(isNull()); + } + + @Test + void apply_calls_decrypt_with_correct_values_when_encryption_context_is_present() { + Map encryptionContext = Map.of( + UUID.randomUUID().toString(), UUID.randomUUID().toString(), + UUID.randomUUID().toString(), UUID.randomUUID().toString(), + UUID.randomUUID().toString(), UUID.randomUUID().toString() + ); KmsKeyProvider objectUnderTest = createObjectUnderTest(); + + when(kmsConfig.getEncryptionContext()).thenReturn(encryptionContext); + try (MockedStatic kmsClientMockedStatic = mockStatic(KmsClient.class)) { kmsClientMockedStatic.when(() -> KmsClient.builder()).thenReturn(kmsClientBuilder); objectUnderTest.apply(topicConfig); @@ -119,6 +172,7 @@ void apply_calls_decrypt_with_correct_values() { DecryptRequest.Builder builder = mock(DecryptRequest.Builder.class); when(builder.keyId(anyString())).thenReturn(builder); when(builder.ciphertextBlob(any())).thenReturn(builder); + when(builder.encryptionContext(any())).thenReturn(builder); actualConsumer.accept(builder); verify(builder).keyId(kmsKeyId); @@ -127,6 +181,8 @@ void apply_calls_decrypt_with_correct_values() { SdkBytes actualSdkBytes = actualBytesCaptor.getValue(); assertThat(actualSdkBytes.asByteArray(), equalTo(encryptionKey.getBytes())); + + verify(builder).encryptionContext(encryptionContext); } } } \ No newline at end of file