Skip to content

Commit

Permalink
Merge pull request #640 from notjustmoney/refactoring/use-new-error-w…
Browse files Browse the repository at this point in the history
…ith-context

refactor: use WriteErr function for handling error returned by handler
  • Loading branch information
danielgtaylor authored Nov 12, 2024
2 parents 4218123 + f628ea0 commit bea7c1a
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 52 deletions.
23 changes: 6 additions & 17 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"net/http"
"strconv"
)

// ErrorDetailer returns error details for responses & debugging. This enables
Expand Down Expand Up @@ -262,22 +261,12 @@ func WriteErr(api API, ctx Context, status int, msg string, errs ...error) error
// If it was not modified then this is a no-op.
status = err.GetStatus()

ct, negotiateErr := api.Negotiate(ctx.Header("Accept"))
if negotiateErr != nil {
return negotiateErr
writeErr := writeResponse(api, ctx, status, "", err)
if writeErr != nil {
// If we can't write the error, log it so we know what happened.
fmt.Printf("could not write error: %s\n", writeErr)
}

if ctf, ok := err.(ContentTypeFilter); ok {
ct = ctf.ContentType(ct)
}

ctx.SetHeader("Content-Type", ct)
ctx.SetStatus(status)
tval, terr := api.Transform(ctx, strconv.Itoa(status), err)
if terr != nil {
return terr
}
return api.Marshal(ctx.BodyWriter(), ct, tval)
return writeErr
}

// Status304NotModified returns a 304. This is not really an error, but
Expand Down Expand Up @@ -372,4 +361,4 @@ func Error504GatewayTimeout(msg string, errs ...error) StatusError {
}

// ErrorFormatter is a function that formats an error message
var ErrorFormatter func(format string, a ...any) string = fmt.Sprintf
var ErrorFormatter = fmt.Sprintf
4 changes: 2 additions & 2 deletions error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestNegotiateError(t *testing.T) {

req, _ := http.NewRequest("GET", "/", nil)
resp := httptest.NewRecorder()
ctx := humatest.NewContext(nil, req, resp)
ctx := humatest.NewContext(&huma.Operation{}, req, resp)
require.Error(t, huma.WriteErr(api, ctx, 400, "bad request"))
}

Expand All @@ -98,7 +98,7 @@ func TestTransformError(t *testing.T) {

req, _ := http.NewRequest("GET", "/", nil)
resp := httptest.NewRecorder()
ctx := humatest.NewContext(nil, req, resp)
ctx := humatest.NewContext(&huma.Operation{}, req, resp)

require.Error(t, huma.WriteErr(api, ctx, 400, "bad request"))
}
Expand Down
78 changes: 46 additions & 32 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,41 @@ var bufPool = sync.Pool{
},
}

func writeResponse(api API, ctx Context, status int, ct string, body any) error {
if ct == "" {
// If no content type was provided, try to negotiate one with the client.
var err error
ct, err = api.Negotiate(ctx.Header("Accept"))
if err != nil {
notAccept := NewErrorWithContext(ctx, http.StatusNotAcceptable, "unable to marshal response", err)
if e := transformAndWrite(api, ctx, http.StatusNotAcceptable, "application/json", notAccept); e != nil {
return e
}
return err
}

if ctf, ok := body.(ContentTypeFilter); ok {
ct = ctf.ContentType(ct)
}

ctx.SetHeader("Content-Type", ct)
}

if err := transformAndWrite(api, ctx, status, ct, body); err != nil {
return err
}
return nil
}

func writeResponseWithPanic(api API, ctx Context, status int, ct string, body any) {
if err := writeResponse(api, ctx, status, ct, body); err != nil {
panic(err)
}
}

