From 1e10e693adffcbf0898e01eb2fa224799f621567 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Wed, 10 May 2023 21:05:15 +0900 Subject: [PATCH 01/24] introduce InvokeMode --- ridgenative.go | 55 +++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/ridgenative.go b/ridgenative.go index e886555..be56703 100644 --- a/ridgenative.go +++ b/ridgenative.go @@ -479,6 +479,36 @@ func newLambdaFunction(mux http.Handler) *lambdaFunction { } } +// InvokeMode is the mode that determines which API operation Lambda uses. +type InvokeMode string + +const ( + // InvokeModeBuffered indicates that your function is invoked using the Invoke API operation. + // Invocation results are available when the payload is complete. + InvokeModeBuffered InvokeMode = "BUFFERED" + + // InvokeModeResponseStreaming indicates that your function is invoked using + // the InvokeWithResponseStream API operation. + // It enables your function to stream payload results as they become available. + InvokeModeResponseStreaming InvokeMode = "RESPONSE_STREAMING" +) + +// Start starts the AWS Lambda function. +// The handler is typically nil, in which case the DefaultServeMux is used. +func Start(mux http.Handler, mode InvokeMode) error { + api := os.Getenv("AWS_LAMBDA_RUNTIME_API") + if mux == nil { + mux = http.DefaultServeMux + } + f := newLambdaFunction(mux) + c := newRuntimeAPIClient(api) + if err := c.start(f.lambdaHandler); err != nil { + log.Println(err) + return err + } + return nil +} + // ListenAndServe starts HTTP server. // // If AWS_LAMBDA_RUNTIME_API environment value is defined, it wait for new AWS Lambda events and handle it as HTTP requests. @@ -493,25 +523,30 @@ func newLambdaFunction(mux http.Handler) *lambdaFunction { // If AWS_LAMBDA_RUNTIME_API environment value is NOT defined, it just calls http.ListenAndServe. // // The handler is typically nil, in which case the DefaultServeMux is used. +// +// If AWS_LAMBDA_RUNTIME_API environment value is defined, ListenAndServe uses it as the invoke mode. +// The default is InvokeModeBuffered. func ListenAndServe(address string, mux http.Handler) error { if go1 := os.Getenv("AWS_EXECUTION_ENV"); go1 == "AWS_Lambda_go1.x" { // run on go1.x runtime return errors.New("ridgenative: go1.x runtime is not supported") } - api := os.Getenv("AWS_LAMBDA_RUNTIME_API") // run on provided or provided.al2 runtime + api := os.Getenv("AWS_LAMBDA_RUNTIME_API") if api == "" { // fall back to normal HTTP server. return http.ListenAndServe(address, mux) } - if mux == nil { - mux = http.DefaultServeMux - } - f := newLambdaFunction(mux) - c := newRuntimeAPIClient(api) - if err := c.start(f.lambdaHandler); err != nil { - log.Println(err) - return err + + // run on provided or provided.al2 runtime + mode := InvokeModeBuffered + switch os.Getenv("RIDGENATIVE_INVOKE_MODE") { + case "BUFFERED", "": + mode = InvokeModeBuffered + case "RESPONSE_STREAMING": + mode = InvokeModeResponseStreaming + default: + return errors.New("ridgenative: invalid RIDGENATIVE_INVOKE_MODE") } - return nil + return Start(mux, mode) } From 5e895b67d90d83a68361f005d762232e9de3fc3a Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Wed, 10 May 2023 21:12:17 +0900 Subject: [PATCH 02/24] fix --- ridgenative.go | 2 +- runtime_api_client.go | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/ridgenative.go b/ridgenative.go index be56703..7756edc 100644 --- a/ridgenative.go +++ b/ridgenative.go @@ -519,7 +519,7 @@ func Start(mux http.Handler, mode InvokeMode) error { // // https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html // -// If AWS_EXECUTION_ENV is AWS_Lambda_go1.x, it returns an error. +// If AWS_EXECUTION_ENV environment value is AWS_Lambda_go1.x, it returns an error. // If AWS_LAMBDA_RUNTIME_API environment value is NOT defined, it just calls http.ListenAndServe. // // The handler is typically nil, in which case the DefaultServeMux is used. diff --git a/runtime_api_client.go b/runtime_api_client.go index 592df7b..963c466 100644 --- a/runtime_api_client.go +++ b/runtime_api_client.go @@ -146,7 +146,7 @@ func (c *runtimeAPIClient) post(path string, body []byte, contentType string) er url := c.baseURL + path req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body)) if err != nil { - return fmt.Errorf("ridgenative: failed to construct POST request to %s: %v", url, err) + return fmt.Errorf("ridgenative: failed to construct POST request to %s: %w", url, err) } req.Header.Set("User-Agent", c.userAgent) req.Header.Set("Content-Type", contentType) @@ -155,18 +155,15 @@ func (c *runtimeAPIClient) post(path string, body []byte, contentType string) er if err != nil { return fmt.Errorf("ridgenative: failed to POST to %s: %v", url, err) } - defer func() { - if err := resp.Body.Close(); err != nil { - log.Printf("ridgenative: runtime API client failed to close %s response body: %v", url, err) - } - }() + defer resp.Body.Close() + if resp.StatusCode != http.StatusAccepted { return fmt.Errorf("ridgenative: failed to POST to %s: got unexpected status code: %d", url, resp.StatusCode) } _, err = io.Copy(io.Discard, resp.Body) if err != nil { - return fmt.Errorf("ridgenative: something went wrong reading the POST response from %s: %v", url, err) + return fmt.Errorf("ridgenative: something went wrong reading the POST response from %s: %w", url, err) } return nil From 3e08ad5bbeab7dabe1e1bb12150229ff49bdd1ab Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Wed, 10 May 2023 23:05:59 +0900 Subject: [PATCH 03/24] implement callHandlerFuncSteaming --- invoke.go | 20 +++++ ridgenative.go | 202 ++++++++++++++++++++++++------------------ ridgenative_test.go | 24 +++-- runtime_api_client.go | 143 ++++++++++++++++++++++++++++-- 4 files changed, 281 insertions(+), 108 deletions(-) diff --git a/invoke.go b/invoke.go index 6370790..dda98ef 100644 --- a/invoke.go +++ b/invoke.go @@ -3,6 +3,7 @@ package ridgenative import ( "context" "encoding/json" + "io" "net/http" ) @@ -29,3 +30,22 @@ func callBytesHandlerFunc(ctx context.Context, payload []byte, h handlerFunc) (r } return json.Marshal(resp) } + +func callHandlerFuncSteaming(ctx context.Context, payload []byte, h handlerFuncSteaming) (response io.ReadCloser, err error) { + defer func() { + if v := recover(); v != nil { + err = lambdaPanicResponse(v) + } + }() + + var req *request + if err := json.Unmarshal(payload, &req); err != nil { + return nil, err + } + + r, w := io.Pipe() + if err := h(ctx, req, w); err != nil { + return nil, err + } + return r, nil +} diff --git a/ridgenative.go b/ridgenative.go index 7756edc..0c27eba 100644 --- a/ridgenative.go +++ b/ridgenative.go @@ -1,10 +1,12 @@ package ridgenative import ( + "bufio" "bytes" "context" "encoding/base64" "errors" + "fmt" "io" "log" "net/http" @@ -18,12 +20,6 @@ import ( type lambdaFunction struct { mux http.Handler - // buffer for string data - builder strings.Builder - - // buffer for binary data - buffer bytes.Buffer - out []byte } type request struct { @@ -207,43 +203,34 @@ func (f *lambdaFunction) httpRequestV2(ctx context.Context, r *request) (*http.R } func (f *lambdaFunction) decodeBody(r *request) (body io.ReadCloser, contentLength int64, err error) { - if r.Body != "" { - var reader io.Reader - if r.IsBase64Encoded { - f.buffer.Reset() - f.buffer.WriteString(r.Body) - n := base64.StdEncoding.DecodedLen(len(r.Body)) - out := f.out - if cap(out) < n { - out = make([]byte, n) - } else { - out = out[:n] - } - n, err := base64.StdEncoding.Decode(out, f.buffer.Bytes()) - f.out = out - if err != nil { - return nil, 0, err - } - contentLength = int64(n) - reader = bytes.NewReader(out[:n]) - } else { - contentLength = int64(len(r.Body)) - reader = io.Reader(strings.NewReader(r.Body)) + if r.Body == "" { + body = http.NoBody + return + } + + var reader io.Reader + if r.IsBase64Encoded { + var b []byte + b, err = base64.StdEncoding.DecodeString(r.Body) + if err != nil { + return } - body = io.NopCloser(reader) + contentLength = int64(len(b)) + reader = bytes.NewReader(b) } else { - body = http.NoBody + contentLength = int64(len(r.Body)) + reader = strings.NewReader(r.Body) } + body = io.NopCloser(reader) return } type responseWriter struct { - w io.Writer + w bytes.Buffer isBinary bool wroteHeader bool header http.Header statusCode int - lambda *lambdaFunction } type response struct { @@ -255,12 +242,9 @@ type response struct { Cookies []string `json:"cookies,omitempty"` } -func (f *lambdaFunction) newResponseWriter() *responseWriter { - f.builder.Reset() - f.buffer.Reset() +func newResponseWriter() *responseWriter { return &responseWriter{ header: make(http.Header, 1), - lambda: f, } } @@ -295,33 +279,10 @@ func (rw *responseWriter) WriteHeader(code int) { } rw.statusCode = code rw.wroteHeader = true - if typ := rw.header.Get("Content-Type"); typ != "" { - rw.isBinary = isBinary(typ) - if rw.isBinary { - rw.w = &rw.lambda.buffer - } else { - rw.w = &rw.lambda.builder - } - } } func (rw *responseWriter) Write(data []byte) (int, error) { - if !rw.wroteHeader { - rw.WriteHeader(http.StatusOK) - } - if rw.w != nil { - return rw.w.Write(data) - } - - f := rw.lambda - rest := 512 - f.buffer.Len() - if len(data) < rest { - return f.buffer.Write(data) - } - n1, _ := f.buffer.Write(data[:rest]) - rw.detectContentType() - n2, _ := rw.w.Write(data[rest:]) - return n1 + n2, nil + return rw.w.Write(data) } func (rw *responseWriter) lambdaResponseV1() (*response, error) { @@ -377,38 +338,24 @@ func (rw *responseWriter) encodeBody() string { if !rw.wroteHeader { rw.WriteHeader(http.StatusOK) } - if rw.w == nil { + + if typ := rw.header.Get("Content-Type"); typ != "" { + rw.isBinary = isBinary(typ) + } else { rw.detectContentType() } - var body string if rw.isBinary { - out := rw.lambda.out - l := base64.StdEncoding.EncodedLen(rw.lambda.buffer.Len()) - if cap(out) < l { - out = make([]byte, l) - } else { - out = out[:l] - } - base64.StdEncoding.Encode(out, rw.lambda.buffer.Bytes()) - body = string(out) - rw.lambda.out = out + return base64.StdEncoding.EncodeToString(rw.w.Bytes()) } else { - body = rw.lambda.builder.String() + return rw.w.String() } - return body } func (rw *responseWriter) detectContentType() { - contentType := http.DetectContentType(rw.lambda.buffer.Bytes()) + contentType := http.DetectContentType(rw.w.Bytes()) rw.header.Set("Content-Type", contentType) rw.isBinary = isBinary(contentType) - if rw.isBinary { - rw.w = &rw.lambda.buffer - } else { - rw.w = &rw.lambda.builder - rw.lambda.buffer.WriteTo(rw.w) - } } // assume text/*, application/json, application/javascript, application/xml, */*+json, */*+xml as text @@ -456,23 +403,92 @@ func (f *lambdaFunction) lambdaHandler(ctx context.Context, req *request) (*resp // Lambda Function URLs or API Gateway v2 r, err := f.httpRequestV2(ctx, req) if err != nil { - return &response{}, err + return nil, err } - rw := f.newResponseWriter() + rw := newResponseWriter() f.mux.ServeHTTP(rw, r) return rw.lambdaResponseV2() } else { // API Gateway v1 or ALB r, err := f.httpRequestV1(ctx, req) if err != nil { - return &response{}, err + return nil, err } - rw := f.newResponseWriter() + rw := newResponseWriter() f.mux.ServeHTTP(rw, r) return rw.lambdaResponseV1() } } +type streamingResponseWriter struct { + w *io.PipeWriter + buf *bufio.Writer + wroteHeader bool + header http.Header + statusCode int +} + +func newStreamingResponseWriter(w *io.PipeWriter) *streamingResponseWriter { + return &streamingResponseWriter{ + w: w, + buf: bufio.NewWriter(w), + header: make(http.Header, 1), + } +} + +func (rw *streamingResponseWriter) Header() http.Header { + return rw.header +} + +func (rw *streamingResponseWriter) WriteHeader(code int) { + if rw.wroteHeader { + caller := relevantCaller() + log.Printf("ridgenative: superfluous response.WriteHeader call from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line) + return + } + rw.statusCode = code + rw.wroteHeader = true +} + +func (rw *streamingResponseWriter) Write(data []byte) (int, error) { + return rw.w.Write(data) +} + +func (rw *streamingResponseWriter) closeWithError(err error) error { + err0 := rw.buf.Flush() + if err1 := rw.w.CloseWithError(err); err0 == nil { + err0 = err1 + } + return err0 +} + +func (rw *streamingResponseWriter) close() error { + err0 := rw.buf.Flush() + if err1 := rw.w.Close(); err0 == nil { + err0 = err1 + } + return err0 +} + +func (f *lambdaFunction) lambdaHandlerStreaming(ctx context.Context, req *request, w *io.PipeWriter) error { + r, err := f.httpRequestV2(ctx, req) + if err != nil { + return err + } + go func() { + rw := newStreamingResponseWriter(w) + defer func() { + if v := recover(); v != nil { + rw.closeWithError(lambdaPanicResponse(v)) + } else { + rw.close() + } + }() + f.mux.ServeHTTP(rw, r) + }() + return nil +} + func newLambdaFunction(mux http.Handler) *lambdaFunction { return &lambdaFunction{ mux: mux, @@ -502,9 +518,19 @@ func Start(mux http.Handler, mode InvokeMode) error { } f := newLambdaFunction(mux) c := newRuntimeAPIClient(api) - if err := c.start(f.lambdaHandler); err != nil { - log.Println(err) - return err + switch mode { + case InvokeModeBuffered: + if err := c.start(f.lambdaHandler); err != nil { + log.Println(err) + return err + } + case InvokeModeResponseStreaming: + if err := c.startStreaming(f.lambdaHandlerStreaming); err != nil { + log.Println(err) + return err + } + default: + return fmt.Errorf("ridgenative: invalid InvokeMode: %s", mode) } return nil } diff --git a/ridgenative_test.go b/ridgenative_test.go index 6b9d65d..ad4d53e 100644 --- a/ridgenative_test.go +++ b/ridgenative_test.go @@ -402,9 +402,8 @@ func TestHTTPRequest(t *testing.T) { } func TestResponseV1(t *testing.T) { - l := &lambdaFunction{} t.Run("normal", func(t *testing.T) { - rw := l.newResponseWriter() + rw := newResponseWriter() // normal header fields rw.Header().Add("foo", "foo") @@ -459,7 +458,7 @@ func TestResponseV1(t *testing.T) { } }) t.Run("set content-type", func(t *testing.T) { - rw := l.newResponseWriter() + rw := newResponseWriter() rw.Header().Set("Content-Type", "text/plain; charset=utf-8") if _, err := io.WriteString(rw, "\n"); err != nil { t.Error(err) @@ -489,7 +488,7 @@ func TestResponseV1(t *testing.T) { } }) t.Run("redirect to example.com", func(t *testing.T) { - rw := l.newResponseWriter() + rw := newResponseWriter() rw.Header().Add("location", "http://example.com/") rw.WriteHeader(http.StatusFound) if _, err := io.WriteString(rw, "\n"); err != nil { @@ -517,7 +516,7 @@ func TestResponseV1(t *testing.T) { } }) t.Run("base64", func(t *testing.T) { - rw := l.newResponseWriter() + rw := newResponseWriter() // 1x1 PNG image if _, err := io.WriteString(rw, "\x89\x50\x4e\x47\x0d\x0a\x1a\x0a\x00\x00\x00\x0d\x49\x48\x44\x52"); err != nil { t.Error(err) @@ -552,9 +551,8 @@ func TestResponseV1(t *testing.T) { } func TestResponseV2(t *testing.T) { - l := &lambdaFunction{} t.Run("normal", func(t *testing.T) { - rw := l.newResponseWriter() + rw := newResponseWriter() // normal header fields rw.Header().Add("foo", "foo") @@ -609,7 +607,7 @@ func TestResponseV2(t *testing.T) { } }) t.Run("set content-type", func(t *testing.T) { - rw := l.newResponseWriter() + rw := newResponseWriter() rw.Header().Set("Content-Type", "text/plain; charset=utf-8") if _, err := io.WriteString(rw, "\n"); err != nil { t.Error(err) @@ -639,7 +637,7 @@ func TestResponseV2(t *testing.T) { } }) t.Run("redirect to example.com", func(t *testing.T) { - rw := l.newResponseWriter() + rw := newResponseWriter() rw.Header().Add("location", "http://example.com/") rw.WriteHeader(http.StatusFound) if _, err := io.WriteString(rw, "\n"); err != nil { @@ -667,7 +665,7 @@ func TestResponseV2(t *testing.T) { } }) t.Run("base64", func(t *testing.T) { - rw := l.newResponseWriter() + rw := newResponseWriter() // 1x1 PNG image if _, err := io.WriteString(rw, "\x89\x50\x4e\x47\x0d\x0a\x1a\x0a\x00\x00\x00\x0d\x49\x48\x44\x52"); err != nil { t.Error(err) @@ -738,25 +736,23 @@ func BenchmarkRequest_text(b *testing.B) { } func BenchmarkResponse_binary(b *testing.B) { - l := newLambdaFunction(nil) data := make([]byte, 1<<20) // 1MB: the maximum size of the response JSON in ALB b.ResetTimer() for i := 0; i < b.N; i++ { - rw := l.newResponseWriter() + rw := newResponseWriter() rw.Write(data) rw.lambdaResponseV1() } } func BenchmarkResponse_text(b *testing.B) { - l := newLambdaFunction(nil) data := make([]byte, 1<<20) // 1MB: the maximum size of the response JSON in ALB for i := 0; i < len(data); i++ { data[i] = 'a' } b.ResetTimer() for i := 0; i < b.N; i++ { - rw := l.newResponseWriter() + rw := newResponseWriter() rw.Write(data) rw.lambdaResponseV1() } diff --git a/runtime_api_client.go b/runtime_api_client.go index 963c466..cfb9dec 100644 --- a/runtime_api_client.go +++ b/runtime_api_client.go @@ -3,7 +3,9 @@ package ridgenative import ( "bytes" "context" + "encoding/base64" "encoding/json" + "errors" "fmt" "io" "log" @@ -21,11 +23,15 @@ const ( headerCognitoIdentity = "Lambda-Runtime-Cognito-Identity" headerClientContext = "Lambda-Runtime-Client-Context" headerInvokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" - trailerLambdaErrorType = "Lambda-Runtime-Function-Error-Type" - trailerLambdaErrorBody = "Lambda-Runtime-Function-Error-Body" - contentTypeJSON = "application/json" - contentTypeBytes = "application/octet-stream" - apiVersion = "2018-06-01" + + trailerLambdaErrorType = "Lambda-Runtime-Function-Error-Type" + trailerLambdaErrorBody = "Lambda-Runtime-Function-Error-Body" + + contentTypeJSON = "application/json" + contentTypeBytes = "application/octet-stream" + contentTypeHTTPIntegrationResponse = "application/vnd.awslambda.http-integration-response" + + apiVersion = "2018-06-01" ) type runtimeAPIClient struct { @@ -96,7 +102,7 @@ func (c *runtimeAPIClient) next() (*invoke, error) { }, nil } -// handleInvoke returns an error if the function panics, or some other non-recoverable error occurred +// handleInvoke handles an invoke. func (c *runtimeAPIClient) handleInvoke(invoke *invoke, h handlerFunc) error { // set the deadline deadline, err := parseDeadline(invoke) @@ -169,6 +175,7 @@ func (c *runtimeAPIClient) post(path string, body []byte, contentType string) er return nil } +// reportFailure reports the error to the Runtime API. func (c *runtimeAPIClient) reportFailure(invoke *invoke, invokeErr *invokeResponseError) error { body, err := json.Marshal(invokeErr) if err != nil { @@ -180,3 +187,127 @@ func (c *runtimeAPIClient) reportFailure(invoke *invoke, invokeErr *invokeRespon } return nil } + +type handlerFuncSteaming func(ctx context.Context, req *request, w *io.PipeWriter) error + +func (c *runtimeAPIClient) startStreaming(h handlerFuncSteaming) error { + for { + invoke, err := c.next() + if err != nil { + return err + } + if err := c.handleInvokeStreaming(invoke, nil); err != nil { + return err + } + } +} + +// handleInvoke handles an invoke. +func (c *runtimeAPIClient) handleInvokeStreaming(invoke *invoke, h handlerFuncSteaming) error { + // set the deadline + deadline, err := parseDeadline(invoke) + if err != nil { + return c.reportFailure(invoke, lambdaErrorResponse(err)) + } + ctx, cancel := context.WithDeadline(context.TODO(), deadline) + defer cancel() + + // set the trace id + traceID := invoke.headers.Get(headerTraceID) + os.Setenv("_X_AMZN_TRACE_ID", traceID) + // to keep compatibility with AWS Lambda X-Ray SDK, we need to set "x-amzn-trace-id" to the context. + // nolint:staticcheck + ctx = context.WithValue(ctx, "x-amzn-trace-id", traceID) + + // call the handler, marshal any returned error + response, err := callHandlerFuncSteaming(ctx, invoke.payload, h) + if err != nil { + invokeErr := lambdaErrorResponse(err) + if err := c.reportFailure(invoke, invokeErr); err != nil { + return err + } + if invokeErr.ShouldExit { + return fmt.Errorf("calling the handler function resulted in a panic, the process should exit") + } + return nil + } + + if err := c.postStreaming(invoke.id+"/response", response, contentTypeHTTPIntegrationResponse); err != nil { + return fmt.Errorf("unexpected error occurred when sending the function functionResponse to the API: %w", err) + } + + return nil +} + +// postStreaming posts body to the Runtime API at the given path. +func (c *runtimeAPIClient) postStreaming(path string, body io.ReadCloser, contentType string) error { + b := newErrorCapturingReader(body) + url := c.baseURL + path + req, err := http.NewRequest(http.MethodPost, url, b) + if err != nil { + return fmt.Errorf("ridgenative: failed to construct POST request to %s: %w", url, err) + } + req.Trailer = b.trailer + req.Header.Set("User-Agent", c.userAgent) + req.Header.Set("Content-Type", contentType) + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("ridgenative: failed to POST to %s: %v", url, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusAccepted { + return fmt.Errorf("ridgenative: failed to POST to %s: got unexpected status code: %d", url, resp.StatusCode) + } + + _, err = io.Copy(io.Discard, resp.Body) + if err != nil { + return fmt.Errorf("ridgenative: something went wrong reading the POST response from %s: %w", url, err) + } + + return nil +} + +// errorCapturingReader is a reader that captures the first error returned by the underlying reader. +type errorCapturingReader struct { + reader io.ReadCloser + err error + trailer http.Header +} + +func newErrorCapturingReader(r io.ReadCloser) *errorCapturingReader { + return &errorCapturingReader{ + reader: r, + trailer: http.Header{}, + } +} + +func (r *errorCapturingReader) Read(p []byte) (int, error) { + if r.reader == nil { + return 0, io.EOF + } + if r.err != nil { + return 0, r.err + } + + n, err := r.reader.Read(p) + if err != nil && errors.Is(err, io.EOF) { + // capture the error + r.err = err + lambdaErr := lambdaErrorResponse(err) + body, err := json.Marshal(lambdaErr) + if err != nil { + // marshaling lambdaErr always succeeds + // because lambdaErr doesn't have any functions and channels. + panic(err) + } + r.trailer.Set(trailerLambdaErrorType, lambdaErr.Type) + r.trailer.Set(trailerLambdaErrorBody, base64.StdEncoding.EncodeToString(body)) + } + return n, err +} + +func (r *errorCapturingReader) Close() error { + return r.reader.Close() +} From 5e605a29c4e313c6f2a0b29e4bc5e95878446d9c Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Wed, 10 May 2023 23:26:43 +0900 Subject: [PATCH 04/24] implement WriteHeader --- ridgenative.go | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/ridgenative.go b/ridgenative.go index 0c27eba..5c410af 100644 --- a/ridgenative.go +++ b/ridgenative.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "encoding/base64" + "encoding/json" "errors" "fmt" "io" @@ -420,6 +421,13 @@ func (f *lambdaFunction) lambdaHandler(ctx context.Context, req *request) (*resp } } +type streamingResponse struct { + StatusCode int `json:"statusCode"` + Headers map[string]string `json:"headers,omitempty"` + Cookies []string `json:"cookies,omitempty"` +} + +// streamingResponseWriter is a http.ResponseWriter that supports streaming. type streamingResponseWriter struct { w *io.PipeWriter buf *bufio.Writer @@ -447,14 +455,32 @@ func (rw *streamingResponseWriter) WriteHeader(code int) { return } rw.statusCode = code + r := &streamingResponse{ + StatusCode: code, + } + data, err := json.Marshal(r) + if err != nil { + log.Printf("ridgenative: %v", err) + return + } + rw.buf.Write(data) + rw.buf.WriteString("\x00\x00\x00\x00\x00\x00\x00\x00") + rw.buf.Flush() rw.wroteHeader = true } func (rw *streamingResponseWriter) Write(data []byte) (int, error) { + if !rw.wroteHeader { + // TODO: detect content type if it is not set. + rw.WriteHeader(http.StatusOK) + } return rw.w.Write(data) } func (rw *streamingResponseWriter) closeWithError(err error) error { + if !rw.wroteHeader { + rw.WriteHeader(http.StatusOK) + } err0 := rw.buf.Flush() if err1 := rw.w.CloseWithError(err); err0 == nil { err0 = err1 @@ -463,6 +489,9 @@ func (rw *streamingResponseWriter) closeWithError(err error) error { } func (rw *streamingResponseWriter) close() error { + if !rw.wroteHeader { + rw.WriteHeader(http.StatusOK) + } err0 := rw.buf.Flush() if err1 := rw.w.Close(); err0 == nil { err0 = err1 @@ -470,6 +499,13 @@ func (rw *streamingResponseWriter) close() error { return err0 } +func (rw *streamingResponseWriter) Flush() { + if !rw.wroteHeader { + rw.WriteHeader(http.StatusOK) + } + rw.buf.Flush() +} + func (f *lambdaFunction) lambdaHandlerStreaming(ctx context.Context, req *request, w *io.PipeWriter) error { r, err := f.httpRequestV2(ctx, req) if err != nil { From b8f6fa580b3358976d93789e3e24d9d49f65ce63 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Thu, 11 May 2023 22:10:21 +0900 Subject: [PATCH 05/24] add an example for response streaming --- examples/function-urls/template.yaml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/function-urls/template.yaml b/examples/function-urls/template.yaml index deb2ddf..dff20f1 100644 --- a/examples/function-urls/template.yaml +++ b/examples/function-urls/template.yaml @@ -10,4 +10,8 @@ Resources: Runtime: provided.al2 Timeout: 3 FunctionUrlConfig: - AuthType: NONE + AuthType: NONE + InvokeMode: RESPONSE_STREAM + Environment: + Variables: + RIDGENATIVE_INVOKE_MODE: RESPONSE_STREAM From e17e5588da6f17cbf50094062ee8114a997e3191 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Thu, 11 May 2023 22:24:56 +0900 Subject: [PATCH 06/24] fix invoke mode name --- ridgenative.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ridgenative.go b/ridgenative.go index 5c410af..91cf487 100644 --- a/ridgenative.go +++ b/ridgenative.go @@ -539,10 +539,10 @@ const ( // Invocation results are available when the payload is complete. InvokeModeBuffered InvokeMode = "BUFFERED" - // InvokeModeResponseStreaming indicates that your function is invoked using + // InvokeModeResponseStream indicates that your function is invoked using // the InvokeWithResponseStream API operation. // It enables your function to stream payload results as they become available. - InvokeModeResponseStreaming InvokeMode = "RESPONSE_STREAMING" + InvokeModeResponseStream InvokeMode = "RESPONSE_STREAM" ) // Start starts the AWS Lambda function. @@ -560,7 +560,7 @@ func Start(mux http.Handler, mode InvokeMode) error { log.Println(err) return err } - case InvokeModeResponseStreaming: + case InvokeModeResponseStream: if err := c.startStreaming(f.lambdaHandlerStreaming); err != nil { log.Println(err) return err @@ -605,8 +605,8 @@ func ListenAndServe(address string, mux http.Handler) error { switch os.Getenv("RIDGENATIVE_INVOKE_MODE") { case "BUFFERED", "": mode = InvokeModeBuffered - case "RESPONSE_STREAMING": - mode = InvokeModeResponseStreaming + case "RESPONSE_STREAM": + mode = InvokeModeResponseStream default: return errors.New("ridgenative: invalid RIDGENATIVE_INVOKE_MODE") } From c2245f9f0cc626eb35429dda73fc2a817108d38e Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Thu, 11 May 2023 22:55:24 +0900 Subject: [PATCH 07/24] add test for runtimeAPIClient.next --- runtime_api_client.go | 2 +- runtime_api_client_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 runtime_api_client_test.go diff --git a/runtime_api_client.go b/runtime_api_client.go index cfb9dec..bac3ecd 100644 --- a/runtime_api_client.go +++ b/runtime_api_client.go @@ -196,7 +196,7 @@ func (c *runtimeAPIClient) startStreaming(h handlerFuncSteaming) error { if err != nil { return err } - if err := c.handleInvokeStreaming(invoke, nil); err != nil { + if err := c.handleInvokeStreaming(invoke, h); err != nil { return err } } diff --git a/runtime_api_client_test.go b/runtime_api_client_test.go new file mode 100644 index 0000000..99cde62 --- /dev/null +++ b/runtime_api_client_test.go @@ -0,0 +1,35 @@ +package ridgenative + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestRuntimeAPIClient_next(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/2018-06-01/runtime/invocation/next" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set(headerAWSRequestID, "request-id") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"key":"value"}`)) + })) + defer ts.Close() + + address := strings.TrimPrefix(ts.URL, "http://") + client := newRuntimeAPIClient(address) + + invoke, err := client.next() + if err != nil { + t.Fatal(err) + } + if invoke.id != "request-id" { + t.Errorf("want id is %s, got %s", "request-id", invoke.id) + } + if string(invoke.payload) != `{"key":"value"}` { + t.Errorf("want payload is %s, got %s", `{"key":"value"}`, string(invoke.payload)) + } +} From 79db1200050b92c24e48606c79712f926ed8f6a4 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Thu, 11 May 2023 23:04:33 +0900 Subject: [PATCH 08/24] pass context --- ridgenative.go | 4 ++-- runtime_api_client.go | 48 +++++++++++++++++++------------------- runtime_api_client_test.go | 3 ++- 3 files changed, 28 insertions(+), 27 deletions(-) diff --git a/ridgenative.go b/ridgenative.go index 91cf487..32b9a6d 100644 --- a/ridgenative.go +++ b/ridgenative.go @@ -556,12 +556,12 @@ func Start(mux http.Handler, mode InvokeMode) error { c := newRuntimeAPIClient(api) switch mode { case InvokeModeBuffered: - if err := c.start(f.lambdaHandler); err != nil { + if err := c.start(context.Background(), f.lambdaHandler); err != nil { log.Println(err) return err } case InvokeModeResponseStream: - if err := c.startStreaming(f.lambdaHandlerStreaming); err != nil { + if err := c.startStreaming(context.Background(), f.lambdaHandlerStreaming); err != nil { log.Println(err) return err } diff --git a/runtime_api_client.go b/runtime_api_client.go index bac3ecd..1d4f4f7 100644 --- a/runtime_api_client.go +++ b/runtime_api_client.go @@ -58,22 +58,22 @@ func newRuntimeAPIClient(address string) *runtimeAPIClient { // handlerFunc is the type of the function that handles an invoke. type handlerFunc func(ctx context.Context, req *request) (*response, error) -func (c *runtimeAPIClient) start(h handlerFunc) error { +func (c *runtimeAPIClient) start(ctx context.Context, h handlerFunc) error { for { - invoke, err := c.next() + invoke, err := c.next(ctx) if err != nil { return err } - if err := c.handleInvoke(invoke, h); err != nil { + if err := c.handleInvoke(ctx, invoke, h); err != nil { return err } } } // next connects to the Runtime API and waits for a new invoke Request to be available. -func (c *runtimeAPIClient) next() (*invoke, error) { +func (c *runtimeAPIClient) next(ctx context.Context) (*invoke, error) { url := c.baseURL + "next" - req, err := http.NewRequest(http.MethodGet, url, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("ridgenative: failed to construct GET request to %s: %w", url, err) } @@ -103,13 +103,13 @@ func (c *runtimeAPIClient) next() (*invoke, error) { } // handleInvoke handles an invoke. -func (c *runtimeAPIClient) handleInvoke(invoke *invoke, h handlerFunc) error { +func (c *runtimeAPIClient) handleInvoke(ctx context.Context, invoke *invoke, h handlerFunc) error { // set the deadline deadline, err := parseDeadline(invoke) if err != nil { - return c.reportFailure(invoke, lambdaErrorResponse(err)) + return c.reportFailure(ctx, invoke, lambdaErrorResponse(err)) } - ctx, cancel := context.WithDeadline(context.TODO(), deadline) + ctx, cancel := context.WithDeadline(ctx, deadline) defer cancel() // set the trace id @@ -123,7 +123,7 @@ func (c *runtimeAPIClient) handleInvoke(invoke *invoke, h handlerFunc) error { response, err := callBytesHandlerFunc(ctx, invoke.payload, h) if err != nil { invokeErr := lambdaErrorResponse(err) - if err := c.reportFailure(invoke, invokeErr); err != nil { + if err := c.reportFailure(ctx, invoke, invokeErr); err != nil { return err } if invokeErr.ShouldExit { @@ -132,7 +132,7 @@ func (c *runtimeAPIClient) handleInvoke(invoke *invoke, h handlerFunc) error { return nil } - if err := c.post(invoke.id+"/response", response, contentTypeJSON); err != nil { + if err := c.post(ctx, invoke.id+"/response", response, contentTypeJSON); err != nil { return fmt.Errorf("unexpected error occurred when sending the function functionResponse to the API: %w", err) } @@ -148,9 +148,9 @@ func parseDeadline(invoke *invoke) (time.Time, error) { } // post posts body to the Runtime API at the given path. -func (c *runtimeAPIClient) post(path string, body []byte, contentType string) error { +func (c *runtimeAPIClient) post(ctx context.Context, path string, body []byte, contentType string) error { url := c.baseURL + path - req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { return fmt.Errorf("ridgenative: failed to construct POST request to %s: %w", url, err) } @@ -176,13 +176,13 @@ func (c *runtimeAPIClient) post(path string, body []byte, contentType string) er } // reportFailure reports the error to the Runtime API. -func (c *runtimeAPIClient) reportFailure(invoke *invoke, invokeErr *invokeResponseError) error { +func (c *runtimeAPIClient) reportFailure(ctx context.Context, invoke *invoke, invokeErr *invokeResponseError) error { body, err := json.Marshal(invokeErr) if err != nil { return fmt.Errorf("ridgenative: failed to marshal the function error: %w", err) } log.Printf("%s", body) - if err := c.post(invoke.id+"/error", body, contentTypeJSON); err != nil { + if err := c.post(ctx, invoke.id+"/error", body, contentTypeJSON); err != nil { return fmt.Errorf("ridgenative: unexpected error occurred when sending the function error to the API: %w", err) } return nil @@ -190,26 +190,26 @@ func (c *runtimeAPIClient) reportFailure(invoke *invoke, invokeErr *invokeRespon type handlerFuncSteaming func(ctx context.Context, req *request, w *io.PipeWriter) error -func (c *runtimeAPIClient) startStreaming(h handlerFuncSteaming) error { +func (c *runtimeAPIClient) startStreaming(ctx context.Context, h handlerFuncSteaming) error { for { - invoke, err := c.next() + invoke, err := c.next(ctx) if err != nil { return err } - if err := c.handleInvokeStreaming(invoke, h); err != nil { + if err := c.handleInvokeStreaming(ctx, invoke, h); err != nil { return err } } } // handleInvoke handles an invoke. -func (c *runtimeAPIClient) handleInvokeStreaming(invoke *invoke, h handlerFuncSteaming) error { +func (c *runtimeAPIClient) handleInvokeStreaming(ctx context.Context, invoke *invoke, h handlerFuncSteaming) error { // set the deadline deadline, err := parseDeadline(invoke) if err != nil { - return c.reportFailure(invoke, lambdaErrorResponse(err)) + return c.reportFailure(ctx, invoke, lambdaErrorResponse(err)) } - ctx, cancel := context.WithDeadline(context.TODO(), deadline) + ctx, cancel := context.WithDeadline(ctx, deadline) defer cancel() // set the trace id @@ -223,7 +223,7 @@ func (c *runtimeAPIClient) handleInvokeStreaming(invoke *invoke, h handlerFuncSt response, err := callHandlerFuncSteaming(ctx, invoke.payload, h) if err != nil { invokeErr := lambdaErrorResponse(err) - if err := c.reportFailure(invoke, invokeErr); err != nil { + if err := c.reportFailure(ctx, invoke, invokeErr); err != nil { return err } if invokeErr.ShouldExit { @@ -232,7 +232,7 @@ func (c *runtimeAPIClient) handleInvokeStreaming(invoke *invoke, h handlerFuncSt return nil } - if err := c.postStreaming(invoke.id+"/response", response, contentTypeHTTPIntegrationResponse); err != nil { + if err := c.postStreaming(ctx, invoke.id+"/response", response, contentTypeHTTPIntegrationResponse); err != nil { return fmt.Errorf("unexpected error occurred when sending the function functionResponse to the API: %w", err) } @@ -240,10 +240,10 @@ func (c *runtimeAPIClient) handleInvokeStreaming(invoke *invoke, h handlerFuncSt } // postStreaming posts body to the Runtime API at the given path. -func (c *runtimeAPIClient) postStreaming(path string, body io.ReadCloser, contentType string) error { +func (c *runtimeAPIClient) postStreaming(ctx context.Context, path string, body io.ReadCloser, contentType string) error { b := newErrorCapturingReader(body) url := c.baseURL + path - req, err := http.NewRequest(http.MethodPost, url, b) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, b) if err != nil { return fmt.Errorf("ridgenative: failed to construct POST request to %s: %w", url, err) } diff --git a/runtime_api_client_test.go b/runtime_api_client_test.go index 99cde62..af0f3f1 100644 --- a/runtime_api_client_test.go +++ b/runtime_api_client_test.go @@ -1,6 +1,7 @@ package ridgenative import ( + "context" "net/http" "net/http/httptest" "strings" @@ -22,7 +23,7 @@ func TestRuntimeAPIClient_next(t *testing.T) { address := strings.TrimPrefix(ts.URL, "http://") client := newRuntimeAPIClient(address) - invoke, err := client.next() + invoke, err := client.next(context.Background()) if err != nil { t.Fatal(err) } From 062a6a9c9d5ab3fc1db872fc2d38e4bc4354313c Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Thu, 18 May 2023 05:39:20 +0900 Subject: [PATCH 09/24] configure Lambda-Runtime-Function-Response-Mode header --- runtime_api_client.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/runtime_api_client.go b/runtime_api_client.go index 1d4f4f7..bacaf3b 100644 --- a/runtime_api_client.go +++ b/runtime_api_client.go @@ -17,12 +17,13 @@ import ( ) const ( - headerAWSRequestID = "Lambda-Runtime-Aws-Request-Id" - headerDeadlineMS = "Lambda-Runtime-Deadline-Ms" - headerTraceID = "Lambda-Runtime-Trace-Id" - headerCognitoIdentity = "Lambda-Runtime-Cognito-Identity" - headerClientContext = "Lambda-Runtime-Client-Context" - headerInvokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" + headerAWSRequestID = "Lambda-Runtime-Aws-Request-Id" + headerDeadlineMS = "Lambda-Runtime-Deadline-Ms" + headerTraceID = "Lambda-Runtime-Trace-Id" + headerCognitoIdentity = "Lambda-Runtime-Cognito-Identity" + headerClientContext = "Lambda-Runtime-Client-Context" + headerInvokedFunctionARN = "Lambda-Runtime-Invoked-Function-Arn" + headerFunctionResponseMode = "Lambda-Runtime-Function-Response-Mode" trailerLambdaErrorType = "Lambda-Runtime-Function-Error-Type" trailerLambdaErrorBody = "Lambda-Runtime-Function-Error-Body" @@ -250,6 +251,7 @@ func (c *runtimeAPIClient) postStreaming(ctx context.Context, path string, body req.Trailer = b.trailer req.Header.Set("User-Agent", c.userAgent) req.Header.Set("Content-Type", contentType) + req.Header.Set(headerFunctionResponseMode, "streaming") resp, err := c.httpClient.Do(req) if err != nil { From f79b2f469d681d66eae47d7f61e35852507988ab Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Sat, 3 Jun 2023 23:03:05 +0900 Subject: [PATCH 10/24] fix: ineffectual assignment to mode (ineffassign) --- ridgenative.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ridgenative.go b/ridgenative.go index 32b9a6d..1a6888b 100644 --- a/ridgenative.go +++ b/ridgenative.go @@ -601,7 +601,7 @@ func ListenAndServe(address string, mux http.Handler) error { } // run on provided or provided.al2 runtime - mode := InvokeModeBuffered + var mode InvokeMode switch os.Getenv("RIDGENATIVE_INVOKE_MODE") { case "BUFFERED", "": mode = InvokeModeBuffered From 6a2bf9258d602a46e6ca0fda51627e0b38fc9dc6 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Sat, 3 Jun 2023 23:38:03 +0900 Subject: [PATCH 11/24] add testcase for context deadline exceede --- runtime_api_client.go | 12 +++---- runtime_api_client_test.go | 66 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/runtime_api_client.go b/runtime_api_client.go index bacaf3b..77c0282 100644 --- a/runtime_api_client.go +++ b/runtime_api_client.go @@ -110,7 +110,7 @@ func (c *runtimeAPIClient) handleInvoke(ctx context.Context, invoke *invoke, h h if err != nil { return c.reportFailure(ctx, invoke, lambdaErrorResponse(err)) } - ctx, cancel := context.WithDeadline(ctx, deadline) + child, cancel := context.WithDeadline(ctx, deadline) defer cancel() // set the trace id @@ -118,10 +118,10 @@ func (c *runtimeAPIClient) handleInvoke(ctx context.Context, invoke *invoke, h h os.Setenv("_X_AMZN_TRACE_ID", traceID) // to keep compatibility with AWS Lambda X-Ray SDK, we need to set "x-amzn-trace-id" to the context. // nolint:staticcheck - ctx = context.WithValue(ctx, "x-amzn-trace-id", traceID) + child = context.WithValue(child, "x-amzn-trace-id", traceID) // call the handler, marshal any returned error - response, err := callBytesHandlerFunc(ctx, invoke.payload, h) + response, err := callBytesHandlerFunc(child, invoke.payload, h) if err != nil { invokeErr := lambdaErrorResponse(err) if err := c.reportFailure(ctx, invoke, invokeErr); err != nil { @@ -210,7 +210,7 @@ func (c *runtimeAPIClient) handleInvokeStreaming(ctx context.Context, invoke *in if err != nil { return c.reportFailure(ctx, invoke, lambdaErrorResponse(err)) } - ctx, cancel := context.WithDeadline(ctx, deadline) + child, cancel := context.WithDeadline(ctx, deadline) defer cancel() // set the trace id @@ -218,10 +218,10 @@ func (c *runtimeAPIClient) handleInvokeStreaming(ctx context.Context, invoke *in os.Setenv("_X_AMZN_TRACE_ID", traceID) // to keep compatibility with AWS Lambda X-Ray SDK, we need to set "x-amzn-trace-id" to the context. // nolint:staticcheck - ctx = context.WithValue(ctx, "x-amzn-trace-id", traceID) + child = context.WithValue(child, "x-amzn-trace-id", traceID) // call the handler, marshal any returned error - response, err := callHandlerFuncSteaming(ctx, invoke.payload, h) + response, err := callHandlerFuncSteaming(child, invoke.payload, h) if err != nil { invokeErr := lambdaErrorResponse(err) if err := c.reportFailure(ctx, invoke, invokeErr); err != nil { diff --git a/runtime_api_client_test.go b/runtime_api_client_test.go index af0f3f1..deb59e3 100644 --- a/runtime_api_client_test.go +++ b/runtime_api_client_test.go @@ -2,10 +2,14 @@ package ridgenative import ( "context" + "errors" + "io" "net/http" "net/http/httptest" + "strconv" "strings" "testing" + "time" ) func TestRuntimeAPIClient_next(t *testing.T) { @@ -34,3 +38,65 @@ func TestRuntimeAPIClient_next(t *testing.T) { t.Errorf("want payload is %s, got %s", `{"key":"value"}`, string(invoke.payload)) } } + +func TestRuntimeAPIClient_handleInvoke(t *testing.T) { + t.Run("context deadline exceeded", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/2018-06-01/runtime/invocation/request-id/error" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + body, err := io.ReadAll(r.Body) + if err != nil { + t.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if string(body) != `{"errorMessage":"context deadline exceeded","errorType":"myError"}` { + t.Errorf("unexpected body: %s", string(body)) + } + w.WriteHeader(http.StatusAccepted) + })) + defer ts.Close() + + address := strings.TrimPrefix(ts.URL, "http://") + client := newRuntimeAPIClient(address) + + invoke := &invoke{ + id: "request-id", + headers: map[string][]string{ + "Lambda-Runtime-Deadline-Ms": { + // the deadline is 100ms + encodeDeadline(time.Now().Add(100 * time.Millisecond)), + }, + }, + payload: []byte(`{}`), + } + err := client.handleInvoke(context.Background(), invoke, func(ctx context.Context, req *request) (*response, error) { + select { + // the handle takes a long time, so the deadline is exceeded. + case <-time.After(time.Second): + t.Error("deadline is too long") + return nil, errors.New("timeout") + + case <-ctx.Done(): + return nil, &myError{"context deadline exceeded"} + } + }) + if err != nil { + t.Fatal(err) + } + }) +} + +type myError struct { + msg string +} + +func (e *myError) Error() string { + return e.msg +} + +func encodeDeadline(t time.Time) string { + ms := t.UnixMilli() + return strconv.FormatInt(ms, 10) +} From 02e499dfc46deb024fc77c6e9846137806aa682e Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Sun, 4 Jun 2023 00:03:44 +0900 Subject: [PATCH 12/24] add test case for succeeds --- runtime_api_client_test.go | 49 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/runtime_api_client_test.go b/runtime_api_client_test.go index deb59e3..b4fc3db 100644 --- a/runtime_api_client_test.go +++ b/runtime_api_client_test.go @@ -40,6 +40,55 @@ func TestRuntimeAPIClient_next(t *testing.T) { } func TestRuntimeAPIClient_handleInvoke(t *testing.T) { + t.Run("succeeds", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/2018-06-01/runtime/invocation/request-id/response" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + body, err := io.ReadAll(r.Body) + if err != nil { + t.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if string(body) != `{"statusCode":200,"body":"{\"key\":\"value\"}"}` { + t.Errorf("unexpected body: %s", string(body)) + } + w.WriteHeader(http.StatusAccepted) + })) + defer ts.Close() + + address := strings.TrimPrefix(ts.URL, "http://") + client := newRuntimeAPIClient(address) + + invoke := &invoke{ + id: "request-id", + headers: map[string][]string{ + "Lambda-Runtime-Deadline-Ms": { + // the deadline is 100ms + encodeDeadline(time.Now().Add(100 * time.Millisecond)), + }, + "Lambda-Runtime-Trace-Id": {"trace-id"}, + }, + payload: []byte(`{}`), + } + err := client.handleInvoke(context.Background(), invoke, func(ctx context.Context, req *request) (*response, error) { + // test trace id + traceID := ctx.Value("x-amzn-trace-id").(string) + if traceID != "trace-id" { + t.Errorf("want trace id is %s, got %s", "trace-id", traceID) + } + + return &response{ + StatusCode: 200, + Body: `{"key":"value"}`, + }, nil + }) + if err != nil { + t.Fatal(err) + } + }) + t.Run("context deadline exceeded", func(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/2018-06-01/runtime/invocation/request-id/error" { From 6035e558c776e98067d38d9adbd312bc1139eb80 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Sun, 4 Jun 2023 21:17:32 +0900 Subject: [PATCH 13/24] test the request --- runtime_api_client_test.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/runtime_api_client_test.go b/runtime_api_client_test.go index b4fc3db..0b42a6e 100644 --- a/runtime_api_client_test.go +++ b/runtime_api_client_test.go @@ -70,7 +70,7 @@ func TestRuntimeAPIClient_handleInvoke(t *testing.T) { }, "Lambda-Runtime-Trace-Id": {"trace-id"}, }, - payload: []byte(`{}`), + payload: []byte(`{"httpMethod":"GET","path":"/"}`), } err := client.handleInvoke(context.Background(), invoke, func(ctx context.Context, req *request) (*response, error) { // test trace id @@ -78,6 +78,12 @@ func TestRuntimeAPIClient_handleInvoke(t *testing.T) { if traceID != "trace-id" { t.Errorf("want trace id is %s, got %s", "trace-id", traceID) } + if req.HTTPMethod != "GET" { + t.Errorf("want method is %s, got %s", "GET", req.HTTPMethod) + } + if req.Path != "/" { + t.Errorf("want path is %s, got %s", "/", req.Path) + } return &response{ StatusCode: 200, From 085dd136a44394a950fb1f7e56b63cbdcc1541cd Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Sun, 4 Jun 2023 21:24:37 +0900 Subject: [PATCH 14/24] add test cases for errors --- runtime_api_client_test.go | 82 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/runtime_api_client_test.go b/runtime_api_client_test.go index 0b42a6e..4944d04 100644 --- a/runtime_api_client_test.go +++ b/runtime_api_client_test.go @@ -95,6 +95,88 @@ func TestRuntimeAPIClient_handleInvoke(t *testing.T) { } }) + t.Run("error", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/2018-06-01/runtime/invocation/request-id/error" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + body, err := io.ReadAll(r.Body) + if err != nil { + t.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if string(body) != `{"errorMessage":"some errors","errorType":"myError"}` { + t.Errorf("unexpected body: %s", string(body)) + } + w.WriteHeader(http.StatusAccepted) + })) + defer ts.Close() + + address := strings.TrimPrefix(ts.URL, "http://") + client := newRuntimeAPIClient(address) + + invoke := &invoke{ + id: "request-id", + headers: map[string][]string{ + "Lambda-Runtime-Deadline-Ms": { + // the deadline is 100ms + encodeDeadline(time.Now().Add(100 * time.Millisecond)), + }, + "Lambda-Runtime-Trace-Id": {"trace-id"}, + }, + payload: []byte(`{"httpMethod":"GET","path":"/"}`), + } + err := client.handleInvoke(context.Background(), invoke, func(ctx context.Context, req *request) (*response, error) { + return nil, &myError{"some errors"} + }) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("panic", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/2018-06-01/runtime/invocation/request-id/error" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + body, err := io.ReadAll(r.Body) + if err != nil { + t.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + // ignore stack traces because it has line numbers and it is not stable. + if !strings.HasPrefix(string(body), `{"errorMessage":"some errors","errorType":"string","stackTrace":`) { + t.Errorf("unexpected body: %s", string(body)) + } + w.WriteHeader(http.StatusAccepted) + })) + defer ts.Close() + + address := strings.TrimPrefix(ts.URL, "http://") + client := newRuntimeAPIClient(address) + + invoke := &invoke{ + id: "request-id", + headers: map[string][]string{ + "Lambda-Runtime-Deadline-Ms": { + // the deadline is 100ms + encodeDeadline(time.Now().Add(100 * time.Millisecond)), + }, + "Lambda-Runtime-Trace-Id": {"trace-id"}, + }, + payload: []byte(`{"httpMethod":"GET","path":"/"}`), + } + err := client.handleInvoke(context.Background(), invoke, func(ctx context.Context, req *request) (*response, error) { + panic("some errors") + }) + if err == nil { + t.Error("want error, but got nil") + } + }) + t.Run("context deadline exceeded", func(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/2018-06-01/runtime/invocation/request-id/error" { From bc103726ad914cdad4755ba058b9e79e66c505b0 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Sun, 4 Jun 2023 23:30:55 +0900 Subject: [PATCH 15/24] test Content-Type header --- runtime_api_client_test.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/runtime_api_client_test.go b/runtime_api_client_test.go index 4944d04..413a288 100644 --- a/runtime_api_client_test.go +++ b/runtime_api_client_test.go @@ -45,6 +45,9 @@ func TestRuntimeAPIClient_handleInvoke(t *testing.T) { if r.URL.Path != "/2018-06-01/runtime/invocation/request-id/response" { t.Errorf("unexpected path: %s", r.URL.Path) } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("unexpected content type: %s", r.Header.Get("Content-Type")) + } body, err := io.ReadAll(r.Body) if err != nil { t.Error(err) @@ -100,6 +103,9 @@ func TestRuntimeAPIClient_handleInvoke(t *testing.T) { if r.URL.Path != "/2018-06-01/runtime/invocation/request-id/error" { t.Errorf("unexpected path: %s", r.URL.Path) } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("unexpected content type: %s", r.Header.Get("Content-Type")) + } body, err := io.ReadAll(r.Body) if err != nil { t.Error(err) @@ -140,6 +146,10 @@ func TestRuntimeAPIClient_handleInvoke(t *testing.T) { if r.URL.Path != "/2018-06-01/runtime/invocation/request-id/error" { t.Errorf("unexpected path: %s", r.URL.Path) } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("unexpected content type: %s", r.Header.Get("Content-Type")) + } + body, err := io.ReadAll(r.Body) if err != nil { t.Error(err) @@ -182,6 +192,10 @@ func TestRuntimeAPIClient_handleInvoke(t *testing.T) { if r.URL.Path != "/2018-06-01/runtime/invocation/request-id/error" { t.Errorf("unexpected path: %s", r.URL.Path) } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("unexpected content type: %s", r.Header.Get("Content-Type")) + } + body, err := io.ReadAll(r.Body) if err != nil { t.Error(err) @@ -225,6 +239,10 @@ func TestRuntimeAPIClient_handleInvoke(t *testing.T) { }) } +func TestRuntimeAPIClient_handleInvokeStreaming(t *testing.T) { + +} + type myError struct { msg string } From 08cfc73ae8b3faafcef78b12a07e1be3196ba448 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Mon, 5 Jun 2023 00:21:40 +0900 Subject: [PATCH 16/24] add test for handleInvokeStreaming --- runtime_api_client_test.go | 65 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/runtime_api_client_test.go b/runtime_api_client_test.go index 413a288..fb6f83e 100644 --- a/runtime_api_client_test.go +++ b/runtime_api_client_test.go @@ -240,7 +240,72 @@ func TestRuntimeAPIClient_handleInvoke(t *testing.T) { } func TestRuntimeAPIClient_handleInvokeStreaming(t *testing.T) { + t.Run("succeeds", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/2018-06-01/runtime/invocation/request-id/response" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.Header.Get("Content-Type") != "application/vnd.awslambda.http-integration-response" { + t.Errorf("unexpected content type: %s", r.Header.Get("Content-Type")) + } + if r.Header.Get("Lambda-Runtime-Function-Response-Mode") != "streaming" { + t.Errorf("unexpected response mode: %s", r.Header.Get("Lambda-Runtime-Function-Response-Mode")) + } + + body, err := io.ReadAll(r.Body) + if err != nil { + t.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if string(body) != `{"statusCode":200,"body":"{\"key\":\"value\"}"}` { + t.Errorf("unexpected body: %s", string(body)) + } + w.WriteHeader(http.StatusAccepted) + })) + defer ts.Close() + + address := strings.TrimPrefix(ts.URL, "http://") + client := newRuntimeAPIClient(address) + + invoke := &invoke{ + id: "request-id", + headers: map[string][]string{ + "Lambda-Runtime-Deadline-Ms": { + // the deadline is 100ms + encodeDeadline(time.Now().Add(100 * time.Millisecond)), + }, + "Lambda-Runtime-Trace-Id": {"trace-id"}, + }, + payload: []byte(`{"httpMethod":"GET","path":"/"}`), + } + err := client.handleInvokeStreaming(context.Background(), invoke, func(ctx context.Context, req *request, w *io.PipeWriter) error { + traceID := ctx.Value("x-amzn-trace-id").(string) + if traceID != "trace-id" { + t.Errorf("want trace id is %s, got %s", "trace-id", traceID) + } + if req.HTTPMethod != "GET" { + t.Errorf("want method is %s, got %s", "GET", req.HTTPMethod) + } + if req.Path != "/" { + t.Errorf("want path is %s, got %s", "/", req.Path) + } + go func() { + if _, err := io.WriteString(w, `{"statusCode":200,"body":"{\"key\":\"value\"}"}`); err != nil { + t.Error(err) + } + if err := w.Close(); err != nil { + t.Error(err) + } + }() + + return nil + }) + if err != nil { + t.Fatal(err) + } + }) } type myError struct { From 478aeeb26f1dc775b757f13990743eed97c72315 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Mon, 5 Jun 2023 20:49:28 +0900 Subject: [PATCH 17/24] pass content-type --- invoke.go | 11 ++++++----- ridgenative.go | 6 +++--- runtime_api_client.go | 6 +++--- runtime_api_client_test.go | 4 ++-- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/invoke.go b/invoke.go index dda98ef..80a6791 100644 --- a/invoke.go +++ b/invoke.go @@ -31,7 +31,7 @@ func callBytesHandlerFunc(ctx context.Context, payload []byte, h handlerFunc) (r return json.Marshal(resp) } -func callHandlerFuncSteaming(ctx context.Context, payload []byte, h handlerFuncSteaming) (response io.ReadCloser, err error) { +func callHandlerFuncSteaming(ctx context.Context, payload []byte, h handlerFuncSteaming) (response io.ReadCloser, contentType string, err error) { defer func() { if v := recover(); v != nil { err = lambdaPanicResponse(v) @@ -40,12 +40,13 @@ func callHandlerFuncSteaming(ctx context.Context, payload []byte, h handlerFuncS var req *request if err := json.Unmarshal(payload, &req); err != nil { - return nil, err + return nil, "", err } r, w := io.Pipe() - if err := h(ctx, req, w); err != nil { - return nil, err + contentType, err = h(ctx, req, w) + if err != nil { + return nil, "", err } - return r, nil + return r, contentType, nil } diff --git a/ridgenative.go b/ridgenative.go index 1a6888b..ce86263 100644 --- a/ridgenative.go +++ b/ridgenative.go @@ -506,10 +506,10 @@ func (rw *streamingResponseWriter) Flush() { rw.buf.Flush() } -func (f *lambdaFunction) lambdaHandlerStreaming(ctx context.Context, req *request, w *io.PipeWriter) error { +func (f *lambdaFunction) lambdaHandlerStreaming(ctx context.Context, req *request, w *io.PipeWriter) (string, error) { r, err := f.httpRequestV2(ctx, req) if err != nil { - return err + return "", err } go func() { rw := newStreamingResponseWriter(w) @@ -522,7 +522,7 @@ func (f *lambdaFunction) lambdaHandlerStreaming(ctx context.Context, req *reques }() f.mux.ServeHTTP(rw, r) }() - return nil + return contentTypeHTTPIntegrationResponse, nil } func newLambdaFunction(mux http.Handler) *lambdaFunction { diff --git a/runtime_api_client.go b/runtime_api_client.go index 77c0282..dbc941a 100644 --- a/runtime_api_client.go +++ b/runtime_api_client.go @@ -189,7 +189,7 @@ func (c *runtimeAPIClient) reportFailure(ctx context.Context, invoke *invoke, in return nil } -type handlerFuncSteaming func(ctx context.Context, req *request, w *io.PipeWriter) error +type handlerFuncSteaming func(ctx context.Context, req *request, w *io.PipeWriter) (contentType string, err error) func (c *runtimeAPIClient) startStreaming(ctx context.Context, h handlerFuncSteaming) error { for { @@ -221,7 +221,7 @@ func (c *runtimeAPIClient) handleInvokeStreaming(ctx context.Context, invoke *in child = context.WithValue(child, "x-amzn-trace-id", traceID) // call the handler, marshal any returned error - response, err := callHandlerFuncSteaming(child, invoke.payload, h) + response, contentType, err := callHandlerFuncSteaming(child, invoke.payload, h) if err != nil { invokeErr := lambdaErrorResponse(err) if err := c.reportFailure(ctx, invoke, invokeErr); err != nil { @@ -233,7 +233,7 @@ func (c *runtimeAPIClient) handleInvokeStreaming(ctx context.Context, invoke *in return nil } - if err := c.postStreaming(ctx, invoke.id+"/response", response, contentTypeHTTPIntegrationResponse); err != nil { + if err := c.postStreaming(ctx, invoke.id+"/response", response, contentType); err != nil { return fmt.Errorf("unexpected error occurred when sending the function functionResponse to the API: %w", err) } diff --git a/runtime_api_client_test.go b/runtime_api_client_test.go index fb6f83e..d6f0d2d 100644 --- a/runtime_api_client_test.go +++ b/runtime_api_client_test.go @@ -279,7 +279,7 @@ func TestRuntimeAPIClient_handleInvokeStreaming(t *testing.T) { }, payload: []byte(`{"httpMethod":"GET","path":"/"}`), } - err := client.handleInvokeStreaming(context.Background(), invoke, func(ctx context.Context, req *request, w *io.PipeWriter) error { + err := client.handleInvokeStreaming(context.Background(), invoke, func(ctx context.Context, req *request, w *io.PipeWriter) (string, error) { traceID := ctx.Value("x-amzn-trace-id").(string) if traceID != "trace-id" { t.Errorf("want trace id is %s, got %s", "trace-id", traceID) @@ -300,7 +300,7 @@ func TestRuntimeAPIClient_handleInvokeStreaming(t *testing.T) { } }() - return nil + return "application/vnd.awslambda.http-integration-response", nil }) if err != nil { t.Fatal(err) From 01853e990a56ee2ce0e7e4759bc512bf03d4b1d1 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Mon, 5 Jun 2023 20:50:49 +0900 Subject: [PATCH 18/24] check errors --- runtime_api_client_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/runtime_api_client_test.go b/runtime_api_client_test.go index d6f0d2d..99f1b93 100644 --- a/runtime_api_client_test.go +++ b/runtime_api_client_test.go @@ -20,7 +20,9 @@ func TestRuntimeAPIClient_next(t *testing.T) { w.Header().Set("Content-Type", "application/json") w.Header().Set(headerAWSRequestID, "request-id") w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"key":"value"}`)) + if _, err := w.Write([]byte(`{"key":"value"}`)); err != nil { + t.Error(err) + } })) defer ts.Close() From 019f52e0882f8d45b5beb1aca6acb42fbadf338b Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Tue, 6 Jun 2023 21:32:59 +0900 Subject: [PATCH 19/24] test error cases of handleInvokeStreaming --- runtime_api_client_test.go | 141 +++++++++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) diff --git a/runtime_api_client_test.go b/runtime_api_client_test.go index 99f1b93..1144e57 100644 --- a/runtime_api_client_test.go +++ b/runtime_api_client_test.go @@ -308,6 +308,147 @@ func TestRuntimeAPIClient_handleInvokeStreaming(t *testing.T) { t.Fatal(err) } }) + + t.Run("error", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/2018-06-01/runtime/invocation/request-id/error" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("unexpected content type: %s", r.Header.Get("Content-Type")) + } + body, err := io.ReadAll(r.Body) + if err != nil { + t.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if string(body) != `{"errorMessage":"some errors","errorType":"myError"}` { + t.Errorf("unexpected body: %s", string(body)) + } + w.WriteHeader(http.StatusAccepted) + })) + defer ts.Close() + + address := strings.TrimPrefix(ts.URL, "http://") + client := newRuntimeAPIClient(address) + + invoke := &invoke{ + id: "request-id", + headers: map[string][]string{ + "Lambda-Runtime-Deadline-Ms": { + // the deadline is 100ms + encodeDeadline(time.Now().Add(100 * time.Millisecond)), + }, + "Lambda-Runtime-Trace-Id": {"trace-id"}, + }, + payload: []byte(`{"httpMethod":"GET","path":"/"}`), + } + err := client.handleInvokeStreaming(context.Background(), invoke, func(ctx context.Context, req *request, w *io.PipeWriter) (string, error) { + return "", &myError{"some errors"} + }) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("panic", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/2018-06-01/runtime/invocation/request-id/error" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("unexpected content type: %s", r.Header.Get("Content-Type")) + } + + body, err := io.ReadAll(r.Body) + if err != nil { + t.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + // ignore stack traces because it has line numbers and it is not stable. + if !strings.HasPrefix(string(body), `{"errorMessage":"some errors","errorType":"string","stackTrace":`) { + t.Errorf("unexpected body: %s", string(body)) + } + w.WriteHeader(http.StatusAccepted) + })) + defer ts.Close() + + address := strings.TrimPrefix(ts.URL, "http://") + client := newRuntimeAPIClient(address) + + invoke := &invoke{ + id: "request-id", + headers: map[string][]string{ + "Lambda-Runtime-Deadline-Ms": { + // the deadline is 100ms + encodeDeadline(time.Now().Add(100 * time.Millisecond)), + }, + "Lambda-Runtime-Trace-Id": {"trace-id"}, + }, + payload: []byte(`{"httpMethod":"GET","path":"/"}`), + } + err := client.handleInvokeStreaming(context.Background(), invoke, func(ctx context.Context, req *request, w *io.PipeWriter) (string, error) { + panic("some errors") + }) + if err == nil { + t.Error("want error, but got nil") + } + }) + + t.Run("context deadline exceeded", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/2018-06-01/runtime/invocation/request-id/error" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("unexpected content type: %s", r.Header.Get("Content-Type")) + } + body, err := io.ReadAll(r.Body) + if err != nil { + t.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if string(body) != `{"errorMessage":"some errors","errorType":"myError"}` { + t.Errorf("unexpected body: %s", string(body)) + } + w.WriteHeader(http.StatusAccepted) + })) + defer ts.Close() + + address := strings.TrimPrefix(ts.URL, "http://") + client := newRuntimeAPIClient(address) + + invoke := &invoke{ + id: "request-id", + headers: map[string][]string{ + "Lambda-Runtime-Deadline-Ms": { + // the deadline is 100ms + encodeDeadline(time.Now().Add(100 * time.Millisecond)), + }, + "Lambda-Runtime-Trace-Id": {"trace-id"}, + }, + payload: []byte(`{"httpMethod":"GET","path":"/"}`), + } + err := client.handleInvokeStreaming(context.Background(), invoke, func(ctx context.Context, req *request, w *io.PipeWriter) (string, error) { + select { + // the handle takes a long time, so the deadline is exceeded. + case <-time.After(time.Second): + t.Error("deadline is too long") + return "", errors.New("timeout") + + case <-ctx.Done(): + return "", &myError{"context deadline exceeded"} + } + }) + if err != nil { + t.Fatal(err) + } + }) + } type myError struct { From ed71e1559e8f3ebdab124a3f677152c16e8e9a24 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Sun, 18 Jun 2023 23:09:58 +0900 Subject: [PATCH 20/24] add test: error during streaming --- runtime_api_client.go | 8 +++-- runtime_api_client_test.go | 69 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 73 insertions(+), 4 deletions(-) diff --git a/runtime_api_client.go b/runtime_api_client.go index dbc941a..0b042a3 100644 --- a/runtime_api_client.go +++ b/runtime_api_client.go @@ -294,9 +294,8 @@ func (r *errorCapturingReader) Read(p []byte) (int, error) { } n, err := r.reader.Read(p) - if err != nil && errors.Is(err, io.EOF) { + if err != nil && !errors.Is(err, io.EOF) { // capture the error - r.err = err lambdaErr := lambdaErrorResponse(err) body, err := json.Marshal(lambdaErr) if err != nil { @@ -306,10 +305,15 @@ func (r *errorCapturingReader) Read(p []byte) (int, error) { } r.trailer.Set(trailerLambdaErrorType, lambdaErr.Type) r.trailer.Set(trailerLambdaErrorBody, base64.StdEncoding.EncodeToString(body)) + r.err = io.EOF + return n, io.EOF } return n, err } func (r *errorCapturingReader) Close() error { + if r.reader == nil { + return nil + } return r.reader.Close() } diff --git a/runtime_api_client_test.go b/runtime_api_client_test.go index 1144e57..9c8e83b 100644 --- a/runtime_api_client_test.go +++ b/runtime_api_client_test.go @@ -2,6 +2,7 @@ package ridgenative import ( "context" + "encoding/base64" "errors" "io" "net/http" @@ -263,6 +264,12 @@ func TestRuntimeAPIClient_handleInvokeStreaming(t *testing.T) { if string(body) != `{"statusCode":200,"body":"{\"key\":\"value\"}"}` { t.Errorf("unexpected body: %s", string(body)) } + if len(r.Trailer.Values("Lambda-Runtime-Function-Error-Type")) != 0 { + t.Errorf("unexpected error type: %s", r.Trailer.Values("Lambda-Runtime-Function-Error-Type")) + } + if len(r.Trailer.Values("Lambda-Runtime-Function-Error-Body")) != 0 { + t.Errorf("unexpected error body: %s", r.Trailer.Values("Lambda-Runtime-Function-Error-Body")) + } w.WriteHeader(http.StatusAccepted) })) defer ts.Close() @@ -309,7 +316,7 @@ func TestRuntimeAPIClient_handleInvokeStreaming(t *testing.T) { } }) - t.Run("error", func(t *testing.T) { + t.Run("error before start streaming", func(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/2018-06-01/runtime/invocation/request-id/error" { t.Errorf("unexpected path: %s", r.URL.Path) @@ -352,6 +359,64 @@ func TestRuntimeAPIClient_handleInvokeStreaming(t *testing.T) { } }) + t.Run("error during streaming", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/2018-06-01/runtime/invocation/request-id/response" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.Header.Get("Content-Type") != "application/vnd.awslambda.http-integration-response" { + t.Errorf("unexpected content type: %s", r.Header.Get("Content-Type")) + } + body, err := io.ReadAll(r.Body) + if err != nil { + t.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if string(body) != "" { + t.Errorf("unexpected body: %s", string(body)) + } + + if r.Trailer.Get("Lambda-Runtime-Function-Error-Type") != "myError" { + t.Errorf("unexpected error type: %s", r.Trailer.Get("Lambda-Runtime-Function-Error-Type")) + } + wantErr := base64.StdEncoding.EncodeToString([]byte(`{"errorMessage":"some errors","errorType":"myError"}`)) + if r.Trailer.Get("Lambda-Runtime-Function-Error-Body") != wantErr { + t.Errorf("unexpected error: %s", r.Trailer.Get("Lambda-Runtime-Function-Error-Body")) + } + + w.WriteHeader(http.StatusAccepted) + })) + defer ts.Close() + + address := strings.TrimPrefix(ts.URL, "http://") + client := newRuntimeAPIClient(address) + + invoke := &invoke{ + id: "request-id", + headers: map[string][]string{ + "Lambda-Runtime-Deadline-Ms": { + // the deadline is 100ms + encodeDeadline(time.Now().Add(100 * time.Millisecond)), + }, + "Lambda-Runtime-Trace-Id": {"trace-id"}, + }, + payload: []byte(`{"httpMethod":"GET","path":"/"}`), + } + err := client.handleInvokeStreaming(context.Background(), invoke, func(ctx context.Context, req *request, w *io.PipeWriter) (string, error) { + go func() { + if err := w.CloseWithError(&myError{"some errors"}); err != nil { + t.Error(err) + } + }() + + return "application/vnd.awslambda.http-integration-response", nil + }) + if err != nil { + t.Fatal(err) + } + }) + t.Run("panic", func(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/2018-06-01/runtime/invocation/request-id/error" { @@ -412,7 +477,7 @@ func TestRuntimeAPIClient_handleInvokeStreaming(t *testing.T) { w.WriteHeader(http.StatusInternalServerError) return } - if string(body) != `{"errorMessage":"some errors","errorType":"myError"}` { + if string(body) != `{"errorMessage":"context deadline exceeded","errorType":"myError"}` { t.Errorf("unexpected body: %s", string(body)) } w.WriteHeader(http.StatusAccepted) From 2ba7395f255f6ea4ea42d5708f9553ed2e837933 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Sun, 18 Jun 2023 23:43:19 +0900 Subject: [PATCH 21/24] add test for normal streaming response --- ridgenative_test.go | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/ridgenative_test.go b/ridgenative_test.go index ad4d53e..35aee96 100644 --- a/ridgenative_test.go +++ b/ridgenative_test.go @@ -757,3 +757,33 @@ func BenchmarkResponse_text(b *testing.B) { rw.lambdaResponseV1() } } + +func TestLambdaHandlerStreaming(t *testing.T) { + t.Run("normal", func(t *testing.T) { + l := newLambdaFunction(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, `{"hello":"world"}`) + })) + r, w := io.Pipe() + contentType, err := l.lambdaHandlerStreaming(context.Background(), &request{ + RequestContext: requestContext{ + HTTP: &requestContextHTTP{ + Path: "/", + }, + }, + }, w) + if err != nil { + t.Fatal(err) + } + if got, want := contentType, "application/vnd.awslambda.http-integration-response"; got != want { + t.Errorf("unexpected content type: want %q, got %q", want, got) + } + data, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + if got, want := string(data), "{\"statusCode\":200}\x00\x00\x00\x00\x00\x00\x00\x00{\"hello\":\"world\"}"; got != want { + t.Errorf("unexpected body: want %q, got %q", want, got) + } + }) +} From 4bbeab2d20a3803ba7ee5ba70feba909d2b068d7 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Mon, 19 Jun 2023 00:03:19 +0900 Subject: [PATCH 22/24] add test for WriteHeader --- ridgenative.go | 2 +- ridgenative_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/ridgenative.go b/ridgenative.go index ce86263..abfda84 100644 --- a/ridgenative.go +++ b/ridgenative.go @@ -474,7 +474,7 @@ func (rw *streamingResponseWriter) Write(data []byte) (int, error) { // TODO: detect content type if it is not set. rw.WriteHeader(http.StatusOK) } - return rw.w.Write(data) + return rw.buf.Write(data) } func (rw *streamingResponseWriter) closeWithError(err error) error { diff --git a/ridgenative_test.go b/ridgenative_test.go index 35aee96..f9adeab 100644 --- a/ridgenative_test.go +++ b/ridgenative_test.go @@ -786,4 +786,59 @@ func TestLambdaHandlerStreaming(t *testing.T) { t.Errorf("unexpected body: want %q, got %q", want, got) } }) + + t.Run("WriteHeader", func(t *testing.T) { + l := newLambdaFunction(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + // Writes to ResponseWriter are buffered, + // so multiple writes to ResponseWriter become a single write to the pipe + io.WriteString(w, `{"hello":`) + io.WriteString(w, `"world"}`) + })) + r, w := io.Pipe() + contentType, err := l.lambdaHandlerStreaming(context.Background(), &request{ + RequestContext: requestContext{ + HTTP: &requestContextHTTP{ + Path: "/", + }, + }, + }, w) + if err != nil { + t.Fatal(err) + } + if got, want := contentType, "application/vnd.awslambda.http-integration-response"; got != want { + t.Errorf("unexpected content type: want %q, got %q", want, got) + } + + // Reads and Writes on the pipe are matched one to one, + // so we get only the header on first read. + buf := make([]byte, 1024) + n, err := r.Read(buf) + if err != nil { + t.Fatal(err) + } + if got, want := string(buf[:n]), "{\"statusCode\":200}\x00\x00\x00\x00\x00\x00\x00\x00"; got != want { + t.Errorf("unexpected body: want %q, got %q", want, got) + } + + // The second read gets the body. + n, err = r.Read(buf) + if err != nil { + t.Fatal(err) + } + if got, want := string(buf[:n]), "{\"hello\":\"world\"}"; got != want { + t.Errorf("unexpected body: want %q, got %q", want, got) + } + + // The third read gets EOF. + n, err = r.Read(buf) + if err != io.EOF { + t.Errorf("unexpected error: want %v, got %v", io.EOF, err) + } + if n != 0 { + t.Errorf("unexpected read size: want %d, got %d", 0, n) + } + }) } From 423bf884a0fbd899bd63ff2452a042db379340c6 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Mon, 19 Jun 2023 21:11:01 +0900 Subject: [PATCH 23/24] detecting Content-Type --- ridgenative.go | 105 ++++++++++++++++++++++++++++++++++---------- ridgenative_test.go | 99 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 179 insertions(+), 25 deletions(-) diff --git a/ridgenative.go b/ridgenative.go index abfda84..5ba49f5 100644 --- a/ridgenative.go +++ b/ridgenative.go @@ -434,13 +434,19 @@ type streamingResponseWriter struct { wroteHeader bool header http.Header statusCode int + err error + + // prelude is the first part of the body. + // it is used for detecting content-type. + prelude []byte } func newStreamingResponseWriter(w *io.PipeWriter) *streamingResponseWriter { return &streamingResponseWriter{ - w: w, - buf: bufio.NewWriter(w), - header: make(http.Header, 1), + w: w, + buf: bufio.NewWriter(w), + header: make(http.Header, 1), + prelude: make([]byte, 0, 512), } } @@ -454,49 +460,102 @@ func (rw *streamingResponseWriter) WriteHeader(code int) { log.Printf("ridgenative: superfluous response.WriteHeader call from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line) return } + if rw.err != nil { + return + } + + if !rw.hasContentType() { + rw.header.Set("Content-Type", http.DetectContentType(rw.prelude)) + } + + rw.wroteHeader = true rw.statusCode = code + + // build the prelude + h := make(map[string]string, len(rw.header)) + for key, value := range rw.header { + if key == "Set-Cookie" { + continue + } + h[key] = strings.Join(value, ", ") + } + cookies := rw.header.Values("Set-Cookie") r := &streamingResponse{ StatusCode: code, + Headers: h, + Cookies: cookies, } + data, err := json.Marshal(r) if err != nil { - log.Printf("ridgenative: %v", err) + rw.err = fmt.Errorf("ridgenative: failed to marshal response: %w", err) return } - rw.buf.Write(data) - rw.buf.WriteString("\x00\x00\x00\x00\x00\x00\x00\x00") - rw.buf.Flush() - rw.wroteHeader = true + if _, err := rw.buf.Write(data); err != nil { + rw.err = err + return + } + if _, err := rw.buf.WriteString("\x00\x00\x00\x00\x00\x00\x00\x00"); err != nil { + rw.err = err + return + } + if len(rw.prelude) != 0 { + if _, err := rw.buf.Write(rw.prelude); err != nil { + rw.err = err + return + } + } + if err := rw.buf.Flush(); err != nil { + rw.err = err + } +} + +func (rw *streamingResponseWriter) hasContentType() bool { + return rw.header.Get("Content-Type") != "" } func (rw *streamingResponseWriter) Write(data []byte) (int, error) { + var m int if !rw.wroteHeader { - // TODO: detect content type if it is not set. - rw.WriteHeader(http.StatusOK) + if rw.hasContentType() { + rw.WriteHeader(http.StatusOK) + } else { + // save the first part of the body for detecting content-type. + data0 := data + if len(rw.prelude)+len(data0) > cap(rw.prelude) { + data0 = data0[:cap(rw.prelude)-len(rw.prelude)] + } + rw.prelude = append(rw.prelude, data0...) + + if len(rw.prelude) == cap(rw.prelude) { + rw.WriteHeader(http.StatusOK) + } + m = len(data0) + data = data[m:] + if len(data) == 0 { + return m, nil + } + } } - return rw.buf.Write(data) + n, err := rw.buf.Write(data) + return n + m, err } func (rw *streamingResponseWriter) closeWithError(err error) error { if !rw.wroteHeader { rw.WriteHeader(http.StatusOK) } - err0 := rw.buf.Flush() - if err1 := rw.w.CloseWithError(err); err0 == nil { - err0 = err1 + if rw.err != nil { + err = rw.err } - return err0 + if err0 := rw.buf.Flush(); err0 != nil { + err = err0 + } + return rw.w.CloseWithError(err) } func (rw *streamingResponseWriter) close() error { - if !rw.wroteHeader { - rw.WriteHeader(http.StatusOK) - } - err0 := rw.buf.Flush() - if err1 := rw.w.Close(); err0 == nil { - err0 = err1 - } - return err0 + return rw.closeWithError(nil) } func (rw *streamingResponseWriter) Flush() { diff --git a/ridgenative_test.go b/ridgenative_test.go index f9adeab..6faf92c 100644 --- a/ridgenative_test.go +++ b/ridgenative_test.go @@ -782,7 +782,7 @@ func TestLambdaHandlerStreaming(t *testing.T) { if err != nil { t.Fatal(err) } - if got, want := string(data), "{\"statusCode\":200}\x00\x00\x00\x00\x00\x00\x00\x00{\"hello\":\"world\"}"; got != want { + if got, want := string(data), "{\"statusCode\":200,\"headers\":{\"Content-Type\":\"application/json\"}}\x00\x00\x00\x00\x00\x00\x00\x00{\"hello\":\"world\"}"; got != want { t.Errorf("unexpected body: want %q, got %q", want, got) } }) @@ -819,7 +819,7 @@ func TestLambdaHandlerStreaming(t *testing.T) { if err != nil { t.Fatal(err) } - if got, want := string(buf[:n]), "{\"statusCode\":200}\x00\x00\x00\x00\x00\x00\x00\x00"; got != want { + if got, want := string(buf[:n]), "{\"statusCode\":200,\"headers\":{\"Content-Type\":\"application/json\"}}\x00\x00\x00\x00\x00\x00\x00\x00"; got != want { t.Errorf("unexpected body: want %q, got %q", want, got) } @@ -841,4 +841,99 @@ func TestLambdaHandlerStreaming(t *testing.T) { t.Errorf("unexpected read size: want %d, got %d", 0, n) } }) + + t.Run("flush", func(t *testing.T) { + l := newLambdaFunction(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + f, ok := w.(http.Flusher) + if !ok { + t.Error("http.ResponseWriter doesn't implement http.Flusher") + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + io.WriteString(w, `{"hello":`) + f.Flush() + io.WriteString(w, `"world"}`) + })) + r, w := io.Pipe() + contentType, err := l.lambdaHandlerStreaming(context.Background(), &request{ + RequestContext: requestContext{ + HTTP: &requestContextHTTP{ + Path: "/", + }, + }, + }, w) + if err != nil { + t.Fatal(err) + } + if got, want := contentType, "application/vnd.awslambda.http-integration-response"; got != want { + t.Errorf("unexpected content type: want %q, got %q", want, got) + } + + // Reads and Writes on the pipe are matched one to one, + // so we get only the header on first read. + buf := make([]byte, 1024) + n, err := r.Read(buf) + if err != nil { + t.Fatal(err) + } + if got, want := string(buf[:n]), "{\"statusCode\":200,\"headers\":{\"Content-Type\":\"application/json\"}}\x00\x00\x00\x00\x00\x00\x00\x00"; got != want { + t.Errorf("unexpected body: want %q, got %q", want, got) + } + + // The second read gets the half of the body. + n, err = r.Read(buf) + if err != nil { + t.Fatal(err) + } + if got, want := string(buf[:n]), "{\"hello\":"; got != want { + t.Errorf("unexpected body: want %q, got %q", want, got) + } + + // The third read gets the rest of the body. + n, err = r.Read(buf) + if err != nil { + t.Fatal(err) + } + if got, want := string(buf[:n]), "\"world\"}"; got != want { + t.Errorf("unexpected body: want %q, got %q", want, got) + } + + // The forth read gets EOF. + n, err = r.Read(buf) + if err != io.EOF { + t.Errorf("unexpected error: want %v, got %v", io.EOF, err) + } + if n != 0 { + t.Errorf("unexpected read size: want %d, got %d", 0, n) + } + }) + + t.Run("detect content-type", func(t *testing.T) { + l := newLambdaFunction(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, ``) + })) + r, w := io.Pipe() + contentType, err := l.lambdaHandlerStreaming(context.Background(), &request{ + RequestContext: requestContext{ + HTTP: &requestContextHTTP{ + Path: "/", + }, + }, + }, w) + if err != nil { + t.Fatal(err) + } + if got, want := contentType, "application/vnd.awslambda.http-integration-response"; got != want { + t.Errorf("unexpected content type: want %q, got %q", want, got) + } + + data, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + if got, want := string(data), "{\"statusCode\":200,\"headers\":{\"Content-Type\":\"text/html; charset=utf-8\"}}\x00\x00\x00\x00\x00\x00\x00\x00"; got != want { + t.Errorf("unexpected body: want %q, got %q", want, got) + } + }) } From 3bc187e68ee01dbb4f1ecfb45dadf3c86179d88d Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Mon, 19 Jun 2023 22:55:12 +0900 Subject: [PATCH 24/24] check the errors --- ridgenative.go | 4 ++-- ridgenative_test.go | 24 ++++++++++++++++++------ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/ridgenative.go b/ridgenative.go index 5ba49f5..cf52a18 100644 --- a/ridgenative.go +++ b/ridgenative.go @@ -574,9 +574,9 @@ func (f *lambdaFunction) lambdaHandlerStreaming(ctx context.Context, req *reques rw := newStreamingResponseWriter(w) defer func() { if v := recover(); v != nil { - rw.closeWithError(lambdaPanicResponse(v)) + _ = rw.closeWithError(lambdaPanicResponse(v)) } else { - rw.close() + _ = rw.close() } }() f.mux.ServeHTTP(rw, r) diff --git a/ridgenative_test.go b/ridgenative_test.go index 6faf92c..787e497 100644 --- a/ridgenative_test.go +++ b/ridgenative_test.go @@ -762,7 +762,9 @@ func TestLambdaHandlerStreaming(t *testing.T) { t.Run("normal", func(t *testing.T) { l := newLambdaFunction(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - io.WriteString(w, `{"hello":"world"}`) + if _, err := io.WriteString(w, `{"hello":"world"}`); err != nil { + t.Error(err) + } })) r, w := io.Pipe() contentType, err := l.lambdaHandlerStreaming(context.Background(), &request{ @@ -794,8 +796,12 @@ func TestLambdaHandlerStreaming(t *testing.T) { // Writes to ResponseWriter are buffered, // so multiple writes to ResponseWriter become a single write to the pipe - io.WriteString(w, `{"hello":`) - io.WriteString(w, `"world"}`) + if _, err := io.WriteString(w, `{"hello":`); err != nil { + t.Error(err) + } + if _, err := io.WriteString(w, `"world"}`); err != nil { + t.Error(err) + } })) r, w := io.Pipe() contentType, err := l.lambdaHandlerStreaming(context.Background(), &request{ @@ -851,9 +857,13 @@ func TestLambdaHandlerStreaming(t *testing.T) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - io.WriteString(w, `{"hello":`) + if _, err := io.WriteString(w, `{"hello":`); err != nil { + t.Error(err) + } f.Flush() - io.WriteString(w, `"world"}`) + if _, err := io.WriteString(w, `"world"}`); err != nil { + t.Error(err) + } })) r, w := io.Pipe() contentType, err := l.lambdaHandlerStreaming(context.Background(), &request{ @@ -911,7 +921,9 @@ func TestLambdaHandlerStreaming(t *testing.T) { t.Run("detect content-type", func(t *testing.T) { l := newLambdaFunction(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, ``) + if _, err := io.WriteString(w, ``); err != nil { + t.Error(err) + } })) r, w := io.Pipe() contentType, err := l.lambdaHandlerStreaming(context.Background(), &request{