diff --git a/pkg/serverless/daemon/routes.go b/pkg/serverless/daemon/routes.go index 733b4a49b050c..fe78651ac4d2f 100644 --- a/pkg/serverless/daemon/routes.go +++ b/pkg/serverless/daemon/routes.go @@ -104,7 +104,17 @@ func (e *EndInvocation) ServeHTTP(w http.ResponseWriter, r *http.Request) { } errorMsg := r.Header.Get(invocationlifecycle.InvocationErrorMsgHeader) + if decodedMsg, err := base64.StdEncoding.DecodeString(errorMsg); err != nil { + log.Debug("Error message header may not be encoded, setting as is") + } else { + errorMsg = string(decodedMsg) + } errorType := r.Header.Get(invocationlifecycle.InvocationErrorTypeHeader) + if decodedType, err := base64.StdEncoding.DecodeString(errorType); err != nil { + log.Debug("Error type header may not be encoded, setting as is") + } else { + errorType = string(decodedType) + } errorStack := r.Header.Get(invocationlifecycle.InvocationErrorStackHeader) if decodedStack, err := base64.StdEncoding.DecodeString(errorStack); err != nil { log.Debug("Could not decode error stack header") diff --git a/pkg/serverless/daemon/routes_test.go b/pkg/serverless/daemon/routes_test.go index 25231204af6c8..e630d2829bd61 100644 --- a/pkg/serverless/daemon/routes_test.go +++ b/pkg/serverless/daemon/routes_test.go @@ -7,6 +7,7 @@ package daemon import ( "bytes" + "encoding/base64" "fmt" "io" "net/http" @@ -104,7 +105,7 @@ func TestEndInvocation(t *testing.T) { assert.Equal(m.lastEndDetails.Runtime, d.ExecutionContext.GetCurrentState().Runtime) } -func TestEndInvocationWithError(t *testing.T) { +func TestEndInvocationWithErrorEncodedHeaders(t *testing.T) { assert := assert.New(t) port := testutil.FreeTCPPort(t) d := StartDaemon(fmt.Sprintf("127.0.0.1:%d", port)) @@ -114,10 +115,52 @@ func TestEndInvocationWithError(t *testing.T) { m := &mockLifecycleProcessor{} d.InvocationProcessor = m + errorMessage := "Error message" + errorType := "System.Exception" + errorStack := "System.Exception: Error message \n at TestFunction.Handle(ILambdaContext context)" + + client := &http.Client{} + body := bytes.NewBuffer([]byte(`{}`)) + request, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/lambda/end-invocation", port), body) + request.Header.Set("x-datadog-invocation-error", "true") + request.Header.Set(invocationlifecycle.InvocationErrorMsgHeader, base64.StdEncoding.EncodeToString([]byte(errorMessage))) + request.Header.Set(invocationlifecycle.InvocationErrorTypeHeader, base64.StdEncoding.EncodeToString([]byte(errorType))) + request.Header.Set(invocationlifecycle.InvocationErrorStackHeader, base64.StdEncoding.EncodeToString([]byte(errorStack))) + assert.Nil(err) + res, err := client.Do(request) + assert.Nil(err) + if res != nil { + res.Body.Close() + assert.Equal(res.StatusCode, 200) + } + assert.True(m.OnInvokeEndCalled) + assert.True(m.isError) + assert.Equal(m.lastEndDetails.ErrorMsg, errorMessage) + assert.Equal(m.lastEndDetails.ErrorType, errorType) + assert.Equal(m.lastEndDetails.ErrorStack, errorStack) +} + +func TestEndInvocationWithErrorNonEncodedHeaders(t *testing.T) { + assert := assert.New(t) + port := testutil.FreeTCPPort(t) + d := StartDaemon(fmt.Sprintf("127.0.0.1:%d", port)) + time.Sleep(100 * time.Millisecond) + defer d.Stop() + + m := &mockLifecycleProcessor{} + d.InvocationProcessor = m + + errorMessage := "Error message" + errorType := "System.Exception" + errorStack := "System.Exception: Error message at TestFunction.Handle(ILambdaContext context)" + client := &http.Client{} body := bytes.NewBuffer([]byte(`{}`)) request, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/lambda/end-invocation", port), body) request.Header.Set("x-datadog-invocation-error", "true") + request.Header.Set(invocationlifecycle.InvocationErrorMsgHeader, errorMessage) + request.Header.Set(invocationlifecycle.InvocationErrorTypeHeader, errorType) + request.Header.Set(invocationlifecycle.InvocationErrorStackHeader, errorStack) assert.Nil(err) res, err := client.Do(request) assert.Nil(err) @@ -127,6 +170,9 @@ func TestEndInvocationWithError(t *testing.T) { } assert.True(m.OnInvokeEndCalled) assert.True(m.isError) + assert.Equal(m.lastEndDetails.ErrorMsg, errorMessage) + assert.Equal(m.lastEndDetails.ErrorType, errorType) + assert.Equal(m.lastEndDetails.ErrorStack, errorStack) } func TestTraceContext(t *testing.T) {