Skip to content

Commit

Permalink
Cross account support (#8)
Browse files Browse the repository at this point in the history
* Add support for Canned access control lists

* Moved SSEAwsKeyManagementParams and CannedAccessControlList as properties in S3Dao
  • Loading branch information
adam-aws authored Sep 17, 2020
1 parent 07587a7 commit a784c1b
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 84 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ You can download release builds through the [releases section of this](https://g
<dependency>
<groupId>software.amazon.payloadoffloading</groupId>
<artifactId>payloadoffloading-common</artifactId>
<version>1.0.0</version>
<version>1.1.0</version>
<type>jar</type>
</dependency>
```
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>software.amazon.payloadoffloading</groupId>
<artifactId>payloadoffloading-common</artifactId>
<version>1.0.0</version>
<version>1.1.0</version>
<packaging>jar</packaging>
<name>Payload offloading common library for AWS</name>
<description>Common library between extended Amazon AWS clients to save payloads up to 2GB on Amazon S3.</description>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.amazonaws.AmazonClientException;
import com.amazonaws.annotation.NotThreadSafe;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.CannedAccessControlList;
import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand All @@ -20,6 +21,10 @@ public class PayloadStorageConfiguration {
private int payloadSizeThreshold = 0;
private boolean alwaysThroughS3 = false;
private boolean payloadSupport = false;
/**
* This field is optional, it is set only when we want to add access control list to Amazon S3 buckets and objects
*/
private CannedAccessControlList cannedAccessControlList;
/**
* This field is optional, it is set only when we want to configure S3 Server Side Encryption with KMS.
*/
Expand All @@ -29,6 +34,7 @@ public PayloadStorageConfiguration() {
s3 = null;
s3BucketName = null;
sseAwsKeyManagementParams = null;
cannedAccessControlList = null;
}

public PayloadStorageConfiguration(PayloadStorageConfiguration other) {
Expand All @@ -38,6 +44,7 @@ public PayloadStorageConfiguration(PayloadStorageConfiguration other) {
this.payloadSupport = other.isPayloadSupportEnabled();
this.alwaysThroughS3 = other.isAlwaysThroughS3();
this.payloadSizeThreshold = other.getPayloadSizeThreshold();
this.cannedAccessControlList = other.cannedAccessControlList;
}

/**
Expand Down Expand Up @@ -212,4 +219,39 @@ public boolean isAlwaysThroughS3() {
public void setAlwaysThroughS3(boolean alwaysThroughS3) {
this.alwaysThroughS3 = alwaysThroughS3;
}

/**
* Configures the ACL to apply to the Amazon S3 putObject request.
* @param cannedAccessControlList
* The ACL to be used when storing objects in Amazon S3
*/
public void setCannedAccessControlList(CannedAccessControlList cannedAccessControlList) {
this.cannedAccessControlList = cannedAccessControlList;
}

/**
* Configures the ACL to apply to the Amazon S3 putObject request.
* @param cannedAccessControlList
* The ACL to be used when storing objects in Amazon S3
*/
public PayloadStorageConfiguration withCannedAccessControlList(CannedAccessControlList cannedAccessControlList) {
setCannedAccessControlList(cannedAccessControlList);
return this;
}

/**
* Checks whether an ACL have been configured for storing objects in Amazon S3.
* @return True if ACL is defined
*/
public boolean isCannedAccessControlListDefined() {
return null != cannedAccessControlList;
}

/**
* Gets the AWS ACL to apply to the Amazon S3 putObject request.
* @return Amazon S3 object ACL
*/
public CannedAccessControlList getCannedAccessControlList() {
return cannedAccessControlList;
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package software.amazon.payloadoffloading;

import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

Expand All @@ -14,25 +13,18 @@ public class S3BackedPayloadStore implements PayloadStore {

private final String s3BucketName;
private final S3Dao s3Dao;
private final SSEAwsKeyManagementParams sseAwsKeyManagementParams;

public S3BackedPayloadStore(S3Dao s3Dao, String s3BucketName) {
this(s3Dao, s3BucketName, null);
}

public S3BackedPayloadStore(S3Dao s3Dao, String s3BucketName,
SSEAwsKeyManagementParams sseAwsKeyManagementParams) {
this.s3BucketName = s3BucketName;
this.s3Dao = s3Dao;
this.sseAwsKeyManagementParams = sseAwsKeyManagementParams;
}

@Override
public String storeOriginalPayload(String payload, Long payloadContentSize) {
String s3Key = UUID.randomUUID().toString();

// Store the payload content in S3.
s3Dao.storeTextInS3(s3BucketName, s3Key, sseAwsKeyManagementParams, payload, payloadContentSize);
s3Dao.storeTextInS3(s3BucketName, s3Key, payload, payloadContentSize);
LOG.info("S3 object created, Bucket name: " + s3BucketName + ", Object key: " + s3Key + ".");

// Convert S3 pointer (bucket name, key, etc) to JSON string
Expand Down
19 changes: 13 additions & 6 deletions src/main/java/software/amazon/payloadoffloading/S3Dao.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,17 @@
public class S3Dao {
private static final Log LOG = LogFactory.getLog(S3Dao.class);
private final AmazonS3 s3Client;
private final SSEAwsKeyManagementParams sseAwsKeyManagementParams;
private final CannedAccessControlList cannedAccessControlList;

public S3Dao(AmazonS3 s3Client) {
this(s3Client, null, null);
}

public S3Dao(AmazonS3 s3Client, SSEAwsKeyManagementParams sseAwsKeyManagementParams, CannedAccessControlList cannedAccessControlList) {
this.s3Client = s3Client;
this.sseAwsKeyManagementParams = sseAwsKeyManagementParams;
this.cannedAccessControlList = cannedAccessControlList;
}

public String getTextFromS3(String s3BucketName, String s3Key) {
Expand Down Expand Up @@ -60,14 +68,17 @@ public String getTextFromS3(String s3BucketName, String s3Key) {
return embeddedText;
}

public void storeTextInS3(String s3BucketName, String s3Key, SSEAwsKeyManagementParams sseAwsKeyManagementParams,
String payloadContentStr, Long payloadContentSize) {
public void storeTextInS3(String s3BucketName, String s3Key, String payloadContentStr, Long payloadContentSize) {
InputStream payloadContentStream = new ByteArrayInputStream(payloadContentStr.getBytes(StandardCharsets.UTF_8));
ObjectMetadata payloadContentStreamMetadata = new ObjectMetadata();
payloadContentStreamMetadata.setContentLength(payloadContentSize);
PutObjectRequest putObjectRequest = new PutObjectRequest(s3BucketName, s3Key,
payloadContentStream, payloadContentStreamMetadata);

if (cannedAccessControlList != null) {
putObjectRequest.withCannedAcl(cannedAccessControlList);
}

// https://docs.aws.amazon.com/AmazonS3/latest/dev/kms-using-sdks.html
if (sseAwsKeyManagementParams != null) {
LOG.debug("Using SSE-KMS in put object request.");
Expand All @@ -89,10 +100,6 @@ public void storeTextInS3(String s3BucketName, String s3Key, SSEAwsKeyManagement
}
}

public void storeTextInS3(String s3BucketName, String s3Key, String payloadContentStr, Long payloadContentSize) {
storeTextInS3(s3BucketName, s3Key, null, payloadContentStr, payloadContentSize);
}

public void deletePayloadFromS3(String s3BucketName, String s3Key) {
try {
s3Client.deleteObject(s3BucketName, s3Key);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package software.amazon.payloadoffloading;

import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.CannedAccessControlList;
import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -16,10 +17,12 @@ public class PayloadStorageConfigurationTest {
private static String s3BucketName = "test-bucket-name";
private static String s3ServerSideEncryptionKMSKeyId = "test-customer-managed-kms-key-id";
private SSEAwsKeyManagementParams sseAwsKeyManagementParams;
private CannedAccessControlList cannedAccessControlList;

@Before
public void setup() {
sseAwsKeyManagementParams = new SSEAwsKeyManagementParams(s3ServerSideEncryptionKMSKeyId);
cannedAccessControlList = CannedAccessControlList.BucketOwnerFullControl;
}

@Test
Expand All @@ -33,14 +36,16 @@ public void testCopyConstructor() {

payloadStorageConfiguration.withPayloadSupportEnabled(s3, s3BucketName)
.withAlwaysThroughS3(alwaysThroughS3).withPayloadSizeThreshold(payloadSizeThreshold)
.withSSEAwsKeyManagementParams(sseAwsKeyManagementParams);
.withSSEAwsKeyManagementParams(sseAwsKeyManagementParams)
.withCannedAccessControlList(cannedAccessControlList);

PayloadStorageConfiguration newPayloadStorageConfiguration = new PayloadStorageConfiguration(payloadStorageConfiguration);

assertEquals(s3, newPayloadStorageConfiguration.getAmazonS3Client());
assertEquals(s3BucketName, newPayloadStorageConfiguration.getS3BucketName());
assertEquals(sseAwsKeyManagementParams, newPayloadStorageConfiguration.getSSEAwsKeyManagementParams());
assertEquals(s3ServerSideEncryptionKMSKeyId, newPayloadStorageConfiguration.getSSEAwsKeyManagementParams().getAwsKmsKeyId());
assertEquals(cannedAccessControlList, newPayloadStorageConfiguration.getCannedAccessControlList());
assertTrue(newPayloadStorageConfiguration.isPayloadSupportEnabled());
assertEquals(alwaysThroughS3, newPayloadStorageConfiguration.isAlwaysThroughS3());
assertEquals(payloadSizeThreshold, newPayloadStorageConfiguration.getPayloadSizeThreshold());
Expand Down Expand Up @@ -88,4 +93,16 @@ public void testSseAwsKeyManagementParams() {
assertEquals(s3ServerSideEncryptionKMSKeyId, payloadStorageConfiguration.getSSEAwsKeyManagementParams()
.getAwsKmsKeyId());
}

@Test
public void testCannedAccessControlList() {

PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration();

assertFalse(payloadStorageConfiguration.isCannedAccessControlListDefined());

payloadStorageConfiguration.withCannedAccessControlList(cannedAccessControlList);
assertTrue(payloadStorageConfiguration.isCannedAccessControlListDefined());
assertEquals(cannedAccessControlList, payloadStorageConfiguration.getCannedAccessControlList());
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package software.amazon.payloadoffloading;

import com.amazonaws.AmazonClientException;
import com.amazonaws.services.s3.model.CannedAccessControlList;
import com.amazonaws.services.s3.model.SSEAwsKeyManagementParams;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.hamcrest.Matchers;
import org.junit.Before;
import org.junit.Rule;
Expand Down Expand Up @@ -35,65 +35,23 @@ public void setup() {
payloadStore = new S3BackedPayloadStore(s3Dao, S3_BUCKET_NAME);
}

private Object[] testData() {
// Here, we create separate mock of S3Dao because JUnitParamsRunner collects parameters
// for tests well before invocation of @Before or @BeforeClass methods.
// That means our default s3Dao mock isn't instantiated until then. For parameterized tests,
// we instantiate our local S3Dao mock per combination, pass it to S3BackedPayloadStore and also pass it
// as test parameter to allow verifying calls to the mockS3Dao.
S3Dao noEncryptionS3Dao = mock(S3Dao.class);
S3Dao defaultEncryptionS3Dao = mock(S3Dao.class);
S3Dao customerKMSKeyEncryptionS3Dao = mock(S3Dao.class);
return new Object[][]{
// No S3 SSE-KMS encryption
{
new S3BackedPayloadStore(noEncryptionS3Dao, S3_BUCKET_NAME),
null,
noEncryptionS3Dao
},
// S3 SSE-KMS encryption with AWS managed KMS keys
{
new S3BackedPayloadStore(defaultEncryptionS3Dao, S3_BUCKET_NAME, new SSEAwsKeyManagementParams()),
new SSEAwsKeyManagementParams(),
defaultEncryptionS3Dao
},
// S3 SSE-KMS encryption with customer managed KMS key
{
new S3BackedPayloadStore(customerKMSKeyEncryptionS3Dao, S3_BUCKET_NAME,
new SSEAwsKeyManagementParams(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID)),
new SSEAwsKeyManagementParams(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID),
customerKMSKeyEncryptionS3Dao
}
};
}

@Test
@Parameters(method = "testData")
public void testStoreOriginalPayloadOnSuccess(PayloadStore payloadStore,
SSEAwsKeyManagementParams expectedParams, S3Dao mockS3Dao) {
public void testStoreOriginalPayloadOnSuccess() {
String actualPayloadPointer = payloadStore.storeOriginalPayload(ANY_PAYLOAD, ANY_PAYLOAD_LENGTH);

ArgumentCaptor<String> keyCaptor = ArgumentCaptor.forClass(String.class);
ArgumentCaptor<SSEAwsKeyManagementParams> sseArgsCaptor = ArgumentCaptor.forClass(SSEAwsKeyManagementParams.class);
ArgumentCaptor<CannedAccessControlList> cannedArgsCaptor = ArgumentCaptor.forClass(CannedAccessControlList.class);

verify(mockS3Dao, times(1)).storeTextInS3(eq(S3_BUCKET_NAME), keyCaptor.capture(),
sseArgsCaptor.capture(), eq(ANY_PAYLOAD), eq(ANY_PAYLOAD_LENGTH));
verify(s3Dao, times(1)).storeTextInS3(eq(S3_BUCKET_NAME), keyCaptor.capture(),
eq(ANY_PAYLOAD), eq(ANY_PAYLOAD_LENGTH));

PayloadS3Pointer expectedPayloadPointer = new PayloadS3Pointer(S3_BUCKET_NAME, keyCaptor.getValue());
assertEquals(expectedPayloadPointer.toJson(), actualPayloadPointer);

if (expectedParams == null) {
assertTrue(sseArgsCaptor.getValue() == null);
} else {
assertEquals(expectedParams.getAwsKmsKeyId(), sseArgsCaptor.getValue().getAwsKmsKeyId());
}
}

@Test
@Parameters(method = "testData")
public void testStoreOriginalPayloadDoesAlwaysCreateNewObjects(PayloadStore payloadStore,
SSEAwsKeyManagementParams expectedParams,
S3Dao mockS3Dao) {
public void testStoreOriginalPayloadDoesAlwaysCreateNewObjects() {
//Store any payload
String anyActualPayloadPointer = payloadStore
.storeOriginalPayload(ANY_PAYLOAD, ANY_PAYLOAD_LENGTH);
Expand All @@ -104,11 +62,8 @@ public void testStoreOriginalPayloadDoesAlwaysCreateNewObjects(PayloadStore payl

ArgumentCaptor<String> anyOtherKeyCaptor = ArgumentCaptor.forClass(String.class);

ArgumentCaptor<SSEAwsKeyManagementParams> sseArgsCaptor = ArgumentCaptor
.forClass(SSEAwsKeyManagementParams.class);

verify(mockS3Dao, times(2)).storeTextInS3(eq(S3_BUCKET_NAME), anyOtherKeyCaptor.capture(),
sseArgsCaptor.capture(), eq(ANY_PAYLOAD), eq(ANY_PAYLOAD_LENGTH));
verify(s3Dao, times(2)).storeTextInS3(eq(S3_BUCKET_NAME), anyOtherKeyCaptor.capture(),
eq(ANY_PAYLOAD), eq(ANY_PAYLOAD_LENGTH));

String anyS3Key = anyOtherKeyCaptor.getAllValues().get(0);
String anyOtherS3Key = anyOtherKeyCaptor.getAllValues().get(1);
Expand All @@ -121,26 +76,15 @@ public void testStoreOriginalPayloadDoesAlwaysCreateNewObjects(PayloadStore payl

assertThat(anyS3Key, Matchers.not(anyOtherS3Key));
assertThat(anyActualPayloadPointer, Matchers.not(anyOtherActualPayloadPointer));

if (expectedParams == null) {
assertTrue(sseArgsCaptor.getAllValues().stream().allMatch(actualParams -> actualParams == null));
} else {
assertTrue(sseArgsCaptor.getAllValues().stream().allMatch(actualParams ->
(actualParams.getAwsKmsKeyId() == null && expectedParams.getAwsKmsKeyId() == null)
|| (actualParams.getAwsKmsKeyId().equals(expectedParams.getAwsKmsKeyId()))));
}
}

@Test
@Parameters(method = "testData")
public void testStoreOriginalPayloadOnS3Failure(PayloadStore payloadStore,
SSEAwsKeyManagementParams expectedParams, S3Dao mockS3Dao) {
public void testStoreOriginalPayloadOnS3Failure() {
doThrow(new AmazonClientException("S3 Exception"))
.when(mockS3Dao)
.when(s3Dao)
.storeTextInS3(
any(String.class),
any(String.class),
expectedParams == null ? isNull() : any(SSEAwsKeyManagementParams.class),
any(String.class),
any(Long.class));

Expand Down
Loading

0 comments on commit a784c1b

Please sign in to comment.