diff --git a/contrib/aws/aws-sdk-go-v2/aws/aws.go b/contrib/aws/aws-sdk-go-v2/aws/aws.go index 7a1d0b4e6a..f193914cbd 100644 --- a/contrib/aws/aws-sdk-go-v2/aws/aws.go +++ b/contrib/aws/aws-sdk-go-v2/aws/aws.go @@ -229,21 +229,23 @@ func tableName(requestInput middleware.InitializeInput) string { func streamName(requestInput middleware.InitializeInput) string { switch params := requestInput.Parameters.(type) { case *kinesis.PutRecordInput: - return *params.StreamName + return coalesceNameOrArnResource(params.StreamName, params.StreamARN) case *kinesis.PutRecordsInput: - return *params.StreamName + return coalesceNameOrArnResource(params.StreamName, params.StreamARN) case *kinesis.AddTagsToStreamInput: - return *params.StreamName + return coalesceNameOrArnResource(params.StreamName, params.StreamARN) case *kinesis.RemoveTagsFromStreamInput: - return *params.StreamName + return coalesceNameOrArnResource(params.StreamName, params.StreamARN) case *kinesis.CreateStreamInput: - return *params.StreamName + if params.StreamName != nil { + return *params.StreamName + } case *kinesis.DeleteStreamInput: - return *params.StreamName + return coalesceNameOrArnResource(params.StreamName, params.StreamARN) case *kinesis.DescribeStreamInput: - return *params.StreamName + return coalesceNameOrArnResource(params.StreamName, params.StreamARN) case *kinesis.DescribeStreamSummaryInput: - return *params.StreamName + return coalesceNameOrArnResource(params.StreamName, params.StreamARN) case *kinesis.GetRecordsInput: if params.StreamARN != nil { streamArnValue := *params.StreamARN @@ -353,3 +355,16 @@ func serviceName(cfg *config, awsService string) string { defaultName := fmt.Sprintf("aws.%s", awsService) return namingschema.ServiceNameOverrideV0(defaultName, defaultName) } + +func coalesceNameOrArnResource(name *string, arnVal *string) string { + if name != nil { + return *name + } + + if arnVal != nil { + parts := strings.Split(*arnVal, "/") + return parts[len(parts)-1] + } + + return "" +} diff --git a/contrib/aws/aws-sdk-go-v2/aws/aws_test.go b/contrib/aws/aws-sdk-go-v2/aws/aws_test.go index 3c45d1a1a7..88fba42c5d 100644 --- a/contrib/aws/aws-sdk-go-v2/aws/aws_test.go +++ b/contrib/aws/aws-sdk-go-v2/aws/aws_test.go @@ -30,6 +30,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sns" "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/aws/aws-sdk-go-v2/service/sqs/types" + "github.com/aws/smithy-go/middleware" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -1077,3 +1078,64 @@ func TestWithErrorCheck(t *testing.T) { }) } } + +func TestStreamName(t *testing.T) { + dummyName := `my-stream` + dummyArn := `arn:aws:kinesis:us-east-1:111111111111:stream/` + dummyName + + tests := []struct { + name string + input any + expected string + }{ + { + name: "PutRecords with ARN", + input: &kinesis.PutRecordsInput{StreamARN: &dummyArn}, + expected: dummyName, + }, + { + name: "PutRecords with Name", + input: &kinesis.PutRecordsInput{StreamName: &dummyName}, + expected: dummyName, + }, + { + name: "PutRecords with both", + input: &kinesis.PutRecordsInput{StreamName: &dummyName, StreamARN: &dummyArn}, + expected: dummyName, + }, + { + name: "PutRecord with Name", + input: &kinesis.PutRecordInput{StreamName: &dummyName}, + expected: dummyName, + }, + { + name: "CreateStream", + input: &kinesis.CreateStreamInput{StreamName: &dummyName}, + expected: dummyName, + }, + { + name: "CreateStream with nothing", + input: &kinesis.CreateStreamInput{}, + expected: "", + }, + { + name: "GetRecords", + input: &kinesis.GetRecordsInput{StreamARN: &dummyArn}, + expected: dummyName, + }, + { + name: "GetRecords with nothing", + input: &kinesis.GetRecordsInput{}, + expected: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := middleware.InitializeInput{ + Parameters: tt.input, + } + val := streamName(req) + assert.Equal(t, tt.expected, val) + }) + } +}