diff --git a/aws-xray-recorder-sdk-aws-sdk-v2/build.gradle.kts b/aws-xray-recorder-sdk-aws-sdk-v2/build.gradle.kts index 440d091c..e458a70a 100644 --- a/aws-xray-recorder-sdk-aws-sdk-v2/build.gradle.kts +++ b/aws-xray-recorder-sdk-aws-sdk-v2/build.gradle.kts @@ -13,6 +13,8 @@ dependencies { testImplementation("org.skyscreamer:jsonassert:1.3.0") testImplementation("software.amazon.awssdk:dynamodb:2.15.20") testImplementation("software.amazon.awssdk:lambda:2.15.20") + testImplementation("software.amazon.awssdk:sqs:2.15.20") + testImplementation("software.amazon.awssdk:sns:2.15.20") } tasks.jar { diff --git a/aws-xray-recorder-sdk-aws-sdk-v2/src/main/resources/com/amazonaws/xray/interceptors/DefaultOperationParameterWhitelist.json b/aws-xray-recorder-sdk-aws-sdk-v2/src/main/resources/com/amazonaws/xray/interceptors/DefaultOperationParameterWhitelist.json index 7cc6bda9..fc6d9196 100644 --- a/aws-xray-recorder-sdk-aws-sdk-v2/src/main/resources/com/amazonaws/xray/interceptors/DefaultOperationParameterWhitelist.json +++ b/aws-xray-recorder-sdk-aws-sdk-v2/src/main/resources/com/amazonaws/xray/interceptors/DefaultOperationParameterWhitelist.json @@ -1,6 +1,6 @@ { "services": { - "SNS": { + "Sns": { "operations": { "Publish": { "request_parameters": [ @@ -152,7 +152,7 @@ } } }, - "SQS": { + "Sqs": { "operations": { "AddPermission": { "request_parameters": [ diff --git a/aws-xray-recorder-sdk-aws-sdk-v2/src/test/java/com/amazonaws/xray/interceptors/TracingInterceptorTest.java b/aws-xray-recorder-sdk-aws-sdk-v2/src/test/java/com/amazonaws/xray/interceptors/TracingInterceptorTest.java index d26a634a..c51760ff 100644 --- a/aws-xray-recorder-sdk-aws-sdk-v2/src/test/java/com/amazonaws/xray/interceptors/TracingInterceptorTest.java +++ b/aws-xray-recorder-sdk-aws-sdk-v2/src/test/java/com/amazonaws/xray/interceptors/TracingInterceptorTest.java @@ -62,6 +62,10 @@ import software.amazon.awssdk.services.lambda.LambdaAsyncClient; import software.amazon.awssdk.services.lambda.LambdaClient; import software.amazon.awssdk.services.lambda.model.InvokeRequest; +import software.amazon.awssdk.services.sns.SnsClient; +import software.amazon.awssdk.services.sns.model.PublishRequest; +import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.model.SendMessageRequest; @FixMethodOrder(MethodSorters.JVM) @RunWith(MockitoJUnitRunner.class) @@ -86,50 +90,6 @@ public void teardown() { AWSXRay.endSegment(); } - private SdkHttpClient mockSdkHttpClient(SdkHttpResponse response) throws Exception { - return mockSdkHttpClient(response, "OK"); - } - - private SdkHttpClient mockSdkHttpClient(SdkHttpResponse response, String body) throws Exception { - ExecutableHttpRequest abortableCallable = Mockito.mock(ExecutableHttpRequest.class); - SdkHttpClient mockClient = Mockito.mock(SdkHttpClient.class); - - when(mockClient.prepareRequest(Mockito.any())).thenReturn(abortableCallable); - when(abortableCallable.call()).thenReturn(HttpExecuteResponse.builder() - .response(response) - .responseBody(AbortableInputStream.create( - new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8)) - )) - .build() - ); - return mockClient; - } - - private SdkAsyncHttpClient mockSdkAsyncHttpClient(SdkHttpResponse response) { - SdkAsyncHttpClient mockClient = Mockito.mock(SdkAsyncHttpClient.class); - when(mockClient.execute(Mockito.any(AsyncExecuteRequest.class))) - .thenAnswer((Answer>) invocationOnMock -> { - AsyncExecuteRequest request = invocationOnMock.getArgument(0); - SdkAsyncHttpResponseHandler handler = request.responseHandler(); - handler.onHeaders(response); - handler.onStream(new EmptyPublisher<>()); - - return CompletableFuture.completedFuture(null); - }); - - return mockClient; - } - - private SdkHttpResponse generateLambdaInvokeResponse(int statusCode) { - return SdkHttpResponse.builder() - .statusCode(statusCode) - .putHeader("x-amz-request-id", "1111-2222-3333-4444") - .putHeader("x-amz-id-2", "extended") - .putHeader("Content-Length", "2") - .putHeader("X-Amz-Function-Error", "Failure") - .build(); - } - @Test public void testResponseDescriptors() throws Exception { String responseBody = "{\"LastEvaluatedTableName\":\"baz\",\"TableNames\":[\"foo\",\"bar\",\"baz\"]}"; @@ -140,19 +100,7 @@ public void testResponseDescriptors() throws Exception { .putHeader("Content-Type", "application/x-amz-json-1.0") .build(); SdkHttpClient mockClient = mockSdkHttpClient(mockResponse, responseBody); - - DynamoDbClient client = DynamoDbClient.builder() - .httpClient(mockClient) - .endpointOverride(URI.create("http://example.com")) - .region(Region.of("us-west-42")) - .credentialsProvider(StaticCredentialsProvider.create( - AwsSessionCredentials.create("key", "secret", "session") - )) - .overrideConfiguration(ClientOverrideConfiguration.builder() - .addExecutionInterceptor(new TracingInterceptor()) - .build() - ) - .build(); + DynamoDbClient client = dynamoDbClient(mockClient); Segment segment = AWSXRay.getCurrentSegment(); client.listTables(ListTablesRequest.builder() @@ -177,22 +125,73 @@ public void testResponseDescriptors() throws Exception { Assert.assertEquals(false, subsegment.isInProgress()); } + @Test + public void testSqsSendMessageSubsegmentContainsQueueUrl() throws Exception { + SdkHttpClient mockClient = mockClientWithSuccessResponse( + "" + + "" + + "b10a8db164e0754105b7a99be72e3fe5" + + "abc-def-ghi" + + "" + + "123-456-789" + + "" + ); + SqsClient client = sqsClient(mockClient); + + Segment segment = AWSXRay.getCurrentSegment(); + client.sendMessage(SendMessageRequest.builder() + .queueUrl("http://queueurl.com") + .messageBody("Hello World") + .build() + ); + + Assert.assertEquals(1, segment.getSubsegments().size()); + Subsegment subsegment = segment.getSubsegments().get(0); + Map awsStats = subsegment.getAws(); + + Assert.assertEquals("SendMessage", awsStats.get("operation")); + Assert.assertEquals("http://queueurl.com", awsStats.get("queue_url")); + Assert.assertEquals("abc-def-ghi", awsStats.get("message_id")); + Assert.assertEquals("123-456-789", awsStats.get("request_id")); + Assert.assertEquals("us-west-42", awsStats.get("region")); + Assert.assertEquals(0, awsStats.get("retries")); + Assert.assertEquals(false, subsegment.isInProgress()); + } + + @Test + public void testSnsPublishSubsegmentContainsTopicArn() throws Exception { + SdkHttpClient mockClient = mockClientWithSuccessResponse( + "" + + "abc-def-ghi" + + "123-456-789" + + "" + ); + SnsClient client = snsClient(mockClient); + + Segment segment = AWSXRay.getCurrentSegment(); + client.publish(PublishRequest.builder() + .topicArn("arn:aws:sns:us-west-42:123456789012:MyTopic") + .message("Hello World") + .build() + ); + + Assert.assertEquals(1, segment.getSubsegments().size()); + Subsegment subsegment = segment.getSubsegments().get(0); + Map awsStats = subsegment.getAws(); + + Assert.assertEquals("Publish", awsStats.get("operation")); + Assert.assertEquals("arn:aws:sns:us-west-42:123456789012:MyTopic", awsStats.get("topic_arn")); + Assert.assertEquals("us-west-42", awsStats.get("region")); + Assert.assertEquals("123-456-789", awsStats.get("request_id")); + Assert.assertEquals(0, awsStats.get("retries")); + Assert.assertEquals(false, subsegment.isInProgress()); + } + @Test public void testLambdaInvokeSubsegmentContainsFunctionName() throws Exception { SdkHttpClient mockClient = mockSdkHttpClient(generateLambdaInvokeResponse(200)); - LambdaClient client = LambdaClient.builder() - .httpClient(mockClient) - .endpointOverride(URI.create("http://example.com")) - .region(Region.of("us-west-42")) - .credentialsProvider(StaticCredentialsProvider.create( - AwsSessionCredentials.create("key", "secret", "session") - )) - .overrideConfiguration(ClientOverrideConfiguration.builder() - .addExecutionInterceptor(new TracingInterceptor()) - .build() - ) - .build(); + LambdaClient client = lambdaClient(mockClient); Segment segment = AWSXRay.getCurrentSegment(); @@ -223,22 +222,9 @@ public void testLambdaInvokeSubsegmentContainsFunctionName() throws Exception { @Test public void testAsyncLambdaInvokeSubsegmentContainsFunctionName() { SdkAsyncHttpClient mockClient = mockSdkAsyncHttpClient(generateLambdaInvokeResponse(200)); - - LambdaAsyncClient client = LambdaAsyncClient.builder() - .httpClient(mockClient) - .endpointOverride(URI.create("http://example.com")) - .region(Region.of("us-west-42")) - .credentialsProvider(StaticCredentialsProvider.create( - AwsSessionCredentials.create("key", "secret", "session") - )) - .overrideConfiguration(ClientOverrideConfiguration.builder() - .addExecutionInterceptor(new TracingInterceptor()) - .build() - ) - .build(); + LambdaAsyncClient client = lambdaAsyncClient(mockClient); Segment segment = AWSXRay.getCurrentSegment(); - client.invoke(InvokeRequest.builder() .functionName("testFunctionName") .build() @@ -265,23 +251,9 @@ public void testAsyncLambdaInvokeSubsegmentContainsFunctionName() { @Test public void test400Exception() throws Exception { SdkHttpClient mockClient = mockSdkHttpClient(generateLambdaInvokeResponse(400)); - - LambdaClient client = LambdaClient.builder() - .httpClient(mockClient) - .endpointOverride(URI.create("http://example.com")) - .region(Region.of("us-west-42")) - .credentialsProvider(StaticCredentialsProvider.create( - AwsSessionCredentials.create("key", "secret", "session") - )) - .overrideConfiguration(ClientOverrideConfiguration.builder() - .addExecutionInterceptor(new TracingInterceptor()) - .build() - ) - .build(); - + LambdaClient client = lambdaClient(mockClient); Segment segment = AWSXRay.getCurrentSegment(); - try { client.invoke(InvokeRequest.builder() .functionName("testFunctionName") @@ -317,22 +289,9 @@ public void test400Exception() throws Exception { @Test public void testAsync400Exception() { SdkAsyncHttpClient mockClient = mockSdkAsyncHttpClient(generateLambdaInvokeResponse(400)); - - LambdaAsyncClient client = LambdaAsyncClient.builder() - .httpClient(mockClient) - .endpointOverride(URI.create("http://example.com")) - .region(Region.of("us-west-42")) - .credentialsProvider(StaticCredentialsProvider.create( - AwsSessionCredentials.create("key", "secret", "session") - )) - .overrideConfiguration(ClientOverrideConfiguration.builder() - .addExecutionInterceptor(new TracingInterceptor()) - .build() - ) - .build(); + LambdaAsyncClient client = lambdaAsyncClient(mockClient); Segment segment = AWSXRay.getCurrentSegment(); - try { client.invoke(InvokeRequest.builder() .functionName("testFunctionName") @@ -368,23 +327,9 @@ public void testAsync400Exception() { @Test public void testThrottledException() throws Exception { SdkHttpClient mockClient = mockSdkHttpClient(generateLambdaInvokeResponse(429)); - - LambdaClient client = LambdaClient.builder() - .httpClient(mockClient) - .endpointOverride(URI.create("http://example.com")) - .region(Region.of("us-west-42")) - .credentialsProvider(StaticCredentialsProvider.create( - AwsSessionCredentials.create("key", "secret", "session") - )) - .overrideConfiguration(ClientOverrideConfiguration.builder() - .addExecutionInterceptor(new TracingInterceptor()) - .build() - ) - .build(); - + LambdaClient client = lambdaClient(mockClient); Segment segment = AWSXRay.getCurrentSegment(); - try { client.invoke(InvokeRequest.builder() .functionName("testFunctionName") @@ -418,22 +363,9 @@ public void testThrottledException() throws Exception { @Test public void testAsyncThrottledException() { SdkAsyncHttpClient mockClient = mockSdkAsyncHttpClient(generateLambdaInvokeResponse(429)); - - LambdaAsyncClient client = LambdaAsyncClient.builder() - .httpClient(mockClient) - .endpointOverride(URI.create("http://example.com")) - .region(Region.of("us-west-42")) - .credentialsProvider(StaticCredentialsProvider.create( - AwsSessionCredentials.create("key", "secret", "session") - )) - .overrideConfiguration(ClientOverrideConfiguration.builder() - .addExecutionInterceptor(new TracingInterceptor()) - .build() - ) - .build(); + LambdaAsyncClient client = lambdaAsyncClient(mockClient); Segment segment = AWSXRay.getCurrentSegment(); - try { client.invoke(InvokeRequest.builder() .functionName("testFunctionName") @@ -467,23 +399,9 @@ public void testAsyncThrottledException() { @Test public void test500Exception() throws Exception { SdkHttpClient mockClient = mockSdkHttpClient(generateLambdaInvokeResponse(500)); - - LambdaClient client = LambdaClient.builder() - .httpClient(mockClient) - .endpointOverride(URI.create("http://example.com")) - .region(Region.of("us-west-42")) - .credentialsProvider(StaticCredentialsProvider.create( - AwsSessionCredentials.create("key", "secret", "session") - )) - .overrideConfiguration(ClientOverrideConfiguration.builder() - .addExecutionInterceptor(new TracingInterceptor()) - .build() - ) - .build(); - + LambdaClient client = lambdaClient(mockClient); Segment segment = AWSXRay.getCurrentSegment(); - try { client.invoke(InvokeRequest.builder() .functionName("testFunctionName") @@ -517,22 +435,9 @@ public void test500Exception() throws Exception { @Test public void testAsync500Exception() { SdkAsyncHttpClient mockClient = mockSdkAsyncHttpClient(generateLambdaInvokeResponse(500)); - - LambdaAsyncClient client = LambdaAsyncClient.builder() - .httpClient(mockClient) - .endpointOverride(URI.create("http://example.com")) - .region(Region.of("us-west-42")) - .credentialsProvider(StaticCredentialsProvider.create( - AwsSessionCredentials.create("key", "secret", "session") - )) - .overrideConfiguration(ClientOverrideConfiguration.builder() - .addExecutionInterceptor(new TracingInterceptor()) - .build() - ) - .build(); + LambdaAsyncClient client = lambdaAsyncClient(mockClient); Segment segment = AWSXRay.getCurrentSegment(); - try { client.invoke(InvokeRequest.builder() .functionName("testFunctionName") @@ -579,5 +484,131 @@ public void testNoHeaderAddedWhenPropagationOff() { verify(mockRequest.toBuilder(), never()).appendHeader(anyString(), anyString()); } + + private SdkHttpClient mockSdkHttpClient(SdkHttpResponse response) throws Exception { + return mockSdkHttpClient(response, "OK"); + } + + private SdkHttpClient mockSdkHttpClient(SdkHttpResponse response, String body) throws Exception { + ExecutableHttpRequest abortableCallable = Mockito.mock(ExecutableHttpRequest.class); + SdkHttpClient mockClient = Mockito.mock(SdkHttpClient.class); + + when(mockClient.prepareRequest(Mockito.any())).thenReturn(abortableCallable); + when(abortableCallable.call()).thenReturn(HttpExecuteResponse.builder() + .response(response) + .responseBody(AbortableInputStream.create( + new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8)) + )) + .build() + ); + return mockClient; + } + + private SdkAsyncHttpClient mockSdkAsyncHttpClient(SdkHttpResponse response) { + SdkAsyncHttpClient mockClient = Mockito.mock(SdkAsyncHttpClient.class); + when(mockClient.execute(Mockito.any(AsyncExecuteRequest.class))) + .thenAnswer((Answer>) invocationOnMock -> { + AsyncExecuteRequest request = invocationOnMock.getArgument(0); + SdkAsyncHttpResponseHandler handler = request.responseHandler(); + handler.onHeaders(response); + handler.onStream(new EmptyPublisher<>()); + + return CompletableFuture.completedFuture(null); + }); + + return mockClient; + } + + private SdkHttpResponse generateLambdaInvokeResponse(int statusCode) { + return SdkHttpResponse.builder() + .statusCode(statusCode) + .putHeader("x-amz-request-id", "1111-2222-3333-4444") + .putHeader("x-amz-id-2", "extended") + .putHeader("Content-Length", "2") + .putHeader("X-Amz-Function-Error", "Failure") + .build(); + } + + private SdkHttpClient mockClientWithSuccessResponse(String responseBody) throws Exception { + SdkHttpResponse mockResponse = SdkHttpResponse.builder() + .statusCode(200) + .build(); + return mockSdkHttpClient(mockResponse, responseBody); + } + + private static LambdaClient lambdaClient(SdkHttpClient mockClient) { + return LambdaClient.builder() + .httpClient(mockClient) + .endpointOverride(URI.create("http://example.com")) + .region(Region.of("us-west-42")) + .credentialsProvider(StaticCredentialsProvider.create( + AwsSessionCredentials.create("key", "secret", "session") + )) + .overrideConfiguration(ClientOverrideConfiguration.builder() + .addExecutionInterceptor(new TracingInterceptor()) + .build() + ) + .build(); + } + + private static LambdaAsyncClient lambdaAsyncClient(SdkAsyncHttpClient mockClient) { + return LambdaAsyncClient.builder() + .httpClient(mockClient) + .endpointOverride(URI.create("http://example.com")) + .region(Region.of("us-west-42")) + .credentialsProvider(StaticCredentialsProvider.create( + AwsSessionCredentials.create("key", "secret", "session") + )) + .overrideConfiguration(ClientOverrideConfiguration.builder() + .addExecutionInterceptor(new TracingInterceptor()) + .build() + ) + .build(); + } + + private static DynamoDbClient dynamoDbClient(SdkHttpClient mockClient) { + return DynamoDbClient.builder() + .httpClient(mockClient) + .endpointOverride(URI.create("http://example.com")) + .region(Region.of("us-west-42")) + .credentialsProvider(StaticCredentialsProvider.create( + AwsSessionCredentials.create("key", "secret", "session") + )) + .overrideConfiguration(ClientOverrideConfiguration.builder() + .addExecutionInterceptor(new TracingInterceptor()) + .build() + ) + .build(); + } + + private static SqsClient sqsClient(SdkHttpClient mockClient) { + return SqsClient.builder() + .httpClient(mockClient) + .endpointOverride(URI.create("http://example.com")) + .region(Region.of("us-west-42")) + .credentialsProvider(StaticCredentialsProvider.create( + AwsSessionCredentials.create("key", "secret", "session") + )) + .overrideConfiguration(ClientOverrideConfiguration.builder() + .addExecutionInterceptor(new TracingInterceptor()) + .build() + ) + .build(); + } + + private static SnsClient snsClient(SdkHttpClient mockClient) { + return SnsClient.builder() + .httpClient(mockClient) + .endpointOverride(URI.create("http://example.com")) + .region(Region.of("us-west-42")) + .credentialsProvider(StaticCredentialsProvider.create( + AwsSessionCredentials.create("key", "secret", "session") + )) + .overrideConfiguration(ClientOverrideConfiguration.builder() + .addExecutionInterceptor(new TracingInterceptor()) + .build() + ) + .build(); + } }