// transformAndWrite is a utility function to transform and write a response.
// It is best-effort as the status code and headers may have already been sent.
func transformAndWrite(api API, ctx Context, status int, ct string, body any) {
func transformAndWrite(api API, ctx Context, status int, ct string, body any) error {
// Try to transform and then marshal/write the response.
// Status code was already sent, so just log the error if something fails,
// and do our best to stuff it into the body of the response.
Expand All @@ -475,17 +507,18 @@ func transformAndWrite(api API, ctx Context, status int, ct string, body any) {
ctx.BodyWriter().Write([]byte("error transforming response"))
// When including tval in the panic message, the server may become unresponsive for some time if the value is very large
// therefore, it has been removed from the panic message
panic(fmt.Errorf("error transforming response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, terr))
return fmt.Errorf("error transforming response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, terr)
}
ctx.SetStatus(status)
if status != http.StatusNoContent && status != http.StatusNotModified {
if merr := api.Marshal(ctx.BodyWriter(), ct, tval); merr != nil {
ctx.BodyWriter().Write([]byte("error marshaling response"))
// When including tval in the panic message, the server may become unresponsive for some time if the value is very large
// therefore, it has been removed from the panic message
panic(fmt.Errorf("error marshaling response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, merr))
return fmt.Errorf("error marshaling response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, merr)
}
}
return nil
}

func parseArrElement[T any](values []string, parse func(string) (T, error)) ([]T, error) {
Expand Down Expand Up @@ -963,7 +996,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
if f.Type() == reflect.TypeOf(values) {
f.Set(reflect.ValueOf(values))
} else {
//Change element type to support slice of string subtypes (enums)
// Change element type to support slice of string subtypes (enums)
enumValues := reflect.New(f.Type()).Elem()
for _, val := range values {
enumVal := reflect.New(f.Type().Elem()).Elem()
Expand Down Expand Up @@ -1403,21 +1436,16 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
}

status := http.StatusInternalServerError

// handle status error
var se StatusError
if errors.As(err, &se) {
status = se.GetStatus()
err = se
} else {
err = NewError(http.StatusInternalServerError, err.Error())
}

ct, _ := api.Negotiate(ctx.Header("Accept"))
if ctf, ok := err.(ContentTypeFilter); ok {
ct = ctf.ContentType(ct)
writeResponseWithPanic(api, ctx, se.GetStatus(), "", se)
return
}

ctx.SetHeader("Content-Type", ct)
transformAndWrite(api, ctx, status, ct, err)
se = NewErrorWithContext(ctx, status, "unexpected error occurred", err)
writeResponseWithPanic(api, ctx, se.GetStatus(), "", se)
return
}

Expand All @@ -1442,7 +1470,8 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
}
} else {
if f.Kind() == reflect.String && info.Name == "Content-Type" {
// Track custom content type.
// Track custom content type. This overrides any content negotiation
// that would happen when writing the response.
ct = f.String()
}
writeHeader(ctx.SetHeader, info, f)
Expand All @@ -1469,22 +1498,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
return
}

// Only write a content type if one wasn't already written by the
// response headers handled above.
if ct == "" {
ct, err = api.Negotiate(ctx.Header("Accept"))
if err != nil {
WriteErr(api, ctx, http.StatusNotAcceptable, "unable to marshal response", err)
return
}
if ctf, ok := body.(ContentTypeFilter); ok {
ct = ctf.ContentType(ct)
}

ctx.SetHeader("Content-Type", ct)
}

transformAndWrite(api, ctx, status, ct, body)
writeResponseWithPanic(api, ctx, status, ct, body)
} else {
ctx.SetStatus(status)
}
Expand Down
16 changes: 16 additions & 0 deletions huma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1376,6 +1376,22 @@ Content of example2.txt.
assert.Equal(t, http.StatusForbidden, resp.Code)
},
},
{
Name: "handler-generic-error",
Register: func(t *testing.T, api huma.API) {
huma.Register(api, huma.Operation{
Method: http.MethodGet,
Path: "/error",
}, func(ctx context.Context, input *struct{}) (*struct{}, error) {
return nil, errors.New("whoops")
})
},
Method: http.MethodGet,
URL: "/error",
Assert: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusInternalServerError, resp.Code)
},
},
{
Name: "response-headers",
Register: func(t *testing.T, api huma.API) {
Expand Down
3 changes: 2 additions & 1 deletion openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,8 @@ type Operation struct {
// This is a convenience for handlers that return a fixed set of errors
// where you do not wish to provide each one as an OpenAPI response object.
// Each error specified here is expanded into a response object with the
// schema generated from the type returned by `huma.NewError()`.
// schema generated from the type returned by `huma.NewError()`
// or `huma.NewErrorWithContext`.
Errors []int `yaml:"-"`

// SkipValidateParams disables validation of path, query, and header
Expand Down

0 comments on commit bea7c1a

Please sign in to comment.