From 71ad75c75beb9768bae1486efd12084687635e55 Mon Sep 17 00:00:00 2001 From: komuw Date: Thu, 15 Aug 2024 10:16:58 +0300 Subject: [PATCH 01/20] g --- middleware/log.go | 31 +++++++++++++++++++++++++++++++ middleware/recoverer.go | 10 +++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/middleware/log.go b/middleware/log.go index 1331a4bf..67da94ac 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -90,6 +90,37 @@ func logger( } } +// TODO: +func doLog(w http.ResponseWriter, r http.Request, statusCode int, l *slog.Logger, fields []any) { + reqL := log.WithID(r.Context(), l) + msg := "http_server" + + var rateShedSamplePercent int = 0 // TODO + if (statusCode == http.StatusServiceUnavailable || statusCode == http.StatusTooManyRequests) && w.Header().Get(retryAfterHeader) != "" { + // We are either in load shedding or rate-limiting. + // Only log (rateShedSamplePercent)% of the errors. + shouldLog := mathRand.IntN(100) <= rateShedSamplePercent + if shouldLog { + reqL.Error(msg, fields...) + return + } + } + + if statusCode < http.StatusBadRequest { + reqL.Info(msg, fields...) + return + } + + // Both client and server errors. + if statusCode == http.StatusNotFound || statusCode == http.StatusMethodNotAllowed || statusCode == http.StatusTeapot { + // These ones are more of an annoyance, than been actual errors. + reqL.Info(msg, fields...) + return + } + + reqL.Info(msg, fields...) +} + // logRW provides an http.ResponseWriter interface, which logs requests/responses. type logRW struct { http.ResponseWriter diff --git a/middleware/recoverer.go b/middleware/recoverer.go index a3535ddb..23e9fbac 100644 --- a/middleware/recoverer.go +++ b/middleware/recoverer.go @@ -38,7 +38,15 @@ func recoverer(wrappedHandler http.Handler, l *slog.Logger) http.HandlerFunc { extra := []any{"ongError", ongError} flds = append(flds, extra...) } - w.Header().Del(ongMiddlewareErrorHeader) // remove header so that users dont see it. + + // Remove header so that users dont see it. + // + // Note that this may not actually work. + // According to: https://pkg.go.dev/net/http#ResponseWriter + // Changing the header map after a call to WriteHeader (or + // Write) has no effect unless the HTTP status code was of the + // 1xx class or the modified headers are trailers. + w.Header().Del(ongMiddlewareErrorHeader) if e, ok := errR.(error); ok { extra := []any{"err", errors.Wrap(e)} // wrap with ong/errors so that the log will have a stacktrace. From 0c3e9527b5da1d421e1184c00e89852f3af00492 Mon Sep 17 00:00:00 2001 From: komuw Date: Thu, 15 Aug 2024 10:28:20 +0300 Subject: [PATCH 02/20] g --- middleware/log.go | 25 +------------------------ middleware/recoverer.go | 15 +++++---------- 2 files changed, 6 insertions(+), 34 deletions(-) diff --git a/middleware/log.go b/middleware/log.go index 67da94ac..d685e48b 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -37,7 +37,6 @@ func logger( ResponseWriter: w, } defer func() { - msg := "http_server" flds := []any{ "clientIP", ClientIP(r), "clientFingerPrint", ClientFingerPrint(r), @@ -61,29 +60,7 @@ func logger( // 1xx class or the modified headers are trailers. lrw.Header().Del(ongMiddlewareErrorHeader) - // The logger should be in the defer block so that it uses the updated context containing the logID. - reqL := log.WithID(r.Context(), l) - - if (lrw.code == http.StatusServiceUnavailable || lrw.code == http.StatusTooManyRequests) && w.Header().Get(retryAfterHeader) != "" { - // We are either in load shedding or rate-limiting. - // Only log (rateShedSamplePercent)% of the errors. - shouldLog := mathRand.IntN(100) <= rateShedSamplePercent - if shouldLog { - reqL.Error(msg, flds...) - } - } else if lrw.code >= http.StatusBadRequest { - // Both client and server errors. - if lrw.code == http.StatusNotFound || - lrw.code == http.StatusMethodNotAllowed || - lrw.code == http.StatusTeapot { - // These ones are more of an annoyance, than been actual errors. - reqL.Info(msg, flds...) - } else { - reqL.Error(msg, flds...) - } - } else { - reqL.Info(msg, flds...) - } + doLog(w, *r, lrw.code, l, flds) }() wrappedHandler.ServeHTTP(lrw, r) diff --git a/middleware/recoverer.go b/middleware/recoverer.go index 23e9fbac..110f70af 100644 --- a/middleware/recoverer.go +++ b/middleware/recoverer.go @@ -6,7 +6,6 @@ import ( "net/http" "github.com/komuw/ong/errors" - "github.com/komuw/ong/log" ) // Some of the code here is inspired(or taken from) by: @@ -19,14 +18,10 @@ func recoverer(wrappedHandler http.Handler, l *slog.Logger) http.HandlerFunc { defer func() { errR := recover() if errR != nil { - reqL := log.WithID(r.Context(), l) - code := http.StatusInternalServerError status := http.StatusText(code) - msg := "http_server" flds := []any{ - "error", fmt.Sprint(errR), "clientIP", ClientIP(r), "clientFingerPrint", ClientFingerPrint(r), "method", r.Method, @@ -48,13 +43,13 @@ func recoverer(wrappedHandler http.Handler, l *slog.Logger) http.HandlerFunc { // 1xx class or the modified headers are trailers. w.Header().Del(ongMiddlewareErrorHeader) + extra := []any{"error", fmt.Sprint(errR)} if e, ok := errR.(error); ok { - extra := []any{"err", errors.Wrap(e)} // wrap with ong/errors so that the log will have a stacktrace. - flds = append(flds, extra...) - reqL.Error(msg, flds...) - } else { - reqL.Error(msg, flds...) + extra = []any{"error", errors.Wrap(e)} // wrap with ong/errors so that the log will have a stacktrace. } + flds = append(flds, extra...) + + doLog(w, *r, code, l, flds) // respond. http.Error( From bfaf80be439a7b5da05a76640378070ec135f850 Mon Sep 17 00:00:00 2001 From: komuw Date: Thu, 15 Aug 2024 10:40:36 +0300 Subject: [PATCH 03/20] g --- middleware/log.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/middleware/log.go b/middleware/log.go index d685e48b..035ab359 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -83,15 +83,15 @@ func doLog(w http.ResponseWriter, r http.Request, statusCode int, l *slog.Logger } } - if statusCode < http.StatusBadRequest { - reqL.Info(msg, fields...) - return - } + if statusCode >= http.StatusBadRequest { + // Both client and server errors. + if statusCode == http.StatusNotFound || statusCode == http.StatusMethodNotAllowed || statusCode == http.StatusTeapot { + // These ones are more of an annoyance, than been actual errors. + reqL.Info(msg, fields...) + return + } - // Both client and server errors. - if statusCode == http.StatusNotFound || statusCode == http.StatusMethodNotAllowed || statusCode == http.StatusTeapot { - // These ones are more of an annoyance, than been actual errors. - reqL.Info(msg, fields...) + reqL.Error(msg, fields...) return } From 7933bfea1320eacd96689c1ff739840851b3e305 Mon Sep 17 00:00:00 2001 From: komuw Date: Thu, 15 Aug 2024 10:46:55 +0300 Subject: [PATCH 04/20] g --- middleware/log.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/middleware/log.go b/middleware/log.go index 035ab359..e61a90b0 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -23,8 +23,6 @@ func logger( // We pass the logger as an argument so that the middleware can share the same logger as the app. // That way, if the app logs an error, the middleware logs are also flushed. // This makes debugging easier for developers. - // - // However, each request should get its own context. That's why we call `logger.WithCtx` for every request. // Note: a value of 0, disables logging of ratelimited and loadshed responses. if rateShedSamplePercent < 0 { @@ -69,6 +67,7 @@ func logger( // TODO: func doLog(w http.ResponseWriter, r http.Request, statusCode int, l *slog.Logger, fields []any) { + // Each request should get its own context. That's why we call `log.WithID` for every request. reqL := log.WithID(r.Context(), l) msg := "http_server" From d5ce346a98414d14256f049b03b128a8fe84b531 Mon Sep 17 00:00:00 2001 From: komuw Date: Thu, 15 Aug 2024 10:47:10 +0300 Subject: [PATCH 05/20] g --- middleware/log.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/middleware/log.go b/middleware/log.go index e61a90b0..58b5a64f 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -31,9 +31,8 @@ func logger( return func(w http.ResponseWriter, r *http.Request) { start := time.Now() - lrw := &logRW{ - ResponseWriter: w, - } + lrw := &logRW{ResponseWriter: w} + defer func() { flds := []any{ "clientIP", ClientIP(r), From d3744570ea349b1415187592361608499be0a15d Mon Sep 17 00:00:00 2001 From: komuw Date: Thu, 15 Aug 2024 10:54:42 +0300 Subject: [PATCH 06/20] g --- config/config.go | 27 --------------------------- config/config_test.go | 6 ------ config/example_test.go | 2 -- middleware/log.go | 10 ++-------- middleware/log_test.go | 11 +++++------ middleware/middleware.go | 4 ---- middleware/middleware_test.go | 1 - 7 files changed, 7 insertions(+), 54 deletions(-) diff --git a/config/config.go b/config/config.go index 4941faf7..ffd2c233 100644 --- a/config/config.go +++ b/config/config.go @@ -16,12 +16,6 @@ import ( "github.com/komuw/ong/internal/key" ) -// logging middleware. -const ( - // DefaultRateShedSamplePercent is the percentage of rate limited or loadshed responses that will be logged as errors, by default. - DefaultRateShedSamplePercent = 10 -) - // ratelimit middleware. const ( // DefaultRateLimit is the maximum requests allowed (from one IP address) per second, by default. @@ -201,8 +195,6 @@ func (o Opts) GoString() string { // // logger is an [slog.Logger] that will be used for logging. // -// rateShedSamplePercent is the percentage of rate limited or loadshed responses that will be logged as errors. If it is less than 0, [DefaultRateShedSamplePercent] is used instead. -// // rateLimit is the maximum requests allowed (from one IP address) per second. If it is les than 1.0, [DefaultRateLimit] is used instead. // // loadShedSamplingPeriod is the duration over which we calculate response latencies for purposes of determining whether to loadshed. If it is less than 1second, [DefaultLoadShedSamplingPeriod] is used instead. @@ -262,7 +254,6 @@ func New( secretKey string, strategy ClientIPstrategy, logger *slog.Logger, - rateShedSamplePercent int, rateLimit float64, loadShedSamplingPeriod time.Duration, loadShedMinSampleSize int, @@ -296,7 +287,6 @@ func New( secretKey, strategy, logger, - rateShedSamplePercent, rateLimit, loadShedSamplingPeriod, loadShedMinSampleSize, @@ -360,7 +350,6 @@ func WithOpts( secretKey, strategy, logger, - DefaultRateShedSamplePercent, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -409,7 +398,6 @@ func DevOpts(logger *slog.Logger, secretKey string) Opts { secretKey, clientip.DirectIpStrategy, logger, - DefaultRateShedSamplePercent, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -465,7 +453,6 @@ func CertOpts( secretKey, clientip.DirectIpStrategy, logger, - DefaultRateShedSamplePercent, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -524,7 +511,6 @@ func AcmeOpts( secretKey, clientip.DirectIpStrategy, logger, - DefaultRateShedSamplePercent, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -582,7 +568,6 @@ func LetsEncryptOpts( secretKey, clientip.DirectIpStrategy, logger, - DefaultRateShedSamplePercent, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -640,9 +625,6 @@ type middlewareOpts struct { Strategy ClientIPstrategy Logger *slog.Logger - // logger - RateShedSamplePercent int - // ratelimit RateLimit float64 @@ -674,7 +656,6 @@ func (m middlewareOpts) String() string { SecretKey: %s, Strategy: %v, Logger: %v, - RateShedSamplePercent: %v, RateLimit: %v, LoadShedSamplingPeriod: %v, LoadShedMinSampleSize: %v, @@ -693,7 +674,6 @@ func (m middlewareOpts) String() string { m.SecretKey, m.Strategy, m.Logger, - m.RateShedSamplePercent, m.RateLimit, m.LoadShedSamplingPeriod, m.LoadShedMinSampleSize, @@ -720,7 +700,6 @@ func newMiddlewareOpts( secretKey string, strategy ClientIPstrategy, logger *slog.Logger, - rateShedSamplePercent int, rateLimit float64, loadShedSamplingPeriod time.Duration, loadShedMinSampleSize int, @@ -775,9 +754,6 @@ func newMiddlewareOpts( Strategy: strategy, Logger: logger, - // logger - RateShedSamplePercent: rateShedSamplePercent, - // ratelimiter RateLimit: rateLimit, @@ -1055,9 +1031,6 @@ func (o Opts) Equal(other Opts) bool { return false } - if o.RateShedSamplePercent != other.RateShedSamplePercent { - return false - } if int(o.RateLimit) != int(other.RateLimit) { return false } diff --git a/config/config_test.go b/config/config_test.go index a69425ce..dba24f7b 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -35,8 +35,6 @@ func validOpts(t *testing.T) Opts { SingleIpStrategy("CF-Connecting-IP"), // Logger. l, - // log 90% of all responses that are either rate-limited or loadshed. - 90, // If a particular IP address sends more than 13 requests per second, throttle requests from that IP. 13.0, // Sample response latencies over a 5 minute window to determine if to loadshed. @@ -135,7 +133,6 @@ func TestNewMiddlewareOpts(t *testing.T) { string(opt.SecretKey), opt.Strategy, opt.Logger, - opt.RateShedSamplePercent, opt.RateLimit, opt.LoadShedSamplingPeriod, opt.LoadShedMinSampleSize, @@ -196,7 +193,6 @@ func TestNewMiddlewareOptsDomain(t *testing.T) { tst.SecretKey(), clientip.DirectIpStrategy, slog.Default(), - DefaultRateShedSamplePercent, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -218,7 +214,6 @@ func TestNewMiddlewareOptsDomain(t *testing.T) { tst.SecretKey(), clientip.DirectIpStrategy, slog.Default(), - DefaultRateShedSamplePercent, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -254,7 +249,6 @@ func TestOpts(t *testing.T) { SecretKey: secureKey(tst.SecretKey()), Strategy: clientip.DirectIpStrategy, Logger: l, - RateShedSamplePercent: DefaultRateShedSamplePercent, RateLimit: DefaultRateLimit, LoadShedSamplingPeriod: DefaultLoadShedSamplingPeriod, LoadShedMinSampleSize: DefaultLoadShedMinSampleSize, diff --git a/config/example_test.go b/config/example_test.go index 4acaf8c1..49f3a412 100644 --- a/config/example_test.go +++ b/config/example_test.go @@ -32,8 +32,6 @@ func ExampleNew() { config.SingleIpStrategy("CF-Connecting-IP"), // Logger. l, - // log 90% of all responses that are either rate-limited or loadshed. - 90, // If a particular IP address sends more than 13 requests per second, throttle requests from that IP. 13.0, // Sample response latencies over a 5 minute window to determine if to loadshed. diff --git a/middleware/log.go b/middleware/log.go index 58b5a64f..fba208e5 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -10,7 +10,6 @@ import ( "net/http" "time" - "github.com/komuw/ong/config" "github.com/komuw/ong/log" ) @@ -18,17 +17,11 @@ import ( func logger( wrappedHandler http.Handler, l *slog.Logger, - rateShedSamplePercent int, ) http.HandlerFunc { // We pass the logger as an argument so that the middleware can share the same logger as the app. // That way, if the app logs an error, the middleware logs are also flushed. // This makes debugging easier for developers. - // Note: a value of 0, disables logging of ratelimited and loadshed responses. - if rateShedSamplePercent < 0 { - rateShedSamplePercent = config.DefaultRateShedSamplePercent - } - return func(w http.ResponseWriter, r *http.Request) { start := time.Now() lrw := &logRW{ResponseWriter: w} @@ -69,8 +62,9 @@ func doLog(w http.ResponseWriter, r http.Request, statusCode int, l *slog.Logger // Each request should get its own context. That's why we call `log.WithID` for every request. reqL := log.WithID(r.Context(), l) msg := "http_server" + // rateShedSamplePercent is the percentage of rate limited or loadshed responses that will be logged as errors, by default. + rateShedSamplePercent := 10 - var rateShedSamplePercent int = 0 // TODO if (statusCode == http.StatusServiceUnavailable || statusCode == http.StatusTooManyRequests) && w.Header().Get(retryAfterHeader) != "" { // We are either in load shedding or rate-limiting. // Only log (rateShedSamplePercent)% of the errors. diff --git a/middleware/log_test.go b/middleware/log_test.go index 29f23483..7cb7049b 100644 --- a/middleware/log_test.go +++ b/middleware/log_test.go @@ -12,7 +12,6 @@ import ( "testing" "time" - "github.com/komuw/ong/config" "github.com/komuw/ong/id" "github.com/komuw/ong/log" @@ -54,7 +53,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} successMsg := "hello" - wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput), config.DefaultRateShedSamplePercent) + wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodHead, "/someUri", nil) @@ -77,7 +76,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} errorMsg := "someLogHandler failed" successMsg := "hello" - wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput), config.DefaultRateShedSamplePercent) + wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodHead, "/someUri", nil) @@ -113,7 +112,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} successMsg := "hello" errorMsg := "someLogHandler failed" - wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput), config.DefaultRateShedSamplePercent) + wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput)) { // first request that succeds @@ -192,7 +191,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} successMsg := "hello" - wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput), config.DefaultRateShedSamplePercent) + wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput)) someLogID := "hey-some-log-id:" + id.New() @@ -222,7 +221,7 @@ func TestLogMiddleware(t *testing.T) { successMsg := "hello" // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput), config.DefaultRateShedSamplePercent) + wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput)) runhandler := func() { rec := httptest.NewRecorder() diff --git a/middleware/middleware.go b/middleware/middleware.go index c29bd80e..f91d8116 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -56,9 +56,6 @@ func allDefaultMiddlewares( strategy := o.Strategy l := o.Logger - // logger - rateShedSamplePercent := o.RateShedSamplePercent - // ratelimit rateLimit := o.RateLimit @@ -173,7 +170,6 @@ func allDefaultMiddlewares( ), ), l, - rateShedSamplePercent, ), l, ), diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 369a0f70..68ada4b1 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -489,7 +489,6 @@ func BenchmarkAllMiddlewares(b *testing.B) { tst.SecretKey(), config.DirectIpStrategy, l, - config.DefaultRateShedSamplePercent, rateLimit, config.DefaultLoadShedSamplingPeriod, config.DefaultLoadShedMinSampleSize, From df26041a81340c676b14054a7249babdf6b22ba7 Mon Sep 17 00:00:00 2001 From: komuw Date: Thu, 15 Aug 2024 14:50:55 +0300 Subject: [PATCH 07/20] g --- middleware/log.go | 41 ++-------------------------------------- middleware/log_test.go | 37 ++++++++++++++++++++++++++++++++++++ middleware/middleware.go | 3 --- middleware/recoverer.go | 8 +++++--- 4 files changed, 44 insertions(+), 45 deletions(-) diff --git a/middleware/log.go b/middleware/log.go index fba208e5..9c62ac5d 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -4,19 +4,15 @@ import ( "bufio" "fmt" "io" - "log/slog" - mathRand "math/rand/v2" "net" "net/http" "time" - - "github.com/komuw/ong/log" ) // logger is a middleware that logs http requests and responses using [log.Logger]. func logger( wrappedHandler http.Handler, - l *slog.Logger, + doLog func(w http.ResponseWriter, r http.Request, statusCode int, fields []any), ) http.HandlerFunc { // We pass the logger as an argument so that the middleware can share the same logger as the app. // That way, if the app logs an error, the middleware logs are also flushed. @@ -50,46 +46,13 @@ func logger( // 1xx class or the modified headers are trailers. lrw.Header().Del(ongMiddlewareErrorHeader) - doLog(w, *r, lrw.code, l, flds) + doLog(w, *r, lrw.code, flds) }() wrappedHandler.ServeHTTP(lrw, r) } } -// TODO: -func doLog(w http.ResponseWriter, r http.Request, statusCode int, l *slog.Logger, fields []any) { - // Each request should get its own context. That's why we call `log.WithID` for every request. - reqL := log.WithID(r.Context(), l) - msg := "http_server" - // rateShedSamplePercent is the percentage of rate limited or loadshed responses that will be logged as errors, by default. - rateShedSamplePercent := 10 - - if (statusCode == http.StatusServiceUnavailable || statusCode == http.StatusTooManyRequests) && w.Header().Get(retryAfterHeader) != "" { - // We are either in load shedding or rate-limiting. - // Only log (rateShedSamplePercent)% of the errors. - shouldLog := mathRand.IntN(100) <= rateShedSamplePercent - if shouldLog { - reqL.Error(msg, fields...) - return - } - } - - if statusCode >= http.StatusBadRequest { - // Both client and server errors. - if statusCode == http.StatusNotFound || statusCode == http.StatusMethodNotAllowed || statusCode == http.StatusTeapot { - // These ones are more of an annoyance, than been actual errors. - reqL.Info(msg, fields...) - return - } - - reqL.Error(msg, fields...) - return - } - - reqL.Info(msg, fields...) -} - // logRW provides an http.ResponseWriter interface, which logs requests/responses. type logRW struct { http.ResponseWriter diff --git a/middleware/log_test.go b/middleware/log_test.go index 7cb7049b..aeea4be4 100644 --- a/middleware/log_test.go +++ b/middleware/log_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "log/slog" + mathRand "math/rand/v2" "net/http" "net/http/httptest" "sync" @@ -41,6 +42,42 @@ func someLogHandler(successMsg string) http.HandlerFunc { } } +func toLog(t *testing.T, l *slog.Logger) func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) { + t.Helper() + + return func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) { + // Each request should get its own context. That's why we call `log.WithID` for every request. + reqL := log.WithID(r.Context(), l) + msg := "http_server" + // rateShedSamplePercent is the percentage of rate limited or loadshed responses that will be logged as errors, by default. + rateShedSamplePercent := 10 + + if (statusCode == http.StatusServiceUnavailable || statusCode == http.StatusTooManyRequests) && w.Header().Get(retryAfterHeader) != "" { + // We are either in load shedding or rate-limiting. + // Only log (rateShedSamplePercent)% of the errors. + shouldLog := mathRand.IntN(100) <= rateShedSamplePercent + if shouldLog { + reqL.Error(msg, fields...) + return + } + } + + if statusCode >= http.StatusBadRequest { + // Both client and server errors. + if statusCode == http.StatusNotFound || statusCode == http.StatusMethodNotAllowed || statusCode == http.StatusTeapot { + // These ones are more of an annoyance, than been actual errors. + reqL.Info(msg, fields...) + return + } + + reqL.Error(msg, fields...) + return + } + + reqL.Info(msg, fields...) + } +} + func TestLogMiddleware(t *testing.T) { t.Parallel() diff --git a/middleware/middleware.go b/middleware/middleware.go index f91d8116..f3cc809b 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -54,7 +54,6 @@ func allDefaultMiddlewares( httpsPort := o.HttpsPort secretKey := o.SecretKey strategy := o.Strategy - l := o.Logger // ratelimit rateLimit := o.RateLimit @@ -169,9 +168,7 @@ func allDefaultMiddlewares( rateLimit, ), ), - l, ), - l, ), ), strategy, diff --git a/middleware/recoverer.go b/middleware/recoverer.go index 110f70af..a0e62c2a 100644 --- a/middleware/recoverer.go +++ b/middleware/recoverer.go @@ -2,7 +2,6 @@ package middleware import ( "fmt" - "log/slog" "net/http" "github.com/komuw/ong/errors" @@ -13,7 +12,10 @@ import ( // recoverer is a middleware that recovers from panics in wrappedHandler. // When/if a panic occurs, it logs the stack trace and returns an InternalServerError response. -func recoverer(wrappedHandler http.Handler, l *slog.Logger) http.HandlerFunc { +func recoverer( + wrappedHandler http.Handler, + doLog func(w http.ResponseWriter, r http.Request, statusCode int, fields []any), +) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { defer func() { errR := recover() @@ -49,7 +51,7 @@ func recoverer(wrappedHandler http.Handler, l *slog.Logger) http.HandlerFunc { } flds = append(flds, extra...) - doLog(w, *r, code, l, flds) + doLog(w, *r, code, flds) // respond. http.Error( From 163fa9c0b39c8ba9fe92f102da922c4f287aef12 Mon Sep 17 00:00:00 2001 From: komuw Date: Fri, 16 Aug 2024 17:56:23 +0300 Subject: [PATCH 08/20] g --- middleware/log_test.go | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/middleware/log_test.go b/middleware/log_test.go index aeea4be4..dd7baea0 100644 --- a/middleware/log_test.go +++ b/middleware/log_test.go @@ -24,24 +24,6 @@ const ( someLatencyMS = 3 ) -func someLogHandler(successMsg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // sleep so that the log middleware has some useful duration metrics to report. - time.Sleep(someLatencyMS * time.Millisecond) - if r.Header.Get(someLogHandlerHeader) != "" { - http.Error( - w, - r.Header.Get(someLogHandlerHeader), - http.StatusInternalServerError, - ) - return - } else { - fmt.Fprint(w, successMsg) - return - } - } -} - func toLog(t *testing.T, l *slog.Logger) func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) { t.Helper() @@ -78,6 +60,24 @@ func toLog(t *testing.T, l *slog.Logger) func(w http.ResponseWriter, r http.Requ } } +func someLogHandler(successMsg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // sleep so that the log middleware has some useful duration metrics to report. + time.Sleep(someLatencyMS * time.Millisecond) + if r.Header.Get(someLogHandlerHeader) != "" { + http.Error( + w, + r.Header.Get(someLogHandlerHeader), + http.StatusInternalServerError, + ) + return + } else { + fmt.Fprint(w, successMsg) + return + } + } +} + func TestLogMiddleware(t *testing.T) { t.Parallel() From 2cf73be5c3672df005ef7507dd7c01a779186073 Mon Sep 17 00:00:00 2001 From: komuw Date: Fri, 16 Aug 2024 18:17:03 +0300 Subject: [PATCH 09/20] g --- middleware/log_test.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/middleware/log_test.go b/middleware/log_test.go index dd7baea0..d35cb41b 100644 --- a/middleware/log_test.go +++ b/middleware/log_test.go @@ -27,12 +27,15 @@ const ( func toLog(t *testing.T, l *slog.Logger) func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) { t.Helper() + const ( + msg = "http_server" + // rateShedSamplePercent is the percentage of rate limited or loadshed responses that will be logged as errors, by default. + rateShedSamplePercent = 10 + ) + return func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) { // Each request should get its own context. That's why we call `log.WithID` for every request. reqL := log.WithID(r.Context(), l) - msg := "http_server" - // rateShedSamplePercent is the percentage of rate limited or loadshed responses that will be logged as errors, by default. - rateShedSamplePercent := 10 if (statusCode == http.StatusServiceUnavailable || statusCode == http.StatusTooManyRequests) && w.Header().Get(retryAfterHeader) != "" { // We are either in load shedding or rate-limiting. From 0460c9c7015180c8d90357a863b431e4fde7d10b Mon Sep 17 00:00:00 2001 From: komuw Date: Fri, 16 Aug 2024 21:56:04 +0300 Subject: [PATCH 10/20] g --- middleware/log_test.go | 19 ++++++++----------- middleware/middleware.go | 7 +++++++ middleware/recoverer_test.go | 25 ++++++++----------------- 3 files changed, 23 insertions(+), 28 deletions(-) diff --git a/middleware/log_test.go b/middleware/log_test.go index d35cb41b..88786a07 100644 --- a/middleware/log_test.go +++ b/middleware/log_test.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "io" - "log/slog" mathRand "math/rand/v2" "net/http" "net/http/httptest" @@ -24,7 +23,7 @@ const ( someLatencyMS = 3 ) -func toLog(t *testing.T, l *slog.Logger) func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) { +func toLog(t *testing.T, buf *bytes.Buffer) func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) { t.Helper() const ( @@ -33,6 +32,8 @@ func toLog(t *testing.T, l *slog.Logger) func(w http.ResponseWriter, r http.Requ rateShedSamplePercent = 10 ) + l := log.New(context.Background(), buf, 500) + return func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) { // Each request should get its own context. That's why we call `log.WithID` for every request. reqL := log.WithID(r.Context(), l) @@ -84,16 +85,12 @@ func someLogHandler(successMsg string) http.HandlerFunc { func TestLogMiddleware(t *testing.T) { t.Parallel() - getLogger := func(w io.Writer) *slog.Logger { - return log.New(context.Background(), w, 500) - } - t.Run("success", func(t *testing.T) { t.Parallel() logOutput := &bytes.Buffer{} successMsg := "hello" - wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput)) + wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodHead, "/someUri", nil) @@ -116,7 +113,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} errorMsg := "someLogHandler failed" successMsg := "hello" - wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput)) + wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodHead, "/someUri", nil) @@ -152,7 +149,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} successMsg := "hello" errorMsg := "someLogHandler failed" - wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput)) + wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput)) { // first request that succeds @@ -231,7 +228,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} successMsg := "hello" - wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput)) + wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput)) someLogID := "hey-some-log-id:" + id.New() @@ -261,7 +258,7 @@ func TestLogMiddleware(t *testing.T) { successMsg := "hello" // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput)) + wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput)) runhandler := func() { rec := httptest.NewRecorder() diff --git a/middleware/middleware.go b/middleware/middleware.go index f3cc809b..c57fcc24 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -41,6 +41,11 @@ const ( allowHeader = "Allow" ) +// TODO: remove +func todoRemove(w http.ResponseWriter, r http.Request, statusCode int, fields []any) { + return +} + // allDefaultMiddlewares is a middleware that bundles all the default/core middlewares into one. // // example usage: @@ -168,7 +173,9 @@ func allDefaultMiddlewares( rateLimit, ), ), + todoRemove, ), + todoRemove, ), ), strategy, diff --git a/middleware/recoverer_test.go b/middleware/recoverer_test.go index c9571136..2ba83a6f 100644 --- a/middleware/recoverer_test.go +++ b/middleware/recoverer_test.go @@ -2,19 +2,14 @@ package middleware import ( "bytes" - "context" "fmt" - "io" - "log/slog" "net/http" "net/http/httptest" - "os" "strings" "sync" "testing" "github.com/komuw/ong/errors" - "github.com/komuw/ong/log" "go.akshayshah.org/attest" ) @@ -48,16 +43,12 @@ func anotherHandlerThatPanics() http.HandlerFunc { func TestPanic(t *testing.T) { t.Parallel() - getLogger := func(w io.Writer) *slog.Logger { - return log.New(context.Background(), w, 500) - } - t.Run("ok if no panic", func(t *testing.T) { t.Parallel() logOutput := &bytes.Buffer{} msg := "hello" - wrappedHandler := recoverer(handlerThatPanics(msg, false, nil), getLogger(logOutput)) + wrappedHandler := recoverer(handlerThatPanics(msg, false, nil), toLog(t, logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -74,7 +65,7 @@ func TestPanic(t *testing.T) { logOutput := &bytes.Buffer{} msg := "hello" - wrappedHandler := recoverer(handlerThatPanics(msg, true, nil), getLogger(logOutput)) + wrappedHandler := recoverer(handlerThatPanics(msg, true, nil), toLog(t, logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -104,7 +95,7 @@ func TestPanic(t *testing.T) { msg := "hello" errMsg := "99 problems" err := errors.New(errMsg) - wrappedHandler := recoverer(handlerThatPanics(msg, false, err), getLogger(logOutput)) + wrappedHandler := recoverer(handlerThatPanics(msg, false, err), toLog(t, logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -132,7 +123,7 @@ func TestPanic(t *testing.T) { t.Parallel() logOutput := &bytes.Buffer{} - wrappedHandler := recoverer(anotherHandlerThatPanics(), getLogger(logOutput)) + wrappedHandler := recoverer(anotherHandlerThatPanics(), toLog(t, logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -141,19 +132,19 @@ func TestPanic(t *testing.T) { res := rec.Result() defer res.Body.Close() attest.Equal(t, res.StatusCode, http.StatusInternalServerError) - attest.Subsequence(t, logOutput.String(), "middleware/recoverer_test.go:42") // line where panic happened. + attest.Subsequence(t, logOutput.String(), "middleware/recoverer_test.go:37") // line where panic happened. }) t.Run("concurrency safe", func(t *testing.T) { t.Parallel() - // &bytes.Buffer{} is not concurrency safe, so we use os.Stderr instead. - logOutput := os.Stderr + // If &bytes.Buffer{} is not concurrency safe, we can use os.Stderr instead. + logOutput := &bytes.Buffer{} msg := "hey" err := errors.New(msg) // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := recoverer(handlerThatPanics(msg, false, err), getLogger(logOutput)) + wrappedHandler := recoverer(handlerThatPanics(msg, false, err), toLog(t, logOutput)) runhandler := func() { rec := httptest.NewRecorder() From e82feae63978fb1a2ea8dffc38529e416cbdf4c4 Mon Sep 17 00:00:00 2001 From: komuw Date: Sat, 17 Aug 2024 11:45:20 +0300 Subject: [PATCH 11/20] g --- config/config.go | 37 +++++++++++++++++++++-------------- config/config_test.go | 17 ++++++++++------ config/example_test.go | 9 +++++++-- middleware/middleware_test.go | 3 ++- 4 files changed, 42 insertions(+), 24 deletions(-) diff --git a/config/config.go b/config/config.go index ffd2c233..ae24b219 100644 --- a/config/config.go +++ b/config/config.go @@ -253,7 +253,7 @@ func New( // middleware secretKey string, strategy ClientIPstrategy, - logger *slog.Logger, + logFunc func(w http.ResponseWriter, r http.Request, statusCode int, fields []any), rateLimit float64, loadShedSamplingPeriod time.Duration, loadShedMinSampleSize int, @@ -267,6 +267,7 @@ func New( sessionCookieDuration time.Duration, sessionAntiReplayFunc func(r http.Request) string, // server + logger *slog.Logger, maxBodyBytes uint64, serverLogLevel slog.Level, readHeaderTimeout time.Duration, @@ -286,7 +287,7 @@ func New( port, secretKey, strategy, - logger, + logFunc, rateLimit, loadShedSamplingPeriod, loadShedMinSampleSize, @@ -309,6 +310,7 @@ func New( serverOpts: newServerOpts( domain, port, + logger, maxBodyBytes, serverLogLevel, readHeaderTimeout, @@ -349,7 +351,7 @@ func WithOpts( // middleware secretKey, strategy, - logger, + nil, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -363,6 +365,7 @@ func WithOpts( DefaultSessionCookieDuration, DefaultSessionAntiReplayFunc, // server + logger, DefaultMaxBodyBytes, DefaultServerLogLevel, defaultReadHeaderTimeout, @@ -397,7 +400,7 @@ func DevOpts(logger *slog.Logger, secretKey string) Opts { // middleware secretKey, clientip.DirectIpStrategy, - logger, + nil, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -411,6 +414,7 @@ func DevOpts(logger *slog.Logger, secretKey string) Opts { DefaultSessionCookieDuration, DefaultSessionAntiReplayFunc, // server + logger, DefaultMaxBodyBytes, DefaultServerLogLevel, defaultReadHeaderTimeout, @@ -452,7 +456,7 @@ func CertOpts( // middleware secretKey, clientip.DirectIpStrategy, - logger, + nil, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -466,6 +470,7 @@ func CertOpts( DefaultSessionCookieDuration, DefaultSessionAntiReplayFunc, // server + logger, DefaultMaxBodyBytes, DefaultServerLogLevel, defaultReadHeaderTimeout, @@ -510,7 +515,7 @@ func AcmeOpts( // middleware secretKey, clientip.DirectIpStrategy, - logger, + nil, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -524,6 +529,7 @@ func AcmeOpts( DefaultSessionCookieDuration, DefaultSessionAntiReplayFunc, // server + logger, DefaultMaxBodyBytes, DefaultServerLogLevel, defaultReadHeaderTimeout, @@ -567,7 +573,7 @@ func LetsEncryptOpts( // middleware secretKey, clientip.DirectIpStrategy, - logger, + nil, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -581,6 +587,7 @@ func LetsEncryptOpts( DefaultSessionCookieDuration, DefaultSessionAntiReplayFunc, // server + logger, DefaultMaxBodyBytes, DefaultServerLogLevel, defaultReadHeaderTimeout, @@ -623,7 +630,7 @@ type middlewareOpts struct { // - https://go.dev/play/p/wL2gqumZ23b SecretKey secureKey Strategy ClientIPstrategy - Logger *slog.Logger + LogFunc func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) // ratelimit RateLimit float64 @@ -655,7 +662,6 @@ func (m middlewareOpts) String() string { HttpsPort: %d, SecretKey: %s, Strategy: %v, - Logger: %v, RateLimit: %v, LoadShedSamplingPeriod: %v, LoadShedMinSampleSize: %v, @@ -673,7 +679,6 @@ func (m middlewareOpts) String() string { m.HttpsPort, m.SecretKey, m.Strategy, - m.Logger, m.RateLimit, m.LoadShedSamplingPeriod, m.LoadShedMinSampleSize, @@ -699,7 +704,7 @@ func newMiddlewareOpts( httpsPort uint16, secretKey string, strategy ClientIPstrategy, - logger *slog.Logger, + logFunc func(w http.ResponseWriter, r http.Request, statusCode int, fields []any), rateLimit float64, loadShedSamplingPeriod time.Duration, loadShedMinSampleSize int, @@ -752,7 +757,7 @@ func newMiddlewareOpts( HttpsPort: httpsPort, SecretKey: secureKey(secretKey), Strategy: strategy, - Logger: logger, + LogFunc: logFunc, // ratelimiter RateLimit: rateLimit, @@ -821,6 +826,7 @@ func (t tlsOpts) GoString() string { // serverOpts are the various parameters(optionals) that can be used to configure a HTTP server. type serverOpts struct { port uint16 // tcp port is a 16bit unsigned integer. + Logger *slog.Logger MaxBodyBytes uint64 // max size of request body allowed. ServerLogLevel slog.Level ReadHeaderTimeout time.Duration @@ -842,6 +848,7 @@ type serverOpts struct { func newServerOpts( domain string, port uint16, + logger *slog.Logger, maxBodyBytes uint64, serverLogLevel slog.Level, readHeaderTimeout time.Duration, @@ -887,6 +894,7 @@ func newServerOpts( return serverOpts{ port: port, + Logger: logger, MaxBodyBytes: maxBodyBytes, ServerLogLevel: serverLogLevel, ReadHeaderTimeout: readHeaderTimeout, @@ -917,6 +925,7 @@ func newServerOpts( func (s serverOpts) String() string { return fmt.Sprintf(`serverOpts{ port: %v, + Logger: %v, MaxBodyBytes: %v, ServerLogLevel: %v, ReadHeaderTimeout: %v, @@ -932,6 +941,7 @@ func (s serverOpts) String() string { HttpPort: %v, }`, s.port, + s.Logger, s.MaxBodyBytes, s.ServerLogLevel, s.ReadHeaderTimeout, @@ -1027,9 +1037,6 @@ func (o Opts) Equal(other Opts) bool { if o.Strategy != other.Strategy { return false } - if o.Logger != other.Logger { - return false - } if int(o.RateLimit) != int(other.RateLimit) { return false diff --git a/config/config_test.go b/config/config_test.go index dba24f7b..f724f74a 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -33,8 +33,11 @@ func validOpts(t *testing.T) Opts { "super-h@rd-Pas1word", // In this case, the actual client IP address is fetched from the given http header. SingleIpStrategy("CF-Connecting-IP"), - // Logger. - l, + // function to log in middlewares. + func(_ http.ResponseWriter, r http.Request, statusCode int, fields []any) { + reqL := log.WithID(r.Context(), l) + reqL.Info("request-and-response", fields...) + }, // If a particular IP address sends more than 13 requests per second, throttle requests from that IP. 13.0, // Sample response latencies over a 5 minute window to determine if to loadshed. @@ -60,6 +63,8 @@ func validOpts(t *testing.T) Opts { // Use a given header to try and mitigate against replay-attacks. func(r http.Request) string { return r.Header.Get("Anti-Replay") }, // + // Logger. + l, // The maximum size in bytes for incoming request bodies. 2*1024*1024, // Log level of the logger that will be passed into [http.Server.ErrorLog] @@ -132,7 +137,7 @@ func TestNewMiddlewareOpts(t *testing.T) { opt.HttpsPort, string(opt.SecretKey), opt.Strategy, - opt.Logger, + nil, opt.RateLimit, opt.LoadShedSamplingPeriod, opt.LoadShedMinSampleSize, @@ -192,7 +197,7 @@ func TestNewMiddlewareOptsDomain(t *testing.T) { 443, tst.SecretKey(), clientip.DirectIpStrategy, - slog.Default(), + nil, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -213,7 +218,7 @@ func TestNewMiddlewareOptsDomain(t *testing.T) { 443, tst.SecretKey(), clientip.DirectIpStrategy, - slog.Default(), + nil, DefaultRateLimit, DefaultLoadShedSamplingPeriod, DefaultLoadShedMinSampleSize, @@ -248,7 +253,7 @@ func TestOpts(t *testing.T) { HttpsPort: 65081, SecretKey: secureKey(tst.SecretKey()), Strategy: clientip.DirectIpStrategy, - Logger: l, + LogFunc: nil, RateLimit: DefaultRateLimit, LoadShedSamplingPeriod: DefaultLoadShedSamplingPeriod, LoadShedMinSampleSize: DefaultLoadShedMinSampleSize, diff --git a/config/example_test.go b/config/example_test.go index 49f3a412..750bf39c 100644 --- a/config/example_test.go +++ b/config/example_test.go @@ -30,8 +30,11 @@ func ExampleNew() { "super-h@rd-Pas1word", // In this case, the actual client IP address is fetched from the given http header. config.SingleIpStrategy("CF-Connecting-IP"), - // Logger. - l, + // function to log in middlewares. + func(_ http.ResponseWriter, r http.Request, statusCode int, fields []any) { + reqL := log.WithID(r.Context(), l) + reqL.Info("request-and-response", fields...) + }, // If a particular IP address sends more than 13 requests per second, throttle requests from that IP. 13.0, // Sample response latencies over a 5 minute window to determine if to loadshed. @@ -57,6 +60,8 @@ func ExampleNew() { // Use a given header to try and mitigate against replay-attacks. func(r http.Request) string { return r.Header.Get("Anti-Replay") }, // + // Logger. + l, // The maximum size in bytes for incoming request bodies. 2*1024*1024, // Log level of the logger that will be passed into [http.Server.ErrorLog] diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 68ada4b1..396fe274 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -488,7 +488,7 @@ func BenchmarkAllMiddlewares(b *testing.B) { httpsPort, tst.SecretKey(), config.DirectIpStrategy, - l, + nil, rateLimit, config.DefaultLoadShedSamplingPeriod, config.DefaultLoadShedMinSampleSize, @@ -501,6 +501,7 @@ func BenchmarkAllMiddlewares(b *testing.B) { config.DefaultCsrfCookieDuration, config.DefaultSessionCookieDuration, config.DefaultSessionAntiReplayFunc, + l, 20*1024*1024, slog.LevelDebug, 1*time.Second, From 7d45042d2f4ceea39e7aa15304177ecc98e45955 Mon Sep 17 00:00:00 2001 From: komuw Date: Sat, 17 Aug 2024 11:47:15 +0300 Subject: [PATCH 12/20] g --- middleware/middleware.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/middleware/middleware.go b/middleware/middleware.go index c57fcc24..05f2b487 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -41,11 +41,6 @@ const ( allowHeader = "Allow" ) -// TODO: remove -func todoRemove(w http.ResponseWriter, r http.Request, statusCode int, fields []any) { - return -} - // allDefaultMiddlewares is a middleware that bundles all the default/core middlewares into one. // // example usage: @@ -60,6 +55,9 @@ func allDefaultMiddlewares( secretKey := o.SecretKey strategy := o.Strategy + // logger + logFunc := o.LogFunc + // ratelimit rateLimit := o.RateLimit @@ -173,9 +171,9 @@ func allDefaultMiddlewares( rateLimit, ), ), - todoRemove, + logFunc, ), - todoRemove, + logFunc, ), ), strategy, From cbb1ea9a6651b77ce01b4efab56942f0805851f4 Mon Sep 17 00:00:00 2001 From: komuw Date: Sat, 17 Aug 2024 12:00:24 +0300 Subject: [PATCH 13/20] g --- config/config.go | 5 +++++ middleware/log.go | 11 +++++++---- middleware/recoverer.go | 11 +++++++---- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/config/config.go b/config/config.go index ae24b219..c5e9cc43 100644 --- a/config/config.go +++ b/config/config.go @@ -195,6 +195,9 @@ func (o Opts) GoString() string { // // logger is an [slog.Logger] that will be used for logging. // +// logFunc is a function that dictates what/how middleware is going to log. It is also used to log any recovered panics in the middleware. +// If it is nil, no logging happens in the middleware. For server logging, see [logger] +// // rateLimit is the maximum requests allowed (from one IP address) per second. If it is les than 1.0, [DefaultRateLimit] is used instead. // // loadShedSamplingPeriod is the duration over which we calculate response latencies for purposes of determining whether to loadshed. If it is less than 1second, [DefaultLoadShedSamplingPeriod] is used instead. @@ -216,6 +219,8 @@ func (o Opts) GoString() string { // sessionAntiReplayFunc is the function used to return a token that will be used to try and mitigate against [replay attacks]. This mitigation not foolproof. // If it is nil, [DefaultSessionAntiReplayFunc] is used instead. // +// logger is an [slog.Logger] that will be used for logging in the server but not middleware. For middleware see [logFunc] +// // maxBodyBytes is the maximum size in bytes for incoming request bodies. If this is zero, a reasonable default is used. // // serverLogLevel is the log level of the logger that will be passed into [http.Server.ErrorLog] diff --git a/middleware/log.go b/middleware/log.go index 9c62ac5d..4c200dfd 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -12,11 +12,12 @@ import ( // logger is a middleware that logs http requests and responses using [log.Logger]. func logger( wrappedHandler http.Handler, - doLog func(w http.ResponseWriter, r http.Request, statusCode int, fields []any), + logFunc func(w http.ResponseWriter, r http.Request, statusCode int, fields []any), ) http.HandlerFunc { - // We pass the logger as an argument so that the middleware can share the same logger as the app. + // The middleware should ideally share the same logger as the app. // That way, if the app logs an error, the middleware logs are also flushed. - // This makes debugging easier for developers. + // That's one reason why we pass in the logFunc.This makes debugging easier for developers. + // Another reason is so that app developers are in control of what(and how) exactly gets logged. return func(w http.ResponseWriter, r *http.Request) { start := time.Now() @@ -46,7 +47,9 @@ func logger( // 1xx class or the modified headers are trailers. lrw.Header().Del(ongMiddlewareErrorHeader) - doLog(w, *r, lrw.code, flds) + if logFunc != nil { + logFunc(w, *r, lrw.code, flds) + } }() wrappedHandler.ServeHTTP(lrw, r) diff --git a/middleware/recoverer.go b/middleware/recoverer.go index a0e62c2a..f4c882bf 100644 --- a/middleware/recoverer.go +++ b/middleware/recoverer.go @@ -14,14 +14,15 @@ import ( // When/if a panic occurs, it logs the stack trace and returns an InternalServerError response. func recoverer( wrappedHandler http.Handler, - doLog func(w http.ResponseWriter, r http.Request, statusCode int, fields []any), + logFunc func(w http.ResponseWriter, r http.Request, statusCode int, fields []any), ) http.HandlerFunc { + code := http.StatusInternalServerError + status := http.StatusText(code) + return func(w http.ResponseWriter, r *http.Request) { defer func() { errR := recover() if errR != nil { - code := http.StatusInternalServerError - status := http.StatusText(code) flds := []any{ "clientIP", ClientIP(r), @@ -51,7 +52,9 @@ func recoverer( } flds = append(flds, extra...) - doLog(w, *r, code, flds) + if logFunc != nil { + logFunc(w, *r, code, flds) + } // respond. http.Error( From 6e311a66bb5360a421aaad5560a253a291a8c1e0 Mon Sep 17 00:00:00 2001 From: komuw Date: Sat, 17 Aug 2024 12:19:50 +0300 Subject: [PATCH 14/20] g --- config/config.go | 37 +++++++++++++----------- config/config_test.go | 7 +++-- config/example_test.go | 4 +-- middleware/log.go | 54 +++++++++++++++++++++++++++++++++-- middleware/log_test.go | 10 +++---- middleware/middleware.go | 3 ++ middleware/middleware_test.go | 2 +- middleware/recoverer.go | 10 +++++-- middleware/recoverer_test.go | 10 +++---- 9 files changed, 99 insertions(+), 38 deletions(-) diff --git a/config/config.go b/config/config.go index c5e9cc43..957200a7 100644 --- a/config/config.go +++ b/config/config.go @@ -3,6 +3,7 @@ package config import ( "crypto/x509" + "errors" "fmt" "log/slog" "net/http" @@ -187,16 +188,17 @@ func (o Opts) GoString() string { // domain is the domain name of your website. It can be an exact domain, subdomain or wildcard. // port is the TLS port where the server will listen on. Http requests will also redirected to that port. // +// logger is an [slog.Logger] that will be used for logging. +// // secretKey is used for securing signed data. It should be unique & kept secret. // If it becomes compromised, generate a new one and restart your application using the new one. // // strategy is the algorithm to use when fetching the client's IP address; see [ClientIPstrategy]. // It is important to choose your strategy carefully, see the warning in [ClientIPstrategy]. // -// logger is an [slog.Logger] that will be used for logging. -// // logFunc is a function that dictates what/how middleware is going to log. It is also used to log any recovered panics in the middleware. -// If it is nil, no logging happens in the middleware. For server logging, see [logger] +// This function is not used in the server. +// If it is nil, a suitable default is used. To disable logging, use a function that does nothing. // // rateLimit is the maximum requests allowed (from one IP address) per second. If it is les than 1.0, [DefaultRateLimit] is used instead. // @@ -219,8 +221,6 @@ func (o Opts) GoString() string { // sessionAntiReplayFunc is the function used to return a token that will be used to try and mitigate against [replay attacks]. This mitigation not foolproof. // If it is nil, [DefaultSessionAntiReplayFunc] is used instead. // -// logger is an [slog.Logger] that will be used for logging in the server but not middleware. For middleware see [logFunc] -// // maxBodyBytes is the maximum size in bytes for incoming request bodies. If this is zero, a reasonable default is used. // // serverLogLevel is the log level of the logger that will be passed into [http.Server.ErrorLog] @@ -254,6 +254,7 @@ func New( // common domain string, port uint16, + logger *slog.Logger, // middleware secretKey string, @@ -272,7 +273,6 @@ func New( sessionCookieDuration time.Duration, sessionAntiReplayFunc func(r http.Request) string, // server - logger *slog.Logger, maxBodyBytes uint64, serverLogLevel slog.Level, readHeaderTimeout time.Duration, @@ -290,6 +290,7 @@ func New( middlewareOpts, err := newMiddlewareOpts( domain, port, + logger, secretKey, strategy, logFunc, @@ -315,7 +316,6 @@ func New( serverOpts: newServerOpts( domain, port, - logger, maxBodyBytes, serverLogLevel, readHeaderTimeout, @@ -352,6 +352,7 @@ func WithOpts( // common domain, httpsPort, + logger, // middleware secretKey, @@ -370,7 +371,6 @@ func WithOpts( DefaultSessionCookieDuration, DefaultSessionAntiReplayFunc, // server - logger, DefaultMaxBodyBytes, DefaultServerLogLevel, defaultReadHeaderTimeout, @@ -401,6 +401,7 @@ func DevOpts(logger *slog.Logger, secretKey string) Opts { // common domain, httpsPort, + logger, // middleware secretKey, @@ -419,7 +420,6 @@ func DevOpts(logger *slog.Logger, secretKey string) Opts { DefaultSessionCookieDuration, DefaultSessionAntiReplayFunc, // server - logger, DefaultMaxBodyBytes, DefaultServerLogLevel, defaultReadHeaderTimeout, @@ -457,6 +457,7 @@ func CertOpts( // common domain, httpsPort, + logger, // middleware secretKey, @@ -475,7 +476,6 @@ func CertOpts( DefaultSessionCookieDuration, DefaultSessionAntiReplayFunc, // server - logger, DefaultMaxBodyBytes, DefaultServerLogLevel, defaultReadHeaderTimeout, @@ -516,6 +516,7 @@ func AcmeOpts( // common domain, httpsPort, + logger, // middleware secretKey, @@ -534,7 +535,6 @@ func AcmeOpts( DefaultSessionCookieDuration, DefaultSessionAntiReplayFunc, // server - logger, DefaultMaxBodyBytes, DefaultServerLogLevel, defaultReadHeaderTimeout, @@ -574,6 +574,7 @@ func LetsEncryptOpts( // common domain, httpsPort, + logger, // middleware secretKey, @@ -592,7 +593,6 @@ func LetsEncryptOpts( DefaultSessionCookieDuration, DefaultSessionAntiReplayFunc, // server - logger, DefaultMaxBodyBytes, DefaultServerLogLevel, defaultReadHeaderTimeout, @@ -629,6 +629,8 @@ func (s secureKey) GoString() string { type middlewareOpts struct { Domain string HttpsPort uint16 + Logger *slog.Logger + // When printing a struct, fmt does not invoke custom formatting methods on unexported fields. // We thus need to make this field to be exported. // - https://pkg.go.dev/fmt#:~:text=When%20printing%20a%20struct @@ -707,6 +709,7 @@ func (m middlewareOpts) GoString() string { func newMiddlewareOpts( domain string, httpsPort uint16, + logger *slog.Logger, secretKey string, strategy ClientIPstrategy, logFunc func(w http.ResponseWriter, r http.Request, statusCode int, fields []any), @@ -732,6 +735,10 @@ func newMiddlewareOpts( domain = domain[2:] } + if logger == nil && logFunc == nil { + return middlewareOpts{}, errors.New("both logger and logFunc should not be nil at the same time") + } + if err := key.IsSecure(secretKey); err != nil { return middlewareOpts{}, err } @@ -760,6 +767,7 @@ func newMiddlewareOpts( return middlewareOpts{ Domain: domain, HttpsPort: httpsPort, + Logger: logger, SecretKey: secureKey(secretKey), Strategy: strategy, LogFunc: logFunc, @@ -831,7 +839,6 @@ func (t tlsOpts) GoString() string { // serverOpts are the various parameters(optionals) that can be used to configure a HTTP server. type serverOpts struct { port uint16 // tcp port is a 16bit unsigned integer. - Logger *slog.Logger MaxBodyBytes uint64 // max size of request body allowed. ServerLogLevel slog.Level ReadHeaderTimeout time.Duration @@ -853,7 +860,6 @@ type serverOpts struct { func newServerOpts( domain string, port uint16, - logger *slog.Logger, maxBodyBytes uint64, serverLogLevel slog.Level, readHeaderTimeout time.Duration, @@ -899,7 +905,6 @@ func newServerOpts( return serverOpts{ port: port, - Logger: logger, MaxBodyBytes: maxBodyBytes, ServerLogLevel: serverLogLevel, ReadHeaderTimeout: readHeaderTimeout, @@ -930,7 +935,6 @@ func newServerOpts( func (s serverOpts) String() string { return fmt.Sprintf(`serverOpts{ port: %v, - Logger: %v, MaxBodyBytes: %v, ServerLogLevel: %v, ReadHeaderTimeout: %v, @@ -946,7 +950,6 @@ func (s serverOpts) String() string { HttpPort: %v, }`, s.port, - s.Logger, s.MaxBodyBytes, s.ServerLogLevel, s.ReadHeaderTimeout, diff --git a/config/config_test.go b/config/config_test.go index f724f74a..f1931748 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -29,6 +29,8 @@ func validOpts(t *testing.T) Opts { "example.com", // The https port that our application will be listening on. 443, + // Logger. + l, // The security key to use for securing signed data. "super-h@rd-Pas1word", // In this case, the actual client IP address is fetched from the given http header. @@ -63,8 +65,6 @@ func validOpts(t *testing.T) Opts { // Use a given header to try and mitigate against replay-attacks. func(r http.Request) string { return r.Header.Get("Anti-Replay") }, // - // Logger. - l, // The maximum size in bytes for incoming request bodies. 2*1024*1024, // Log level of the logger that will be passed into [http.Server.ErrorLog] @@ -135,6 +135,7 @@ func TestNewMiddlewareOpts(t *testing.T) { o, err := newMiddlewareOpts( opt.Domain, opt.HttpsPort, + slog.Default(), string(opt.SecretKey), opt.Strategy, nil, @@ -195,6 +196,7 @@ func TestNewMiddlewareOptsDomain(t *testing.T) { _, err := newMiddlewareOpts( tt.domain, 443, + slog.Default(), tst.SecretKey(), clientip.DirectIpStrategy, nil, @@ -216,6 +218,7 @@ func TestNewMiddlewareOptsDomain(t *testing.T) { _, err := newMiddlewareOpts( tt.domain, 443, + slog.Default(), tst.SecretKey(), clientip.DirectIpStrategy, nil, diff --git a/config/example_test.go b/config/example_test.go index 750bf39c..bceb0a0c 100644 --- a/config/example_test.go +++ b/config/example_test.go @@ -26,6 +26,8 @@ func ExampleNew() { "example.com", // The https port that our application will be listening on. 443, + // Logger. + l, // The security key to use for securing signed data. "super-h@rd-Pas1word", // In this case, the actual client IP address is fetched from the given http header. @@ -60,8 +62,6 @@ func ExampleNew() { // Use a given header to try and mitigate against replay-attacks. func(r http.Request) string { return r.Header.Get("Anti-Replay") }, // - // Logger. - l, // The maximum size in bytes for incoming request bodies. 2*1024*1024, // Log level of the logger that will be passed into [http.Server.ErrorLog] diff --git a/middleware/log.go b/middleware/log.go index 4c200dfd..ddbdb069 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -4,20 +4,28 @@ import ( "bufio" "fmt" "io" + "log/slog" + mathRand "math/rand/v2" "net" "net/http" "time" + + "github.com/komuw/ong/log" ) // logger is a middleware that logs http requests and responses using [log.Logger]. func logger( wrappedHandler http.Handler, logFunc func(w http.ResponseWriter, r http.Request, statusCode int, fields []any), + l *slog.Logger, ) http.HandlerFunc { // The middleware should ideally share the same logger as the app. // That way, if the app logs an error, the middleware logs are also flushed. // That's one reason why we pass in the logFunc.This makes debugging easier for developers. // Another reason is so that app developers are in control of what(and how) exactly gets logged. + if logFunc == nil { + logFunc = defaultLogFunc(l) + } return func(w http.ResponseWriter, r *http.Request) { start := time.Now() @@ -47,9 +55,7 @@ func logger( // 1xx class or the modified headers are trailers. lrw.Header().Del(ongMiddlewareErrorHeader) - if logFunc != nil { - logFunc(w, *r, lrw.code, flds) - } + logFunc(w, *r, lrw.code, flds) }() wrappedHandler.ServeHTTP(lrw, r) @@ -146,3 +152,45 @@ func (lrw *logRW) ReadFrom(src io.Reader) (n int64, err error) { func (lrw *logRW) Unwrap() http.ResponseWriter { return lrw.ResponseWriter } + +// defaultLogFunc is the logging function used if the user did not explicitly provide one. +func defaultLogFunc(l *slog.Logger) func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) { + const ( + msg = "http_server" + // rateShedSamplePercent is the percentage of rate limited or loadshed responses that will be logged as errors, by default. + rateShedSamplePercent = 10 + ) + + if l == nil { + return func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) {} + } + + return func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) { + // Each request should get its own context. That's why we call `log.WithID` for every request. + reqL := log.WithID(r.Context(), l) + + if (statusCode == http.StatusServiceUnavailable || statusCode == http.StatusTooManyRequests) && w.Header().Get(retryAfterHeader) != "" { + // We are either in load shedding or rate-limiting. + // Only log (rateShedSamplePercent)% of the errors. + shouldLog := mathRand.IntN(100) <= rateShedSamplePercent + if shouldLog { + reqL.Error(msg, fields...) + return + } + } + + if statusCode >= http.StatusBadRequest { + // Both client and server errors. + if statusCode == http.StatusNotFound || statusCode == http.StatusMethodNotAllowed || statusCode == http.StatusTeapot { + // These ones are more of an annoyance, than been actual errors. + reqL.Info(msg, fields...) + return + } + + reqL.Error(msg, fields...) + return + } + + reqL.Info(msg, fields...) + } +} diff --git a/middleware/log_test.go b/middleware/log_test.go index 88786a07..178b6028 100644 --- a/middleware/log_test.go +++ b/middleware/log_test.go @@ -90,7 +90,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} successMsg := "hello" - wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput)) + wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput), nil) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodHead, "/someUri", nil) @@ -113,7 +113,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} errorMsg := "someLogHandler failed" successMsg := "hello" - wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput)) + wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput), nil) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodHead, "/someUri", nil) @@ -149,7 +149,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} successMsg := "hello" errorMsg := "someLogHandler failed" - wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput)) + wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput), nil) { // first request that succeds @@ -228,7 +228,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} successMsg := "hello" - wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput)) + wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput), nil) someLogID := "hey-some-log-id:" + id.New() @@ -258,7 +258,7 @@ func TestLogMiddleware(t *testing.T) { successMsg := "hello" // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput)) + wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput), nil) runhandler := func() { rec := httptest.NewRecorder() diff --git a/middleware/middleware.go b/middleware/middleware.go index 05f2b487..e93def23 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -56,6 +56,7 @@ func allDefaultMiddlewares( strategy := o.Strategy // logger + l := o.Logger logFunc := o.LogFunc // ratelimit @@ -172,8 +173,10 @@ func allDefaultMiddlewares( ), ), logFunc, + l, ), logFunc, + l, ), ), strategy, diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 396fe274..28b048c3 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -486,6 +486,7 @@ func BenchmarkAllMiddlewares(b *testing.B) { o := config.New( domain, httpsPort, + l, tst.SecretKey(), config.DirectIpStrategy, nil, @@ -501,7 +502,6 @@ func BenchmarkAllMiddlewares(b *testing.B) { config.DefaultCsrfCookieDuration, config.DefaultSessionCookieDuration, config.DefaultSessionAntiReplayFunc, - l, 20*1024*1024, slog.LevelDebug, 1*time.Second, diff --git a/middleware/recoverer.go b/middleware/recoverer.go index f4c882bf..47804641 100644 --- a/middleware/recoverer.go +++ b/middleware/recoverer.go @@ -2,6 +2,7 @@ package middleware import ( "fmt" + "log/slog" "net/http" "github.com/komuw/ong/errors" @@ -15,10 +16,15 @@ import ( func recoverer( wrappedHandler http.Handler, logFunc func(w http.ResponseWriter, r http.Request, statusCode int, fields []any), + l *slog.Logger, ) http.HandlerFunc { code := http.StatusInternalServerError status := http.StatusText(code) + if logFunc == nil { + logFunc = defaultLogFunc(l) + } + return func(w http.ResponseWriter, r *http.Request) { defer func() { errR := recover() @@ -52,9 +58,7 @@ func recoverer( } flds = append(flds, extra...) - if logFunc != nil { - logFunc(w, *r, code, flds) - } + logFunc(w, *r, code, flds) // respond. http.Error( diff --git a/middleware/recoverer_test.go b/middleware/recoverer_test.go index 2ba83a6f..64f77d77 100644 --- a/middleware/recoverer_test.go +++ b/middleware/recoverer_test.go @@ -48,7 +48,7 @@ func TestPanic(t *testing.T) { logOutput := &bytes.Buffer{} msg := "hello" - wrappedHandler := recoverer(handlerThatPanics(msg, false, nil), toLog(t, logOutput)) + wrappedHandler := recoverer(handlerThatPanics(msg, false, nil), toLog(t, logOutput), nil) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -65,7 +65,7 @@ func TestPanic(t *testing.T) { logOutput := &bytes.Buffer{} msg := "hello" - wrappedHandler := recoverer(handlerThatPanics(msg, true, nil), toLog(t, logOutput)) + wrappedHandler := recoverer(handlerThatPanics(msg, true, nil), toLog(t, logOutput), nil) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -95,7 +95,7 @@ func TestPanic(t *testing.T) { msg := "hello" errMsg := "99 problems" err := errors.New(errMsg) - wrappedHandler := recoverer(handlerThatPanics(msg, false, err), toLog(t, logOutput)) + wrappedHandler := recoverer(handlerThatPanics(msg, false, err), toLog(t, logOutput), nil) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -123,7 +123,7 @@ func TestPanic(t *testing.T) { t.Parallel() logOutput := &bytes.Buffer{} - wrappedHandler := recoverer(anotherHandlerThatPanics(), toLog(t, logOutput)) + wrappedHandler := recoverer(anotherHandlerThatPanics(), toLog(t, logOutput), nil) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -144,7 +144,7 @@ func TestPanic(t *testing.T) { err := errors.New(msg) // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := recoverer(handlerThatPanics(msg, false, err), toLog(t, logOutput)) + wrappedHandler := recoverer(handlerThatPanics(msg, false, err), toLog(t, logOutput), nil) runhandler := func() { rec := httptest.NewRecorder() From c0b9021da346a634e1c80a0447f4a27a9a9b04a7 Mon Sep 17 00:00:00 2001 From: komuw Date: Sat, 17 Aug 2024 12:22:45 +0300 Subject: [PATCH 15/20] g --- middleware/log_test.go | 57 +++++++----------------------------- middleware/recoverer_test.go | 18 ++++++++---- 2 files changed, 23 insertions(+), 52 deletions(-) diff --git a/middleware/log_test.go b/middleware/log_test.go index 178b6028..c61bc1ae 100644 --- a/middleware/log_test.go +++ b/middleware/log_test.go @@ -5,7 +5,7 @@ import ( "context" "fmt" "io" - mathRand "math/rand/v2" + "log/slog" "net/http" "net/http/httptest" "sync" @@ -23,47 +23,6 @@ const ( someLatencyMS = 3 ) -func toLog(t *testing.T, buf *bytes.Buffer) func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) { - t.Helper() - - const ( - msg = "http_server" - // rateShedSamplePercent is the percentage of rate limited or loadshed responses that will be logged as errors, by default. - rateShedSamplePercent = 10 - ) - - l := log.New(context.Background(), buf, 500) - - return func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) { - // Each request should get its own context. That's why we call `log.WithID` for every request. - reqL := log.WithID(r.Context(), l) - - if (statusCode == http.StatusServiceUnavailable || statusCode == http.StatusTooManyRequests) && w.Header().Get(retryAfterHeader) != "" { - // We are either in load shedding or rate-limiting. - // Only log (rateShedSamplePercent)% of the errors. - shouldLog := mathRand.IntN(100) <= rateShedSamplePercent - if shouldLog { - reqL.Error(msg, fields...) - return - } - } - - if statusCode >= http.StatusBadRequest { - // Both client and server errors. - if statusCode == http.StatusNotFound || statusCode == http.StatusMethodNotAllowed || statusCode == http.StatusTeapot { - // These ones are more of an annoyance, than been actual errors. - reqL.Info(msg, fields...) - return - } - - reqL.Error(msg, fields...) - return - } - - reqL.Info(msg, fields...) - } -} - func someLogHandler(successMsg string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // sleep so that the log middleware has some useful duration metrics to report. @@ -85,12 +44,16 @@ func someLogHandler(successMsg string) http.HandlerFunc { func TestLogMiddleware(t *testing.T) { t.Parallel() + getLogger := func(w io.Writer) *slog.Logger { + return log.New(context.Background(), w, 500) + } + t.Run("success", func(t *testing.T) { t.Parallel() logOutput := &bytes.Buffer{} successMsg := "hello" - wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput), nil) + wrappedHandler := logger(someLogHandler(successMsg), nil, getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodHead, "/someUri", nil) @@ -113,7 +76,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} errorMsg := "someLogHandler failed" successMsg := "hello" - wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput), nil) + wrappedHandler := logger(someLogHandler(successMsg), nil, getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodHead, "/someUri", nil) @@ -149,7 +112,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} successMsg := "hello" errorMsg := "someLogHandler failed" - wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput), nil) + wrappedHandler := logger(someLogHandler(successMsg), nil, getLogger(logOutput)) { // first request that succeds @@ -228,7 +191,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} successMsg := "hello" - wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput), nil) + wrappedHandler := logger(someLogHandler(successMsg), nil, getLogger(logOutput)) someLogID := "hey-some-log-id:" + id.New() @@ -258,7 +221,7 @@ func TestLogMiddleware(t *testing.T) { successMsg := "hello" // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput), nil) + wrappedHandler := logger(someLogHandler(successMsg), nil, getLogger(logOutput)) runhandler := func() { rec := httptest.NewRecorder() diff --git a/middleware/recoverer_test.go b/middleware/recoverer_test.go index 64f77d77..bb527268 100644 --- a/middleware/recoverer_test.go +++ b/middleware/recoverer_test.go @@ -2,7 +2,10 @@ package middleware import ( "bytes" + "context" "fmt" + "io" + "log/slog" "net/http" "net/http/httptest" "strings" @@ -10,6 +13,7 @@ import ( "testing" "github.com/komuw/ong/errors" + "github.com/komuw/ong/log" "go.akshayshah.org/attest" ) @@ -43,12 +47,16 @@ func anotherHandlerThatPanics() http.HandlerFunc { func TestPanic(t *testing.T) { t.Parallel() + getLogger := func(w io.Writer) *slog.Logger { + return log.New(context.Background(), w, 500) + } + t.Run("ok if no panic", func(t *testing.T) { t.Parallel() logOutput := &bytes.Buffer{} msg := "hello" - wrappedHandler := recoverer(handlerThatPanics(msg, false, nil), toLog(t, logOutput), nil) + wrappedHandler := recoverer(handlerThatPanics(msg, false, nil), nil, getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -65,7 +73,7 @@ func TestPanic(t *testing.T) { logOutput := &bytes.Buffer{} msg := "hello" - wrappedHandler := recoverer(handlerThatPanics(msg, true, nil), toLog(t, logOutput), nil) + wrappedHandler := recoverer(handlerThatPanics(msg, true, nil), nil, getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -95,7 +103,7 @@ func TestPanic(t *testing.T) { msg := "hello" errMsg := "99 problems" err := errors.New(errMsg) - wrappedHandler := recoverer(handlerThatPanics(msg, false, err), toLog(t, logOutput), nil) + wrappedHandler := recoverer(handlerThatPanics(msg, false, err), nil, getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -123,7 +131,7 @@ func TestPanic(t *testing.T) { t.Parallel() logOutput := &bytes.Buffer{} - wrappedHandler := recoverer(anotherHandlerThatPanics(), toLog(t, logOutput), nil) + wrappedHandler := recoverer(anotherHandlerThatPanics(), nil, getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -144,7 +152,7 @@ func TestPanic(t *testing.T) { err := errors.New(msg) // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := recoverer(handlerThatPanics(msg, false, err), toLog(t, logOutput), nil) + wrappedHandler := recoverer(handlerThatPanics(msg, false, err), nil, getLogger(logOutput)) runhandler := func() { rec := httptest.NewRecorder() From fbb0399b391f27d1d09de09e42c1d9df16556d8d Mon Sep 17 00:00:00 2001 From: komuw Date: Sat, 17 Aug 2024 12:24:39 +0300 Subject: [PATCH 16/20] g --- middleware/log_test.go | 22 ++++++++++++++++++++++ middleware/recoverer_test.go | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/middleware/log_test.go b/middleware/log_test.go index c61bc1ae..135e516b 100644 --- a/middleware/log_test.go +++ b/middleware/log_test.go @@ -214,6 +214,28 @@ func TestLogMiddleware(t *testing.T) { attest.Zero(t, logOutput.String()) }) + t.Run("nil logger and logfunc", func(t *testing.T) { + t.Parallel() + + logOutput := &bytes.Buffer{} + successMsg := "hello" + wrappedHandler := logger(someLogHandler(successMsg), nil, nil) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodHead, "/someUri", nil) + wrappedHandler.ServeHTTP(rec, req) + + res := rec.Result() + defer res.Body.Close() + + rb, err := io.ReadAll(res.Body) + attest.Ok(t, err) + + attest.Equal(t, res.StatusCode, http.StatusOK) + attest.Equal(t, string(rb), successMsg) + attest.Zero(t, logOutput.String()) + }) + t.Run("concurrency safe", func(t *testing.T) { t.Parallel() diff --git a/middleware/recoverer_test.go b/middleware/recoverer_test.go index bb527268..aeb7a6fc 100644 --- a/middleware/recoverer_test.go +++ b/middleware/recoverer_test.go @@ -140,7 +140,7 @@ func TestPanic(t *testing.T) { res := rec.Result() defer res.Body.Close() attest.Equal(t, res.StatusCode, http.StatusInternalServerError) - attest.Subsequence(t, logOutput.String(), "middleware/recoverer_test.go:37") // line where panic happened. + attest.Subsequence(t, logOutput.String(), "middleware/recoverer_test.go:41") // line where panic happened. }) t.Run("concurrency safe", func(t *testing.T) { From 802c30bae458cb66d594c88c2e6231cbf821a1b8 Mon Sep 17 00:00:00 2001 From: komuw Date: Sat, 17 Aug 2024 12:28:23 +0300 Subject: [PATCH 17/20] g --- config/example_test.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/config/example_test.go b/config/example_test.go index bceb0a0c..64327e1a 100644 --- a/config/example_test.go +++ b/config/example_test.go @@ -34,8 +34,12 @@ func ExampleNew() { config.SingleIpStrategy("CF-Connecting-IP"), // function to log in middlewares. func(_ http.ResponseWriter, r http.Request, statusCode int, fields []any) { - reqL := log.WithID(r.Context(), l) - reqL.Info("request-and-response", fields...) + if statusCode >= http.StatusInternalServerError { + // Only log 500's + reqL := log.WithID(r.Context(), l) + fields = append(fields, "statusCode", statusCode) + reqL.Info("request-and-response", fields...) + } }, // If a particular IP address sends more than 13 requests per second, throttle requests from that IP. 13.0, From 45fb0fd0aad9d0a2fd5f13d524b0d4845eaf5161 Mon Sep 17 00:00:00 2001 From: komuw Date: Sat, 17 Aug 2024 12:29:31 +0300 Subject: [PATCH 18/20] g --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0a35130..dcd775b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Most recent version is listed first. # v0.1.7 - Update go version; https://github.com/komuw/ong/pull/469 - ong/cry: Replace scrypt with argon2id: https://github.com/komuw/ong/pull/471 +- ong/middleware: Give users control over what and how logging happens in the middlewares: https://github.com/komuw/ong/pull/472 # v0.1.6 - Bump versions of dependencies used From 43f46f8b70b0e7f9c69ec90775dd19b6fa20f394 Mon Sep 17 00:00:00 2001 From: komuw Date: Sat, 17 Aug 2024 12:31:23 +0300 Subject: [PATCH 19/20] g --- config/config_test.go | 2 +- config/example_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/config/config_test.go b/config/config_test.go index f1931748..7e22a716 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -35,7 +35,7 @@ func validOpts(t *testing.T) Opts { "super-h@rd-Pas1word", // In this case, the actual client IP address is fetched from the given http header. SingleIpStrategy("CF-Connecting-IP"), - // function to log in middlewares. + // function to use for logging in middlewares. func(_ http.ResponseWriter, r http.Request, statusCode int, fields []any) { reqL := log.WithID(r.Context(), l) reqL.Info("request-and-response", fields...) diff --git a/config/example_test.go b/config/example_test.go index 64327e1a..6dfeab28 100644 --- a/config/example_test.go +++ b/config/example_test.go @@ -32,7 +32,7 @@ func ExampleNew() { "super-h@rd-Pas1word", // In this case, the actual client IP address is fetched from the given http header. config.SingleIpStrategy("CF-Connecting-IP"), - // function to log in middlewares. + // function to use for logging in middlewares func(_ http.ResponseWriter, r http.Request, statusCode int, fields []any) { if statusCode >= http.StatusInternalServerError { // Only log 500's From 4624c651ceeecddea0099efa8e8a930d18c4388a Mon Sep 17 00:00:00 2001 From: komuw Date: Sat, 17 Aug 2024 12:34:58 +0300 Subject: [PATCH 20/20] g --- config/config.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/config.go b/config/config.go index 957200a7..edb619f1 100644 --- a/config/config.go +++ b/config/config.go @@ -188,7 +188,7 @@ func (o Opts) GoString() string { // domain is the domain name of your website. It can be an exact domain, subdomain or wildcard. // port is the TLS port where the server will listen on. Http requests will also redirected to that port. // -// logger is an [slog.Logger] that will be used for logging. +// logger is an [slog.Logger] that will be used for logging. It is used in the server, it's use in middlewares is only if [logFunc] is nil. // // secretKey is used for securing signed data. It should be unique & kept secret. // If it becomes compromised, generate a new one and restart your application using the new one. @@ -198,7 +198,7 @@ func (o Opts) GoString() string { // // logFunc is a function that dictates what/how middleware is going to log. It is also used to log any recovered panics in the middleware. // This function is not used in the server. -// If it is nil, a suitable default is used. To disable logging, use a function that does nothing. +// If it is nil, a suitable default(that utilizes [logger]) is used. To disable logging, use a function that does nothing. // // rateLimit is the maximum requests allowed (from one IP address) per second. If it is les than 1.0, [DefaultRateLimit] is used instead. //