diff --git a/powertools-sqs/pom.xml b/powertools-sqs/pom.xml index 6bd975945..737ce9f79 100644 --- a/powertools-sqs/pom.xml +++ b/powertools-sqs/pom.xml @@ -54,12 +54,12 @@ aws-lambda-java-events - software.amazon.payloadoffloading - payloadoffloading-common + software.amazon.awssdk + sqs software.amazon.awssdk - sqs + s3 com.fasterxml.jackson.core diff --git a/powertools-sqs/src/main/java/software/amazon/lambda/powertools/sqs/SqsUtils.java b/powertools-sqs/src/main/java/software/amazon/lambda/powertools/sqs/SqsUtils.java index d69e7aec8..d962135a6 100644 --- a/powertools-sqs/src/main/java/software/amazon/lambda/powertools/sqs/SqsUtils.java +++ b/powertools-sqs/src/main/java/software/amazon/lambda/powertools/sqs/SqsUtils.java @@ -22,14 +22,13 @@ import com.amazonaws.services.lambda.runtime.events.SQSEvent; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; - +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.sqs.SqsClient; import software.amazon.lambda.powertools.sqs.internal.BatchContext; -import software.amazon.lambda.powertools.sqs.internal.SqsLargeMessageAspect; import software.amazon.payloadoffloading.PayloadS3Pointer; +import software.amazon.lambda.powertools.sqs.internal.SqsLargeMessageAspect; import static com.amazonaws.services.lambda.runtime.events.SQSEvent.SQSMessage; import static software.amazon.lambda.powertools.sqs.internal.SqsLargeMessageAspect.processMessages; @@ -42,6 +41,7 @@ public final class SqsUtils { private static final ObjectMapper objectMapper = new ObjectMapper(); private static SqsClient client; + private static S3Client s3Client; private SqsUtils() { } @@ -98,6 +98,16 @@ public static void overrideSqsClient(SqsClient client) { SqsUtils.client = client; } + /** + * Provides ability to set default {@link S3Client} to be used by utility. + * If no default configuration is provided, client is instantiated via {@link S3Client#create()} + * + * @param s3Client {@link S3Client} to be used by utility + */ + public static void overrideS3Client(S3Client s3Client) { + SqsUtils.s3Client = s3Client; + } + /** * This utility method is used to process each {@link SQSMessage} inside the received {@link SQSEvent} * @@ -524,4 +534,12 @@ private static SQSMessage clonedMessage(final SQSMessage sqsMessage) { public static ObjectMapper objectMapper() { return objectMapper; } + + public static S3Client s3Client() { + if(null == s3Client) { + SqsUtils.s3Client = S3Client.create(); + } + + return s3Client; + } } diff --git a/powertools-sqs/src/main/java/software/amazon/lambda/powertools/sqs/internal/SqsLargeMessageAspect.java b/powertools-sqs/src/main/java/software/amazon/lambda/powertools/sqs/internal/SqsLargeMessageAspect.java index 8698cc737..072d903d0 100644 --- a/powertools-sqs/src/main/java/software/amazon/lambda/powertools/sqs/internal/SqsLargeMessageAspect.java +++ b/powertools-sqs/src/main/java/software/amazon/lambda/powertools/sqs/internal/SqsLargeMessageAspect.java @@ -5,35 +5,33 @@ import java.util.List; import java.util.function.Function; -import com.amazonaws.AmazonServiceException; -import com.amazonaws.SdkClientException; import com.amazonaws.services.lambda.runtime.Context; import com.amazonaws.services.lambda.runtime.events.SQSEvent; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; -import com.amazonaws.util.IOUtils; - import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.annotation.Around; import org.aspectj.lang.annotation.Aspect; import org.aspectj.lang.annotation.Pointcut; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.S3Exception; +import software.amazon.awssdk.utils.IoUtils; import software.amazon.lambda.powertools.sqs.SqsLargeMessage; import software.amazon.payloadoffloading.PayloadS3Pointer; import static com.amazonaws.services.lambda.runtime.events.SQSEvent.SQSMessage; import static java.lang.String.format; import static software.amazon.lambda.powertools.core.internal.LambdaHandlerProcessor.isHandlerMethod; +import static software.amazon.lambda.powertools.sqs.SqsUtils.s3Client; @Aspect public class SqsLargeMessageAspect { private static final Logger LOG = LoggerFactory.getLogger(SqsLargeMessageAspect.class); - private static AmazonS3 amazonS3 = AmazonS3ClientBuilder.defaultClient(); @SuppressWarnings({"EmptyMethod"}) @Pointcut("@annotation(sqsLargeMessage)") @@ -52,7 +50,7 @@ && placedOnSqsEventRequestHandler(pjp)) { Object proceed = pjp.proceed(proceedArgs); if (sqsLargeMessage.deletePayloads()) { - pointersToDelete.forEach(this::deleteMessageFromS3); + pointersToDelete.forEach(SqsLargeMessageAspect::deleteMessage); } return proceed; } @@ -69,15 +67,21 @@ public static List processMessages(final List reco List s3Pointers = new ArrayList<>(); for (SQSMessage sqsMessage : records) { if (isBodyLargeMessagePointer(sqsMessage.getBody())) { - PayloadS3Pointer s3Pointer = PayloadS3Pointer.fromJson(sqsMessage.getBody()); - S3Object s3Object = callS3Gracefully(s3Pointer, pointer -> { - S3Object object = amazonS3.getObject(pointer.getS3BucketName(), pointer.getS3Key()); + PayloadS3Pointer s3Pointer = PayloadS3Pointer.fromJson(sqsMessage.getBody()) + .orElseThrow(() -> new FailedProcessingLargePayloadException(format("Failed processing SQS body to extract S3 details. [ %s ].", sqsMessage.getBody()))); + + ResponseInputStream s3Object = callS3Gracefully(s3Pointer, pointer -> { + ResponseInputStream response = s3Client().getObject(GetObjectRequest.builder() + .bucket(pointer.getS3BucketName()) + .key(pointer.getS3Key()) + .build()); + LOG.debug("Object downloaded with key: " + s3Pointer.getS3Key()); - return object; + return response; }); - sqsMessage.setBody(readStringFromS3Object(s3Object)); + sqsMessage.setBody(readStringFromS3Object(s3Object, s3Pointer)); s3Pointers.add(s3Pointer); } } @@ -89,26 +93,22 @@ private static boolean isBodyLargeMessagePointer(String record) { return record.startsWith("[\"software.amazon.payloadoffloading.PayloadS3Pointer\""); } - private static String readStringFromS3Object(S3Object object) { - try (S3ObjectInputStream is = object.getObjectContent()) { - return IOUtils.toString(is); + private static String readStringFromS3Object(ResponseInputStream response, + PayloadS3Pointer s3Pointer) { + try (ResponseInputStream content = response) { + return IoUtils.toUtf8String(content); } catch (IOException e) { LOG.error("Error converting S3 object to String", e); - throw new FailedProcessingLargePayloadException(format("Failed processing S3 record with [Bucket Name: %s Bucket Key: %s]", object.getBucketName(), object.getKey()), e); + throw new FailedProcessingLargePayloadException(format("Failed processing S3 record with [Bucket Name: %s Bucket Key: %s]", s3Pointer.getS3BucketName(), s3Pointer.getS3Key()), e); } } - private void deleteMessageFromS3(PayloadS3Pointer s3Pointer) { - callS3Gracefully(s3Pointer, pointer -> { - amazonS3.deleteObject(s3Pointer.getS3BucketName(), s3Pointer.getS3Key()); - LOG.info("Message deleted from S3: " + s3Pointer.toJson()); - return null; - }); - } - public static void deleteMessage(PayloadS3Pointer s3Pointer) { callS3Gracefully(s3Pointer, pointer -> { - amazonS3.deleteObject(s3Pointer.getS3BucketName(), s3Pointer.getS3Key()); + s3Client().deleteObject(DeleteObjectRequest.builder() + .bucket(pointer.getS3BucketName()) + .key(pointer.getS3Key()) + .build()); LOG.info("Message deleted from S3: " + s3Pointer.toJson()); return null; }); @@ -118,7 +118,7 @@ private static R callS3Gracefully(final PayloadS3Pointer pointer, final Function function) { try { return function.apply(pointer); - } catch (AmazonServiceException e) { + } catch (S3Exception e) { LOG.error("A service exception", e); throw new FailedProcessingLargePayloadException(format("Failed processing S3 record with [Bucket Name: %s Bucket Key: %s]", pointer.getS3BucketName(), pointer.getS3Key()), e); } catch (SdkClientException e) { @@ -137,5 +137,9 @@ public static class FailedProcessingLargePayloadException extends RuntimeExcepti public FailedProcessingLargePayloadException(String message, Throwable cause) { super(message, cause); } + + public FailedProcessingLargePayloadException(String message) { + super(message); + } } } diff --git a/powertools-sqs/src/main/java/software/amazon/payloadoffloading/PayloadS3Pointer.java b/powertools-sqs/src/main/java/software/amazon/payloadoffloading/PayloadS3Pointer.java new file mode 100644 index 000000000..078b9a773 --- /dev/null +++ b/powertools-sqs/src/main/java/software/amazon/payloadoffloading/PayloadS3Pointer.java @@ -0,0 +1,59 @@ +package software.amazon.payloadoffloading; + +import java.util.Optional; + +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.ObjectWriter; +import com.fasterxml.jackson.databind.SerializationFeature; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static java.util.Optional.empty; +import static java.util.Optional.ofNullable; + +public class PayloadS3Pointer { + private static final Logger LOG = LoggerFactory.getLogger(PayloadS3Pointer.class); + private static final ObjectMapper objectMapper = new ObjectMapper(); + + private String s3BucketName; + private String s3Key; + + static { + objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + objectMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); + objectMapper.activateDefaultTyping(objectMapper.getPolymorphicTypeValidator(), ObjectMapper.DefaultTyping.NON_FINAL); + } + + private PayloadS3Pointer() { + + } + + public String getS3BucketName() { + return this.s3BucketName; + } + + public String getS3Key() { + return this.s3Key; + } + + public static Optional fromJson(String s3PointerJson) { + try { + return ofNullable(objectMapper.readValue(s3PointerJson, PayloadS3Pointer.class)); + } catch (Exception e) { + LOG.error("Failed to read the S3 object pointer from given string.", e); + return empty(); + } + } + + public Optional toJson() { + try { + ObjectWriter objectWriter = objectMapper.writer(); + return ofNullable(objectWriter.writeValueAsString(this)); + + } catch (Exception e) { + LOG.error("Failed to convert S3 object pointer to text.", e); + return empty(); + } + } +} diff --git a/powertools-sqs/src/test/java/software/amazon/lambda/powertools/sqs/SqsUtilsLargeMessageTest.java b/powertools-sqs/src/test/java/software/amazon/lambda/powertools/sqs/SqsUtilsLargeMessageTest.java index d704b04e0..48de3e6a9 100644 --- a/powertools-sqs/src/test/java/software/amazon/lambda/powertools/sqs/SqsUtilsLargeMessageTest.java +++ b/powertools-sqs/src/test/java/software/amazon/lambda/powertools/sqs/SqsUtilsLargeMessageTest.java @@ -4,31 +4,35 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; +import java.util.function.Consumer; import java.util.stream.Stream; -import com.amazonaws.AmazonServiceException; -import com.amazonaws.SdkClientException; import com.amazonaws.services.lambda.runtime.events.SQSEvent; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; -import com.amazonaws.util.StringInputStream; -import org.apache.http.client.methods.HttpRequestBase; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.http.AbortableInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.S3Exception; +import software.amazon.awssdk.utils.StringInputStream; import software.amazon.lambda.powertools.sqs.internal.SqsLargeMessageAspect; import static com.amazonaws.services.lambda.runtime.events.SQSEvent.SQSMessage; import static java.util.Collections.singletonList; -import static org.apache.commons.lang3.reflect.FieldUtils.writeStaticField; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.mockito.Mockito.mock; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -38,22 +42,21 @@ class SqsUtilsLargeMessageTest { @Mock - private AmazonS3 amazonS3; + private S3Client s3Client; private static final String BUCKET_NAME = "ms-extended-sqs-client"; private static final String BUCKET_KEY = "c71eb2ae-37e0-4265-8909-32f4153faddf"; @BeforeEach - void setUp() throws IllegalAccessException { + void setUp() { openMocks(this); - writeStaticField(SqsLargeMessageAspect.class, "amazonS3", amazonS3, true); + SqsUtils.overrideS3Client(s3Client); } @Test public void testLargeMessage() { - S3Object s3Response = new S3Object(); - s3Response.setObjectContent(new ByteArrayInputStream("A big message".getBytes())); + ResponseInputStream s3Response = new ResponseInputStream<>(GetObjectResponse.builder().build(), AbortableInputStream.create(new ByteArrayInputStream("A big message".getBytes()))); + when(s3Client.getObject(any(GetObjectRequest.class))).thenReturn(s3Response); - when(amazonS3.getObject(BUCKET_NAME, BUCKET_KEY)).thenReturn(s3Response); SQSEvent sqsEvent = messageWithBody("[\"software.amazon.payloadoffloading.PayloadS3Pointer\",{\"s3BucketName\":\"" + BUCKET_NAME + "\",\"s3Key\":\"" + BUCKET_KEY + "\"}]"); Map sqsMessage = SqsUtils.enrichedMessageFromS3(sqsEvent, sqsMessages -> { @@ -66,16 +69,27 @@ public void testLargeMessage() { .hasSize(1) .containsEntry("Message", "A big message"); - verify(amazonS3).deleteObject(BUCKET_NAME, BUCKET_KEY); + ArgumentCaptor delete = ArgumentCaptor.forClass(DeleteObjectRequest.class); + + verify(s3Client).deleteObject(delete.capture()); + + Assertions.assertThat(delete.getValue()) + .satisfies((Consumer) deleteObjectRequest -> { + assertThat(deleteObjectRequest.bucket()) + .isEqualTo(BUCKET_NAME); + + assertThat(deleteObjectRequest.key()) + .isEqualTo(BUCKET_KEY); + }); } @ParameterizedTest @ValueSource(booleans = {true, false}) public void testLargeMessageDeleteFromS3Toggle(boolean deleteS3Payload) { - S3Object s3Response = new S3Object(); - s3Response.setObjectContent(new ByteArrayInputStream("A big message".getBytes())); + ResponseInputStream s3Response = new ResponseInputStream<>(GetObjectResponse.builder().build(), AbortableInputStream.create(new ByteArrayInputStream("A big message".getBytes()))); + + when(s3Client.getObject(any(GetObjectRequest.class))).thenReturn(s3Response); - when(amazonS3.getObject(BUCKET_NAME, BUCKET_KEY)).thenReturn(s3Response); SQSEvent sqsEvent = messageWithBody("[\"software.amazon.payloadoffloading.PayloadS3Pointer\",{\"s3BucketName\":\"" + BUCKET_NAME + "\",\"s3Key\":\"" + BUCKET_KEY + "\"}]"); Map sqsMessage = SqsUtils.enrichedMessageFromS3(sqsEvent, deleteS3Payload, sqsMessages -> { @@ -88,18 +102,29 @@ public void testLargeMessageDeleteFromS3Toggle(boolean deleteS3Payload) { .hasSize(1) .containsEntry("Message", "A big message"); if (deleteS3Payload) { - verify(amazonS3).deleteObject(BUCKET_NAME, BUCKET_KEY); + ArgumentCaptor delete = ArgumentCaptor.forClass(DeleteObjectRequest.class); + + verify(s3Client).deleteObject(delete.capture()); + + Assertions.assertThat(delete.getValue()) + .satisfies((Consumer) deleteObjectRequest -> { + assertThat(deleteObjectRequest.bucket()) + .isEqualTo(BUCKET_NAME); + + assertThat(deleteObjectRequest.key()) + .isEqualTo(BUCKET_KEY); + }); } else { - verify(amazonS3, never()).deleteObject(BUCKET_NAME, BUCKET_KEY); + verify(s3Client, never()).deleteObject(any(DeleteObjectRequest.class)); } } @Test public void shouldNotProcessSmallMessageBody() { - S3Object s3Response = new S3Object(); - s3Response.setObjectContent(new ByteArrayInputStream("A big message".getBytes())); + ResponseInputStream s3Response = new ResponseInputStream<>(GetObjectResponse.builder().build(), AbortableInputStream.create(new ByteArrayInputStream("A big message".getBytes()))); + + when(s3Client.getObject(any(GetObjectRequest.class))).thenReturn(s3Response); - when(amazonS3.getObject(BUCKET_NAME, BUCKET_KEY)).thenReturn(s3Response); SQSEvent sqsEvent = messageWithBody("This is small message"); Map sqsMessage = SqsUtils.enrichedMessageFromS3(sqsEvent, sqsMessages -> { @@ -111,13 +136,13 @@ public void shouldNotProcessSmallMessageBody() { assertThat(sqsMessage) .containsEntry("Message", "This is small message"); - verifyNoInteractions(amazonS3); + verifyNoInteractions(s3Client); } @ParameterizedTest @MethodSource("exception") public void shouldFailEntireBatchIfFailedDownloadingFromS3(RuntimeException exception) { - when(amazonS3.getObject(BUCKET_NAME, BUCKET_KEY)).thenThrow(exception); + when(s3Client.getObject(any(GetObjectRequest.class))).thenThrow(exception); String messageBody = "[\"software.amazon.payloadoffloading.PayloadS3Pointer\",{\"s3BucketName\":\"" + BUCKET_NAME + "\",\"s3Key\":\"" + BUCKET_KEY + "\"}]"; SQSEvent sqsEvent = messageWithBody(messageBody); @@ -126,21 +151,19 @@ public void shouldFailEntireBatchIfFailedDownloadingFromS3(RuntimeException exce .isThrownBy(() -> SqsUtils.enrichedMessageFromS3(sqsEvent, sqsMessages -> sqsMessages.get(0).getBody())) .withCause(exception); - verify(amazonS3, never()).deleteObject(BUCKET_NAME, BUCKET_KEY); + verify(s3Client, never()).deleteObject(any(DeleteObjectRequest.class)); } @Test - public void shouldFailEntireBatchIfFailedProcessingDownloadMessageFromS3() throws IOException { - S3Object s3Response = new S3Object(); - - s3Response.setObjectContent(new S3ObjectInputStream(new StringInputStream("test") { + public void shouldFailEntireBatchIfFailedProcessingDownloadMessageFromS3() { + ResponseInputStream s3Response = new ResponseInputStream<>(GetObjectResponse.builder().build(), AbortableInputStream.create(new StringInputStream("test") { @Override public void close() throws IOException { throw new IOException("Failed"); } - }, mock(HttpRequestBase.class))); + })); - when(amazonS3.getObject(BUCKET_NAME, BUCKET_KEY)).thenReturn(s3Response); + when(s3Client.getObject(any(GetObjectRequest.class))).thenReturn(s3Response); String messageBody = "[\"software.amazon.payloadoffloading.PayloadS3Pointer\",{\"s3BucketName\":\"" + BUCKET_NAME + "\",\"s3Key\":\"" + BUCKET_KEY + "\"}]"; SQSEvent sqsEvent = messageWithBody(messageBody); @@ -149,12 +172,16 @@ public void close() throws IOException { .isThrownBy(() -> SqsUtils.enrichedMessageFromS3(sqsEvent, sqsMessages -> sqsMessages.get(0).getBody())) .withCauseInstanceOf(IOException.class); - verify(amazonS3, never()).deleteObject(BUCKET_NAME, BUCKET_KEY); + verify(s3Client, never()).deleteObject(any(DeleteObjectRequest.class)); } private static Stream exception() { - return Stream.of(Arguments.of(new AmazonServiceException("Service Exception")), - Arguments.of(new SdkClientException("Client Exception"))); + return Stream.of(Arguments.of(S3Exception.builder() + .message("Service Exception") + .build()), + Arguments.of(SdkClientException.builder() + .message("Client Exception") + .build())); } private SQSEvent messageWithBody(String messageBody) { diff --git a/powertools-sqs/src/test/java/software/amazon/lambda/powertools/sqs/internal/SqsLargeMessageAspectTest.java b/powertools-sqs/src/test/java/software/amazon/lambda/powertools/sqs/internal/SqsLargeMessageAspectTest.java index 1837686f1..22844ab4c 100644 --- a/powertools-sqs/src/test/java/software/amazon/lambda/powertools/sqs/internal/SqsLargeMessageAspectTest.java +++ b/powertools-sqs/src/test/java/software/amazon/lambda/powertools/sqs/internal/SqsLargeMessageAspectTest.java @@ -2,35 +2,40 @@ import java.io.ByteArrayInputStream; import java.io.IOException; +import java.util.function.Consumer; import java.util.stream.Stream; -import com.amazonaws.AmazonServiceException; -import com.amazonaws.SdkClientException; import com.amazonaws.services.lambda.runtime.Context; import com.amazonaws.services.lambda.runtime.RequestHandler; import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent; import com.amazonaws.services.lambda.runtime.events.SQSEvent; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; -import com.amazonaws.util.StringInputStream; -import org.apache.http.client.methods.HttpRequestBase; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.http.AbortableInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.S3Exception; +import software.amazon.awssdk.utils.StringInputStream; +import software.amazon.lambda.powertools.sqs.SqsUtils; import software.amazon.lambda.powertools.sqs.handlers.LambdaHandlerApiGateway; import software.amazon.lambda.powertools.sqs.handlers.SqsMessageHandler; import software.amazon.lambda.powertools.sqs.handlers.SqsNoDeleteMessageHandler; import static com.amazonaws.services.lambda.runtime.events.SQSEvent.SQSMessage; import static java.util.Collections.singletonList; -import static org.apache.commons.lang3.reflect.FieldUtils.writeStaticField; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.mockito.Mockito.mock; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -46,22 +51,22 @@ public class SqsLargeMessageAspectTest { private Context context; @Mock - private AmazonS3 amazonS3; + private S3Client s3Client; private static final String BUCKET_NAME = "bucketname"; private static final String BUCKET_KEY = "c71eb2ae-37e0-4265-8909-32f4153faddf"; @BeforeEach - void setUp() throws IllegalAccessException { + void setUp() { openMocks(this); setupContext(); - writeStaticField(SqsLargeMessageAspect.class, "amazonS3", amazonS3, true); + SqsUtils.overrideS3Client(s3Client); requestHandler = new SqsMessageHandler(); } @Test public void testLargeMessage() { - when(amazonS3.getObject(BUCKET_NAME, BUCKET_KEY)).thenReturn(s3ObjectWithLargeMessage()); + when(s3Client.getObject(any(GetObjectRequest.class))).thenReturn(s3ObjectWithLargeMessage()); SQSEvent sqsEvent = messageWithBody("[\"software.amazon.payloadoffloading.PayloadS3Pointer\",{\"s3BucketName\":\"" + BUCKET_NAME + "\",\"s3Key\":\"" + BUCKET_KEY + "\"}]"); String response = requestHandler.handleRequest(sqsEvent, context); @@ -69,15 +74,24 @@ public void testLargeMessage() { assertThat(response) .isEqualTo("A big message"); - verify(amazonS3).deleteObject(BUCKET_NAME, BUCKET_KEY); + ArgumentCaptor delete = ArgumentCaptor.forClass(DeleteObjectRequest.class); + + verify(s3Client).deleteObject(delete.capture()); + + Assertions.assertThat(delete.getValue()) + .satisfies((Consumer) deleteObjectRequest -> { + assertThat(deleteObjectRequest.bucket()) + .isEqualTo(BUCKET_NAME); + + assertThat(deleteObjectRequest.key()) + .isEqualTo(BUCKET_KEY); + }); } @Test public void shouldNotProcessSmallMessageBody() { - S3Object s3Response = new S3Object(); - s3Response.setObjectContent(new ByteArrayInputStream("A big message".getBytes())); + when(s3Client.getObject(any(GetObjectRequest.class))).thenReturn(s3ObjectWithLargeMessage()); - when(amazonS3.getObject(BUCKET_NAME, BUCKET_KEY)).thenReturn(s3Response); SQSEvent sqsEvent = messageWithBody("This is small message"); String response = requestHandler.handleRequest(sqsEvent, context); @@ -85,13 +99,13 @@ public void shouldNotProcessSmallMessageBody() { assertThat(response) .isEqualTo("This is small message"); - verifyNoInteractions(amazonS3); + verifyNoInteractions(s3Client); } @ParameterizedTest @MethodSource("exception") public void shouldFailEntireBatchIfFailedDownloadingFromS3(RuntimeException exception) { - when(amazonS3.getObject(BUCKET_NAME, BUCKET_KEY)).thenThrow(exception); + when(s3Client.getObject(any(GetObjectRequest.class))).thenThrow(exception); String messageBody = "[\"software.amazon.payloadoffloading.PayloadS3Pointer\",{\"s3BucketName\":\"" + BUCKET_NAME + "\",\"s3Key\":\"" + BUCKET_KEY + "\"}]"; SQSEvent sqsEvent = messageWithBody(messageBody); @@ -100,35 +114,34 @@ public void shouldFailEntireBatchIfFailedDownloadingFromS3(RuntimeException exce .isThrownBy(() -> requestHandler.handleRequest(sqsEvent, context)) .withCause(exception); - verify(amazonS3, never()).deleteObject(BUCKET_NAME, BUCKET_KEY); + verify(s3Client, never()).deleteObject(any(DeleteObjectRequest.class)); } @Test public void testLargeMessageWithDeletionOff() { requestHandler = new SqsNoDeleteMessageHandler(); - when(amazonS3.getObject(BUCKET_NAME, BUCKET_KEY)).thenReturn(s3ObjectWithLargeMessage()); + when(s3Client.getObject(any(GetObjectRequest.class))).thenReturn(s3ObjectWithLargeMessage()); SQSEvent sqsEvent = messageWithBody("[\"software.amazon.payloadoffloading.PayloadS3Pointer\",{\"s3BucketName\":\"" + BUCKET_NAME + "\",\"s3Key\":\"" + BUCKET_KEY + "\"}]"); String response = requestHandler.handleRequest(sqsEvent, context); assertThat(response).isEqualTo("A big message"); - verify(amazonS3, never()).deleteObject(BUCKET_NAME, BUCKET_KEY); + verify(s3Client, never()).deleteObject(any(DeleteObjectRequest.class)); } @Test - public void shouldFailEntireBatchIfFailedProcessingDownloadMessageFromS3() throws IOException { - S3Object s3Response = new S3Object(); - s3Response.setObjectContent(new S3ObjectInputStream(new StringInputStream("test") { + public void shouldFailEntireBatchIfFailedProcessingDownloadMessageFromS3() { + ResponseInputStream s3Response = new ResponseInputStream<>(GetObjectResponse.builder().build(), AbortableInputStream.create(new StringInputStream("test") { @Override public void close() throws IOException { throw new IOException("Failed"); } - }, mock(HttpRequestBase.class))); + })); - when(amazonS3.getObject(BUCKET_NAME, BUCKET_KEY)).thenReturn(s3Response); + when(s3Client.getObject(any(GetObjectRequest.class))).thenReturn(s3Response); String messageBody = "[\"software.amazon.payloadoffloading.PayloadS3Pointer\",{\"s3BucketName\":\"" + BUCKET_NAME + "\",\"s3Key\":\"" + BUCKET_KEY + "\"}]"; SQSEvent sqsEvent = messageWithBody(messageBody); @@ -137,7 +150,7 @@ public void close() throws IOException { .isThrownBy(() -> requestHandler.handleRequest(sqsEvent, context)) .withCauseInstanceOf(IOException.class); - verify(amazonS3, never()).deleteObject(BUCKET_NAME, BUCKET_KEY); + verify(s3Client, never()).deleteObject(any(DeleteObjectRequest.class)); } @Test @@ -153,18 +166,20 @@ public void shouldNotDoAnyProcessingWhenNotSqsEvent() { assertThat(response) .isEqualTo(messageBody); - verifyNoInteractions(amazonS3); + verifyNoInteractions(s3Client); } - private S3Object s3ObjectWithLargeMessage() { - S3Object s3Response = new S3Object(); - s3Response.setObjectContent(new ByteArrayInputStream("A big message".getBytes())); - return s3Response; + private ResponseInputStream s3ObjectWithLargeMessage() { + return new ResponseInputStream<>(GetObjectResponse.builder().build(), AbortableInputStream.create(new ByteArrayInputStream("A big message".getBytes()))); } private static Stream exception() { - return Stream.of(Arguments.of(new AmazonServiceException("Service Exception")), - Arguments.of(new SdkClientException("Client Exception"))); + return Stream.of(Arguments.of(S3Exception.builder() + .message("Service Exception") + .build()), + Arguments.of(SdkClientException.builder() + .message("Client Exception") + .build())); } private SQSEvent messageWithBody(String messageBody) { diff --git a/powertools-test-suite/src/test/java/software/amazon/lambda/powertools/testsuite/LoggingOrderTest.java b/powertools-test-suite/src/test/java/software/amazon/lambda/powertools/testsuite/LoggingOrderTest.java index 285e7d2fa..7c3e79112 100644 --- a/powertools-test-suite/src/test/java/software/amazon/lambda/powertools/testsuite/LoggingOrderTest.java +++ b/powertools-test-suite/src/test/java/software/amazon/lambda/powertools/testsuite/LoggingOrderTest.java @@ -16,8 +16,6 @@ import com.amazonaws.services.lambda.runtime.Context; import com.amazonaws.services.lambda.runtime.events.SQSEvent; import com.amazonaws.services.lambda.runtime.events.models.s3.S3EventNotification; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.S3Object; import com.amazonaws.xray.AWSXRay; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; @@ -27,9 +25,14 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.http.AbortableInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.lambda.powertools.core.internal.LambdaHandlerProcessor; import software.amazon.lambda.powertools.logging.internal.LambdaLoggingAspect; -import software.amazon.lambda.powertools.sqs.internal.SqsLargeMessageAspect; +import software.amazon.lambda.powertools.sqs.SqsUtils; import software.amazon.lambda.powertools.testsuite.handler.LoggingOrderMessageHandler; import software.amazon.lambda.powertools.testsuite.handler.TracingLoggingStreamMessageHandler; @@ -38,6 +41,7 @@ import static org.apache.commons.lang3.reflect.FieldUtils.writeStaticField; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.openMocks; @@ -50,12 +54,12 @@ public class LoggingOrderTest { private Context context; @Mock - private AmazonS3 amazonS3; + private S3Client s3Client; @BeforeEach void setUp() throws IllegalAccessException, IOException, NoSuchMethodException, InvocationTargetException { openMocks(this); - writeStaticField(SqsLargeMessageAspect.class, "amazonS3", amazonS3, true); + SqsUtils.overrideS3Client(s3Client); ThreadContext.clearAll(); writeStaticField(LambdaHandlerProcessor.class, "IS_COLD_START", null, true); setupContext(); @@ -76,10 +80,9 @@ void tearDown() { */ @Test public void testThatLoggingAnnotationActsLast() throws IOException { - S3Object s3Response = new S3Object(); - s3Response.setObjectContent(new ByteArrayInputStream("A big message".getBytes())); + ResponseInputStream s3Response = new ResponseInputStream<>(GetObjectResponse.builder().build(), AbortableInputStream.create(new ByteArrayInputStream("A big message".getBytes()))); - when(amazonS3.getObject(BUCKET_NAME, BUCKET_KEY)).thenReturn(s3Response); + when(s3Client.getObject(any(GetObjectRequest.class))).thenReturn(s3Response); SQSEvent sqsEvent = messageWithBody("[\"software.amazon.payloadoffloading.PayloadS3Pointer\",{\"s3BucketName\":\"" + BUCKET_NAME + "\",\"s3Key\":\"" + BUCKET_KEY + "\"}]"); LoggingOrderMessageHandler requestHandler = new LoggingOrderMessageHandler(); diff --git a/spotbugs-exclude.xml b/spotbugs-exclude.xml index f42f72f84..e695069a7 100644 --- a/spotbugs-exclude.xml +++ b/spotbugs-exclude.xml @@ -59,6 +59,10 @@ + + + + @@ -76,6 +80,10 @@ + + + +