diff --git a/.github/docs/openapi3filter.txt b/.github/docs/openapi3filter.txt index 0fd7027c8..43540c416 100644 --- a/.github/docs/openapi3filter.txt +++ b/.github/docs/openapi3filter.txt @@ -182,7 +182,7 @@ type ErrCode int occur during validation. These may be used to write an appropriate response in ErrFunc. -type ErrFunc func(w http.ResponseWriter, status int, code ErrCode, err error) +type ErrFunc func(ctx context.Context, w http.ResponseWriter, status int, code ErrCode, err error) ErrFunc handles errors that may occur during validation. type ErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter) @@ -198,7 +198,7 @@ type Headerer interface { Headerer, the provided headers will be applied to the response writer, after the Content-Type is set. -type LogFunc func(message string, err error) +type LogFunc func(ctx context.Context, message string, err error) LogFunc handles log messages that may occur during validation. type Options struct { diff --git a/openapi3filter/middleware.go b/openapi3filter/middleware.go index 0009d61c4..d20889ed9 100644 --- a/openapi3filter/middleware.go +++ b/openapi3filter/middleware.go @@ -2,6 +2,7 @@ package openapi3filter import ( "bytes" + "context" "io" "log" "net/http" @@ -19,10 +20,10 @@ type Validator struct { } // ErrFunc handles errors that may occur during validation. -type ErrFunc func(w http.ResponseWriter, status int, code ErrCode, err error) +type ErrFunc func(ctx context.Context, w http.ResponseWriter, status int, code ErrCode, err error) // LogFunc handles log messages that may occur during validation. -type LogFunc func(message string, err error) +type LogFunc func(ctx context.Context, message string, err error) // ErrCode is used for classification of different types of errors that may // occur during validation. These may be used to write an appropriate response @@ -61,10 +62,10 @@ func (e ErrCode) responseText() string { func NewValidator(router routers.Router, options ...ValidatorOption) *Validator { v := &Validator{ router: router, - errFunc: func(w http.ResponseWriter, status int, code ErrCode, _ error) { + errFunc: func(_ context.Context, w http.ResponseWriter, status int, code ErrCode, _ error) { http.Error(w, code.responseText(), status) }, - logFunc: func(message string, err error) { + logFunc: func(_ context.Context, message string, err error) { log.Printf("%s: %v", message, err) }, } @@ -117,10 +118,11 @@ func ValidationOptions(options Options) ValidatorOption { // request and response validation. func (v *Validator) Middleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() route, pathParams, err := v.router.FindRoute(r) if err != nil { - v.logFunc("validation error: failed to find route for "+r.URL.String(), err) - v.errFunc(w, http.StatusNotFound, ErrCodeCannotFindRoute, err) + v.logFunc(ctx, "validation error: failed to find route for "+r.URL.String(), err) + v.errFunc(ctx, w, http.StatusNotFound, ErrCodeCannotFindRoute, err) return } requestValidationInput := &RequestValidationInput{ @@ -129,9 +131,9 @@ func (v *Validator) Middleware(h http.Handler) http.Handler { Route: route, Options: &v.options, } - if err = ValidateRequest(r.Context(), requestValidationInput); err != nil { - v.logFunc("invalid request", err) - v.errFunc(w, http.StatusBadRequest, ErrCodeRequestInvalid, err) + if err = ValidateRequest(ctx, requestValidationInput); err != nil { + v.logFunc(ctx, "invalid request", err) + v.errFunc(ctx, w, http.StatusBadRequest, ErrCodeRequestInvalid, err) return } @@ -144,22 +146,22 @@ func (v *Validator) Middleware(h http.Handler) http.Handler { h.ServeHTTP(wr, r) - if err = ValidateResponse(r.Context(), &ResponseValidationInput{ + if err = ValidateResponse(ctx, &ResponseValidationInput{ RequestValidationInput: requestValidationInput, Status: wr.statusCode(), Header: wr.Header(), Body: io.NopCloser(bytes.NewBuffer(wr.bodyContents())), Options: &v.options, }); err != nil { - v.logFunc("invalid response", err) + v.logFunc(ctx, "invalid response", err) if v.strict { - v.errFunc(w, http.StatusInternalServerError, ErrCodeResponseInvalid, err) + v.errFunc(ctx, w, http.StatusInternalServerError, ErrCodeResponseInvalid, err) } return } if err = wr.flushBodyContents(); err != nil { - v.logFunc("failed to write response", err) + v.logFunc(ctx, "failed to write response", err) } }) } diff --git a/openapi3filter/middleware_test.go b/openapi3filter/middleware_test.go index 1260ac54c..137228cd1 100644 --- a/openapi3filter/middleware_test.go +++ b/openapi3filter/middleware_test.go @@ -2,6 +2,7 @@ package openapi3filter_test import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -489,7 +490,7 @@ paths: // testing a service against its spec in development and CI. In production, // availability may be more important than strictness. v := openapi3filter.NewValidator(router, openapi3filter.Strict(true), - openapi3filter.OnErr(func(w http.ResponseWriter, status int, code openapi3filter.ErrCode, err error) { + openapi3filter.OnErr(func(_ context.Context, w http.ResponseWriter, status int, code openapi3filter.ErrCode, err error) { // Customize validation error responses to use JSON w.Header().Set("Content-Type", "application/json") w.WriteHeader(status)