diff --git a/lambda/entry.go b/lambda/entry.go index f71ac455..c935e236 100644 --- a/lambda/entry.go +++ b/lambda/entry.go @@ -4,7 +4,6 @@ package lambda import ( "context" - "errors" "log" "os" ) @@ -70,20 +69,11 @@ type startFunction struct { } var ( - // This allows users to save a little bit of coldstart time in the download, by the dependencies brought in for RPC support. - // The tradeoff is dropping compatibility with the go1.x runtime, functions must be "Custom Runtime" instead. - // To drop the rpc dependencies, compile with `-tags lambda.norpc` - rpcStartFunction = &startFunction{ - env: "_LAMBDA_SERVER_PORT", - f: func(_ string, _ Handler) error { - return errors.New("_LAMBDA_SERVER_PORT was present but the function was compiled without RPC support") - }, - } runtimeAPIStartFunction = &startFunction{ env: "AWS_LAMBDA_RUNTIME_API", f: startRuntimeAPILoop, } - startFunctions = []*startFunction{rpcStartFunction, runtimeAPIStartFunction} + startFunctions = []*startFunction{runtimeAPIStartFunction} // This allows end to end testing of the Start functions, by tests overwriting this function to keep the program alive logFatalf = log.Fatalf diff --git a/lambda/entry_test.go b/lambda/entry_test.go index 8da99ca9..6f5304cd 100644 --- a/lambda/entry_test.go +++ b/lambda/entry_test.go @@ -4,15 +4,11 @@ package lambda import ( "context" - "fmt" "log" - "net" - "net/rpc" "os" "strings" "testing" - "github.com/aws/aws-lambda-go/lambda/messages" "github.com/stretchr/testify/assert" ) @@ -35,58 +31,3 @@ func TestStartRuntimeAPIWithContext(t *testing.T) { assert.Equal(t, expected, actual) } - -func TestStartRPCWithContext(t *testing.T) { - expected := "expected" - actual := "unexpected" - port := getFreeTCPPort() - os.Setenv("_LAMBDA_SERVER_PORT", fmt.Sprintf("%d", port)) - defer os.Unsetenv("_LAMBDA_SERVER_PORT") - go StartWithContext(context.WithValue(context.Background(), ctxTestKey{}, expected), func(ctx context.Context) error { - actual, _ = ctx.Value(ctxTestKey{}).(string) - return nil - }) - - var client *rpc.Client - var pingResponse messages.PingResponse - var invokeResponse messages.InvokeResponse - var err error - for { - client, err = rpc.Dial("tcp", fmt.Sprintf("localhost:%d", port)) - if err != nil { - continue - } - break - } - for { - if err := client.Call("Function.Ping", &messages.PingRequest{}, &pingResponse); err != nil { - continue - } - break - } - if err := client.Call("Function.Invoke", &messages.InvokeRequest{}, &invokeResponse); err != nil { - t.Logf("error invoking function: %v", err) - } - - assert.Equal(t, expected, actual) -} - -func getFreeTCPPort() int { - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - log.Fatal("getFreeTCPPort failed: ", err) - } - defer l.Close() - - return l.Addr().(*net.TCPAddr).Port -} - -func TestStartNotInLambda(t *testing.T) { - actual := "unexpected" - logFatalf = func(format string, v ...interface{}) { - actual = fmt.Sprintf(format, v...) - } - - Start(func() error { return nil }) - assert.Equal(t, "expected AWS Lambda environment variables [_LAMBDA_SERVER_PORT AWS_LAMBDA_RUNTIME_API] are not defined", actual) -} diff --git a/lambda/entry_with_no_rpc_test.go b/lambda/entry_with_no_rpc_test.go new file mode 100644 index 00000000..d363d6c8 --- /dev/null +++ b/lambda/entry_with_no_rpc_test.go @@ -0,0 +1,23 @@ +// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +//go:build lambda.norpc +// +build lambda.norpc + +package lambda + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStartNotInLambda(t *testing.T) { + actual := "unexpected" + logFatalf = func(format string, v ...interface{}) { + actual = fmt.Sprintf(format, v...) + } + + Start(func() error { return nil }) + assert.Equal(t, "expected AWS Lambda environment variables [AWS_LAMBDA_RUNTIME_API] are not defined", actual) +} diff --git a/lambda/entry_with_rpc_test.go b/lambda/entry_with_rpc_test.go new file mode 100644 index 00000000..92e9532f --- /dev/null +++ b/lambda/entry_with_rpc_test.go @@ -0,0 +1,74 @@ +// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved + +//go:build !lambda.norpc +// +build !lambda.norpc + +package lambda + +import ( + "context" + "fmt" + "log" + "net" + "net/rpc" + "os" + "testing" + + "github.com/aws/aws-lambda-go/lambda/messages" + "github.com/stretchr/testify/assert" +) + +func TestStartRPCWithContext(t *testing.T) { + expected := "expected" + actual := "unexpected" + port := getFreeTCPPort() + os.Setenv("_LAMBDA_SERVER_PORT", fmt.Sprintf("%d", port)) + defer os.Unsetenv("_LAMBDA_SERVER_PORT") + go StartWithContext(context.WithValue(context.Background(), ctxTestKey{}, expected), func(ctx context.Context) error { + actual, _ = ctx.Value(ctxTestKey{}).(string) + return nil + }) + + var client *rpc.Client + var pingResponse messages.PingResponse + var invokeResponse messages.InvokeResponse + var err error + for { + client, err = rpc.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + continue + } + break + } + for { + if err := client.Call("Function.Ping", &messages.PingRequest{}, &pingResponse); err != nil { + continue + } + break + } + if err := client.Call("Function.Invoke", &messages.InvokeRequest{}, &invokeResponse); err != nil { + t.Logf("error invoking function: %v", err) + } + + assert.Equal(t, expected, actual) +} + +func getFreeTCPPort() int { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + log.Fatal("getFreeTCPPort failed: ", err) + } + defer l.Close() + + return l.Addr().(*net.TCPAddr).Port +} + +func TestStartNotInLambda(t *testing.T) { + actual := "unexpected" + logFatalf = func(format string, v ...interface{}) { + actual = fmt.Sprintf(format, v...) + } + + Start(func() error { return nil }) + assert.Equal(t, "expected AWS Lambda environment variables [_LAMBDA_SERVER_PORT AWS_LAMBDA_RUNTIME_API] are not defined", actual) +} diff --git a/lambda/invoke_loop.go b/lambda/invoke_loop.go index dca2bf68..f73689ba 100644 --- a/lambda/invoke_loop.go +++ b/lambda/invoke_loop.go @@ -3,13 +3,16 @@ package lambda import ( + "context" "encoding/json" "fmt" "log" + "os" "strconv" "time" "github.com/aws/aws-lambda-go/lambda/messages" + "github.com/aws/aws-lambda-go/lambdacontext" ) const ( @@ -17,87 +20,121 @@ const ( nsPerMS = int64(time.Millisecond / time.Nanosecond) ) +// TODO: replace with time.UnixMillis after dropping version <1.17 from CI workflows +func unixMS(ms int64) time.Time { + return time.Unix(ms/msPerS, (ms%msPerS)*nsPerMS) +} + // startRuntimeAPILoop will return an error if handling a particular invoke resulted in a non-recoverable error func startRuntimeAPILoop(api string, handler Handler) error { client := newRuntimeAPIClient(api) - function := NewFunction(handler) + h := newHandler(handler) for { invoke, err := client.next() if err != nil { return err } - - err = handleInvoke(invoke, function) - if err != nil { + if err = handleInvoke(invoke, h); err != nil { return err } } } // handleInvoke returns an error if the function panics, or some other non-recoverable error occurred -func handleInvoke(invoke *invoke, function *Function) error { - functionRequest, err := convertInvokeRequest(invoke) +func handleInvoke(invoke *invoke, handler *handlerOptions) error { + // set the deadline + deadline, err := parseDeadline(invoke) if err != nil { - return fmt.Errorf("unexpected error occurred when parsing the invoke: %v", err) + return reportFailure(invoke, lambdaErrorResponse(err)) } + ctx, cancel := context.WithDeadline(handler.baseContext, deadline) + defer cancel() - functionResponse := &messages.InvokeResponse{} - if err := function.Invoke(functionRequest, functionResponse); err != nil { - return fmt.Errorf("unexpected error occurred when invoking the handler: %v", err) + // set the invoke metadata values + lc := lambdacontext.LambdaContext{ + AwsRequestID: invoke.id, + InvokedFunctionArn: invoke.headers.Get(headerInvokedFunctionARN), } - - if functionResponse.Error != nil { - errorPayload := safeMarshal(functionResponse.Error) - log.Printf("%s", errorPayload) - if err := invoke.failure(errorPayload, contentTypeJSON); err != nil { - return fmt.Errorf("unexpected error occurred when sending the function error to the API: %v", err) + if err := parseClientContext(invoke, &lc.ClientContext); err != nil { + return reportFailure(invoke, lambdaErrorResponse(err)) + } + if err := parseCognitoIdentity(invoke, &lc.Identity); err != nil { + return reportFailure(invoke, lambdaErrorResponse(err)) + } + ctx = lambdacontext.NewContext(ctx, &lc) + + // set the trace id + traceID := invoke.headers.Get(headerTraceID) + os.Setenv("_X_AMZN_TRACE_ID", traceID) + // nolint:staticcheck + ctx = context.WithValue(ctx, "x-amzn-trace-id", traceID) + + // call the handler, marshal any returned error + response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.Handler.Invoke) + if invokeErr != nil { + if err := reportFailure(invoke, invokeErr); err != nil { + return err } - if functionResponse.Error.ShouldExit { + if invokeErr.ShouldExit { return fmt.Errorf("calling the handler function resulted in a panic, the process should exit") } return nil } - - if err := invoke.success(functionResponse.Payload, contentTypeJSON); err != nil { + if err := invoke.success(response, contentTypeJSON); err != nil { return fmt.Errorf("unexpected error occurred when sending the function functionResponse to the API: %v", err) } return nil } -// convertInvokeRequest converts an invoke from the Runtime API, and unpacks it to be compatible with the shape of a `lambda.Function` InvokeRequest. -func convertInvokeRequest(invoke *invoke) (*messages.InvokeRequest, error) { - deadlineEpochMS, err := strconv.ParseInt(invoke.headers.Get(headerDeadlineMS), 10, 64) - if err != nil { - return nil, fmt.Errorf("failed to parse contents of header: %s", headerDeadlineMS) +func reportFailure(invoke *invoke, invokeErr *messages.InvokeResponse_Error) error { + errorPayload := safeMarshal(invokeErr) + log.Printf("%s", errorPayload) + if err := invoke.failure(errorPayload, contentTypeJSON); err != nil { + return fmt.Errorf("unexpected error occurred when sending the function error to the API: %v", err) } - deadlineS := deadlineEpochMS / msPerS - deadlineNS := (deadlineEpochMS % msPerS) * nsPerMS + return nil +} - res := &messages.InvokeRequest{ - InvokedFunctionArn: invoke.headers.Get(headerInvokedFunctionARN), - XAmznTraceId: invoke.headers.Get(headerTraceID), - RequestId: invoke.id, - Deadline: messages.InvokeRequest_Timestamp{ - Seconds: deadlineS, - Nanos: deadlineNS, - }, - Payload: invoke.payload, +func callBytesHandlerFunc(ctx context.Context, payload []byte, handler bytesHandlerFunc) (response []byte, invokeErr *messages.InvokeResponse_Error) { + defer func() { + if err := recover(); err != nil { + invokeErr = lambdaPanicResponse(err) + } + }() + response, err := handler(ctx, payload) + if err != nil { + return nil, lambdaErrorResponse(err) } + return response, nil +} - clientContextJSON := invoke.headers.Get(headerClientContext) - if clientContextJSON != "" { - res.ClientContext = []byte(clientContextJSON) +func parseDeadline(invoke *invoke) (time.Time, error) { + deadlineEpochMS, err := strconv.ParseInt(invoke.headers.Get(headerDeadlineMS), 10, 64) + if err != nil { + return time.Time{}, fmt.Errorf("failed to parse deadline: %v", err) } + return unixMS(deadlineEpochMS), nil +} +func parseCognitoIdentity(invoke *invoke, out *lambdacontext.CognitoIdentity) error { cognitoIdentityJSON := invoke.headers.Get(headerCognitoIdentity) if cognitoIdentityJSON != "" { - if err := json.Unmarshal([]byte(invoke.headers.Get(headerCognitoIdentity)), res); err != nil { - return nil, fmt.Errorf("failed to unmarshal cognito identity json: %v", err) + if err := json.Unmarshal([]byte(cognitoIdentityJSON), out); err != nil { + return fmt.Errorf("failed to unmarshal cognito identity json: %v", err) } } + return nil +} - return res, nil +func parseClientContext(invoke *invoke, out *lambdacontext.ClientContext) error { + clientContextJSON := invoke.headers.Get(headerClientContext) + if clientContextJSON != "" { + if err := json.Unmarshal([]byte(clientContextJSON), out); err != nil { + return fmt.Errorf("failed to unmarshal client context json: %v", err) + } + } + return nil } func safeMarshal(v interface{}) []byte { diff --git a/lambda/invoke_loop_test.go b/lambda/invoke_loop_test.go index 9e94d74d..54ec96cf 100644 --- a/lambda/invoke_loop_test.go +++ b/lambda/invoke_loop_test.go @@ -5,17 +5,21 @@ package lambda import ( "bytes" "context" + "encoding/json" "errors" "fmt" "io" "net/http" "net/http/httptest" + "os" "strings" "testing" "unicode/utf8" + "github.com/aws/aws-lambda-go/lambda/messages" "github.com/aws/aws-lambda-go/lambdacontext" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestFatalErrors(t *testing.T) { @@ -28,7 +32,12 @@ func TestFatalErrors(t *testing.T) { expectedErrorMessage := "calling the handler function resulted in a panic, the process should exit" assert.EqualError(t, startRuntimeAPILoop(endpoint, handler), expectedErrorMessage) assert.Equal(t, 1, record.nGets) - assert.Equal(t, 1, record.nGets) + var invokeErr messages.InvokeResponse_Error + err := json.Unmarshal(record.responses[0], &invokeErr) + assert.NoError(t, err) + assert.NotNil(t, invokeErr.StackTrace) + assert.Equal(t, "errorString", invokeErr.Type) + assert.Equal(t, "a fatal error", invokeErr.Message) } func TestRuntimeAPILoop(t *testing.T) { @@ -52,10 +61,49 @@ func TestRuntimeAPILoop(t *testing.T) { assert.Equal(t, nInvokes, record.nPosts) } +func TestCustomErrorMarshaling(t *testing.T) { + type CustomError struct{ error } + errors := []error{ + errors.New("boring"), + CustomError{errors.New("Something bad happened!")}, + messages.InvokeResponse_Error{Type: "yolo", Message: "hello"}, + } + expected := []string{ + `{ "errorType": "errorString", "errorMessage": "boring"}`, + `{ "errorType": "CustomError", "errorMessage": "Something bad happened!" }`, + `{ "errorType": "yolo", "errorMessage": "hello" }`, + } + require.Equal(t, len(errors), len(expected)) + ts, record := runtimeAPIServer(``, len(errors)) + defer ts.Close() + n := 0 + handler := NewHandler(func() error { + defer func() { n++ }() + return errors[n] + }) + endpoint := strings.Split(ts.URL, "://")[1] + expectedError := fmt.Sprintf("failed to GET http://%s/2018-06-01/runtime/invocation/next: got unexpected status code: 410", endpoint) + assert.EqualError(t, startRuntimeAPILoop(endpoint, handler), expectedError) + for i := range errors { + assert.JSONEq(t, expected[i], string(record.responses[i])) + } +} + func TestRuntimeAPIContextPlumbing(t *testing.T) { handler := NewHandler(func(ctx context.Context) (interface{}, error) { lc, _ := lambdacontext.FromContext(ctx) - return lc, nil + deadline, _ := ctx.Deadline() + return struct { + Context *lambdacontext.LambdaContext + TraceID string + EnvTraceID string + Deadline int64 + }{ + Context: lc, + TraceID: ctx.Value("x-amzn-trace-id").(string), + EnvTraceID: os.Getenv("_X_AMZN_TRACE_ID"), + Deadline: deadline.UnixNano() / nsPerMS, + }, nil }) ts, record := runtimeAPIServer(``, 1) @@ -67,22 +115,27 @@ func TestRuntimeAPIContextPlumbing(t *testing.T) { expected := ` { - "AwsRequestID": "dummyid", - "InvokedFunctionArn": "dummyarn", - "Identity": { - "CognitoIdentityID": "dummyident", - "CognitoIdentityPoolID": "dummypool" - }, - "ClientContext": { - "Client": { - "installation_id": "dummyinstallid", - "app_title": "dummytitle", - "app_version_code": "dummycode", - "app_package_name": "dummyname" + "Context": { + "AwsRequestID": "dummyid", + "InvokedFunctionArn": "dummyarn", + "Identity": { + "CognitoIdentityID": "dummyident", + "CognitoIdentityPoolID": "dummypool" }, - "env": null, - "custom": null - } + "ClientContext": { + "Client": { + "installation_id": "dummyinstallid", + "app_title": "dummytitle", + "app_version_code": "dummycode", + "app_package_name": "dummyname" + }, + "env": null, + "custom": null + } + }, + "TraceID": "its-xray-time", + "EnvTraceID": "its-xray-time", + "Deadline": 22 } ` assert.JSONEq(t, expected, string(record.responses[0])) @@ -106,6 +159,43 @@ func TestReadPayload(t *testing.T) { } +func TestContextDeserializationErrors(t *testing.T) { + badClientContext := defaultInvokeMetadata() + badClientContext.clientContext = `{ not json }` + + badCognito := defaultInvokeMetadata() + badCognito.cognito = `{ not json }` + + badDeadline := defaultInvokeMetadata() + badDeadline.deadline = `yolo` + + badMetadata := []eventMetadata{badClientContext, badCognito, badDeadline} + + ts, record := runtimeAPIServer(`{}`, len(badMetadata), badMetadata...) + defer ts.Close() + handler := NewHandler(func(ctx context.Context) (*lambdacontext.LambdaContext, error) { + lc, _ := lambdacontext.FromContext(ctx) + return lc, nil + }) + endpoint := strings.Split(ts.URL, "://")[1] + _ = startRuntimeAPILoop(endpoint, handler) + + assert.JSONEq(t, `{ + "errorMessage":"failed to unmarshal client context json: invalid character 'n' looking for beginning of object key string", + "errorType":"errorString" + }`, string(record.responses[0])) + + assert.JSONEq(t, `{ + "errorMessage":"failed to unmarshal cognito identity json: invalid character 'n' looking for beginning of object key string", + "errorType":"errorString" + }`, string(record.responses[1])) + + assert.JSONEq(t, `{ + "errorMessage":"failed to parse deadline: strconv.ParseInt: parsing \"yolo\": invalid syntax", + "errorType":"errorString" + }`, string(record.responses[2])) +} + type invalidPayload struct{} func (invalidPayload) MarshalJSON() ([]byte, error) { @@ -124,33 +214,59 @@ type requestRecord struct { responses [][]byte } -func runtimeAPIServer(eventPayload string, failAfter int) (*httptest.Server, *requestRecord) { +type eventMetadata struct { + clientContext string + cognito string + xray string + deadline string + requestID string + functionARN string +} + +func defaultInvokeMetadata() eventMetadata { + return eventMetadata{ + clientContext: `{ + "Client": { + "app_title": "dummytitle", + "installation_id": "dummyinstallid", + "app_version_code": "dummycode", + "app_package_name": "dummyname" + } + }`, + cognito: `{ + "cognitoIdentityId": "dummyident", + "cognitoIdentityPoolId": "dummypool" + }`, + xray: "its-xray-time", + requestID: "dummyid", + deadline: "22", + functionARN: "dummyarn", + } +} + +func runtimeAPIServer(eventPayload string, failAfter int, overrides ...eventMetadata) (*httptest.Server, *requestRecord) { numInvokesRequested := 0 record := &requestRecord{} + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: + metadata := defaultInvokeMetadata() + if numInvokesRequested < len(overrides) { + metadata = overrides[numInvokesRequested] + } record.nGets++ numInvokesRequested++ if numInvokesRequested > failAfter { w.WriteHeader(http.StatusGone) _, _ = w.Write([]byte("END THE TEST!")) } - w.Header().Add(string(headerAWSRequestID), "dummyid") - w.Header().Add(string(headerDeadlineMS), "22") - w.Header().Add(string(headerInvokedFunctionARN), "dummyarn") - w.Header().Add(string(headerClientContext), `{ - "Client": { - "app_title": "dummytitle", - "installation_id": "dummyinstallid", - "app_version_code": "dummycode", - "app_package_name": "dummyname" - } - }`) - w.Header().Add(string(headerCognitoIdentity), `{ - "cognitoIdentityId": "dummyident", - "cognitoIdentityPoolId": "dummypool" - }`) + w.Header().Add(string(headerAWSRequestID), metadata.requestID) + w.Header().Add(string(headerDeadlineMS), metadata.deadline) + w.Header().Add(string(headerInvokedFunctionARN), metadata.functionARN) + w.Header().Add(string(headerClientContext), metadata.clientContext) + w.Header().Add(string(headerCognitoIdentity), metadata.cognito) + w.Header().Add(string(headerTraceID), metadata.xray) w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(eventPayload)) case http.MethodPost: diff --git a/lambda/panic_test.go b/lambda/panic_test.go index b2135ecd..76fdfd6f 100644 --- a/lambda/panic_test.go +++ b/lambda/panic_test.go @@ -1,6 +1,7 @@ package lambda import ( + "errors" "os" "runtime" "strings" @@ -34,7 +35,8 @@ func TestPanicFormattingIntValue(t *testing.T) { } func TestPanicFormattingCustomError(t *testing.T) { - customError := &CustomError{} + type CustomError struct{ error } + customError := &CustomError{errors.New("oh noooooo!")} assertPanicMessage(t, func() { panic(customError) }, customError.Error()) } diff --git a/lambda/rpc.go b/lambda/rpc.go deleted file mode 100644 index 8c232a98..00000000 --- a/lambda/rpc.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved - -//go:build !lambda.norpc -// +build !lambda.norpc - -package lambda - -import ( - "errors" - "log" - "net" - "net/rpc" -) - -func init() { - // Register `startFunctionRPC` to be run if the _LAMBDA_SERVER_PORT environment variable is set. - // This happens when the runtime for the function is configured as `go1.x`. - // The value of the environment variable will be passed as the first argument to `startFunctionRPC`. - rpcStartFunction.f = startFunctionRPC -} - -func startFunctionRPC(port string, handler Handler) error { - lis, err := net.Listen("tcp", "localhost:"+port) - if err != nil { - log.Fatal(err) - } - err = rpc.Register(NewFunction(handler)) - if err != nil { - log.Fatal("failed to register handler function") - } - rpc.Accept(lis) - return errors.New("accept should not have returned") -} diff --git a/lambda/function.go b/lambda/rpc_function.go similarity index 68% rename from lambda/function.go rename to lambda/rpc_function.go index e6fe464f..0c8e798e 100644 --- a/lambda/function.go +++ b/lambda/rpc_function.go @@ -1,10 +1,17 @@ // Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +//go:build !lambda.norpc +// +build !lambda.norpc + package lambda import ( "context" "encoding/json" + "errors" + "log" + "net" + "net/rpc" "os" "time" @@ -12,6 +19,32 @@ import ( "github.com/aws/aws-lambda-go/lambdacontext" ) +func init() { + // Register `startFunctionRPC` to be run if the _LAMBDA_SERVER_PORT environment variable is set. + // This happens when the runtime for the function is configured as `go1.x`. + // The value of the environment variable will be passed as the first argument to `startFunctionRPC`. + // This allows users to save a little bit of coldstart time in the download, by the dependencies brought in for RPC support. + // The tradeoff is dropping compatibility with the RPC mode of the go1.x runtime. + // To drop the rpc dependencies, compile with `-tags lambda.norpc` + startFunctions = append([]*startFunction{{ + env: "_LAMBDA_SERVER_PORT", + f: startFunctionRPC, + }}, startFunctions...) +} + +func startFunctionRPC(port string, handler Handler) error { + lis, err := net.Listen("tcp", "localhost:"+port) + if err != nil { + log.Fatal(err) + } + err = rpc.Register(NewFunction(handler)) + if err != nil { + log.Fatal("failed to register handler function") + } + rpc.Accept(lis) + return errors.New("accept should not have returned") +} + // Function struct which wrap the Handler // // Deprecated: The Function type is public for the go1.x runtime internal use of the net/rpc package diff --git a/lambda/function_test.go b/lambda/rpc_function_test.go similarity index 99% rename from lambda/function_test.go rename to lambda/rpc_function_test.go index a02c25b8..6935084d 100644 --- a/lambda/function_test.go +++ b/lambda/rpc_function_test.go @@ -1,5 +1,8 @@ // Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +//go:build !lambda.norpc +// +build !lambda.norpc + package lambda import ( diff --git a/lambda/testdata/.gitignore b/lambda/testdata/.gitignore new file mode 100644 index 00000000..e45ddf8c --- /dev/null +++ b/lambda/testdata/.gitignore @@ -0,0 +1,2 @@ +handler* +*.json diff --git a/lambda/testdata/bench.sh b/lambda/testdata/bench.sh new file mode 100755 index 00000000..49e9d24b --- /dev/null +++ b/lambda/testdata/bench.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +set -euo pipefail + +echo "\"$(od -N $((512 * 1024)) /dev/random | base64)\"" > data.json +echo "data payload for tests is: $(du -h data.json)" + +trap "docker kill rie-bench" 1 2 3 6 +bench () { + local handler_exe=$1 + local entrypoint=$2 + local image=$3 + echo "-------------------------------------------------" + echo $@ + echo "-------------------------------------------------" + docker run --name rie-bench --rm -d -p 9001:8080 -v "${handler_exe}:/var/task/bootstrap" --entrypoint aws-lambda-rie ${image} ${entrypoint} bootstrap + sleep 2 + echo "ensuring healthy function before starting test" + curl -s -XPOST http://localhost:9001/2015-03-31/functions/function/invocations -d '{"hello": "world"}' | jq + echo "-------------------------------------------------" + ab -p data.json -n 100 http://localhost:9001/2015-03-31/functions/function/invocations + docker kill rie-bench +} + +GOOS=linux GOARCH=amd64 go build -o handler echo_handler.go +GOOS=linux GOARCH=amd64 go build -tags lambda.norpc -o handler_norpc echo_handler.go +ls -lah handler* + +bench "$(pwd)/handler_norpc" /var/task/bootstrap public.ecr.aws/lambda/provided:alami +bench "$(pwd)/handler_norpc" /var/runtime/bootstrap public.ecr.aws/lambda/go +bench "$(pwd)/handler" /var/runtime/bootstrap public.ecr.aws/lambda/go diff --git a/lambda/testdata/echo_handler.go b/lambda/testdata/echo_handler.go new file mode 100644 index 00000000..afde842e --- /dev/null +++ b/lambda/testdata/echo_handler.go @@ -0,0 +1,17 @@ +package main + +import ( + "context" + + "github.com/aws/aws-lambda-go/lambda" +) + +type handler struct{} + +func (h handler) Invoke(_ context.Context, e []byte) ([]byte, error) { + return e, nil +} + +func main() { + lambda.Start(handler{}) +}