diff --git a/README.md b/README.md index f6eacf9..9bc27dc 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ You can download release builds through the [releases section of this](https://g software.amazon.payloadoffloading payloadoffloading-common - 1.0.0 + 1.1.0 jar ``` diff --git a/pom.xml b/pom.xml index 39c11f2..593f665 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ software.amazon.payloadoffloading payloadoffloading-common - 1.0.0 + 1.1.0 jar Payload offloading common library for AWS Common library between extended Amazon AWS clients to save payloads up to 2GB on Amazon S3. diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java index f1cf7c2..6ea9a70 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java @@ -3,6 +3,7 @@ import com.amazonaws.AmazonClientException; import com.amazonaws.annotation.NotThreadSafe; import com.amazonaws.services.s3.AmazonS3; +import com.amazonaws.services.s3.model.CannedAccessControlList; import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -20,6 +21,10 @@ public class PayloadStorageConfiguration { private int payloadSizeThreshold = 0; private boolean alwaysThroughS3 = false; private boolean payloadSupport = false; + /** + * This field is optional, it is set only when we want to add access control list to Amazon S3 buckets and objects + */ + private CannedAccessControlList cannedAccessControlList; /** * This field is optional, it is set only when we want to configure S3 Server Side Encryption with KMS. */ @@ -29,6 +34,7 @@ public PayloadStorageConfiguration() { s3 = null; s3BucketName = null; sseAwsKeyManagementParams = null; + cannedAccessControlList = null; } public PayloadStorageConfiguration(PayloadStorageConfiguration other) { @@ -38,6 +44,7 @@ public PayloadStorageConfiguration(PayloadStorageConfiguration other) { this.payloadSupport = other.isPayloadSupportEnabled(); this.alwaysThroughS3 = other.isAlwaysThroughS3(); this.payloadSizeThreshold = other.getPayloadSizeThreshold(); + this.cannedAccessControlList = other.cannedAccessControlList; } /** @@ -212,4 +219,39 @@ public boolean isAlwaysThroughS3() { public void setAlwaysThroughS3(boolean alwaysThroughS3) { this.alwaysThroughS3 = alwaysThroughS3; } + + /** + * Configures the ACL to apply to the Amazon S3 putObject request. + * @param cannedAccessControlList + * The ACL to be used when storing objects in Amazon S3 + */ + public void setCannedAccessControlList(CannedAccessControlList cannedAccessControlList) { + this.cannedAccessControlList = cannedAccessControlList; + } + + /** + * Configures the ACL to apply to the Amazon S3 putObject request. + * @param cannedAccessControlList + * The ACL to be used when storing objects in Amazon S3 + */ + public PayloadStorageConfiguration withCannedAccessControlList(CannedAccessControlList cannedAccessControlList) { + setCannedAccessControlList(cannedAccessControlList); + return this; + } + + /** + * Checks whether an ACL have been configured for storing objects in Amazon S3. + * @return True if ACL is defined + */ + public boolean isCannedAccessControlListDefined() { + return null != cannedAccessControlList; + } + + /** + * Gets the AWS ACL to apply to the Amazon S3 putObject request. + * @return Amazon S3 object ACL + */ + public CannedAccessControlList getCannedAccessControlList() { + return cannedAccessControlList; + } } diff --git a/src/main/java/software/amazon/payloadoffloading/S3BackedPayloadStore.java b/src/main/java/software/amazon/payloadoffloading/S3BackedPayloadStore.java index 7fe7965..401b68b 100644 --- a/src/main/java/software/amazon/payloadoffloading/S3BackedPayloadStore.java +++ b/src/main/java/software/amazon/payloadoffloading/S3BackedPayloadStore.java @@ -1,6 +1,5 @@ package software.amazon.payloadoffloading; -import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -14,17 +13,10 @@ public class S3BackedPayloadStore implements PayloadStore { private final String s3BucketName; private final S3Dao s3Dao; - private final SSEAwsKeyManagementParams sseAwsKeyManagementParams; public S3BackedPayloadStore(S3Dao s3Dao, String s3BucketName) { - this(s3Dao, s3BucketName, null); - } - - public S3BackedPayloadStore(S3Dao s3Dao, String s3BucketName, - SSEAwsKeyManagementParams sseAwsKeyManagementParams) { this.s3BucketName = s3BucketName; this.s3Dao = s3Dao; - this.sseAwsKeyManagementParams = sseAwsKeyManagementParams; } @Override @@ -32,7 +24,7 @@ public String storeOriginalPayload(String payload, Long payloadContentSize) { String s3Key = UUID.randomUUID().toString(); // Store the payload content in S3. - s3Dao.storeTextInS3(s3BucketName, s3Key, sseAwsKeyManagementParams, payload, payloadContentSize); + s3Dao.storeTextInS3(s3BucketName, s3Key, payload, payloadContentSize); LOG.info("S3 object created, Bucket name: " + s3BucketName + ", Object key: " + s3Key + "."); // Convert S3 pointer (bucket name, key, etc) to JSON string diff --git a/src/main/java/software/amazon/payloadoffloading/S3Dao.java b/src/main/java/software/amazon/payloadoffloading/S3Dao.java index a4c5c07..f556248 100644 --- a/src/main/java/software/amazon/payloadoffloading/S3Dao.java +++ b/src/main/java/software/amazon/payloadoffloading/S3Dao.java @@ -19,9 +19,17 @@ public class S3Dao { private static final Log LOG = LogFactory.getLog(S3Dao.class); private final AmazonS3 s3Client; + private final SSEAwsKeyManagementParams sseAwsKeyManagementParams; + private final CannedAccessControlList cannedAccessControlList; public S3Dao(AmazonS3 s3Client) { + this(s3Client, null, null); + } + + public S3Dao(AmazonS3 s3Client, SSEAwsKeyManagementParams sseAwsKeyManagementParams, CannedAccessControlList cannedAccessControlList) { this.s3Client = s3Client; + this.sseAwsKeyManagementParams = sseAwsKeyManagementParams; + this.cannedAccessControlList = cannedAccessControlList; } public String getTextFromS3(String s3BucketName, String s3Key) { @@ -60,14 +68,17 @@ public String getTextFromS3(String s3BucketName, String s3Key) { return embeddedText; } - public void storeTextInS3(String s3BucketName, String s3Key, SSEAwsKeyManagementParams sseAwsKeyManagementParams, - String payloadContentStr, Long payloadContentSize) { + public void storeTextInS3(String s3BucketName, String s3Key, String payloadContentStr, Long payloadContentSize) { InputStream payloadContentStream = new ByteArrayInputStream(payloadContentStr.getBytes(StandardCharsets.UTF_8)); ObjectMetadata payloadContentStreamMetadata = new ObjectMetadata(); payloadContentStreamMetadata.setContentLength(payloadContentSize); PutObjectRequest putObjectRequest = new PutObjectRequest(s3BucketName, s3Key, payloadContentStream, payloadContentStreamMetadata); + if (cannedAccessControlList != null) { + putObjectRequest.withCannedAcl(cannedAccessControlList); + } + // https://docs.aws.amazon.com/AmazonS3/latest/dev/kms-using-sdks.html if (sseAwsKeyManagementParams != null) { LOG.debug("Using SSE-KMS in put object request."); @@ -89,10 +100,6 @@ public void storeTextInS3(String s3BucketName, String s3Key, SSEAwsKeyManagement } } - public void storeTextInS3(String s3BucketName, String s3Key, String payloadContentStr, Long payloadContentSize) { - storeTextInS3(s3BucketName, s3Key, null, payloadContentStr, payloadContentSize); - } - public void deletePayloadFromS3(String s3BucketName, String s3Key) { try { s3Client.deleteObject(s3BucketName, s3Key); diff --git a/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java b/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java index 2c51438..1f1ee95 100644 --- a/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java +++ b/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java @@ -1,6 +1,7 @@ package software.amazon.payloadoffloading; import com.amazonaws.services.s3.AmazonS3; +import com.amazonaws.services.s3.model.CannedAccessControlList; import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams; import org.junit.Before; import org.junit.Test; @@ -16,10 +17,12 @@ public class PayloadStorageConfigurationTest { private static String s3BucketName = "test-bucket-name"; private static String s3ServerSideEncryptionKMSKeyId = "test-customer-managed-kms-key-id"; private SSEAwsKeyManagementParams sseAwsKeyManagementParams; + private CannedAccessControlList cannedAccessControlList; @Before public void setup() { sseAwsKeyManagementParams = new SSEAwsKeyManagementParams(s3ServerSideEncryptionKMSKeyId); + cannedAccessControlList = CannedAccessControlList.BucketOwnerFullControl; } @Test @@ -33,7 +36,8 @@ public void testCopyConstructor() { payloadStorageConfiguration.withPayloadSupportEnabled(s3, s3BucketName) .withAlwaysThroughS3(alwaysThroughS3).withPayloadSizeThreshold(payloadSizeThreshold) - .withSSEAwsKeyManagementParams(sseAwsKeyManagementParams); + .withSSEAwsKeyManagementParams(sseAwsKeyManagementParams) + .withCannedAccessControlList(cannedAccessControlList); PayloadStorageConfiguration newPayloadStorageConfiguration = new PayloadStorageConfiguration(payloadStorageConfiguration); @@ -41,6 +45,7 @@ public void testCopyConstructor() { assertEquals(s3BucketName, newPayloadStorageConfiguration.getS3BucketName()); assertEquals(sseAwsKeyManagementParams, newPayloadStorageConfiguration.getSSEAwsKeyManagementParams()); assertEquals(s3ServerSideEncryptionKMSKeyId, newPayloadStorageConfiguration.getSSEAwsKeyManagementParams().getAwsKmsKeyId()); + assertEquals(cannedAccessControlList, newPayloadStorageConfiguration.getCannedAccessControlList()); assertTrue(newPayloadStorageConfiguration.isPayloadSupportEnabled()); assertEquals(alwaysThroughS3, newPayloadStorageConfiguration.isAlwaysThroughS3()); assertEquals(payloadSizeThreshold, newPayloadStorageConfiguration.getPayloadSizeThreshold()); @@ -88,4 +93,16 @@ public void testSseAwsKeyManagementParams() { assertEquals(s3ServerSideEncryptionKMSKeyId, payloadStorageConfiguration.getSSEAwsKeyManagementParams() .getAwsKmsKeyId()); } + + @Test + public void testCannedAccessControlList() { + + PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration(); + + assertFalse(payloadStorageConfiguration.isCannedAccessControlListDefined()); + + payloadStorageConfiguration.withCannedAccessControlList(cannedAccessControlList); + assertTrue(payloadStorageConfiguration.isCannedAccessControlListDefined()); + assertEquals(cannedAccessControlList, payloadStorageConfiguration.getCannedAccessControlList()); + } } diff --git a/src/test/java/software/amazon/payloadoffloading/S3BackedPayloadStoreTest.java b/src/test/java/software/amazon/payloadoffloading/S3BackedPayloadStoreTest.java index f6bf2dc..b97c96d 100644 --- a/src/test/java/software/amazon/payloadoffloading/S3BackedPayloadStoreTest.java +++ b/src/test/java/software/amazon/payloadoffloading/S3BackedPayloadStoreTest.java @@ -1,9 +1,9 @@ package software.amazon.payloadoffloading; import com.amazonaws.AmazonClientException; +import com.amazonaws.services.s3.model.CannedAccessControlList; import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams; import junitparams.JUnitParamsRunner; -import junitparams.Parameters; import org.hamcrest.Matchers; import org.junit.Before; import org.junit.Rule; @@ -35,65 +35,23 @@ public void setup() { payloadStore = new S3BackedPayloadStore(s3Dao, S3_BUCKET_NAME); } - private Object[] testData() { - // Here, we create separate mock of S3Dao because JUnitParamsRunner collects parameters - // for tests well before invocation of @Before or @BeforeClass methods. - // That means our default s3Dao mock isn't instantiated until then. For parameterized tests, - // we instantiate our local S3Dao mock per combination, pass it to S3BackedPayloadStore and also pass it - // as test parameter to allow verifying calls to the mockS3Dao. - S3Dao noEncryptionS3Dao = mock(S3Dao.class); - S3Dao defaultEncryptionS3Dao = mock(S3Dao.class); - S3Dao customerKMSKeyEncryptionS3Dao = mock(S3Dao.class); - return new Object[][]{ - // No S3 SSE-KMS encryption - { - new S3BackedPayloadStore(noEncryptionS3Dao, S3_BUCKET_NAME), - null, - noEncryptionS3Dao - }, - // S3 SSE-KMS encryption with AWS managed KMS keys - { - new S3BackedPayloadStore(defaultEncryptionS3Dao, S3_BUCKET_NAME, new SSEAwsKeyManagementParams()), - new SSEAwsKeyManagementParams(), - defaultEncryptionS3Dao - }, - // S3 SSE-KMS encryption with customer managed KMS key - { - new S3BackedPayloadStore(customerKMSKeyEncryptionS3Dao, S3_BUCKET_NAME, - new SSEAwsKeyManagementParams(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID)), - new SSEAwsKeyManagementParams(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID), - customerKMSKeyEncryptionS3Dao - } - }; - } - @Test - @Parameters(method = "testData") - public void testStoreOriginalPayloadOnSuccess(PayloadStore payloadStore, - SSEAwsKeyManagementParams expectedParams, S3Dao mockS3Dao) { + public void testStoreOriginalPayloadOnSuccess() { String actualPayloadPointer = payloadStore.storeOriginalPayload(ANY_PAYLOAD, ANY_PAYLOAD_LENGTH); ArgumentCaptor keyCaptor = ArgumentCaptor.forClass(String.class); ArgumentCaptor sseArgsCaptor = ArgumentCaptor.forClass(SSEAwsKeyManagementParams.class); + ArgumentCaptor cannedArgsCaptor = ArgumentCaptor.forClass(CannedAccessControlList.class); - verify(mockS3Dao, times(1)).storeTextInS3(eq(S3_BUCKET_NAME), keyCaptor.capture(), - sseArgsCaptor.capture(), eq(ANY_PAYLOAD), eq(ANY_PAYLOAD_LENGTH)); + verify(s3Dao, times(1)).storeTextInS3(eq(S3_BUCKET_NAME), keyCaptor.capture(), + eq(ANY_PAYLOAD), eq(ANY_PAYLOAD_LENGTH)); PayloadS3Pointer expectedPayloadPointer = new PayloadS3Pointer(S3_BUCKET_NAME, keyCaptor.getValue()); assertEquals(expectedPayloadPointer.toJson(), actualPayloadPointer); - - if (expectedParams == null) { - assertTrue(sseArgsCaptor.getValue() == null); - } else { - assertEquals(expectedParams.getAwsKmsKeyId(), sseArgsCaptor.getValue().getAwsKmsKeyId()); - } } @Test - @Parameters(method = "testData") - public void testStoreOriginalPayloadDoesAlwaysCreateNewObjects(PayloadStore payloadStore, - SSEAwsKeyManagementParams expectedParams, - S3Dao mockS3Dao) { + public void testStoreOriginalPayloadDoesAlwaysCreateNewObjects() { //Store any payload String anyActualPayloadPointer = payloadStore .storeOriginalPayload(ANY_PAYLOAD, ANY_PAYLOAD_LENGTH); @@ -104,11 +62,8 @@ public void testStoreOriginalPayloadDoesAlwaysCreateNewObjects(PayloadStore payl ArgumentCaptor anyOtherKeyCaptor = ArgumentCaptor.forClass(String.class); - ArgumentCaptor sseArgsCaptor = ArgumentCaptor - .forClass(SSEAwsKeyManagementParams.class); - - verify(mockS3Dao, times(2)).storeTextInS3(eq(S3_BUCKET_NAME), anyOtherKeyCaptor.capture(), - sseArgsCaptor.capture(), eq(ANY_PAYLOAD), eq(ANY_PAYLOAD_LENGTH)); + verify(s3Dao, times(2)).storeTextInS3(eq(S3_BUCKET_NAME), anyOtherKeyCaptor.capture(), + eq(ANY_PAYLOAD), eq(ANY_PAYLOAD_LENGTH)); String anyS3Key = anyOtherKeyCaptor.getAllValues().get(0); String anyOtherS3Key = anyOtherKeyCaptor.getAllValues().get(1); @@ -121,26 +76,15 @@ public void testStoreOriginalPayloadDoesAlwaysCreateNewObjects(PayloadStore payl assertThat(anyS3Key, Matchers.not(anyOtherS3Key)); assertThat(anyActualPayloadPointer, Matchers.not(anyOtherActualPayloadPointer)); - - if (expectedParams == null) { - assertTrue(sseArgsCaptor.getAllValues().stream().allMatch(actualParams -> actualParams == null)); - } else { - assertTrue(sseArgsCaptor.getAllValues().stream().allMatch(actualParams -> - (actualParams.getAwsKmsKeyId() == null && expectedParams.getAwsKmsKeyId() == null) - || (actualParams.getAwsKmsKeyId().equals(expectedParams.getAwsKmsKeyId())))); - } } @Test - @Parameters(method = "testData") - public void testStoreOriginalPayloadOnS3Failure(PayloadStore payloadStore, - SSEAwsKeyManagementParams expectedParams, S3Dao mockS3Dao) { + public void testStoreOriginalPayloadOnS3Failure() { doThrow(new AmazonClientException("S3 Exception")) - .when(mockS3Dao) + .when(s3Dao) .storeTextInS3( any(String.class), any(String.class), - expectedParams == null ? isNull() : any(SSEAwsKeyManagementParams.class), any(String.class), any(Long.class)); diff --git a/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java b/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java new file mode 100644 index 0000000..eb1fdb8 --- /dev/null +++ b/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java @@ -0,0 +1,77 @@ +package software.amazon.payloadoffloading; + +import com.amazonaws.services.s3.AmazonS3; +import com.amazonaws.services.s3.model.CannedAccessControlList; +import com.amazonaws.services.s3.model.PutObjectRequest; +import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams; +import junitparams.JUnitParamsRunner; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +@RunWith(JUnitParamsRunner.class) +public class S3DaoTest { + + private static String s3ServerSideEncryptionKMSKeyId = "test-customer-managed-kms-key-id"; + private static final String S3_BUCKET_NAME = "test-bucket-name"; + private static final String ANY_PAYLOAD = "AnyPayload"; + private static final String ANY_S3_KEY = "AnyS3key"; + private static final Long ANY_PAYLOAD_LENGTH = 300000L; + private SSEAwsKeyManagementParams sseAwsKeyManagementParams; + private CannedAccessControlList cannedAccessControlList; + private AmazonS3 s3Client; + private S3Dao dao; + + @Before + public void setup() { + s3Client = mock(AmazonS3.class); + sseAwsKeyManagementParams = new SSEAwsKeyManagementParams(s3ServerSideEncryptionKMSKeyId); + cannedAccessControlList = CannedAccessControlList.BucketOwnerFullControl; + } + + @Test + public void storeTextInS3WithoutSSEOrCannedTest() { + dao = new S3Dao(s3Client); + ArgumentCaptor argument = ArgumentCaptor.forClass(PutObjectRequest.class); + + dao.storeTextInS3(S3_BUCKET_NAME, ANY_S3_KEY, ANY_PAYLOAD, ANY_PAYLOAD_LENGTH); + + verify(s3Client, times(1)).putObject(argument.capture()); + + assertNull(argument.getValue().getSSEAwsKeyManagementParams()); + assertNull(argument.getValue().getCannedAcl()); + assertEquals(S3_BUCKET_NAME, argument.getValue().getBucketName()); + } + + @Test + public void storeTextInS3WithSSETest() { + dao = new S3Dao(s3Client, sseAwsKeyManagementParams, null); + ArgumentCaptor argument = ArgumentCaptor.forClass(PutObjectRequest.class); + + dao.storeTextInS3(S3_BUCKET_NAME, ANY_S3_KEY, ANY_PAYLOAD, ANY_PAYLOAD_LENGTH); + + verify(s3Client, times(1)).putObject(argument.capture()); + + assertEquals(sseAwsKeyManagementParams, argument.getValue().getSSEAwsKeyManagementParams()); + assertNull(argument.getValue().getCannedAcl()); + assertEquals(S3_BUCKET_NAME, argument.getValue().getBucketName()); + } + + @Test + public void storeTextInS3WithBothTest() { + dao = new S3Dao(s3Client, sseAwsKeyManagementParams, cannedAccessControlList); + ArgumentCaptor argument = ArgumentCaptor.forClass(PutObjectRequest.class); + + dao.storeTextInS3(S3_BUCKET_NAME, ANY_S3_KEY, ANY_PAYLOAD, ANY_PAYLOAD_LENGTH); + + verify(s3Client, times(1)).putObject(argument.capture()); + + assertEquals(sseAwsKeyManagementParams, argument.getValue().getSSEAwsKeyManagementParams()); + assertEquals(cannedAccessControlList, argument.getValue().getCannedAcl()); + assertEquals(S3_BUCKET_NAME, argument.getValue().getBucketName()); + } +}