diff --git a/xrayaws-v2/client.go b/xrayaws-v2/client.go index bd752ab..3856409 100644 --- a/xrayaws-v2/client.go +++ b/xrayaws-v2/client.go @@ -3,6 +3,7 @@ package xrayaws import ( "context" "reflect" + "strings" "sync" "time" @@ -77,7 +78,8 @@ func (m xrayMiddleware) HandleInitialize( out middleware.InitializeOutput, metadata middleware.Metadata, err error, ) { segs := &subsegments{ - ctx: ctx, + name: m.o.name, + ctx: ctx, } ctx = context.WithValue(ctx, segmentsContextKey, segs) segs.initializeTime = time.Now() @@ -127,23 +129,59 @@ func (endMarshalMiddleware) ID() string { return "XRayEndMarshalMiddleware" } -func (endMarshalMiddleware) HandleBuild( - ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler, -) ( - out middleware.BuildOutput, metadata middleware.Metadata, err error, +func (m endMarshalMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, ) { if segs := contextSubsegments(ctx); segs != nil { segs.mu.Lock() - name := awsmiddle.GetSigningName(ctx) - segs.name = name - segs.awsCtx, segs.awsSeg = xray.BeginSubsegmentAt(ctx, segs.initializeTime, name) + if segs.name == "" { + segs.name = getServiceName(ctx, in) + } + segs.awsCtx, segs.awsSeg = xray.BeginSubsegmentAt(ctx, segs.initializeTime, segs.name) _, marshalSeg := xray.BeginSubsegmentAt(segs.awsCtx, segs.marshalTime, "marshal") marshalSeg.Close() segs.mu.Unlock() } - return next.HandleBuild(ctx, in) + return next.HandleFinalize(ctx, in) +} + +func getServiceName(ctx context.Context, in middleware.FinalizeInput) string { + if name := awsmiddle.GetSigningName(ctx); name != "" { + return name + } + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return awsmiddle.GetServiceID(ctx) + } + + const prefix = "AWS4-HMAC-SHA256 " + auth := req.Header.Get("Authorization") + if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) { + return awsmiddle.GetServiceID(ctx) + } + auth = auth[len(prefix):] + parts := strings.Split(auth, ",") + for _, part := range parts { + key, value, ok := strings.Cut(part, "=") + if !ok { + continue + } + key = strings.TrimSpace(key) + value = strings.TrimSpace(value) + if key != "Credential" { + continue + } + + cred := strings.Split(value, "/") + if len(cred) < 4 { + return awsmiddle.GetServiceID(ctx) + } + return cred[3] + } + + return awsmiddle.GetServiceID(ctx) } type beginAttemptMiddleware struct{} @@ -251,13 +289,21 @@ var segmentsContextKey = &contextKey{"segments"} // WithXRay is the X-Ray tracing option. func WithXRay() config.LoadOptionsFunc { - return WithWhitelist(defaultWhitelist) + return WithServiceName("", defaultWhitelist) } // WithWhitelist returns a X-Ray tracing option with custom whitelist. func WithWhitelist(whitelist *whitelist.Whitelist) config.LoadOptionsFunc { + return WithServiceName("", whitelist) +} + +// WithServiceName returns a X-Ray tracing option with custom service name. +func WithServiceName(name string, whitelist *whitelist.Whitelist) config.LoadOptionsFunc { return func(o *config.LoadOptions) error { - newOption := option{whitelist: whitelist} + newOption := option{ + name: name, + whitelist: whitelist, + } o.APIOptions = append( o.APIOptions, newOption.addMiddleware, @@ -267,19 +313,23 @@ func WithWhitelist(whitelist *whitelist.Whitelist) config.LoadOptionsFunc { } type option struct { + name string whitelist *whitelist.Whitelist } func (o *option) addMiddleware(stack *middleware.Stack) error { stack.Initialize.Add(xrayMiddleware{o: o}, middleware.After) stack.Serialize.Add(beginMarshalMiddleware{}, middleware.Before) - stack.Build.Add(endMarshalMiddleware{}, middleware.After) + stack.Finalize.Insert(endMarshalMiddleware{}, "Signing", middleware.After) stack.Deserialize.Add(beginAttemptMiddleware{}, middleware.Before) stack.Deserialize.Insert(endAttemptMiddleware{}, "OperationDeserializer", middleware.After) return nil } func (o *option) insertParameter(aws schema.AWS, serviceName, operationName string, params, result any) { + if o.whitelist == nil { + return + } service, ok := o.whitelist.Services[serviceName] if !ok { return diff --git a/xrayaws-v2/client_test.go b/xrayaws-v2/client_test.go index a8e6575..1746b32 100644 --- a/xrayaws-v2/client_test.go +++ b/xrayaws-v2/client_test.go @@ -1,6 +1,7 @@ package xrayaws import ( + "context" "errors" "fmt" "io" @@ -13,10 +14,14 @@ import ( "unicode" "github.com/aws/aws-sdk-go-v2/aws" + awsmiddle "github.com/aws/aws-sdk-go-v2/aws/middleware" awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/lambda" + smithyendpoints "github.com/aws/smithy-go/endpoints" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" "github.com/google/go-cmp/cmp" "github.com/shogo82148/aws-xray-yasdk-go/xray" "github.com/shogo82148/aws-xray-yasdk-go/xray/schema" @@ -107,6 +112,7 @@ func TestClient(t *testing.T) { SigningName: "lambda", }, nil }), + HTTPClient: opt.HTTPClient, Retryer: func() aws.Retryer { return aws.NopRetryer{} }, @@ -212,6 +218,284 @@ func TestClient(t *testing.T) { } } +func TestClient_CustomServiceName(t *testing.T) { + // setup dummy X-Ray daemon + ctx, td := xray.NewTestDaemon(nil) + defer td.Close() + + // setup dummy aws service + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if _, err := io.WriteString(w, "{}"); err != nil { + panic(err) + } + })) + defer ts.Close() + + u, err := url.Parse(ts.URL) + if err != nil { + t.Fatal(err) + } + + var opt config.LoadOptions + WithServiceName("lunar-lambda", nil)(&opt) + cfg := aws.Config{ + Region: "fake-moon-1", + Retryer: func() aws.Retryer { + return aws.NopRetryer{} + }, + HTTPClient: opt.HTTPClient, + APIOptions: opt.APIOptions, + Credentials: credentials.NewStaticCredentialsProvider("AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", ""), + } + + r := lambdaEndpointResolver(func(ctx context.Context, params lambda.EndpointParameters) (smithyendpoints.Endpoint, error) { + return smithyendpoints.Endpoint{ + URI: *u, + }, nil + }) + + // start testing + svc := lambda.NewFromConfig(cfg, lambda.WithEndpointResolverV2(r)) + ctx, root := xray.BeginSegment(ctx, "Test") + _, err = svc.ListFunctions(ctx, &lambda.ListFunctionsInput{}) + root.Close() + if err != nil { + t.Fatal(err) + } + + // check the segment + got, err := td.Recv() + if err != nil { + t.Fatal(err) + } + want := &schema.Segment{ + Name: "Test", + ID: "xxxxxxxxxxxxxxxx", + TraceID: "x-xxxxxxxx-xxxxxxxxxxxxxxxxxxxxxxxx", + StartTime: timeFilled, + EndTime: timeFilled, + Subsegments: []*schema.Segment{ + { + Name: "lunar-lambda", + ID: "xxxxxxxxxxxxxxxx", + Namespace: "aws", + StartTime: timeFilled, + EndTime: timeFilled, + Subsegments: []*schema.Segment{ + { + Name: "marshal", + ID: "xxxxxxxxxxxxxxxx", + StartTime: timeFilled, + EndTime: timeFilled, + }, + { + Name: "attempt", + ID: "xxxxxxxxxxxxxxxx", + StartTime: timeFilled, + EndTime: timeFilled, + Subsegments: []*schema.Segment{ + { + Name: "connect", + ID: "xxxxxxxxxxxxxxxx", + StartTime: timeFilled, + EndTime: timeFilled, + Subsegments: []*schema.Segment{ + { + Name: "dial", + ID: "xxxxxxxxxxxxxxxx", + StartTime: timeFilled, + EndTime: timeFilled, + Metadata: map[string]any{ + "http": map[string]any{ + "dial": map[string]any{ + "network": "tcp", + "address": u.Host, + }, + }, + }, + }, + }, + }, + { + Name: "request", + ID: "xxxxxxxxxxxxxxxx", + StartTime: timeFilled, + EndTime: timeFilled, + }, + }, + }, + { + Name: "unmarshal", + ID: "xxxxxxxxxxxxxxxx", + StartTime: timeFilled, + EndTime: timeFilled, + }, + }, + HTTP: &schema.HTTP{ + Response: &schema.HTTPResponse{ + Status: 200, + ContentLength: 2, + }, + }, + AWS: schema.AWS{ + "operation": "ListFunctions", + "region": "fake-moon-1", + "request_id": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + // "retries": 0.0, + }, + }, + }, + Service: xray.ServiceData, + } + if diff := cmp.Diff(want, got, ignoreVariableField); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) + } +} + +type lambdaEndpointResolver func(ctx context.Context, params lambda.EndpointParameters) (smithyendpoints.Endpoint, error) + +func (r lambdaEndpointResolver) ResolveEndpoint(ctx context.Context, params lambda.EndpointParameters) (smithyendpoints.Endpoint, error) { + return r(ctx, params) +} + +func TestClient_ResolveEndpointV2(t *testing.T) { + // setup dummy X-Ray daemon + ctx, td := xray.NewTestDaemon(nil) + defer td.Close() + + // setup dummy aws service + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if _, err := io.WriteString(w, "{}"); err != nil { + panic(err) + } + })) + defer ts.Close() + + u, err := url.Parse(ts.URL) + if err != nil { + t.Fatal(err) + } + + var opt config.LoadOptions + WithXRay()(&opt) + cfg := aws.Config{ + Region: "fake-moon-1", + Retryer: func() aws.Retryer { + return aws.NopRetryer{} + }, + HTTPClient: opt.HTTPClient, + APIOptions: opt.APIOptions, + Credentials: credentials.NewStaticCredentialsProvider("AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", ""), + } + + r := lambdaEndpointResolver(func(ctx context.Context, params lambda.EndpointParameters) (smithyendpoints.Endpoint, error) { + return smithyendpoints.Endpoint{ + URI: *u, + }, nil + }) + + // start testing + svc := lambda.NewFromConfig(cfg, lambda.WithEndpointResolverV2(r)) + ctx, root := xray.BeginSegment(ctx, "Test") + _, err = svc.ListFunctions(ctx, &lambda.ListFunctionsInput{}) + root.Close() + if err != nil { + t.Fatal(err) + } + + // check the segment + got, err := td.Recv() + if err != nil { + t.Fatal(err) + } + want := &schema.Segment{ + Name: "Test", + ID: "xxxxxxxxxxxxxxxx", + TraceID: "x-xxxxxxxx-xxxxxxxxxxxxxxxxxxxxxxxx", + StartTime: timeFilled, + EndTime: timeFilled, + Subsegments: []*schema.Segment{ + { + Name: "lambda", + ID: "xxxxxxxxxxxxxxxx", + Namespace: "aws", + StartTime: timeFilled, + EndTime: timeFilled, + Subsegments: []*schema.Segment{ + { + Name: "marshal", + ID: "xxxxxxxxxxxxxxxx", + StartTime: timeFilled, + EndTime: timeFilled, + }, + { + Name: "attempt", + ID: "xxxxxxxxxxxxxxxx", + StartTime: timeFilled, + EndTime: timeFilled, + Subsegments: []*schema.Segment{ + { + Name: "connect", + ID: "xxxxxxxxxxxxxxxx", + StartTime: timeFilled, + EndTime: timeFilled, + Subsegments: []*schema.Segment{ + { + Name: "dial", + ID: "xxxxxxxxxxxxxxxx", + StartTime: timeFilled, + EndTime: timeFilled, + Metadata: map[string]any{ + "http": map[string]any{ + "dial": map[string]any{ + "network": "tcp", + "address": u.Host, + }, + }, + }, + }, + }, + }, + { + Name: "request", + ID: "xxxxxxxxxxxxxxxx", + StartTime: timeFilled, + EndTime: timeFilled, + }, + }, + }, + { + Name: "unmarshal", + ID: "xxxxxxxxxxxxxxxx", + StartTime: timeFilled, + EndTime: timeFilled, + }, + }, + HTTP: &schema.HTTP{ + Response: &schema.HTTPResponse{ + Status: 200, + ContentLength: 2, + }, + }, + AWS: schema.AWS{ + "operation": "ListFunctions", + "region": "fake-moon-1", + "request_id": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + // "retries": 0.0, + }, + }, + }, + Service: xray.ServiceData, + } + if diff := cmp.Diff(want, got, ignoreVariableField); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) + } +} + func TestClient_FailDial(t *testing.T) { // setup dummy X-Ray daemon ctx, td := xray.NewTestDaemon(nil) @@ -586,3 +870,95 @@ func TestInsertDescriptor_value(t *testing.T) { t.Errorf("want bar, got %s", got) } } + +func TestGetServiceName(t *testing.T) { + t.Run("legacySigningMethod", func(t *testing.T) { + ctx := awsmiddle.SetSigningName(context.Background(), "lambda") + if got := getServiceName(ctx, middleware.FinalizeInput{}); got != "lambda" { + t.Errorf("want lambda, got %s", got) + } + }) + + t.Run("v4SigningMethod", func(t *testing.T) { + ctx := context.Background() + in := middleware.FinalizeInput{ + Request: &smithyhttp.Request{ + Request: &http.Request{ + Header: http.Header{ + "Authorization": []string{ + "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20231220/fake-moon-1/lambda/aws4_request, SignedHeaders=amz-sdk-invocation-id;host;x-amz-date, Signature=c2edddb0b3072e5f11ebb852a3e99ad6871e89d8ab867b41d23e7d9b6ad7ed71", + }, + }, + }, + }, + } + + if got := getServiceName(ctx, in); got != "lambda" { + t.Errorf("want lambda, got %s", got) + } + }) + + t.Run("BadV4SigningMethod1", func(t *testing.T) { + ctx := awsmiddle.SetServiceID(context.Background(), "Lambda") + in := middleware.FinalizeInput{ + Request: &smithyhttp.Request{ + Request: &http.Request{ + Header: http.Header{ + "Authorization": []string{ + "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20231220/fake-moon-1", + }, + }, + }, + }, + } + + if got := getServiceName(ctx, in); got != "Lambda" { + t.Errorf("want Lambda, got %s", got) + } + }) + + t.Run("BadV4SigningMethod2", func(t *testing.T) { + ctx := awsmiddle.SetServiceID(context.Background(), "Lambda") + in := middleware.FinalizeInput{ + Request: &smithyhttp.Request{ + Request: &http.Request{ + Header: http.Header{ + "Authorization": []string{ + "AWS4-HMAC-SHA256 SignedHeaders=amz-sdk-invocation-id;host;x-amz-date, Signature=c2edddb0b3072e5f11ebb852a3e99ad6871e89d8ab867b41d23e7d9b6ad7ed71", + }, + }, + }, + }, + } + + if got := getServiceName(ctx, in); got != "Lambda" { + t.Errorf("want Lambda, got %s", got) + } + }) + + t.Run("BadV4SigningMethod3", func(t *testing.T) { + ctx := awsmiddle.SetServiceID(context.Background(), "Lambda") + in := middleware.FinalizeInput{ + Request: &smithyhttp.Request{ + Request: &http.Request{ + Header: http.Header{ + "Authorization": []string{ + "Bearer c2edddb0b3072e5f11ebb852a3e99ad6871e89d8ab867b41d23e7d9b6ad7ed71", + }, + }, + }, + }, + } + + if got := getServiceName(ctx, in); got != "Lambda" { + t.Errorf("want Lambda, got %s", got) + } + }) + + t.Run("ServiceID", func(t *testing.T) { + ctx := awsmiddle.SetServiceID(context.Background(), "Lambda") + if got := getServiceName(ctx, middleware.FinalizeInput{}); got != "Lambda" { + t.Errorf("want lambda, got %s", got) + } + }) +}