diff --git a/internal/api/router.go b/internal/api/router.go index 63ed681e..c2242800 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -98,22 +98,31 @@ func (r *Router) Routes(rg *echo.Group) { func errorMiddleware(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - err := next(c) + origErr := next(c) - if err == nil { + if origErr == nil { return nil } - // If error is an echo.HTTPError simply return it - if _, ok := err.(*echo.HTTPError); ok { - return err + var ( + checkErr = origErr + echoMsg []any + ) + + // If error is an echo.HTTPError, extract it's message to be reused if status is rewritten. + // Additionally we unwrap the internal error which is then checked instead of the echo error. + if eerr, ok := origErr.(*echo.HTTPError); ok { + echoMsg = []any{eerr.Message} + checkErr = eerr.Internal } switch { - case errors.Is(err, context.Canceled): - return echo.ErrUnprocessableEntity.WithInternal(err) + // Only if the error is a context canceled error and the request context has been canceled. + // If the request was not canceled, then the context canceled error probably came from the service. + case errors.Is(checkErr, context.Canceled) && errors.Is(c.Request().Context().Err(), context.Canceled): + return echo.NewHTTPError(http.StatusUnprocessableEntity, echoMsg...).WithInternal(checkErr) default: - return err + return origErr } } } diff --git a/internal/api/router_test.go b/internal/api/router_test.go index cd8366e2..2a296b03 100644 --- a/internal/api/router_test.go +++ b/internal/api/router_test.go @@ -2,6 +2,7 @@ package api import ( "context" + "encoding/json" "io" "net/http" "net/http/httptest" @@ -19,6 +20,8 @@ func TestErrorMiddleware(t *testing.T) { ctx := context.Background() e := echo.New() + e.Debug = true + e.Use(echoTestLogger(t, e)) e.Use(errorMiddleware) @@ -27,7 +30,15 @@ func TestErrorMiddleware(t *testing.T) { select { case <-c.Request().Context().Done(): - return c.Request().Context().Err() + err := c.Request().Context().Err() + + switch errType { + case "": + case "echo": + return echo.NewHTTPError(http.StatusInternalServerError, "some message").WithInternal(err) + } + + return err case <-time.After(time.Second): } @@ -37,6 +48,8 @@ func TestErrorMiddleware(t *testing.T) { return echo.ErrTeapot case "other": return io.ErrUnexpectedEOF + case "internalCancel": + return echo.NewHTTPError(http.StatusInternalServerError, "service error").WithInternal(context.Canceled) } return nil @@ -97,6 +110,55 @@ func TestErrorMiddleware(t *testing.T) { assert.Equal(t, http.StatusUnprocessableEntity, res.Success.Code) }, }, + { + Name: "Canceled echo", + Input: testinput{ + path: "/test?error=echo", + delay: time.Second / 2, + }, + CheckFn: func(_ context.Context, t *testing.T, res testingx.TestResult[*httptest.ResponseRecorder]) { + require.NoError(t, res.Err) + require.NotNil(t, res.Success) + + require.Equal(t, http.StatusUnprocessableEntity, res.Success.Code) + + resp := map[string]string{} + + err := json.Unmarshal(res.Success.Body.Bytes(), &resp) + require.NoError(t, err, "no error expected decoding response body") + + expect := map[string]string{ + "error": "code=422, message=some message, internal=context canceled", + "message": "some message", + } + + assert.Equal(t, expect, resp, "unexpected response") + }, + }, + { + Name: "Canceled echo internal", + Input: testinput{ + path: "/test?error=internalCancel", + }, + CheckFn: func(_ context.Context, t *testing.T, res testingx.TestResult[*httptest.ResponseRecorder]) { + require.NoError(t, res.Err) + require.NotNil(t, res.Success) + + require.Equal(t, http.StatusInternalServerError, res.Success.Code) + + resp := map[string]string{} + + err := json.Unmarshal(res.Success.Body.Bytes(), &resp) + require.NoError(t, err, "no error expected decoding response body") + + expect := map[string]string{ + "error": "code=500, message=service error, internal=context canceled", + "message": "service error", + } + + assert.Equal(t, expect, resp, "unexpected response") + }, + }, } testFn := func(ctx context.Context, input testinput) testingx.TestResult[*httptest.ResponseRecorder] {