diff --git a/CHANGELOG.md b/CHANGELOG.md index a0a35130..dcd775b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Most recent version is listed first. # v0.1.7 - Update go version; https://github.com/komuw/ong/pull/469 - ong/cry: Replace scrypt with argon2id: https://github.com/komuw/ong/pull/471 +- ong/middleware: Give users control over what and how logging happens in the middlewares: https://github.com/komuw/ong/pull/472 # v0.1.6 - Bump versions of dependencies used diff --git a/config/config.go b/config/config.go index 4941faf7..edb619f1 100644 --- a/config/config.go +++ b/config/config.go @@ -3,6 +3,7 @@ package config import ( "crypto/x509" + "errors" "fmt" "log/slog" "net/http" @@ -16,12 +17,6 @@ import ( "github.com/komuw/ong/internal/key" ) -// logging middleware. -const ( - // DefaultRateShedSamplePercent is the percentage of rate limited or loadshed responses that will be logged as errors, by default. - DefaultRateShedSamplePercent = 10 -) - // ratelimit middleware. const ( // DefaultRateLimit is the maximum requests allowed (from one IP address) per second, by default. @@ -193,15 +188,17 @@ func (o Opts) GoString() string { // domain is the domain name of your website. It can be an exact domain, subdomain or wildcard. // port is the TLS port where the server will listen on. Http requests will also redirected to that port. // +// logger is an [slog.Logger] that will be used for logging. It is used in the server, it's use in middlewares is only if [logFunc] is nil. +// // secretKey is used for securing signed data. It should be unique & kept secret. // If it becomes compromised, generate a new one and restart your application using the new one. // // strategy is the algorithm to use when fetching the client's IP address; see [ClientIPstrategy]. // It is important to choose your strategy carefully, see the warning in [ClientIPstrategy]. // -// logger is an [slog.Logger] that will be used for logging. -// -// rateShedSamplePercent is the percentage of rate limited or loadshed responses that will be logged as errors. If it is less than 0, [DefaultRateShedSamplePercent] is used instead. +// logFunc is a function that dictates what/how middleware is going to log. It is also used to log any recovered panics in the middleware. +// This function is not used in the server. +// If it is nil, a suitable default(that utilizes [logger]) is used. To disable logging, use a function that does nothing. // // rateLimit is the maximum requests allowed (from one IP address) per second. If it is les than 1.0, [DefaultRateLimit] is used instead. // @@ -257,12 +254,12 @@ func New( // common domain string, port uint16, + logger *slog.Logger, // middleware secretKey string, strategy ClientIPstrategy, - logger *slog.Logger, - rateShedSamplePercent int, + logFunc func(w http.ResponseWriter, r http.Request, statusCode int, fields []any), rateLimit float64, loadShedSamplingPeriod time.Duration, loadShedMinSampleSize int, @@ -293,10 +290,10 @@ func New( middlewareOpts, err := newMiddlewareOpts( domain, port, + logger, secretKey, strategy, - logger, - rateShedSamplePercent, + logFunc, rateLimit, loadShedSamplingPeriod, loadShedMinSampleSize, @@ -355,12 +352,12 @@ func WithOpts( // common domain, httpsPort, + logger, // middleware secretKey, strategy, - logger, - DefaultRateShedSamplePercent, + nil, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -404,12 +401,12 @@ func DevOpts(logger *slog.Logger, secretKey string) Opts { // common domain, httpsPort, + logger, // middleware secretKey, clientip.DirectIpStrategy, - logger, - DefaultRateShedSamplePercent, + nil, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -460,12 +457,12 @@ func CertOpts( // common domain, httpsPort, + logger, // middleware secretKey, clientip.DirectIpStrategy, - logger, - DefaultRateShedSamplePercent, + nil, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -519,12 +516,12 @@ func AcmeOpts( // common domain, httpsPort, + logger, // middleware secretKey, clientip.DirectIpStrategy, - logger, - DefaultRateShedSamplePercent, + nil, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -577,12 +574,12 @@ func LetsEncryptOpts( // common domain, httpsPort, + logger, // middleware secretKey, clientip.DirectIpStrategy, - logger, - DefaultRateShedSamplePercent, + nil, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -632,16 +629,15 @@ func (s secureKey) GoString() string { type middlewareOpts struct { Domain string HttpsPort uint16 + Logger *slog.Logger + // When printing a struct, fmt does not invoke custom formatting methods on unexported fields. // We thus need to make this field to be exported. // - https://pkg.go.dev/fmt#:~:text=When%20printing%20a%20struct // - https://go.dev/play/p/wL2gqumZ23b SecretKey secureKey Strategy ClientIPstrategy - Logger *slog.Logger - - // logger - RateShedSamplePercent int + LogFunc func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) // ratelimit RateLimit float64 @@ -673,8 +669,6 @@ func (m middlewareOpts) String() string { HttpsPort: %d, SecretKey: %s, Strategy: %v, - Logger: %v, - RateShedSamplePercent: %v, RateLimit: %v, LoadShedSamplingPeriod: %v, LoadShedMinSampleSize: %v, @@ -692,8 +686,6 @@ func (m middlewareOpts) String() string { m.HttpsPort, m.SecretKey, m.Strategy, - m.Logger, - m.RateShedSamplePercent, m.RateLimit, m.LoadShedSamplingPeriod, m.LoadShedMinSampleSize, @@ -717,10 +709,10 @@ func (m middlewareOpts) GoString() string { func newMiddlewareOpts( domain string, httpsPort uint16, + logger *slog.Logger, secretKey string, strategy ClientIPstrategy, - logger *slog.Logger, - rateShedSamplePercent int, + logFunc func(w http.ResponseWriter, r http.Request, statusCode int, fields []any), rateLimit float64, loadShedSamplingPeriod time.Duration, loadShedMinSampleSize int, @@ -743,6 +735,10 @@ func newMiddlewareOpts( domain = domain[2:] } + if logger == nil && logFunc == nil { + return middlewareOpts{}, errors.New("both logger and logFunc should not be nil at the same time") + } + if err := key.IsSecure(secretKey); err != nil { return middlewareOpts{}, err } @@ -771,12 +767,10 @@ func newMiddlewareOpts( return middlewareOpts{ Domain: domain, HttpsPort: httpsPort, + Logger: logger, SecretKey: secureKey(secretKey), Strategy: strategy, - Logger: logger, - - // logger - RateShedSamplePercent: rateShedSamplePercent, + LogFunc: logFunc, // ratelimiter RateLimit: rateLimit, @@ -1051,13 +1045,7 @@ func (o Opts) Equal(other Opts) bool { if o.Strategy != other.Strategy { return false } - if o.Logger != other.Logger { - return false - } - if o.RateShedSamplePercent != other.RateShedSamplePercent { - return false - } if int(o.RateLimit) != int(other.RateLimit) { return false } diff --git a/config/config_test.go b/config/config_test.go index a69425ce..7e22a716 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -29,14 +29,17 @@ func validOpts(t *testing.T) Opts { "example.com", // The https port that our application will be listening on. 443, + // Logger. + l, // The security key to use for securing signed data. "super-h@rd-Pas1word", // In this case, the actual client IP address is fetched from the given http header. SingleIpStrategy("CF-Connecting-IP"), - // Logger. - l, - // log 90% of all responses that are either rate-limited or loadshed. - 90, + // function to use for logging in middlewares. + func(_ http.ResponseWriter, r http.Request, statusCode int, fields []any) { + reqL := log.WithID(r.Context(), l) + reqL.Info("request-and-response", fields...) + }, // If a particular IP address sends more than 13 requests per second, throttle requests from that IP. 13.0, // Sample response latencies over a 5 minute window to determine if to loadshed. @@ -132,10 +135,10 @@ func TestNewMiddlewareOpts(t *testing.T) { o, err := newMiddlewareOpts( opt.Domain, opt.HttpsPort, + slog.Default(), string(opt.SecretKey), opt.Strategy, - opt.Logger, - opt.RateShedSamplePercent, + nil, opt.RateLimit, opt.LoadShedSamplingPeriod, opt.LoadShedMinSampleSize, @@ -193,10 +196,10 @@ func TestNewMiddlewareOptsDomain(t *testing.T) { _, err := newMiddlewareOpts( tt.domain, 443, + slog.Default(), tst.SecretKey(), clientip.DirectIpStrategy, - slog.Default(), - DefaultRateShedSamplePercent, + nil, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -215,10 +218,10 @@ func TestNewMiddlewareOptsDomain(t *testing.T) { _, err := newMiddlewareOpts( tt.domain, 443, + slog.Default(), tst.SecretKey(), clientip.DirectIpStrategy, - slog.Default(), - DefaultRateShedSamplePercent, + nil, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -253,8 +256,7 @@ func TestOpts(t *testing.T) { HttpsPort: 65081, SecretKey: secureKey(tst.SecretKey()), Strategy: clientip.DirectIpStrategy, - Logger: l, - RateShedSamplePercent: DefaultRateShedSamplePercent, + LogFunc: nil, RateLimit: DefaultRateLimit, LoadShedSamplingPeriod: DefaultLoadShedSamplingPeriod, LoadShedMinSampleSize: DefaultLoadShedMinSampleSize, diff --git a/config/example_test.go b/config/example_test.go index 4acaf8c1..6dfeab28 100644 --- a/config/example_test.go +++ b/config/example_test.go @@ -26,14 +26,21 @@ func ExampleNew() { "example.com", // The https port that our application will be listening on. 443, + // Logger. + l, // The security key to use for securing signed data. "super-h@rd-Pas1word", // In this case, the actual client IP address is fetched from the given http header. config.SingleIpStrategy("CF-Connecting-IP"), - // Logger. - l, - // log 90% of all responses that are either rate-limited or loadshed. - 90, + // function to use for logging in middlewares + func(_ http.ResponseWriter, r http.Request, statusCode int, fields []any) { + if statusCode >= http.StatusInternalServerError { + // Only log 500's + reqL := log.WithID(r.Context(), l) + fields = append(fields, "statusCode", statusCode) + reqL.Info("request-and-response", fields...) + } + }, // If a particular IP address sends more than 13 requests per second, throttle requests from that IP. 13.0, // Sample response latencies over a 5 minute window to determine if to loadshed. diff --git a/middleware/log.go b/middleware/log.go index 1331a4bf..ddbdb069 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -10,34 +10,28 @@ import ( "net/http" "time" - "github.com/komuw/ong/config" "github.com/komuw/ong/log" ) // logger is a middleware that logs http requests and responses using [log.Logger]. func logger( wrappedHandler http.Handler, + logFunc func(w http.ResponseWriter, r http.Request, statusCode int, fields []any), l *slog.Logger, - rateShedSamplePercent int, ) http.HandlerFunc { - // We pass the logger as an argument so that the middleware can share the same logger as the app. + // The middleware should ideally share the same logger as the app. // That way, if the app logs an error, the middleware logs are also flushed. - // This makes debugging easier for developers. - // - // However, each request should get its own context. That's why we call `logger.WithCtx` for every request. - - // Note: a value of 0, disables logging of ratelimited and loadshed responses. - if rateShedSamplePercent < 0 { - rateShedSamplePercent = config.DefaultRateShedSamplePercent + // That's one reason why we pass in the logFunc.This makes debugging easier for developers. + // Another reason is so that app developers are in control of what(and how) exactly gets logged. + if logFunc == nil { + logFunc = defaultLogFunc(l) } return func(w http.ResponseWriter, r *http.Request) { start := time.Now() - lrw := &logRW{ - ResponseWriter: w, - } + lrw := &logRW{ResponseWriter: w} + defer func() { - msg := "http_server" flds := []any{ "clientIP", ClientIP(r), "clientFingerPrint", ClientFingerPrint(r), @@ -61,29 +55,7 @@ func logger( // 1xx class or the modified headers are trailers. lrw.Header().Del(ongMiddlewareErrorHeader) - // The logger should be in the defer block so that it uses the updated context containing the logID. - reqL := log.WithID(r.Context(), l) - - if (lrw.code == http.StatusServiceUnavailable || lrw.code == http.StatusTooManyRequests) && w.Header().Get(retryAfterHeader) != "" { - // We are either in load shedding or rate-limiting. - // Only log (rateShedSamplePercent)% of the errors. - shouldLog := mathRand.IntN(100) <= rateShedSamplePercent - if shouldLog { - reqL.Error(msg, flds...) - } - } else if lrw.code >= http.StatusBadRequest { - // Both client and server errors. - if lrw.code == http.StatusNotFound || - lrw.code == http.StatusMethodNotAllowed || - lrw.code == http.StatusTeapot { - // These ones are more of an annoyance, than been actual errors. - reqL.Info(msg, flds...) - } else { - reqL.Error(msg, flds...) - } - } else { - reqL.Info(msg, flds...) - } + logFunc(w, *r, lrw.code, flds) }() wrappedHandler.ServeHTTP(lrw, r) @@ -180,3 +152,45 @@ func (lrw *logRW) ReadFrom(src io.Reader) (n int64, err error) { func (lrw *logRW) Unwrap() http.ResponseWriter { return lrw.ResponseWriter } + +// defaultLogFunc is the logging function used if the user did not explicitly provide one. +func defaultLogFunc(l *slog.Logger) func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) { + const ( + msg = "http_server" + // rateShedSamplePercent is the percentage of rate limited or loadshed responses that will be logged as errors, by default. + rateShedSamplePercent = 10 + ) + + if l == nil { + return func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) {} + } + + return func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) { + // Each request should get its own context. That's why we call `log.WithID` for every request. + reqL := log.WithID(r.Context(), l) + + if (statusCode == http.StatusServiceUnavailable || statusCode == http.StatusTooManyRequests) && w.Header().Get(retryAfterHeader) != "" { + // We are either in load shedding or rate-limiting. + // Only log (rateShedSamplePercent)% of the errors. + shouldLog := mathRand.IntN(100) <= rateShedSamplePercent + if shouldLog { + reqL.Error(msg, fields...) + return + } + } + + if statusCode >= http.StatusBadRequest { + // Both client and server errors. + if statusCode == http.StatusNotFound || statusCode == http.StatusMethodNotAllowed || statusCode == http.StatusTeapot { + // These ones are more of an annoyance, than been actual errors. + reqL.Info(msg, fields...) + return + } + + reqL.Error(msg, fields...) + return + } + + reqL.Info(msg, fields...) + } +} diff --git a/middleware/log_test.go b/middleware/log_test.go index 29f23483..135e516b 100644 --- a/middleware/log_test.go +++ b/middleware/log_test.go @@ -12,7 +12,6 @@ import ( "testing" "time" - "github.com/komuw/ong/config" "github.com/komuw/ong/id" "github.com/komuw/ong/log" @@ -54,7 +53,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} successMsg := "hello" - wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput), config.DefaultRateShedSamplePercent) + wrappedHandler := logger(someLogHandler(successMsg), nil, getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodHead, "/someUri", nil) @@ -77,7 +76,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} errorMsg := "someLogHandler failed" successMsg := "hello" - wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput), config.DefaultRateShedSamplePercent) + wrappedHandler := logger(someLogHandler(successMsg), nil, getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodHead, "/someUri", nil) @@ -113,7 +112,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} successMsg := "hello" errorMsg := "someLogHandler failed" - wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput), config.DefaultRateShedSamplePercent) + wrappedHandler := logger(someLogHandler(successMsg), nil, getLogger(logOutput)) { // first request that succeds @@ -192,7 +191,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} successMsg := "hello" - wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput), config.DefaultRateShedSamplePercent) + wrappedHandler := logger(someLogHandler(successMsg), nil, getLogger(logOutput)) someLogID := "hey-some-log-id:" + id.New() @@ -215,6 +214,28 @@ func TestLogMiddleware(t *testing.T) { attest.Zero(t, logOutput.String()) }) + t.Run("nil logger and logfunc", func(t *testing.T) { + t.Parallel() + + logOutput := &bytes.Buffer{} + successMsg := "hello" + wrappedHandler := logger(someLogHandler(successMsg), nil, nil) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodHead, "/someUri", nil) + wrappedHandler.ServeHTTP(rec, req) + + res := rec.Result() + defer res.Body.Close() + + rb, err := io.ReadAll(res.Body) + attest.Ok(t, err) + + attest.Equal(t, res.StatusCode, http.StatusOK) + attest.Equal(t, string(rb), successMsg) + attest.Zero(t, logOutput.String()) + }) + t.Run("concurrency safe", func(t *testing.T) { t.Parallel() @@ -222,7 +243,7 @@ func TestLogMiddleware(t *testing.T) { successMsg := "hello" // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput), config.DefaultRateShedSamplePercent) + wrappedHandler := logger(someLogHandler(successMsg), nil, getLogger(logOutput)) runhandler := func() { rec := httptest.NewRecorder() diff --git a/middleware/middleware.go b/middleware/middleware.go index c29bd80e..e93def23 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -54,10 +54,10 @@ func allDefaultMiddlewares( httpsPort := o.HttpsPort secretKey := o.SecretKey strategy := o.Strategy - l := o.Logger // logger - rateShedSamplePercent := o.RateShedSamplePercent + l := o.Logger + logFunc := o.LogFunc // ratelimit rateLimit := o.RateLimit @@ -172,9 +172,10 @@ func allDefaultMiddlewares( rateLimit, ), ), + logFunc, l, - rateShedSamplePercent, ), + logFunc, l, ), ), diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 369a0f70..28b048c3 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -486,10 +486,10 @@ func BenchmarkAllMiddlewares(b *testing.B) { o := config.New( domain, httpsPort, + l, tst.SecretKey(), config.DirectIpStrategy, - l, - config.DefaultRateShedSamplePercent, + nil, rateLimit, config.DefaultLoadShedSamplingPeriod, config.DefaultLoadShedMinSampleSize, diff --git a/middleware/recoverer.go b/middleware/recoverer.go index a3535ddb..47804641 100644 --- a/middleware/recoverer.go +++ b/middleware/recoverer.go @@ -6,7 +6,6 @@ import ( "net/http" "github.com/komuw/ong/errors" - "github.com/komuw/ong/log" ) // Some of the code here is inspired(or taken from) by: @@ -14,19 +13,24 @@ import ( // recoverer is a middleware that recovers from panics in wrappedHandler. // When/if a panic occurs, it logs the stack trace and returns an InternalServerError response. -func recoverer(wrappedHandler http.Handler, l *slog.Logger) http.HandlerFunc { +func recoverer( + wrappedHandler http.Handler, + logFunc func(w http.ResponseWriter, r http.Request, statusCode int, fields []any), + l *slog.Logger, +) http.HandlerFunc { + code := http.StatusInternalServerError + status := http.StatusText(code) + + if logFunc == nil { + logFunc = defaultLogFunc(l) + } + return func(w http.ResponseWriter, r *http.Request) { defer func() { errR := recover() if errR != nil { - reqL := log.WithID(r.Context(), l) - - code := http.StatusInternalServerError - status := http.StatusText(code) - msg := "http_server" flds := []any{ - "error", fmt.Sprint(errR), "clientIP", ClientIP(r), "clientFingerPrint", ClientFingerPrint(r), "method", r.Method, @@ -38,15 +42,23 @@ func recoverer(wrappedHandler http.Handler, l *slog.Logger) http.HandlerFunc { extra := []any{"ongError", ongError} flds = append(flds, extra...) } - w.Header().Del(ongMiddlewareErrorHeader) // remove header so that users dont see it. + // Remove header so that users dont see it. + // + // Note that this may not actually work. + // According to: https://pkg.go.dev/net/http#ResponseWriter + // Changing the header map after a call to WriteHeader (or + // Write) has no effect unless the HTTP status code was of the + // 1xx class or the modified headers are trailers. + w.Header().Del(ongMiddlewareErrorHeader) + + extra := []any{"error", fmt.Sprint(errR)} if e, ok := errR.(error); ok { - extra := []any{"err", errors.Wrap(e)} // wrap with ong/errors so that the log will have a stacktrace. - flds = append(flds, extra...) - reqL.Error(msg, flds...) - } else { - reqL.Error(msg, flds...) + extra = []any{"error", errors.Wrap(e)} // wrap with ong/errors so that the log will have a stacktrace. } + flds = append(flds, extra...) + + logFunc(w, *r, code, flds) // respond. http.Error( diff --git a/middleware/recoverer_test.go b/middleware/recoverer_test.go index c9571136..aeb7a6fc 100644 --- a/middleware/recoverer_test.go +++ b/middleware/recoverer_test.go @@ -8,7 +8,6 @@ import ( "log/slog" "net/http" "net/http/httptest" - "os" "strings" "sync" "testing" @@ -57,7 +56,7 @@ func TestPanic(t *testing.T) { logOutput := &bytes.Buffer{} msg := "hello" - wrappedHandler := recoverer(handlerThatPanics(msg, false, nil), getLogger(logOutput)) + wrappedHandler := recoverer(handlerThatPanics(msg, false, nil), nil, getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -74,7 +73,7 @@ func TestPanic(t *testing.T) { logOutput := &bytes.Buffer{} msg := "hello" - wrappedHandler := recoverer(handlerThatPanics(msg, true, nil), getLogger(logOutput)) + wrappedHandler := recoverer(handlerThatPanics(msg, true, nil), nil, getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -104,7 +103,7 @@ func TestPanic(t *testing.T) { msg := "hello" errMsg := "99 problems" err := errors.New(errMsg) - wrappedHandler := recoverer(handlerThatPanics(msg, false, err), getLogger(logOutput)) + wrappedHandler := recoverer(handlerThatPanics(msg, false, err), nil, getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -132,7 +131,7 @@ func TestPanic(t *testing.T) { t.Parallel() logOutput := &bytes.Buffer{} - wrappedHandler := recoverer(anotherHandlerThatPanics(), getLogger(logOutput)) + wrappedHandler := recoverer(anotherHandlerThatPanics(), nil, getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -141,19 +140,19 @@ func TestPanic(t *testing.T) { res := rec.Result() defer res.Body.Close() attest.Equal(t, res.StatusCode, http.StatusInternalServerError) - attest.Subsequence(t, logOutput.String(), "middleware/recoverer_test.go:42") // line where panic happened. + attest.Subsequence(t, logOutput.String(), "middleware/recoverer_test.go:41") // line where panic happened. }) t.Run("concurrency safe", func(t *testing.T) { t.Parallel() - // &bytes.Buffer{} is not concurrency safe, so we use os.Stderr instead. - logOutput := os.Stderr + // If &bytes.Buffer{} is not concurrency safe, we can use os.Stderr instead. + logOutput := &bytes.Buffer{} msg := "hey" err := errors.New(msg) // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := recoverer(handlerThatPanics(msg, false, err), getLogger(logOutput)) + wrappedHandler := recoverer(handlerThatPanics(msg, false, err), nil, getLogger(logOutput)) runhandler := func() { rec := httptest.NewRecorder()