From ded42db67e1294b2cd0d603d8c80d0e1e5fda0c6 Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Wed, 7 Jun 2023 22:59:34 +0300 Subject: [PATCH 01/13] func oneMiddleware(wrappedHandler http.HandlerFunc) http.HandlerFunc { count := 0 return func(w http.ResponseWriter, r *http.Request) { count = count + 1 r.Header.Set("MY-COUNT", fmt.Sprint(count)) wrappedHandler(w, r) } } func twoMiddleware(wrappedHandler http.Handler) http.HandlerFunc { count := 0 return func(w http.ResponseWriter, r *http.Request) { count = count + 1 r.Header.Set("MY-COUNT", fmt.Sprint(count)) wrappedHandler.ServeHTTP(w, r) } } func threeMiddleware(wrappedHandler http.Handler) http.Handler { count := 0 return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { count = count + 1 r.Header.Set("MY-COUNT", fmt.Sprint(count)) wrappedHandler.ServeHTTP(w, r) }, ) } func BenchmarkMid(b *testing.B) { b.Run("oneMiddleware", func(b *testing.B) { h := oneMiddleware(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "ok") }) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) b.ReportAllocs() b.ResetTimer() for n := 0; n < b.N; n++ { h.ServeHTTP(rec, req) } }) b.Run("twoMiddleware", func(b *testing.B) { h := twoMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "ok") })) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) b.ReportAllocs() b.ResetTimer() for n := 0; n < b.N; n++ { h.ServeHTTP(rec, req) } }) b.Run("threeMiddleware", func(b *testing.B) { h := threeMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "ok") })) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) b.ReportAllocs() b.ResetTimer() for n := 0; n < b.N; n++ { h.ServeHTTP(rec, req) } }) } go test -benchmem -run=^$ -bench ^BenchmarkMid$ github.com/komuw/ong/mux -count=5 BenchmarkMid/oneMiddleware-8 358.7 ns/op BenchmarkMid/oneMiddleware-8 371.8 ns/op BenchmarkMid/oneMiddleware-8 343.6 ns/op BenchmarkMid/oneMiddleware-8 289.0 ns/op BenchmarkMid/oneMiddleware-8 291.4 ns/op BenchmarkMid/twoMiddleware-8 296.7 ns/op BenchmarkMid/twoMiddleware-8 353.3 ns/op BenchmarkMid/twoMiddleware-8 363.2 ns/op BenchmarkMid/twoMiddleware-8 366.2 ns/op BenchmarkMid/twoMiddleware-8 315.2 ns/op BenchmarkMid/threeMiddleware-8 299.0 ns/op BenchmarkMid/threeMiddleware-8 307.6 ns/op BenchmarkMid/threeMiddleware-8 324.9 ns/op BenchmarkMid/threeMiddleware-8 367.6 ns/op BenchmarkMid/threeMiddleware-8 373.3 ns/op The change we want is `threeMiddleware` and it seems competitive with the others. --- server/tls_conf_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/server/tls_conf_test.go b/server/tls_conf_test.go index a9b06c29..1bfc4fa1 100644 --- a/server/tls_conf_test.go +++ b/server/tls_conf_test.go @@ -26,6 +26,7 @@ func TestCustomHostWhitelist(t *testing.T) { {"three.example.net", false}, {"dummy", false}, } + for i, test := range tt { err := policy(nil, test.host) if err != nil && test.allow { From f91036c848cb09cef2212a3803d96cc815230c88 Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Wed, 7 Jun 2023 23:24:04 +0300 Subject: [PATCH 02/13] m --- middleware/auth.go | 42 ++--- middleware/auth_test.go | 7 +- middleware/client_ip.go | 40 ++--- middleware/client_ip_test.go | 14 +- middleware/cors.go | 34 ++-- middleware/cors_test.go | 10 +- middleware/csrf.go | 268 +++++++++++++++--------------- middleware/csrf_test.go | 12 +- middleware/example_test.go | 40 +++-- middleware/fingerprint.go | 36 ++-- middleware/gzip.go | 60 +++---- middleware/gzip_test.go | 88 +++++----- middleware/loadshed.go | 80 ++++----- middleware/loadshed_test.go | 36 ++-- middleware/log.go | 94 ++++++----- middleware/log_test.go | 34 ++-- middleware/middleware.go | 180 ++++++++++---------- middleware/middleware_test.go | 40 +++-- middleware/ratelimiter.go | 52 +++--- middleware/ratelimiter_test.go | 10 +- middleware/recoverer.go | 88 +++++----- middleware/recoverer_test.go | 48 +++--- middleware/redirect.go | 76 ++++----- middleware/redirect_test.go | 24 +-- middleware/reload_protect.go | 80 ++++----- middleware/reload_protect_test.go | 40 ++--- middleware/security.go | 142 ++++++++-------- middleware/security_test.go | 12 +- middleware/session.go | 22 +-- middleware/session_test.go | 40 +++-- middleware/trace.go | 56 ++++--- middleware/trace_test.go | 34 ++-- mux/example_test.go | 22 ++- mux/mux.go | 6 +- mux/mux_test.go | 32 ++-- mux/route.go | 12 +- 36 files changed, 1006 insertions(+), 905 deletions(-) diff --git a/middleware/auth.go b/middleware/auth.go index e19f68cd..b21dba74 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -9,7 +9,7 @@ import ( const minPasswdSize = 16 // BasicAuth is a middleware that protects wrappedHandler using basic authentication. -func BasicAuth(wrappedHandler http.HandlerFunc, user, passwd string) http.HandlerFunc { +func BasicAuth(wrappedHandler http.Handler, user, passwd string) http.Handler { const realm = "enter username and password" if len(passwd) < minPasswdSize { @@ -23,23 +23,25 @@ func BasicAuth(wrappedHandler http.HandlerFunc, user, passwd string) http.Handle http.Error(w, "Unauthorized", http.StatusUnauthorized) } - return func(w http.ResponseWriter, r *http.Request) { - u, p, ok := r.BasicAuth() - if u == "" || p == "" || !ok { - e(w) - return - } - - if subtle.ConstantTimeCompare([]byte(u), []byte(user)) != 1 { - e(w) - return - } - - if subtle.ConstantTimeCompare([]byte(p), []byte(passwd)) != 1 { - e(w) - return - } - - wrappedHandler(w, r) - } + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + u, p, ok := r.BasicAuth() + if u == "" || p == "" || !ok { + e(w) + return + } + + if subtle.ConstantTimeCompare([]byte(u), []byte(user)) != 1 { + e(w) + return + } + + if subtle.ConstantTimeCompare([]byte(p), []byte(passwd)) != 1 { + e(w) + return + } + + wrappedHandler.ServeHTTP(w, r) + }, + ) } diff --git a/middleware/auth_test.go b/middleware/auth_test.go index d0f01e4d..29ba167e 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -12,10 +12,11 @@ import ( "go.akshayshah.org/attest" ) -func protectedHandler(msg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func protectedHandler(msg string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, msg) - } + }, + ) } func TestBasicAuth(t *testing.T) { diff --git a/middleware/client_ip.go b/middleware/client_ip.go index 47721f7b..005ae819 100644 --- a/middleware/client_ip.go +++ b/middleware/client_ip.go @@ -68,24 +68,26 @@ func ClientIP(r *http.Request) string { // Fetching the "real" client is done in a best-effort basis and can be [grossly inaccurate & precarious]. // // [grossly inaccurate & precarious]: https://adam-p.ca/blog/2022/03/x-forwarded-for/ -func clientIP(wrappedHandler http.HandlerFunc, strategy ClientIPstrategy) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - var clientAddr string - switch v := strategy; v { - case DirectIpStrategy: - clientAddr = clientip.DirectAddress(r.RemoteAddr) - case LeftIpStrategy: - clientAddr = clientip.Leftmost(r.Header) - case RightIpStrategy: - clientAddr = clientip.Rightmost(r.Header) - case ProxyStrategy: - clientAddr = clientip.ProxyHeader(r.Header) - default: - // treat everything else as a `singleIP` strategy - clientAddr = clientip.SingleIPHeader(string(v), r.Header) - } +func clientIP(wrappedHandler http.Handler, strategy ClientIPstrategy) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + var clientAddr string + switch v := strategy; v { + case DirectIpStrategy: + clientAddr = clientip.DirectAddress(r.RemoteAddr) + case LeftIpStrategy: + clientAddr = clientip.Leftmost(r.Header) + case RightIpStrategy: + clientAddr = clientip.Rightmost(r.Header) + case ProxyStrategy: + clientAddr = clientip.ProxyHeader(r.Header) + default: + // treat everything else as a `singleIP` strategy + clientAddr = clientip.SingleIPHeader(string(v), r.Header) + } - r = clientip.With(r, clientAddr) - wrappedHandler(w, r) - } + r = clientip.With(r, clientAddr) + wrappedHandler.ServeHTTP(w, r) + }, + ) } diff --git a/middleware/client_ip_test.go b/middleware/client_ip_test.go index 5339b655..33b86064 100644 --- a/middleware/client_ip_test.go +++ b/middleware/client_ip_test.go @@ -17,12 +17,14 @@ const ( proxyHeader = "PROXY" ) -func someClientIpHandler(msg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ip := ClientIP(r) - res := fmt.Sprintf("message: %s, ip: %s", msg, ip) - fmt.Fprint(w, res) - } +func someClientIpHandler(msg string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + ip := ClientIP(r) + res := fmt.Sprintf("message: %s, ip: %s", msg, ip) + fmt.Fprint(w, res) + }, + ) } func TestClientIP(t *testing.T) { diff --git a/middleware/cors.go b/middleware/cors.go index 4631ab52..aa36c31e 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -60,29 +60,31 @@ const ( // If allowedMethods is nil, "GET", "POST", "HEAD" are allowed. Use * to allow all. // If allowedHeaders is nil, "Origin", "Accept", "Content-Type", "X-Requested-With" are allowed. Use * to allow all. func cors( - wrappedHandler http.HandlerFunc, + wrappedHandler http.Handler, allowedOrigins []string, allowedMethods []string, allowedHeaders []string, -) http.HandlerFunc { +) http.Handler { allowedOrigins, allowedWildcardOrigins := getOrigins(allowedOrigins) allowedMethods = getMethods(allowedMethods) allowedHeaders = getHeaders(allowedHeaders) - return func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodOptions && r.Header.Get(acrmHeader) != "" { - // handle preflight request - handlePreflight(w, r, allowedOrigins, allowedWildcardOrigins, allowedMethods, allowedHeaders) - // Preflight requests are standalone and should stop the chain as some other - // middleware may not handle OPTIONS requests correctly. One typical example - // is authentication middleware ; OPTIONS requests won't carry authentication headers. - w.WriteHeader(http.StatusNoContent) - } else { - // handle actual request - handleActualRequest(w, r, allowedOrigins, allowedWildcardOrigins, allowedMethods) - wrappedHandler(w, r) - } - } + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodOptions && r.Header.Get(acrmHeader) != "" { + // handle preflight request + handlePreflight(w, r, allowedOrigins, allowedWildcardOrigins, allowedMethods, allowedHeaders) + // Preflight requests are standalone and should stop the chain as some other + // middleware may not handle OPTIONS requests correctly. One typical example + // is authentication middleware ; OPTIONS requests won't carry authentication headers. + w.WriteHeader(http.StatusNoContent) + } else { + // handle actual request + handleActualRequest(w, r, allowedOrigins, allowedWildcardOrigins, allowedMethods) + wrappedHandler.ServeHTTP(w, r) + } + }, + ) } func handlePreflight( diff --git a/middleware/cors_test.go b/middleware/cors_test.go index e61dc4e8..79004a34 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -11,10 +11,12 @@ import ( "go.akshayshah.org/attest" ) -func someCorsHandler(msg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, msg) - } +func someCorsHandler(msg string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, msg) + }, + ) } func TestCorsPreflight(t *testing.T) { diff --git a/middleware/csrf.go b/middleware/csrf.go index 3a218038..b50f1dfd 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -67,149 +67,151 @@ const ( // csrf is a middleware that provides protection against Cross Site Request Forgeries. // // If a csrf token is not provided(or is not valid), when it ought to have been; this middleware will issue a http GET redirect to the same url. -func csrf(wrappedHandler http.HandlerFunc, secretKey, domain string) http.HandlerFunc { +func csrf(wrappedHandler http.Handler, secretKey, domain string) http.Handler { once.Do(func() { enc = cry.New(secretKey) }) msgToEncrypt := id.Random(16) - return func(w http.ResponseWriter, r *http.Request) { - // - https://docs.djangoproject.com/en/4.0/ref/csrf/ - // - https://github.com/django/django/blob/4.0.5/django/middleware/csrf.py - // - https://github.com/gofiber/fiber/blob/v2.34.1/middleware/csrf/csrf.go - - // 1. check http method. - // - if it is a 'safe' method like GET, try and get `actualToken` from request. - // - if it is not a 'safe' method, try and get `actualToken` from header/cookies/httpForm - // - take the found token and try to get it from memory store. - // - if not found in memory store, delete the cookie & return an error. - - ctx := r.Context() - - switch r.Method { - // safe methods under rfc7231: https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 - case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: - break - default: - // For POST requests, we insist on a CSRF cookie, and in this way we can avoid all CSRF attacks, including login CSRF. - actualToken := getToken(r) - - ct, _, err := mime.ParseMediaType(r.Header.Get(ctHeader)) - if err == nil && - ct != formUrlEncoded && - ct != multiformData && - r.Header.Get(clientCookieHeader) == "" && - r.Header.Get(authorizationHeader) == "" && - r.Header.Get(proxyAuthorizationHeader) == "" { - // For POST requests that; - // - are not form data. - // - have no cookies. - // - are not using http authentication. - // then it is okay to not validate csrf for them. - // This is especially useful for REST API endpoints. - // see: https://github.com/komuw/ong/issues/76 - break - } + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // - https://docs.djangoproject.com/en/4.0/ref/csrf/ + // - https://github.com/django/django/blob/4.0.5/django/middleware/csrf.py + // - https://github.com/gofiber/fiber/blob/v2.34.1/middleware/csrf/csrf.go - tokVal, errN := enc.DecryptDecode(actualToken) - if errN != nil { - // We should redirect the request since it means that the server is not aware of such a token. - // It shoulbe be a temporary redirect to the same page but this time send a http GET request. - // - // To test using curl, use; - // curl -kL \ - // -H "Content-Type: application/x-www-form-urlencoded" \ - // -d "firstName=john&csrftoken=bogusToken" https://localhost:65081/login/ - // Do NOT use `-X POST`, see: https://stackoverflow.com/a/41890653/2768067 - // - cookie.Delete(w, csrfCookieName, domain) - w.Header().Set(ongMiddlewareErrorHeader, errCsrfTokenNotFound.Error()) - http.Redirect( - w, - r, - r.URL.String(), - // http 303(StatusSeeOther) is guaranteed by the spec to always use http GET. - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/303 - http.StatusSeeOther, - ) - return - } + // 1. check http method. + // - if it is a 'safe' method like GET, try and get `actualToken` from request. + // - if it is not a 'safe' method, try and get `actualToken` from header/cookies/httpForm + // - take the found token and try to get it from memory store. + // - if not found in memory store, delete the cookie & return an error. - res := strings.Split(tokVal, sep) - if len(res) != 2 { - cookie.Delete(w, csrfCookieName, domain) - w.Header().Set(ongMiddlewareErrorHeader, errCsrfTokenWrongFormat.Error()) - http.Redirect(w, r, r.URL.String(), http.StatusSeeOther) - return - } + ctx := r.Context() - expires, errP := strconv.ParseInt(res[1], 10, 64) - if errP != nil { - cookie.Delete(w, csrfCookieName, domain) - w.Header().Set(ongMiddlewareErrorHeader, errP.Error()) - http.Redirect(w, r, r.URL.String(), http.StatusSeeOther) - return - } - - diff := expires - time.Now().UTC().Unix() - if diff <= 0 { - cookie.Delete(w, csrfCookieName, domain) - w.Header().Set(ongMiddlewareErrorHeader, errCsrfTokenExpired.Error()) - http.Redirect(w, r, r.URL.String(), http.StatusSeeOther) - return + switch r.Method { + // safe methods under rfc7231: https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 + case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: + break + default: + // For POST requests, we insist on a CSRF cookie, and in this way we can avoid all CSRF attacks, including login CSRF. + actualToken := getToken(r) + + ct, _, err := mime.ParseMediaType(r.Header.Get(ctHeader)) + if err == nil && + ct != formUrlEncoded && + ct != multiformData && + r.Header.Get(clientCookieHeader) == "" && + r.Header.Get(authorizationHeader) == "" && + r.Header.Get(proxyAuthorizationHeader) == "" { + // For POST requests that; + // - are not form data. + // - have no cookies. + // - are not using http authentication. + // then it is okay to not validate csrf for them. + // This is especially useful for REST API endpoints. + // see: https://github.com/komuw/ong/issues/76 + break + } + + tokVal, errN := enc.DecryptDecode(actualToken) + if errN != nil { + // We should redirect the request since it means that the server is not aware of such a token. + // It shoulbe be a temporary redirect to the same page but this time send a http GET request. + // + // To test using curl, use; + // curl -kL \ + // -H "Content-Type: application/x-www-form-urlencoded" \ + // -d "firstName=john&csrftoken=bogusToken" https://localhost:65081/login/ + // Do NOT use `-X POST`, see: https://stackoverflow.com/a/41890653/2768067 + // + cookie.Delete(w, csrfCookieName, domain) + w.Header().Set(ongMiddlewareErrorHeader, errCsrfTokenNotFound.Error()) + http.Redirect( + w, + r, + r.URL.String(), + // http 303(StatusSeeOther) is guaranteed by the spec to always use http GET. + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/303 + http.StatusSeeOther, + ) + return + } + + res := strings.Split(tokVal, sep) + if len(res) != 2 { + cookie.Delete(w, csrfCookieName, domain) + w.Header().Set(ongMiddlewareErrorHeader, errCsrfTokenWrongFormat.Error()) + http.Redirect(w, r, r.URL.String(), http.StatusSeeOther) + return + } + + expires, errP := strconv.ParseInt(res[1], 10, 64) + if errP != nil { + cookie.Delete(w, csrfCookieName, domain) + w.Header().Set(ongMiddlewareErrorHeader, errP.Error()) + http.Redirect(w, r, r.URL.String(), http.StatusSeeOther) + return + } + + diff := expires - time.Now().UTC().Unix() + if diff <= 0 { + cookie.Delete(w, csrfCookieName, domain) + w.Header().Set(ongMiddlewareErrorHeader, errCsrfTokenExpired.Error()) + http.Redirect(w, r, r.URL.String(), http.StatusSeeOther) + return + } } - } - // 2. generate a new token. - /* - We need to try and protect against BreachAttack[1]. See[2] for a refresher on how it works. - The mitigations against the attack in order of effectiveness are: - (a) Disabling HTTP compression - (b) Separating secrets from user input - (c) Randomizing secrets per request - (d) Masking secrets (effectively randomizing by XORing with a random secret per request) - (e) Protecting vulnerable pages with CSRF - (f) Length hiding (by adding random number of bytes to the responses) - (g) Rate-limiting the requests - Most csrf implementation use (d). Here, we'll use (c) - The [encrypt] func uses a random nonce everytime it is called. - - 1. http://breachattack.com/ - 2. https://security.stackexchange.com/a/172646 - */ - expires := strconv.FormatInt( - time.Now().UTC().Add(tokenMaxAge).Unix(), - 10, - ) - tokenToIssue := enc.EncryptEncode( - // see: https://github.com/golang/net/blob/v0.8.0/xsrftoken/xsrf.go#L33-L46 - fmt.Sprintf("%s%s%s", msgToEncrypt, sep, expires), - ) - - // 3. create cookie - cookie.Set( - w, - csrfCookieName, - tokenToIssue, - domain, - tokenMaxAge, - true, // accessible to javascript - ) - - // 4. set cookie header - w.Header().Set( - CsrfHeader, - tokenToIssue, - ) - - // 5. update Vary header. - w.Header().Add(varyHeader, clientCookieHeader) - - // 6. store tokenToIssue in context - r = r.WithContext(context.WithValue(ctx, csrfCtxKey, tokenToIssue)) - - wrappedHandler(w, r) - } + // 2. generate a new token. + /* + We need to try and protect against BreachAttack[1]. See[2] for a refresher on how it works. + The mitigations against the attack in order of effectiveness are: + (a) Disabling HTTP compression + (b) Separating secrets from user input + (c) Randomizing secrets per request + (d) Masking secrets (effectively randomizing by XORing with a random secret per request) + (e) Protecting vulnerable pages with CSRF + (f) Length hiding (by adding random number of bytes to the responses) + (g) Rate-limiting the requests + Most csrf implementation use (d). Here, we'll use (c) + The [encrypt] func uses a random nonce everytime it is called. + + 1. http://breachattack.com/ + 2. https://security.stackexchange.com/a/172646 + */ + expires := strconv.FormatInt( + time.Now().UTC().Add(tokenMaxAge).Unix(), + 10, + ) + tokenToIssue := enc.EncryptEncode( + // see: https://github.com/golang/net/blob/v0.8.0/xsrftoken/xsrf.go#L33-L46 + fmt.Sprintf("%s%s%s", msgToEncrypt, sep, expires), + ) + + // 3. create cookie + cookie.Set( + w, + csrfCookieName, + tokenToIssue, + domain, + tokenMaxAge, + true, // accessible to javascript + ) + + // 4. set cookie header + w.Header().Set( + CsrfHeader, + tokenToIssue, + ) + + // 5. update Vary header. + w.Header().Add(varyHeader, clientCookieHeader) + + // 6. store tokenToIssue in context + r = r.WithContext(context.WithValue(ctx, csrfCtxKey, tokenToIssue)) + + wrappedHandler.ServeHTTP(w, r) + }, + ) } // GetCsrfToken returns the csrf token that was set for the http request in question. diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 8033f01b..5a258f88 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -115,11 +115,13 @@ func TestGetToken(t *testing.T) { const tokenHeader = "CUSTOM-CSRF-TOKEN-TEST-HEADER" -func someCsrfHandler(msg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(tokenHeader, GetCsrfToken(r.Context())) - fmt.Fprint(w, msg) - } +func someCsrfHandler(msg string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(tokenHeader, GetCsrfToken(r.Context())) + fmt.Fprint(w, msg) + }, + ) } func TestCsrf(t *testing.T) { diff --git a/middleware/example_test.go b/middleware/example_test.go index c4cf8584..df2fb1a1 100644 --- a/middleware/example_test.go +++ b/middleware/example_test.go @@ -13,22 +13,26 @@ import ( "golang.org/x/exp/slog" ) -func loginHandler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - cspNonce := middleware.GetCspNonce(r.Context()) - _ = cspNonce // use CSP nonce - - _, _ = fmt.Fprint(w, "welcome to your favorite website.") - } +func loginHandler() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + cspNonce := middleware.GetCspNonce(r.Context()) + _ = cspNonce // use CSP nonce + + _, _ = fmt.Fprint(w, "welcome to your favorite website.") + }, + ) } -func welcomeHandler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - csrfToken := middleware.GetCsrfToken(r.Context()) - _ = csrfToken // use CSRF token +func welcomeHandler() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + csrfToken := middleware.GetCsrfToken(r.Context()) + _ = csrfToken // use CSRF token - _, _ = fmt.Fprint(w, "welcome.") - } + _, _ = fmt.Fprint(w, "welcome.") + }, + ) } func Example_getCspNonce() { @@ -66,14 +70,16 @@ func ExampleAll() { l := log.New(os.Stdout, 100)(context.Background()) opts := middleware.WithOpts("example.com", 443, "secretKey", middleware.DirectIpStrategy, l) - myHandler := func(w http.ResponseWriter, _ *http.Request) { - _, _ = io.WriteString(w, "Hello from a HandleFunc \n") - } + myHandler := http.HandlerFunc( + func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "Hello from a HandleFunc \n") + }, + ) handler := middleware.All(myHandler, opts) mux := http.NewServeMux() - mux.HandleFunc("/", handler) + mux.Handle("/", handler) // Output: } diff --git a/middleware/fingerprint.go b/middleware/fingerprint.go index 42436c06..ae2b5735 100644 --- a/middleware/fingerprint.go +++ b/middleware/fingerprint.go @@ -10,28 +10,30 @@ import ( // fingerprint is a middleware that adds the client's TLS fingerprint to the request context. // The fingerprint can be fetched using [ClientFingerPrint] -func fingerprint(wrappedHandler http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - fHash := "" +func fingerprint(wrappedHandler http.Handler) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + fHash := "" - if vCtx := ctx.Value(octx.FingerPrintCtxKey); vCtx != nil { - if s, ok := vCtx.(*finger.Print); ok { - if hash := s.Hash.Load(); hash != nil { - fHash = *hash + if vCtx := ctx.Value(octx.FingerPrintCtxKey); vCtx != nil { + if s, ok := vCtx.(*finger.Print); ok { + if hash := s.Hash.Load(); hash != nil { + fHash = *hash + } } } - } - ctx = context.WithValue( - ctx, - octx.FingerPrintCtxKey, - fHash, - ) - r = r.WithContext(ctx) + ctx = context.WithValue( + ctx, + octx.FingerPrintCtxKey, + fHash, + ) + r = r.WithContext(ctx) - wrappedHandler(w, r) - } + wrappedHandler.ServeHTTP(w, r) + }, + ) } // ClientFingerPrint returns the [TLS fingerprint] of the client. diff --git a/middleware/gzip.go b/middleware/gzip.go index 10e8836d..4c3da9c1 100644 --- a/middleware/gzip.go +++ b/middleware/gzip.go @@ -29,35 +29,37 @@ const ( ) // gzip is a middleware that transparently gzips the http response body, for clients that support it. -func gzip(wrappedHandler http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - w.Header().Add(varyHeader, acceptEncodingHeader) - - if !shouldGzipReq(r) { - wrappedHandler(w, r) - return - } - - gzipWriter, _ := stdGzip.NewWriterLevel(w, stdGzip.BestSpeed) - grw := &gzipRW{ - ResponseWriter: w, - // Bytes written during ServeHTTP are redirected to this gzip writer - // before being written to the underlying response. - gw: gzipWriter, - } - defer func() { _ = grw.Close() }() // errcheck made me do this. - - // We do not handle range requests when compression is used, as the - // range specified applies to the compressed data, not to the uncompressed one. - // see: https://github.com/nytimes/gziphandler/issues/83 - r.Header.Del(rangeHeader) - - // todo: we could detect if `w` is a `http.CloseNotifier` and do something special here. - // see: https://github.com/klauspost/compress/blob/4a97174a615ed745c450077edf0e1f7e97aabd58/gzhttp/compress.go#L383-L385 - // However `http.CloseNotifier` has been deprecated sinc Go v1.11(year 2018) - - wrappedHandler(grw, r) - } +func gzip(wrappedHandler http.Handler) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.Header().Add(varyHeader, acceptEncodingHeader) + + if !shouldGzipReq(r) { + wrappedHandler.ServeHTTP(w, r) + return + } + + gzipWriter, _ := stdGzip.NewWriterLevel(w, stdGzip.BestSpeed) + grw := &gzipRW{ + ResponseWriter: w, + // Bytes written during ServeHTTP are redirected to this gzip writer + // before being written to the underlying response. + gw: gzipWriter, + } + defer func() { _ = grw.Close() }() // errcheck made me do this. + + // We do not handle range requests when compression is used, as the + // range specified applies to the compressed data, not to the uncompressed one. + // see: https://github.com/nytimes/gziphandler/issues/83 + r.Header.Del(rangeHeader) + + // todo: we could detect if `w` is a `http.CloseNotifier` and do something special here. + // see: https://github.com/klauspost/compress/blob/4a97174a615ed745c450077edf0e1f7e97aabd58/gzhttp/compress.go#L383-L385 + // However `http.CloseNotifier` has been deprecated sinc Go v1.11(year 2018) + + wrappedHandler.ServeHTTP(grw, r) + }, + ) } // gzipRW provides an http.ResponseWriter interface, which gzips diff --git a/middleware/gzip_test.go b/middleware/gzip_test.go index 93805d3f..6c28f58b 100644 --- a/middleware/gzip_test.go +++ b/middleware/gzip_test.go @@ -19,31 +19,35 @@ import ( tmthrgd "github.com/tmthrgd/gziphandler" ) -func someGzipHandler(msg string) http.HandlerFunc { +func someGzipHandler(msg string) http.Handler { // bound stack growth. // see: https://github.com/komuw/ong/issues/54 fMsg := strings.Repeat(msg, 3) - return func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, fMsg) - } + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, fMsg) + }, + ) } -func handlerImplementingFlush(msg string) http.HandlerFunc { +func handlerImplementingFlush(msg string) http.Handler { iterations := 3 - return func(w http.ResponseWriter, r *http.Request) { - if f, ok := w.(http.Flusher); ok { - msg = "FlusherCalled::" + strings.Repeat(msg, iterations) - fmt.Fprint(w, msg) - - f.Flush() - } else { - msg = strings.Repeat(msg, iterations) - fmt.Fprint(w, msg) - } - } + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if f, ok := w.(http.Flusher); ok { + msg = "FlusherCalled::" + strings.Repeat(msg, iterations) + fmt.Fprint(w, msg) + + f.Flush() + } else { + msg = strings.Repeat(msg, iterations) + fmt.Fprint(w, msg) + } + }, + ) } -func login() http.HandlerFunc { +func login() http.Handler { tmpl, err := template.New("myTpl").Parse(` @@ -70,29 +74,31 @@ func login() http.HandlerFunc { panic(err) } - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - data := struct { - CsrfTokenName string - CsrfTokenValue string - CspNonceValue string - }{ - CsrfTokenName: CsrfTokenFormName, - CsrfTokenValue: GetCsrfToken(r.Context()), - CspNonceValue: GetCspNonce(r.Context()), + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + data := struct { + CsrfTokenName string + CsrfTokenValue string + CspNonceValue string + }{ + CsrfTokenName: CsrfTokenFormName, + CsrfTokenValue: GetCsrfToken(r.Context()), + CspNonceValue: GetCspNonce(r.Context()), + } + if err = tmpl.Execute(w, data); err != nil { + panic(err) + } + return } - if err = tmpl.Execute(w, data); err != nil { + + if err = r.ParseForm(); err != nil { panic(err) } - return - } - - if err = r.ParseForm(); err != nil { - panic(err) - } - _, _ = fmt.Fprintf(w, "you have submitted: %s", r.Form) - } + _, _ = fmt.Fprintf(w, "you have submitted: %s", r.Form) + }, + ) } func readBody(t *testing.T, res *http.Response) (strBody string) { @@ -302,15 +308,17 @@ BenchmarkNytimesGzip-8 4 315_386_476 ns/op 3_813_934 B/op 116 BenchmarkTmthrgdGzip-8 4 319_786_254 ns/op 3_527_012 B/op 116 allocs/op */ -func gzipBenchmarkHandler() http.HandlerFunc { +func gzipBenchmarkHandler() http.Handler { bin, err := os.ReadFile("testdata/benchmark.json") if err != nil { panic(err) } - return func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, bin) - } + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, bin) + }, + ) } var result int //nolint:gochecknoglobals diff --git a/middleware/loadshed.go b/middleware/loadshed.go index 16419da7..eca702e2 100644 --- a/middleware/loadshed.go +++ b/middleware/loadshed.go @@ -44,51 +44,53 @@ const ( ) // loadShedder is a middleware that sheds load based on http response latencies. -func loadShedder(wrappedHandler http.HandlerFunc) http.HandlerFunc { +func loadShedder(wrappedHandler http.Handler) http.Handler { // lq should not be a global variable, we want it to be per handler. // This is because different handlers(URIs) could have different latencies and we want each to be loadshed independently. lq := newLatencyQueue() loadShedCheckStart := time.Now().UTC() - return func(w http.ResponseWriter, r *http.Request) { - startReq := time.Now().UTC() - defer func() { - endReq := time.Now().UTC() - durReq := endReq.Sub(startReq) - lq.add(durReq) - - // we do not want to reduce size of `lq` before a period `> samplingPeriod` otherwise `lq.getP99()` will always return zero. - if endReq.Sub(loadShedCheckStart) > resizePeriod { - // lets reduce the size of latencyQueue - lq.reSize() - loadShedCheckStart = endReq + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + startReq := time.Now().UTC() + defer func() { + endReq := time.Now().UTC() + durReq := endReq.Sub(startReq) + lq.add(durReq) + + // we do not want to reduce size of `lq` before a period `> samplingPeriod` otherwise `lq.getP99()` will always return zero. + if endReq.Sub(loadShedCheckStart) > resizePeriod { + // lets reduce the size of latencyQueue + lq.reSize() + loadShedCheckStart = endReq + } + }() + + sendProbe := false + { + // Even if the server is overloaded, we want to send a percentage of the requests through. + // These requests act as a probe. If the server eventually recovers, + // these requests will re-populate latencyQueue(`lq`) with lower latencies and thus end the load-shed. + sendProbe = mathRand.Intn(100) == 1 // let 1% of requests through. NB: Intn(100) is `0-99` ie, 100 is not included. } - }() - - sendProbe := false - { - // Even if the server is overloaded, we want to send a percentage of the requests through. - // These requests act as a probe. If the server eventually recovers, - // these requests will re-populate latencyQueue(`lq`) with lower latencies and thus end the load-shed. - sendProbe = mathRand.Intn(100) == 1 // let 1% of requests through. NB: Intn(100) is `0-99` ie, 100 is not included. - } - - p99 := lq.getP99(minSampleSize) - if p99.Milliseconds() > breachLatency.Milliseconds() && !sendProbe { - // drop request - err := fmt.Errorf("ong/middleware: server is overloaded, retry after %s", retryAfter) - w.Header().Set(ongMiddlewareErrorHeader, fmt.Sprintf("%s. p99latency: %s. breachLatency: %s", err.Error(), p99, breachLatency)) - w.Header().Set(retryAfterHeader, fmt.Sprintf("%d", int(retryAfter.Seconds()))) // header should be in seconds(decimal-integer). - http.Error( - w, - err.Error(), - http.StatusServiceUnavailable, - ) - return - } - - wrappedHandler(w, r) - } + + p99 := lq.getP99(minSampleSize) + if p99.Milliseconds() > breachLatency.Milliseconds() && !sendProbe { + // drop request + err := fmt.Errorf("ong/middleware: server is overloaded, retry after %s", retryAfter) + w.Header().Set(ongMiddlewareErrorHeader, fmt.Sprintf("%s. p99latency: %s. breachLatency: %s", err.Error(), p99, breachLatency)) + w.Header().Set(retryAfterHeader, fmt.Sprintf("%d", int(retryAfter.Seconds()))) // header should be in seconds(decimal-integer). + http.Error( + w, + err.Error(), + http.StatusServiceUnavailable, + ) + return + } + + wrappedHandler.ServeHTTP(w, r) + }, + ) } type latencyQueue struct { diff --git a/middleware/loadshed_test.go b/middleware/loadshed_test.go index 9fb1af6f..2d4b127a 100644 --- a/middleware/loadshed_test.go +++ b/middleware/loadshed_test.go @@ -16,16 +16,18 @@ import ( const loadShedderTestHeader = "LoadShedderTestHeader" -func someLoadShedderHandler(msg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - lat := r.Header.Get(loadShedderTestHeader) - latency, err := strconv.Atoi(lat) - if err != nil { - panic(err) - } - time.Sleep(time.Duration(latency) * time.Millisecond) - fmt.Fprint(w, msg) - } +func someLoadShedderHandler(msg string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + lat := r.Header.Get(loadShedderTestHeader) + latency, err := strconv.Atoi(lat) + if err != nil { + panic(err) + } + time.Sleep(time.Duration(latency) * time.Millisecond) + fmt.Fprint(w, msg) + }, + ) } func TestLoadShedder(t *testing.T) { @@ -240,12 +242,14 @@ func TestLatencyQueue(t *testing.T) { }) } -func loadShedderBenchmarkHandler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - latency := time.Duration(rand.Intn(100)+1) * time.Millisecond - time.Sleep(latency) - fmt.Fprint(w, "hey") - } +func loadShedderBenchmarkHandler() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + latency := time.Duration(rand.Intn(100)+1) * time.Millisecond + time.Sleep(latency) + fmt.Fprint(w, "hey") + }, + ) } func BenchmarkLoadShedder(b *testing.B) { diff --git a/middleware/log.go b/middleware/log.go index 4e19ed78..f9bfce94 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -15,63 +15,65 @@ import ( ) // logger is a middleware that logs http requests and responses using [log.Logger]. -func logger(wrappedHandler http.HandlerFunc, l *slog.Logger) http.HandlerFunc { +func logger(wrappedHandler http.Handler, l *slog.Logger) http.Handler { // We pass the logger as an argument so that the middleware can share the same logger as the app. // That way, if the app logs an error, the middleware logs are also flushed. // This makes debugging easier for developers. // // However, each request should get its own context. That's why we call `logger.WithCtx` for every request. - return func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - lrw := &logRW{ - ResponseWriter: w, - } - defer func() { - msg := "http_server" - flds := []any{ - "clientIP", ClientIP(r), - "clientFingerPrint", ClientFingerPrint(r), - "method", r.Method, - "path", r.URL.Redacted(), - "code", lrw.code, - "status", http.StatusText(lrw.code), - "durationMS", time.Since(start).Milliseconds(), - } - if ongError := lrw.Header().Get(ongMiddlewareErrorHeader); ongError != "" { - extra := []any{"ongError", ongError} - flds = append(flds, extra...) + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + lrw := &logRW{ + ResponseWriter: w, } + defer func() { + msg := "http_server" + flds := []any{ + "clientIP", ClientIP(r), + "clientFingerPrint", ClientFingerPrint(r), + "method", r.Method, + "path", r.URL.Redacted(), + "code", lrw.code, + "status", http.StatusText(lrw.code), + "durationMS", time.Since(start).Milliseconds(), + } + if ongError := lrw.Header().Get(ongMiddlewareErrorHeader); ongError != "" { + extra := []any{"ongError", ongError} + flds = append(flds, extra...) + } - // Remove header so that users dont see it. - // - // Note that this may not actually work. - // According to: https://pkg.go.dev/net/http#ResponseWriter - // Changing the header map after a call to WriteHeader (or - // Write) has no effect unless the HTTP status code was of the - // 1xx class or the modified headers are trailers. - lrw.Header().Del(ongMiddlewareErrorHeader) - - // The logger should be in the defer block so that it uses the updated context containing the logID. - reqL := log.WithID(r.Context(), l) - - if lrw.code == http.StatusServiceUnavailable || lrw.code == http.StatusTooManyRequests && w.Header().Get(retryAfterHeader) != "" { - // We are either in load shedding or rate-limiting. - // Only log 10% of the errors. - shouldLog := mathRand.Intn(100) > 90 - if shouldLog { + // Remove header so that users dont see it. + // + // Note that this may not actually work. + // According to: https://pkg.go.dev/net/http#ResponseWriter + // Changing the header map after a call to WriteHeader (or + // Write) has no effect unless the HTTP status code was of the + // 1xx class or the modified headers are trailers. + lrw.Header().Del(ongMiddlewareErrorHeader) + + // The logger should be in the defer block so that it uses the updated context containing the logID. + reqL := log.WithID(r.Context(), l) + + if lrw.code == http.StatusServiceUnavailable || lrw.code == http.StatusTooManyRequests && w.Header().Get(retryAfterHeader) != "" { + // We are either in load shedding or rate-limiting. + // Only log 10% of the errors. + shouldLog := mathRand.Intn(100) > 90 + if shouldLog { + reqL.Error(msg, flds...) + } + } else if lrw.code >= http.StatusBadRequest { + // both client and server errors. reqL.Error(msg, flds...) + } else { + reqL.Info(msg, flds...) } - } else if lrw.code >= http.StatusBadRequest { - // both client and server errors. - reqL.Error(msg, flds...) - } else { - reqL.Info(msg, flds...) - } - }() + }() - wrappedHandler(lrw, r) - } + wrappedHandler.ServeHTTP(lrw, r) + }, + ) } // logRW provides an http.ResponseWriter interface, which logs requests/responses. diff --git a/middleware/log_test.go b/middleware/log_test.go index e8213bdc..af9653e0 100644 --- a/middleware/log_test.go +++ b/middleware/log_test.go @@ -23,22 +23,24 @@ const ( someLatencyMS = 3 ) -func someLogHandler(successMsg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // sleep so that the log middleware has some useful duration metrics to report. - time.Sleep(someLatencyMS * time.Millisecond) - if r.Header.Get(someLogHandlerHeader) != "" { - http.Error( - w, - r.Header.Get(someLogHandlerHeader), - http.StatusInternalServerError, - ) - return - } else { - fmt.Fprint(w, successMsg) - return - } - } +func someLogHandler(successMsg string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // sleep so that the log middleware has some useful duration metrics to report. + time.Sleep(someLatencyMS * time.Millisecond) + if r.Header.Get(someLogHandlerHeader) != "" { + http.Error( + w, + r.Header.Get(someLogHandlerHeader), + http.StatusInternalServerError, + ) + return + } else { + fmt.Fprint(w, successMsg) + return + } + }, + ) } func TestLogMiddleware(t *testing.T) { diff --git a/middleware/middleware.go b/middleware/middleware.go index 58dae2bf..2057c822 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -105,9 +105,9 @@ func WithOpts( // // allDefaultMiddlewares(wh, WithOpts("example.com", 443, "secretKey", RightIpStrategy, log.New(os.Stdout, 10))) func allDefaultMiddlewares( - wrappedHandler http.HandlerFunc, + wrappedHandler http.Handler, o Opts, -) http.HandlerFunc { +) http.Handler { domain := o.domain httpsPort := o.httpsPort allowedOrigins := o.allowedOrigins @@ -200,137 +200,147 @@ func allDefaultMiddlewares( // All is a middleware that allows all http methods. // // See the package documentation for the additional functionality provided by this middleware. -func All(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { +func All(wrappedHandler http.Handler, o Opts) http.Handler { return allDefaultMiddlewares( all(wrappedHandler), o, ) } -func all(wrappedHandler http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - wrappedHandler(w, r) - } +func all(wrappedHandler http.Handler) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + wrappedHandler.ServeHTTP(w, r) + }, + ) } // Get is a middleware that only allows http GET requests and http OPTIONS requests. // // See the package documentation for the additional functionality provided by this middleware. -func Get(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { +func Get(wrappedHandler http.Handler, o Opts) http.Handler { return allDefaultMiddlewares( get(wrappedHandler), o, ) } -func get(wrappedHandler http.HandlerFunc) http.HandlerFunc { +func get(wrappedHandler http.Handler) http.Handler { msg := "http method: %s not allowed. only allows http GET" - return func(w http.ResponseWriter, r *http.Request) { - // We do not need to allow `http.MethodOptions` here. - // This is coz, the cors middleware has already handled that for us and it comes before the Get middleware. - if r.Method != http.MethodGet { - errMsg := fmt.Sprintf(msg, r.Method) - w.Header().Set(ongMiddlewareErrorHeader, errMsg) - http.Error( - w, - errMsg, - http.StatusMethodNotAllowed, - ) - return - } + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // We do not need to allow `http.MethodOptions` here. + // This is coz, the cors middleware has already handled that for us and it comes before the Get middleware. + if r.Method != http.MethodGet { + errMsg := fmt.Sprintf(msg, r.Method) + w.Header().Set(ongMiddlewareErrorHeader, errMsg) + http.Error( + w, + errMsg, + http.StatusMethodNotAllowed, + ) + return + } - wrappedHandler(w, r) - } + wrappedHandler.ServeHTTP(w, r) + }, + ) } // Post is a middleware that only allows http POST requests and http OPTIONS requests. // // See the package documentation for the additional functionality provided by this middleware. -func Post(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { +func Post(wrappedHandler http.Handler, o Opts) http.Handler { return allDefaultMiddlewares( post(wrappedHandler), o, ) } -func post(wrappedHandler http.HandlerFunc) http.HandlerFunc { +func post(wrappedHandler http.Handler) http.Handler { msg := "http method: %s not allowed. only allows http POST" - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - errMsg := fmt.Sprintf(msg, r.Method) - w.Header().Set(ongMiddlewareErrorHeader, errMsg) - http.Error( - w, - errMsg, - http.StatusMethodNotAllowed, - ) - return - } + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + errMsg := fmt.Sprintf(msg, r.Method) + w.Header().Set(ongMiddlewareErrorHeader, errMsg) + http.Error( + w, + errMsg, + http.StatusMethodNotAllowed, + ) + return + } - wrappedHandler(w, r) - } + wrappedHandler.ServeHTTP(w, r) + }, + ) } // Head is a middleware that only allows http HEAD requests and http OPTIONS requests. // // See the package documentation for the additional functionality provided by this middleware. -func Head(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { +func Head(wrappedHandler http.Handler, o Opts) http.Handler { return allDefaultMiddlewares( head(wrappedHandler), o, ) } -func head(wrappedHandler http.HandlerFunc) http.HandlerFunc { +func head(wrappedHandler http.Handler) http.Handler { msg := "http method: %s not allowed. only allows http HEAD" - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodHead { - errMsg := fmt.Sprintf(msg, r.Method) - w.Header().Set(ongMiddlewareErrorHeader, errMsg) - http.Error( - w, - errMsg, - http.StatusMethodNotAllowed, - ) - return - } + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodHead { + errMsg := fmt.Sprintf(msg, r.Method) + w.Header().Set(ongMiddlewareErrorHeader, errMsg) + http.Error( + w, + errMsg, + http.StatusMethodNotAllowed, + ) + return + } - wrappedHandler(w, r) - } + wrappedHandler.ServeHTTP(w, r) + }, + ) } // Put is a middleware that only allows http PUT requests and http OPTIONS requests. // // See the package documentation for the additional functionality provided by this middleware. -func Put(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { +func Put(wrappedHandler http.Handler, o Opts) http.Handler { return allDefaultMiddlewares( put(wrappedHandler), o, ) } -func put(wrappedHandler http.HandlerFunc) http.HandlerFunc { +func put(wrappedHandler http.Handler) http.Handler { msg := "http method: %s not allowed. only allows http PUT" - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPut { - errMsg := fmt.Sprintf(msg, r.Method) - w.Header().Set(ongMiddlewareErrorHeader, errMsg) - http.Error( - w, - errMsg, - http.StatusMethodNotAllowed, - ) - return - } + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + errMsg := fmt.Sprintf(msg, r.Method) + w.Header().Set(ongMiddlewareErrorHeader, errMsg) + http.Error( + w, + errMsg, + http.StatusMethodNotAllowed, + ) + return + } - wrappedHandler(w, r) - } + wrappedHandler.ServeHTTP(w, r) + }, + ) } // Delete is a middleware that only allows http DELETE requests and http OPTIONS requests. // // See the package documentation for the additional functionality provided by this middleware. -func Delete(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { +func Delete(wrappedHandler http.Handler, o Opts) http.Handler { return allDefaultMiddlewares( deleteH(wrappedHandler), o, @@ -338,20 +348,22 @@ func Delete(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { } // this is not called `delete` since that is a Go builtin func for deleting from maps. -func deleteH(wrappedHandler http.HandlerFunc) http.HandlerFunc { +func deleteH(wrappedHandler http.Handler) http.Handler { msg := "http method: %s not allowed. only allows http DELETE" - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodDelete { - errMsg := fmt.Sprintf(msg, r.Method) - w.Header().Set(ongMiddlewareErrorHeader, errMsg) - http.Error( - w, - errMsg, - http.StatusMethodNotAllowed, - ) - return - } + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + errMsg := fmt.Sprintf(msg, r.Method) + w.Header().Set(ongMiddlewareErrorHeader, errMsg) + http.Error( + w, + errMsg, + http.StatusMethodNotAllowed, + ) + return + } - wrappedHandler(w, r) - } + wrappedHandler.ServeHTTP(w, r) + }, + ) } diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 5d76a0d6..f5b1022e 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -20,21 +20,23 @@ import ( "go.uber.org/goleak" ) -func someMiddlewareTestHandler(msg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodPost { - b, e := io.ReadAll(r.Body) - if e != nil { - panic(e) +func someMiddlewareTestHandler(msg string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + b, e := io.ReadAll(r.Body) + if e != nil { + panic(e) + } + if len(b) > 1 { + _, _ = w.Write(b) + return + } } - if len(b) > 1 { - _, _ = w.Write(b) - return - } - } - fmt.Fprint(w, msg) - } + fmt.Fprint(w, msg) + }, + ) } func TestMain(m *testing.M) { @@ -57,7 +59,7 @@ func TestAllMiddleware(t *testing.T) { errMsg := "not allowed. only allows http" tests := []struct { name string - middleware func(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc + middleware func(wrappedHandler http.Handler, o Opts) http.Handler httpMethod string expectedStatusCode int expectedMsg string @@ -355,13 +357,15 @@ func TestMiddlewareServer(t *testing.T) { }) } -func someBenchmarkAllMiddlewaresHandler() http.HandlerFunc { +func someBenchmarkAllMiddlewaresHandler() http.Handler { // bound stack growth. // see: https://github.com/komuw/ong/issues/54 msg := strings.Repeat("hello world", 2) - return func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, msg) - } + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, msg) + }, + ) } var resultBenchmarkAllMiddlewares int //nolint:gochecknoglobals diff --git a/middleware/ratelimiter.go b/middleware/ratelimiter.go index e19eadd5..ad42e595 100644 --- a/middleware/ratelimiter.go +++ b/middleware/ratelimiter.go @@ -25,34 +25,36 @@ import ( var rateLimiterSendRate = 100.00 //nolint:gochecknoglobals // rateLimiter is a middleware that limits requests by IP address. -func rateLimiter(wrappedHandler http.HandlerFunc) http.HandlerFunc { +func rateLimiter(wrappedHandler http.Handler) http.Handler { rl := newRl() const retryAfter = 15 * time.Minute - return func(w http.ResponseWriter, r *http.Request) { - rl.reSize() - - host := ClientIP(r) - tb := rl.get(host, rateLimiterSendRate) - - if !tb.allow() { - err := fmt.Errorf("ong/middleware: rate limited, retry after %s", retryAfter) - w.Header().Set(ongMiddlewareErrorHeader, err.Error()) - w.Header().Set(retryAfterHeader, fmt.Sprintf("%d", int(retryAfter.Seconds()))) // header should be in seconds(decimal-integer). - http.Error( - w, - err.Error(), - http.StatusTooManyRequests, - ) - return - } - - // todo: maybe also limit max body size using something like `http.MaxBytesHandler` - // todo: also maybe add another limiter for IP subnet. - // see limitation: https://github.com/komuw/ong/issues/17#issuecomment-1114551281 - - wrappedHandler(w, r) - } + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + rl.reSize() + + host := ClientIP(r) + tb := rl.get(host, rateLimiterSendRate) + + if !tb.allow() { + err := fmt.Errorf("ong/middleware: rate limited, retry after %s", retryAfter) + w.Header().Set(ongMiddlewareErrorHeader, err.Error()) + w.Header().Set(retryAfterHeader, fmt.Sprintf("%d", int(retryAfter.Seconds()))) // header should be in seconds(decimal-integer). + http.Error( + w, + err.Error(), + http.StatusTooManyRequests, + ) + return + } + + // todo: maybe also limit max body size using something like `http.MaxBytesHandler` + // todo: also maybe add another limiter for IP subnet. + // see limitation: https://github.com/komuw/ong/issues/17#issuecomment-1114551281 + + wrappedHandler.ServeHTTP(w, r) + }, + ) } // rl is a ratelimiter per IP address. diff --git a/middleware/ratelimiter_test.go b/middleware/ratelimiter_test.go index 588228c6..d15c80c7 100644 --- a/middleware/ratelimiter_test.go +++ b/middleware/ratelimiter_test.go @@ -13,10 +13,12 @@ import ( "golang.org/x/exp/slices" ) -func someRateLimiterHandler(msg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, msg) - } +func someRateLimiterHandler(msg string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, msg) + }, + ) } func TestRateLimiter(t *testing.T) { diff --git a/middleware/recoverer.go b/middleware/recoverer.go index 038454c6..43a51154 100644 --- a/middleware/recoverer.go +++ b/middleware/recoverer.go @@ -15,49 +15,51 @@ import ( // recoverer is a middleware that recovers from panics in wrappedHandler. // When/if a panic occurs, it logs the stack trace and returns an InternalServerError response. -func recoverer(wrappedHandler http.HandlerFunc, l *slog.Logger) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - defer func() { - errR := recover() - if errR != nil { - reqL := log.WithID(r.Context(), l) - - code := http.StatusInternalServerError - status := http.StatusText(code) - - msg := "http_server" - flds := []any{ - "error", fmt.Sprint(errR), - "clientIP", ClientIP(r), - "clientFingerPrint", ClientFingerPrint(r), - "method", r.Method, - "path", r.URL.Redacted(), - "code", code, - "status", status, - } - if ongError := w.Header().Get(ongMiddlewareErrorHeader); ongError != "" { - extra := []any{"ongError", ongError} - flds = append(flds, extra...) - } - w.Header().Del(ongMiddlewareErrorHeader) // remove header so that users dont see it. - - if e, ok := errR.(error); ok { - extra := []any{"err", errors.Wrap(e)} // wrap with ong/errors so that the log will have a stacktrace. - flds = append(flds, extra...) - reqL.Error(msg, flds...) - } else { - reqL.Error(msg, flds...) +func recoverer(wrappedHandler http.Handler, l *slog.Logger) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + defer func() { + errR := recover() + if errR != nil { + reqL := log.WithID(r.Context(), l) + + code := http.StatusInternalServerError + status := http.StatusText(code) + + msg := "http_server" + flds := []any{ + "error", fmt.Sprint(errR), + "clientIP", ClientIP(r), + "clientFingerPrint", ClientFingerPrint(r), + "method", r.Method, + "path", r.URL.Redacted(), + "code", code, + "status", status, + } + if ongError := w.Header().Get(ongMiddlewareErrorHeader); ongError != "" { + extra := []any{"ongError", ongError} + flds = append(flds, extra...) + } + w.Header().Del(ongMiddlewareErrorHeader) // remove header so that users dont see it. + + if e, ok := errR.(error); ok { + extra := []any{"err", errors.Wrap(e)} // wrap with ong/errors so that the log will have a stacktrace. + flds = append(flds, extra...) + reqL.Error(msg, flds...) + } else { + reqL.Error(msg, flds...) + } + + // respond. + http.Error( + w, + status, + code, + ) } + }() - // respond. - http.Error( - w, - status, - code, - ) - } - }() - - wrappedHandler(w, r) - } + wrappedHandler.ServeHTTP(w, r) + }, + ) } diff --git a/middleware/recoverer_test.go b/middleware/recoverer_test.go index 028c09ad..b1f713b3 100644 --- a/middleware/recoverer_test.go +++ b/middleware/recoverer_test.go @@ -19,30 +19,34 @@ import ( "golang.org/x/exp/slog" ) -func handlerThatPanics(msg string, shouldPanic bool, err error) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - x := 3 + 9 - _ = x - if shouldPanic { - panic(msg) - } - if err != nil { - panic(err) - } - - fmt.Fprint(w, msg) - } +func handlerThatPanics(msg string, shouldPanic bool, err error) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + x := 3 + 9 + _ = x + if shouldPanic { + panic(msg) + } + if err != nil { + panic(err) + } + + fmt.Fprint(w, msg) + }, + ) } -func anotherHandlerThatPanics() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - _ = 90 - someSlice := []string{"zero", "one", "two"} - _ = "kilo" - _ = someSlice[16] // panic - - fmt.Fprint(w, "anotherHandlerThatPanics") - } +func anotherHandlerThatPanics() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + _ = 90 + someSlice := []string{"zero", "one", "two"} + _ = "kilo" + _ = someSlice[16] // panic + + fmt.Fprint(w, "anotherHandlerThatPanics") + }, + ) } func TestPanic(t *testing.T) { diff --git a/middleware/redirect.go b/middleware/redirect.go index e22a0bcb..7c46f99f 100644 --- a/middleware/redirect.go +++ b/middleware/redirect.go @@ -12,49 +12,51 @@ import ( // // domain is the domain name of your website. // httpsPort is the tls port where http requests will be redirected to. -func httpsRedirector(wrappedHandler http.HandlerFunc, httpsPort uint16, domain string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - isTls := strings.EqualFold(r.URL.Scheme, "https") || r.TLS != nil - if !isTls { - url := r.URL - url.Scheme = "https" - url.Host = joinHostPort(domain, fmt.Sprint(httpsPort)) - path := url.String() +func httpsRedirector(wrappedHandler http.Handler, httpsPort uint16, domain string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + isTls := strings.EqualFold(r.URL.Scheme, "https") || r.TLS != nil + if !isTls { + url := r.URL + url.Scheme = "https" + url.Host = joinHostPort(domain, fmt.Sprint(httpsPort)) + path := url.String() - http.Redirect(w, r, path, http.StatusPermanentRedirect) - return - } + http.Redirect(w, r, path, http.StatusPermanentRedirect) + return + } - // A Host header field must be sent in all HTTP/1.1 request messages. - // Thus we expect `r.Host[0]` to always have a value. - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host - isHostBareIP := unicode.IsDigit(rune(r.Host[0])) - if isHostBareIP { - /* - the request has tried to access us via an IP address, redirect them to our domain. + // A Host header field must be sent in all HTTP/1.1 request messages. + // Thus we expect `r.Host[0]` to always have a value. + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host + isHostBareIP := unicode.IsDigit(rune(r.Host[0])) + if isHostBareIP { + /* + the request has tried to access us via an IP address, redirect them to our domain. - curl -vkIL 172.217.170.174 #google - HEAD / HTTP/1.1 - Host: 172.217.170.174 + curl -vkIL 172.217.170.174 #google + HEAD / HTTP/1.1 + Host: 172.217.170.174 - HTTP/1.1 301 Moved Permanently - Location: http://www.google.com/ - */ - url := r.URL - url.Scheme = "https" - _, port, err := net.SplitHostPort(r.Host) - if err != nil { - port = fmt.Sprint(httpsPort) - } - url.Host = joinHostPort(domain, port) - path := url.String() + HTTP/1.1 301 Moved Permanently + Location: http://www.google.com/ + */ + url := r.URL + url.Scheme = "https" + _, port, err := net.SplitHostPort(r.Host) + if err != nil { + port = fmt.Sprint(httpsPort) + } + url.Host = joinHostPort(domain, port) + path := url.String() - http.Redirect(w, r, path, http.StatusPermanentRedirect) - return - } + http.Redirect(w, r, path, http.StatusPermanentRedirect) + return + } - wrappedHandler(w, r) - } + wrappedHandler.ServeHTTP(w, r) + }, + ) } // joinHostPort is like `net.JoinHostPort` except suited for this package. diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go index 0e89b1aa..f3c849ac 100644 --- a/middleware/redirect_test.go +++ b/middleware/redirect_test.go @@ -13,19 +13,21 @@ import ( "go.akshayshah.org/attest" ) -func someHttpsRedirectorHandler(msg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodPost { - p := make([]byte, 16) - _, err := r.Body.Read(p) - if err == nil || err == io.EOF { - _, _ = w.Write(p) - return +func someHttpsRedirectorHandler(msg string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + p := make([]byte, 16) + _, err := r.Body.Read(p) + if err == nil || err == io.EOF { + _, _ = w.Write(p) + return + } } - } - fmt.Fprint(w, msg) - } + fmt.Fprint(w, msg) + }, + ) } const locationHeader = "Location" diff --git a/middleware/reload_protect.go b/middleware/reload_protect.go index 2cd70c84..e2597fa9 100644 --- a/middleware/reload_protect.go +++ b/middleware/reload_protect.go @@ -16,7 +16,7 @@ const reloadProtectCookiePrefix = "ong_form_reload_protect" // reloadProtector is a middleware that attempts to provides protection against a form re-submission when a user reloads/refreshes an already submitted web page/form. // // If such a situation is detected; this middleware will issue a http GET redirect to the same url. -func reloadProtector(wrappedHandler http.HandlerFunc, domain string) http.HandlerFunc { +func reloadProtector(wrappedHandler http.Handler, domain string) http.Handler { safeMethods := []string{ // safe methods under rfc7231: https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 http.MethodGet, @@ -24,47 +24,49 @@ func reloadProtector(wrappedHandler http.HandlerFunc, domain string) http.Handle http.MethodOptions, http.MethodTrace, } - return func(w http.ResponseWriter, r *http.Request) { - // It is possible for one to send a form without having added the requiste form http header. - if !slices.Contains(safeMethods, r.Method) { - // This could be a http POST/DELETE/etc + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // It is possible for one to send a form without having added the requiste form http header. + if !slices.Contains(safeMethods, r.Method) { + // This could be a http POST/DELETE/etc - theCookie := fmt.Sprintf("%s-%s", - reloadProtectCookiePrefix, - strings.ReplaceAll(r.URL.EscapedPath(), "/", ""), - ) + theCookie := fmt.Sprintf("%s-%s", + reloadProtectCookiePrefix, + strings.ReplaceAll(r.URL.EscapedPath(), "/", ""), + ) - // todo: should we check if gotCookie.MaxAge > 0 - gotCookie, err := r.Cookie(theCookie) - if err == nil && gotCookie != nil { - // It means that the form had been submitted before. + // todo: should we check if gotCookie.MaxAge > 0 + gotCookie, err := r.Cookie(theCookie) + if err == nil && gotCookie != nil { + // It means that the form had been submitted before. - cookie.Delete( - w, - theCookie, - domain, - ) - http.Redirect( - w, - r, - r.URL.String(), - // http 303(StatusSeeOther) is guaranteed by the spec to always use http GET. - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/303 - http.StatusSeeOther, - ) - return - } else { - cookie.Set( - w, - theCookie, - "YES", - domain, - 1*time.Hour, - false, - ) + cookie.Delete( + w, + theCookie, + domain, + ) + http.Redirect( + w, + r, + r.URL.String(), + // http 303(StatusSeeOther) is guaranteed by the spec to always use http GET. + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/303 + http.StatusSeeOther, + ) + return + } else { + cookie.Set( + w, + theCookie, + "YES", + domain, + 1*time.Hour, + false, + ) + } } - } - wrappedHandler(w, r) - } + wrappedHandler.ServeHTTP(w, r) + }, + ) } diff --git a/middleware/reload_protect_test.go b/middleware/reload_protect_test.go index bef3918f..2a359ca6 100644 --- a/middleware/reload_protect_test.go +++ b/middleware/reload_protect_test.go @@ -11,30 +11,32 @@ import ( "go.akshayshah.org/attest" ) -func someReloadProtectorHandler(msg, expectedFormName, expectedFormValue string) http.HandlerFunc { +func someReloadProtectorHandler(msg, expectedFormName, expectedFormValue string) http.Handler { // count is state that is affected by form submission. // eg, when a form is submitted; we create a new user. count := 0 - return func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodPost { - err := r.ParseForm() - if err != nil { - panic(err) + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + err := r.ParseForm() + if err != nil { + panic(err) + } + val := r.Form.Get(expectedFormName) + if val != expectedFormValue { + panic(fmt.Sprintf("expected = %v got = %v", expectedFormValue, val)) + } + + count = count + 1 + if count > 1 { + // form re-submission happened + panic("form re-submission happened") + } } - val := r.Form.Get(expectedFormName) - if val != expectedFormValue { - panic(fmt.Sprintf("expected = %v got = %v", expectedFormValue, val)) - } - - count = count + 1 - if count > 1 { - // form re-submission happened - panic("form re-submission happened") - } - } - fmt.Fprint(w, msg) - } + fmt.Fprint(w, msg) + }, + ) } func TestReloadProtector(t *testing.T) { diff --git a/middleware/security.go b/middleware/security.go index f0ceee49..742d6c00 100644 --- a/middleware/security.go +++ b/middleware/security.go @@ -35,80 +35,82 @@ const ( // securityHeaders is a middleware that adds some important HTTP security headers and assigns them sensible default values. // // Some of the headers set are Permissions-Policy, Content-securityHeaders-Policy, X-Content-Type-Options, X-Frame-Options, Cross-Origin-Resource-Policy, Cross-Origin-Opener-Policy, Referrer-Policy & Strict-Transport-securityHeaders -func securityHeaders(wrappedHandler http.HandlerFunc, domain string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - w.Header().Set( - permissionsPolicyHeader, - // flocOptOut disables floc which is otherwise ON by default - // see: https://github.com/WICG/floc#opting-out-of-computation - "interest-cohort=()", - ) - - // The nonce should be generated per request & propagated to the html of the page. - // The nonce can be fetched in middlewares using the GetCspNonce func - // - // eg; - // - nonce := id.Random(cspBytesTokenLength) - r = r.WithContext(context.WithValue(ctx, cspCtxKey, nonce)) - w.Header().Set( - cspHeader, - // - https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP - // - https://web.dev/security-headers/ - // - https://stackoverflow.com/a/66955464/2768067 - // - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/script-src - // - https://web.dev/security-headers/#tt +func securityHeaders(wrappedHandler http.Handler, domain string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + w.Header().Set( + permissionsPolicyHeader, + // flocOptOut disables floc which is otherwise ON by default + // see: https://github.com/WICG/floc#opting-out-of-computation + "interest-cohort=()", + ) + + // The nonce should be generated per request & propagated to the html of the page. + // The nonce can be fetched in middlewares using the GetCspNonce func // - // content is only permitted from: - // - the document's origin(and subdomains) - // - images may load from anywhere - // - media is allowed from domain(and its subdomains) - // - executable scripts is only allowed from self(& subdomains). - // - DOM xss(eg setting innerHtml) is blocked by require-trusted-types. - getCsp(domain, nonce), - ) - - w.Header().Set( - xContentOptionsHeader, - "nosniff", - ) - - w.Header().Set( - xFrameHeader, - "DENY", - ) - - w.Header().Set( - corpHeader, - "same-site", - ) - - w.Header().Set( - coopHeader, - "same-origin", - ) - - w.Header().Set( - referrerHeader, - // - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy - "strict-origin-when-cross-origin", - ) - - if r.TLS != nil { + // eg; + // + nonce := id.Random(cspBytesTokenLength) + r = r.WithContext(context.WithValue(ctx, cspCtxKey, nonce)) w.Header().Set( - stsHeader, - // - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Strict-Transport-Security - // A max-age(in seconds) of 2yrs is recommended - getSts(15*24*time.Hour), // 15 days + cspHeader, + // - https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP + // - https://web.dev/security-headers/ + // - https://stackoverflow.com/a/66955464/2768067 + // - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/script-src + // - https://web.dev/security-headers/#tt + // + // content is only permitted from: + // - the document's origin(and subdomains) + // - images may load from anywhere + // - media is allowed from domain(and its subdomains) + // - executable scripts is only allowed from self(& subdomains). + // - DOM xss(eg setting innerHtml) is blocked by require-trusted-types. + getCsp(domain, nonce), ) - } - wrappedHandler(w, r) - } + w.Header().Set( + xContentOptionsHeader, + "nosniff", + ) + + w.Header().Set( + xFrameHeader, + "DENY", + ) + + w.Header().Set( + corpHeader, + "same-site", + ) + + w.Header().Set( + coopHeader, + "same-origin", + ) + + w.Header().Set( + referrerHeader, + // - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy + "strict-origin-when-cross-origin", + ) + + if r.TLS != nil { + w.Header().Set( + stsHeader, + // - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Strict-Transport-Security + // A max-age(in seconds) of 2yrs is recommended + getSts(15*24*time.Hour), // 15 days + ) + } + + wrappedHandler.ServeHTTP(w, r) + }, + ) } // GetCspNonce returns the Content-Security-Policy nonce that was set for the http request in question. diff --git a/middleware/security_test.go b/middleware/security_test.go index 6bad00a3..0f42f6fb 100644 --- a/middleware/security_test.go +++ b/middleware/security_test.go @@ -16,11 +16,13 @@ import ( const nonceHeader = "CUSTOM-CSP-NONCE-TEST-HEADER" // echoHandler echos back in the response, the msg that was passed in. -func echoHandler(msg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(nonceHeader, GetCspNonce(r.Context())) - fmt.Fprint(w, msg) - } +func echoHandler(msg string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(nonceHeader, GetCspNonce(r.Context())) + fmt.Fprint(w, msg) + }, + ) } func TestSecurity(t *testing.T) { diff --git a/middleware/session.go b/middleware/session.go index b360e650..9bff4596 100644 --- a/middleware/session.go +++ b/middleware/session.go @@ -21,16 +21,18 @@ const ( // It lets you store and retrieve arbitrary data on a per-site-visitor basis. // // This middleware works best when used together with the [sess] package. -func session(wrappedHandler http.HandlerFunc, secretKey, domain string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // 1. Read from cookies and check for session cookie. - // 2. Get that cookie and save it to r.context - r = sess.Initialise(r, secretKey) - - srw := newSessRW(w, r, domain, secretKey) - - wrappedHandler(srw, r) - } +func session(wrappedHandler http.Handler, secretKey, domain string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // 1. Read from cookies and check for session cookie. + // 2. Get that cookie and save it to r.context + r = sess.Initialise(r, secretKey) + + srw := newSessRW(w, r, domain, secretKey) + + wrappedHandler.ServeHTTP(srw, r) + }, + ) } // sessRW provides an http.ResponseWriter interface, which provides http session functionality. diff --git a/middleware/session_test.go b/middleware/session_test.go index b444e728..adc66773 100644 --- a/middleware/session_test.go +++ b/middleware/session_test.go @@ -25,16 +25,18 @@ func bigMap() map[string]string { return y } -func someSessionHandler(msg, key, value string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - sess.Set(r, key, value) - sess.SetM(r, bigMap()) - fmt.Fprint(w, msg) - } +func someSessionHandler(msg, key, value string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + sess.Set(r, key, value) + sess.SetM(r, bigMap()) + fmt.Fprint(w, msg) + }, + ) } // See https://github.com/komuw/ong/issues/205 -func templateVarsHandler(t *testing.T, name string) http.HandlerFunc { +func templateVarsHandler(t *testing.T, name string) http.Handler { tmpl, err := template.New("myTpl").Parse(` @@ -47,17 +49,19 @@ func templateVarsHandler(t *testing.T, name string) http.HandlerFunc { t.Fatal(err) } - return func(w http.ResponseWriter, r *http.Request) { - sess.Set(r, "name", name) - - data := struct { - Name string - }{Name: name} - if err = tmpl.Execute(w, data); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + sess.Set(r, "name", name) + + data := struct { + Name string + }{Name: name} + if err = tmpl.Execute(w, data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + }, + ) } func TestSession(t *testing.T) { diff --git a/middleware/trace.go b/middleware/trace.go index 348b07c9..7998d4d1 100644 --- a/middleware/trace.go +++ b/middleware/trace.go @@ -13,35 +13,37 @@ import ( const logIDKey = string(octx.LogCtxKey) // trace is a middleware that adds logID to request and response. -func trace(wrappedHandler http.HandlerFunc, domain string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() +func trace(wrappedHandler http.Handler, domain string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() - // set cookie/headers/ctx for logID. - logID := getLogId(r) - ctx = context.WithValue( - ctx, - // using this custom key is important, instead of using `logIDKey` - octx.LogCtxKey, - logID, - ) - r = r.WithContext(ctx) - r.Header.Set(logIDKey, logID) - w.Header().Set(logIDKey, logID) - cookie.Set( - w, - logIDKey, - logID, - domain, - // Hopefully 15mins is enough. - // Google considers a session to be 30mins. - // https://support.google.com/analytics/answer/2731565?hl=en#time-based-expiration - 15*time.Minute, - false, - ) + // set cookie/headers/ctx for logID. + logID := getLogId(r) + ctx = context.WithValue( + ctx, + // using this custom key is important, instead of using `logIDKey` + octx.LogCtxKey, + logID, + ) + r = r.WithContext(ctx) + r.Header.Set(logIDKey, logID) + w.Header().Set(logIDKey, logID) + cookie.Set( + w, + logIDKey, + logID, + domain, + // Hopefully 15mins is enough. + // Google considers a session to be 30mins. + // https://support.google.com/analytics/answer/2731565?hl=en#time-based-expiration + 15*time.Minute, + false, + ) - wrappedHandler(w, r) - } + wrappedHandler.ServeHTTP(w, r) + }, + ) } // getLogId returns a logID from the request or autogenerated if not available from the request. diff --git a/middleware/trace_test.go b/middleware/trace_test.go index 9c919b80..402bda88 100644 --- a/middleware/trace_test.go +++ b/middleware/trace_test.go @@ -18,22 +18,24 @@ import ( const someTraceHandlerHeader = "someTraceHandlerHeader" -func someTraceHandler(successMsg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // sleep so that the trace middleware has some useful duration metrics to report. - time.Sleep(3 * time.Millisecond) - if r.Header.Get(someTraceHandlerHeader) != "" { - http.Error( - w, - r.Header.Get(someTraceHandlerHeader), - http.StatusInternalServerError, - ) - return - } else { - fmt.Fprint(w, successMsg) - return - } - } +func someTraceHandler(successMsg string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // sleep so that the trace middleware has some useful duration metrics to report. + time.Sleep(3 * time.Millisecond) + if r.Header.Get(someTraceHandlerHeader) != "" { + http.Error( + w, + r.Header.Get(someTraceHandlerHeader), + http.StatusInternalServerError, + ) + return + } else { + fmt.Fprint(w, successMsg) + return + } + }, + ) } func TestTraceMiddleware(t *testing.T) { diff --git a/mux/example_test.go b/mux/example_test.go index 58a634cb..d4540893 100644 --- a/mux/example_test.go +++ b/mux/example_test.go @@ -11,17 +11,21 @@ import ( "github.com/komuw/ong/mux" ) -func LoginHandler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - _, _ = fmt.Fprint(w, "welcome to your favorite website.") - } +func LoginHandler() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprint(w, "welcome to your favorite website.") + }, + ) } -func BooksByAuthorHandler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - author := mux.Param(r.Context(), "author") - _, _ = fmt.Fprintf(w, "fetching books by author: %s", author) - } +func BooksByAuthorHandler() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + author := mux.Param(r.Context(), "author") + _, _ = fmt.Fprintf(w, "fetching books by author: %s", author) + }, + ) } func ExampleMux() { diff --git a/mux/mux.go b/mux/mux.go index bd3a3fdf..df22263c 100644 --- a/mux/mux.go +++ b/mux/mux.go @@ -32,7 +32,7 @@ const ( func NewRoute( pattern string, method string, - handler http.HandlerFunc, + handler http.Handler, ) Route { h := getfunc(handler) if strings.Contains(h, "ong/middleware/") && @@ -70,7 +70,7 @@ type Mux struct { // Typically, an application should only have one Mux. // // It panics with a helpful error message if it detects conflicting routes. -func New(l *slog.Logger, opt middleware.Opts, notFoundHandler http.HandlerFunc, routes ...Route) Mux { +func New(l *slog.Logger, opt middleware.Opts, notFoundHandler http.Handler, routes ...Route) Mux { m := Mux{ l: l, router: newRouter(notFoundHandler), @@ -106,7 +106,7 @@ func New(l *slog.Logger, opt middleware.Opts, notFoundHandler http.HandlerFunc, return m } -func (m Mux) addPattern(method, pattern string, originalHandler, wrappingHandler http.HandlerFunc) { +func (m Mux) addPattern(method, pattern string, originalHandler, wrappingHandler http.Handler) { m.router.handle(method, pattern, originalHandler, wrappingHandler) } diff --git a/mux/mux_test.go b/mux/mux_test.go index 75f76b88..c11d9486 100644 --- a/mux/mux_test.go +++ b/mux/mux_test.go @@ -22,23 +22,29 @@ func getSecretKey() string { return key } -func someMuxHandler(msg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, msg) - } +func someMuxHandler(msg string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, msg) + }, + ) } -func thisIsAnotherMuxHandler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "thisIsAnotherMuxHandler") - } +func thisIsAnotherMuxHandler() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "thisIsAnotherMuxHandler") + }, + ) } -func checkAgeHandler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - age := Param(r.Context(), "age") - _, _ = fmt.Fprintf(w, "Age is %s", age) - } +func checkAgeHandler() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + age := Param(r.Context(), "age") + _, _ = fmt.Fprintf(w, "Age is %s", age) + }, + ) } func TestMain(m *testing.M) { diff --git a/mux/route.go b/mux/route.go index 5d9a404c..aa68a253 100644 --- a/mux/route.go +++ b/mux/route.go @@ -23,8 +23,8 @@ type Route struct { method string pattern string segments []string - originalHandler http.HandlerFunc // This is only needed to enhance the debug/panic message when conflicting routes are detected. - wrappingHandler http.HandlerFunc + originalHandler http.Handler // This is only needed to enhance the debug/panic message when conflicting routes are detected. + wrappingHandler http.Handler } func (r Route) String() string { @@ -76,13 +76,13 @@ func (r Route) match(ctx context.Context, segs []string) (context.Context, bool) type router struct { routes []Route // notFoundHandler is the handler to call when no routes match. - notFoundHandler http.HandlerFunc + notFoundHandler http.Handler } // NewRouter makes a new Router. -func newRouter(notFoundHandler http.HandlerFunc) *router { +func newRouter(notFoundHandler http.Handler) *router { if notFoundHandler == nil { - notFoundHandler = http.NotFound + notFoundHandler = http.NotFoundHandler() } return &router{notFoundHandler: notFoundHandler} @@ -95,7 +95,7 @@ func pathSegments(p string) []string { // handle adds a handler with the specified method and pattern. // Pattern can contain path segments such as: /item/:id which is // accessible via the Param function. -func (r *router) handle(method, pattern string, originalHandler, wrappingHandler http.HandlerFunc) { +func (r *router) handle(method, pattern string, originalHandler, wrappingHandler http.Handler) { if !strings.HasSuffix(pattern, "/") { // this will make the mux send requests for; // - localhost:80/check From 0f9aadc0b03f9bdc9efbfc6bbd8f8ce6e0a4d90d Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Wed, 7 Jun 2023 23:34:30 +0300 Subject: [PATCH 03/13] m --- cookie/cookie_test.go | 38 +++-- cookie/example_test.go | 48 +++--- example/handlers.go | 277 ++++++++++++++++++---------------- middleware/auth_test.go | 7 +- middleware/middleware.go | 2 +- middleware/middleware_test.go | 2 +- mux/route_test.go | 20 ++- server/example_test.go | 30 ++-- server/server_test.go | 36 +++-- sess/example_test.go | 24 +-- 10 files changed, 259 insertions(+), 225 deletions(-) diff --git a/cookie/cookie_test.go b/cookie/cookie_test.go index 455523b4..745d9db4 100644 --- a/cookie/cookie_test.go +++ b/cookie/cookie_test.go @@ -11,18 +11,22 @@ import ( "go.uber.org/goleak" ) -func setHandler(name, value, domain string, mAge time.Duration, jsAccess bool) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - Set(w, name, value, domain, mAge, jsAccess) - fmt.Fprint(w, "hello") - } +func setHandler(name, value, domain string, mAge time.Duration, jsAccess bool) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + Set(w, name, value, domain, mAge, jsAccess) + fmt.Fprint(w, "hello") + }, + ) } -func setEncryptedHandler(name, value, domain string, mAge time.Duration, secretKey string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - SetEncrypted(r, w, name, value, domain, mAge, secretKey) - fmt.Fprint(w, "hello") - } +func setEncryptedHandler(name, value, domain string, mAge time.Duration, secretKey string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + SetEncrypted(r, w, name, value, domain, mAge, secretKey) + fmt.Fprint(w, "hello") + }, + ) } func TestMain(m *testing.M) { @@ -185,12 +189,14 @@ func TestCookies(t *testing.T) { }) } -func deleteHandler(name, value, domain string, mAge time.Duration) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - Set(w, name, value, domain, mAge, false) - Delete(w, name, domain) - fmt.Fprint(w, "hello") - } +func deleteHandler(name, value, domain string, mAge time.Duration) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + Set(w, name, value, domain, mAge, false) + Delete(w, name, domain) + fmt.Fprint(w, "hello") + }, + ) } func TestDelete(t *testing.T) { diff --git a/cookie/example_test.go b/cookie/example_test.go index af50e034..4b31bc4d 100644 --- a/cookie/example_test.go +++ b/cookie/example_test.go @@ -15,29 +15,31 @@ type shoppingCart struct { Price uint8 } -func shoppingCartHandler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - cookieName := "cart" - secretKey := "superSecret" - item := shoppingCart{ItemName: "shoe", Price: 89} - - b, err := json.Marshal(item) - if err != nil { - panic(err) - } - - cookie.SetEncrypted( - r, - w, - cookieName, - string(b), - "example.com", - 2*time.Hour, - secretKey, - ) - - fmt.Fprint(w, "thanks for shopping!") - } +func shoppingCartHandler() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + cookieName := "cart" + secretKey := "superSecret" + item := shoppingCart{ItemName: "shoe", Price: 89} + + b, err := json.Marshal(item) + if err != nil { + panic(err) + } + + cookie.SetEncrypted( + r, + w, + cookieName, + string(b), + "example.com", + 2*time.Hour, + secretKey, + ) + + fmt.Fprint(w, "thanks for shopping!") + }, + ) } func ExampleSetEncrypted() { diff --git a/example/handlers.go b/example/handlers.go index d0a8dd1c..03f3cb90 100644 --- a/example/handlers.go +++ b/example/handlers.go @@ -50,7 +50,7 @@ func NewApp(d db, l *slog.Logger) app { // health handler showcases the use of: // - encryption/decryption. // - random id. -func (a app) health(secretKey string) http.HandlerFunc { +func (a app) health(secretKey string) http.Handler { var ( once sync.Once serverBoot time.Time = time.Now().UTC() @@ -59,17 +59,19 @@ func (a app) health(secretKey string) http.HandlerFunc { enc = cry.New(secretKey) ) - return func(w http.ResponseWriter, r *http.Request) { - // intialize somethings only once for perfomance. - once.Do(func() { - serverStart = time.Now().UTC() - }) + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // intialize somethings only once for perfomance. + once.Do(func() { + serverStart = time.Now().UTC() + }) - encryptedSrvID := enc.EncryptEncode(serverID) + encryptedSrvID := enc.EncryptEncode(serverID) - res := fmt.Sprintf("serverBoot=%s, serverStart=%s, serverId=%s\n", serverBoot, serverStart, encryptedSrvID) - _, _ = io.WriteString(w, res) - } + res := fmt.Sprintf("serverBoot=%s, serverStart=%s, serverId=%s\n", serverBoot, serverStart, encryptedSrvID) + _, _ = io.WriteString(w, res) + }, + ) } // check handler showcases the use of: @@ -78,60 +80,62 @@ func (a app) health(secretKey string) http.HandlerFunc { // - xcontext.Detach. // - safe http client. // - error wrapping. -func (a app) check(msg string) http.HandlerFunc { +func (a app) check(msg string) http.Handler { cli := client.Safe(a.l) - return func(w http.ResponseWriter, r *http.Request) { - cartID := "afakHda8eqL" - - age := mux.Param(r.Context(), "age") - sess.SetM(r, sess.M{ - "name": "John Doe", - "age": age, - "cart_id": cartID, - }) - - if sess.Get(r, "cart_id") != "" { - if sess.Get(r, "cart_id") != cartID { - http.Error(w, "wrong cartID", http.StatusBadRequest) - return + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + cartID := "afakHda8eqL" + + age := mux.Param(r.Context(), "age") + sess.SetM(r, sess.M{ + "name": "John Doe", + "age": age, + "cart_id": cartID, + }) + + if sess.Get(r, "cart_id") != "" { + if sess.Get(r, "cart_id") != cartID { + http.Error(w, "wrong cartID", http.StatusBadRequest) + return + } } - } - go func(ctx context.Context) { - ctx, cancel := context.WithTimeout(ctx, 2*time.Second) - defer cancel() + go func(ctx context.Context) { + ctx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() - makeReq := func(url string) (code int, errp error) { - defer errors.Dwrap(&errp) + makeReq := func(url string) (code int, errp error) { + defer errors.Dwrap(&errp) - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) - if err != nil { - return 0, err + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return 0, err + } + resp, err := cli.Do(req) + if err != nil { + return 0, err + } + defer func() { _ = resp.Body.Close() }() + + return resp.StatusCode, nil } - resp, err := cli.Do(req) + + l := log.WithID(ctx, a.l) + code, err := makeReq("https://example.com") if err != nil { - return 0, err + l.Error("handler error", err) } - defer func() { _ = resp.Body.Close() }() - - return resp.StatusCode, nil - } - - l := log.WithID(ctx, a.l) - code, err := makeReq("https://example.com") - if err != nil { - l.Error("handler error", err) - } - l.Info("req succeded", "code", code) - }( - // we need to detach context, - // since this goroutine can outlive the http request lifecycle. - xcontext.Detach(r.Context()), - ) - - _, _ = fmt.Fprintf(w, "hello %s. Age is %s", msg, age) - } + l.Info("req succeded", "code", code) + }( + // we need to detach context, + // since this goroutine can outlive the http request lifecycle. + xcontext.Detach(r.Context()), + ) + + _, _ = fmt.Fprintf(w, "hello %s. Age is %s", msg, age) + }, + ) } // login handler showcases the use of: @@ -139,7 +143,7 @@ func (a app) check(msg string) http.HandlerFunc { // - csp tokens. // - encrypted cookies // - hashing passwords. -func (a app) login(secretKey string) http.HandlerFunc { +func (a app) login(secretKey string) http.Handler { tmpl, err := template.New("myTpl").Parse(` @@ -195,81 +199,83 @@ func (a app) login(secretKey string) http.HandlerFunc { Name string } - return func(w http.ResponseWriter, r *http.Request) { - reqL := log.WithID(r.Context(), a.l) - - if r.Method != http.MethodPost { - data := struct { - CsrfTokenName string - CsrfTokenValue string - CspNonceValue string - }{ - CsrfTokenName: middleware.CsrfTokenFormName, - CsrfTokenValue: middleware.GetCsrfToken(r.Context()), - CspNonceValue: middleware.GetCspNonce(r.Context()), + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + reqL := log.WithID(r.Context(), a.l) + + if r.Method != http.MethodPost { + data := struct { + CsrfTokenName string + CsrfTokenValue string + CspNonceValue string + }{ + CsrfTokenName: middleware.CsrfTokenFormName, + CsrfTokenValue: middleware.GetCsrfToken(r.Context()), + CspNonceValue: middleware.GetCspNonce(r.Context()), + } + + if errE := tmpl.Execute(w, data); errE != nil { + http.Error(w, errE.Error(), http.StatusInternalServerError) + return + } + return } - if errE := tmpl.Execute(w, data); errE != nil { - http.Error(w, errE.Error(), http.StatusInternalServerError) + if errP := r.ParseForm(); errP != nil { + http.Error(w, errP.Error(), http.StatusInternalServerError) return } - return - } - - if errP := r.ParseForm(); errP != nil { - http.Error(w, errP.Error(), http.StatusInternalServerError) - return - } - - email := r.FormValue("email") - firstName := r.FormValue("firstName") - password := r.FormValue("password") - - u := &User{Email: email, Name: firstName} - - s, errM := json.Marshal(u) - if errM != nil { - http.Error(w, errM.Error(), http.StatusInternalServerError) - return - } - - cookieName := "example_session_cookie" - c, errM := cookie.GetEncrypted(r, cookieName, secretKey) - reqL.Info("login handler log cookie", - "err", errM, - "cookie", c, - ) - - cookie.SetEncrypted( - r, - w, - cookieName, - string(s), - "localhost", - 23*24*time.Hour, - secretKey, - ) - - existingPasswdHash := a.db.Get("passwd") - if e := cry.Eql(password, existingPasswdHash); e != nil { - // passwd did not exist before. - hashedPasswd, errH := cry.Hash(password) - if errH != nil { - http.Error(w, errH.Error(), http.StatusInternalServerError) + + email := r.FormValue("email") + firstName := r.FormValue("firstName") + password := r.FormValue("password") + + u := &User{Email: email, Name: firstName} + + s, errM := json.Marshal(u) + if errM != nil { + http.Error(w, errM.Error(), http.StatusInternalServerError) return } - a.db.Set("passwd", hashedPasswd) - } - _, _ = fmt.Fprintf(w, "you have submitted: %s", r.Form) - } + cookieName := "example_session_cookie" + c, errM := cookie.GetEncrypted(r, cookieName, secretKey) + reqL.Info("login handler log cookie", + "err", errM, + "cookie", c, + ) + + cookie.SetEncrypted( + r, + w, + cookieName, + string(s), + "localhost", + 23*24*time.Hour, + secretKey, + ) + + existingPasswdHash := a.db.Get("passwd") + if e := cry.Eql(password, existingPasswdHash); e != nil { + // passwd did not exist before. + hashedPasswd, errH := cry.Hash(password) + if errH != nil { + http.Error(w, errH.Error(), http.StatusInternalServerError) + return + } + a.db.Set("passwd", hashedPasswd) + } + + _, _ = fmt.Fprintf(w, "you have submitted: %s", r.Form) + }, + ) } // handleFileServer handler showcases the use of: // - middleware.ClientIP // - middleware.ClientFingerPrint // - logging -func (a app) handleFileServer() http.HandlerFunc { +func (a app) handleFileServer() http.Handler { // Do NOT let `http.FileServer` be able to serve your root directory. // Otherwise, your .git folder and other sensitive info(including http://localhost:65080/main.go) may be available // instead create a folder that only has your templates and server that. @@ -283,25 +289,30 @@ func (a app) handleFileServer() http.HandlerFunc { fs := http.FileServer(http.Dir(dir)) realHandler := http.StripPrefix("/staticAssets/", fs).ServeHTTP - return func(w http.ResponseWriter, r *http.Request) { - reqL := log.WithID(r.Context(), a.l) - reqL.Info("handleFileServer", "clientIP", middleware.ClientIP(r), "clientFingerPrint", middleware.ClientFingerPrint(r)) - slog.NewLogLogger(reqL.Handler(), log.LevelImmediate). - Println("this is now a Go standard library logger") + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + reqL := log.WithID(r.Context(), a.l) + reqL.Info("handleFileServer", "clientIP", middleware.ClientIP(r), "clientFingerPrint", middleware.ClientFingerPrint(r)) - realHandler(w, r) - } + slog.NewLogLogger(reqL.Handler(), log.LevelImmediate). + Println("this is now a Go standard library logger") + + realHandler(w, r) + }, + ) } // panic handler showcases the use of: // - recoverer middleware. -func (a app) panic() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - names := []string{"John", "Jane", "Kamau"} - _ = 93 - msg := "hey" - n := names[934] - _, _ = io.WriteString(w, fmt.Sprintf("%s %s", msg, n)) - } +func (a app) panic() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + names := []string{"John", "Jane", "Kamau"} + _ = 93 + msg := "hey" + n := names[934] + _, _ = io.WriteString(w, fmt.Sprintf("%s %s", msg, n)) + }, + ) } diff --git a/middleware/auth_test.go b/middleware/auth_test.go index 29ba167e..d75a8a23 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -13,9 +13,10 @@ import ( ) func protectedHandler(msg string) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, msg) - }, + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, msg) + }, ) } diff --git a/middleware/middleware.go b/middleware/middleware.go index 2057c822..e7b8d353 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -1,5 +1,5 @@ // Package middleware provides helpful functions that implement some common functionalities in http servers. -// A middleware is a function that returns a [http.HandlerFunc] +// A middleware is a function that returns a [http.Handler] // // The middlewares [All], [Get], [Post], [Head], [Put] & [Delete] wrap other internal middleware. // The effect of this is that the aforementioned middleware, in addition to their specialised functionality, will: diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index f5b1022e..8ecb4acb 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -1,5 +1,5 @@ // Package middleware provides helpful functions that implement some common functionalities in http servers. -// A middleware is a func that returns a http.HandlerFunc +// A middleware is a func that returns a http.Handler package middleware import ( diff --git a/mux/route_test.go b/mux/route_test.go index ae4379fb..d3ba591b 100644 --- a/mux/route_test.go +++ b/mux/route_test.go @@ -302,16 +302,20 @@ func TestMultipleRoutesDifferentMethods(t *testing.T) { attest.Equal(t, match, "POST") } -func firstRoute(msg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - _, _ = fmt.Fprint(w, msg) - } +func firstRoute(msg string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprint(w, msg) + }, + ) } -func secondRoute(msg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - _, _ = fmt.Fprint(w, msg) - } +func secondRoute(msg string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprint(w, msg) + }, + ) } func TestConflicts(t *testing.T) { diff --git a/server/example_test.go b/server/example_test.go index df3a5761..489ec35f 100644 --- a/server/example_test.go +++ b/server/example_test.go @@ -47,20 +47,24 @@ func ExampleRun() { } } -func hello(msg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - cspNonce := middleware.GetCspNonce(r.Context()) - csrfToken := middleware.GetCsrfToken(r.Context()) - fmt.Printf("hello called cspNonce: %s, csrfToken: %s", cspNonce, csrfToken) +func hello(msg string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + cspNonce := middleware.GetCspNonce(r.Context()) + csrfToken := middleware.GetCsrfToken(r.Context()) + fmt.Printf("hello called cspNonce: %s, csrfToken: %s", cspNonce, csrfToken) - // use msg, which is a dependency specific to this handler - fmt.Fprint(w, msg) - } + // use msg, which is a dependency specific to this handler + fmt.Fprint(w, msg) + }, + ) } -func check() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - age := mux.Param(r.Context(), "age") - _, _ = fmt.Fprintf(w, "Age is %s", age) - } +func check() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + age := mux.Param(r.Context(), "age") + _, _ = fmt.Fprintf(w, "Age is %s", age) + }, + ) } diff --git a/server/server_test.go b/server/server_test.go index aafc08ee..9ae4a8dc 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -136,19 +136,21 @@ func TestOpts(t *testing.T) { const tlsFingerPrintKey = "TlsFingerPrintKey" -func someServerTestHandler(msg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(tlsFingerPrintKey, middleware.ClientFingerPrint(r)) - - if _, err := io.ReadAll(r.Body); err != nil { - // This is where the error produced by `http.MaxBytesHandler` is produced at. - // ie, its produced when a read is made. - fmt.Fprint(w, err.Error()) - return - } +func someServerTestHandler(msg string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(tlsFingerPrintKey, middleware.ClientFingerPrint(r)) + + if _, err := io.ReadAll(r.Body); err != nil { + // This is where the error produced by `http.MaxBytesHandler` is produced at. + // ie, its produced when a read is made. + fmt.Fprint(w, err.Error()) + return + } - fmt.Fprint(w, msg) - } + fmt.Fprint(w, msg) + }, + ) } func TestServer(t *testing.T) { @@ -387,10 +389,12 @@ func TestServer(t *testing.T) { }) } -func benchmarkServerHandler(msg string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, msg) - } +func benchmarkServerHandler(msg string) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, msg) + }, + ) } var result int //nolint:gochecknoglobals diff --git a/sess/example_test.go b/sess/example_test.go index 275a2dc9..7e4c65c8 100644 --- a/sess/example_test.go +++ b/sess/example_test.go @@ -12,17 +12,19 @@ import ( "github.com/komuw/ong/sess" ) -func loginHandler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - mySession := map[string]string{ - "name": "John Doe", - "favorite_color": "red", - "height": "5 feet 6 inches", - } - sess.SetM(r, mySession) - - fmt.Fprint(w, "welcome again.") - } +func loginHandler() http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + mySession := map[string]string{ + "name": "John Doe", + "favorite_color": "red", + "height": "5 feet 6 inches", + } + sess.SetM(r, mySession) + + fmt.Fprint(w, "welcome again.") + }, + ) } func ExampleSetM() { From de2001464dca0f3043a3a9a85c057d3e1b4ac955 Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Wed, 7 Jun 2023 23:38:52 +0300 Subject: [PATCH 04/13] m --- mux/mux_test.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mux/mux_test.go b/mux/mux_test.go index c11d9486..18abc347 100644 --- a/mux/mux_test.go +++ b/mux/mux_test.go @@ -213,8 +213,8 @@ func TestMux(t *testing.T) { rStr := fmt.Sprintf("%v", r) attest.Subsequence(t, rStr, uri2) attest.Subsequence(t, rStr, method) - attest.Subsequence(t, rStr, "ong/mux/mux_test.go:26") // location where `someMuxHandler` is declared. - attest.Subsequence(t, rStr, "ong/mux/mux_test.go:32") // location where `thisIsAnotherMuxHandler` is declared. + attest.Subsequence(t, rStr, "ong/mux/mux_test.go:27") // location where `someMuxHandler` is declared. + attest.Subsequence(t, rStr, "ong/mux/mux_test.go:35") // location where `thisIsAnotherMuxHandler` is declared. }() _ = New( @@ -267,28 +267,28 @@ func TestMux(t *testing.T) { "api", "/api/", MethodGet, - "ong/mux/mux_test.go:26", // location where `someMuxHandler` is declared. + "ong/mux/mux_test.go:27", // location where `someMuxHandler` is declared. }, { "success with prefix slash", "/api", "/api/", MethodGet, - "ong/mux/mux_test.go:26", // location where `someMuxHandler` is declared. + "ong/mux/mux_test.go:27", // location where `someMuxHandler` is declared. }, { "success with suffix slash", "api/", "/api/", MethodGet, - "ong/mux/mux_test.go:26", // location where `someMuxHandler` is declared. + "ong/mux/mux_test.go:27", // location where `someMuxHandler` is declared. }, { "success with all slashes", "/api/", "/api/", MethodGet, - "ong/mux/mux_test.go:26", // location where `someMuxHandler` is declared. + "ong/mux/mux_test.go:27", // location where `someMuxHandler` is declared. }, { "failure", @@ -302,14 +302,14 @@ func TestMux(t *testing.T) { "check/2625", "/check/:age/", MethodAll, - "ong/mux/mux_test.go:38", // location where `checkAgeHandler` is declared. + "ong/mux/mux_test.go:43", // location where `checkAgeHandler` is declared. }, { "url with domain name", "https://localhost/check/2625", "/check/:age/", MethodAll, - "ong/mux/mux_test.go:38", // location where `checkAgeHandler` is declared. + "ong/mux/mux_test.go:43", // location where `checkAgeHandler` is declared. }, } From cd845500b68ca4a22c4498e08af1411561f8948f Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Wed, 7 Jun 2023 23:41:16 +0300 Subject: [PATCH 05/13] m --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b6a0d83..77b57dd9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ Most recent version is listed first. # v0.0.49 - Add mux Resolve function: https://github.com/komuw/ong/pull/268 +- Use http.Handler as the http middleware instead of http.HandlerFunc: https://github.com/komuw/ong/pull/269 ## v0.0.48 - Change attest import path: https://github.com/komuw/ong/pull/265 From 63f8c37b477197e5a0f4a453a499782e12680f19 Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Wed, 7 Jun 2023 23:55:02 +0300 Subject: [PATCH 06/13] m --- cookie/cookie_test.go | 38 ++-- cookie/example_test.go | 48 +++--- example/handlers.go | 276 ++++++++++++++---------------- middleware/auth.go | 42 +++-- middleware/auth_test.go | 10 +- middleware/client_ip.go | 40 ++--- middleware/client_ip_test.go | 14 +- middleware/cors.go | 32 ++-- middleware/cors_test.go | 10 +- middleware/csrf.go | 268 ++++++++++++++--------------- middleware/csrf_test.go | 12 +- middleware/example_test.go | 28 ++- middleware/fingerprint.go | 36 ++-- middleware/gzip.go | 60 ++++--- middleware/gzip_test.go | 88 +++++----- middleware/loadshed.go | 80 +++++---- middleware/loadshed_test.go | 36 ++-- middleware/log.go | 9 +- middleware/log_test.go | 34 ++-- middleware/middleware.go | 26 +-- middleware/middleware_test.go | 38 ++-- middleware/ratelimiter.go | 52 +++--- middleware/ratelimiter_test.go | 10 +- middleware/recoverer.go | 88 +++++----- middleware/recoverer_test.go | 48 +++--- middleware/redirect.go | 76 ++++---- middleware/redirect_test.go | 24 ++- middleware/reload_protect.go | 80 +++++---- middleware/reload_protect_test.go | 40 ++--- middleware/security.go | 142 ++++++++------- middleware/security_test.go | 12 +- middleware/session.go | 22 ++- middleware/session_test.go | 40 ++--- middleware/trace.go | 56 +++--- middleware/trace_test.go | 34 ++-- mux/example_test.go | 22 +-- mux/mux_test.go | 32 ++-- mux/route_test.go | 20 +-- server/example_test.go | 30 ++-- server/server_test.go | 36 ++-- sess/example_test.go | 24 ++- 41 files changed, 996 insertions(+), 1117 deletions(-) diff --git a/cookie/cookie_test.go b/cookie/cookie_test.go index 745d9db4..455523b4 100644 --- a/cookie/cookie_test.go +++ b/cookie/cookie_test.go @@ -11,22 +11,18 @@ import ( "go.uber.org/goleak" ) -func setHandler(name, value, domain string, mAge time.Duration, jsAccess bool) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - Set(w, name, value, domain, mAge, jsAccess) - fmt.Fprint(w, "hello") - }, - ) +func setHandler(name, value, domain string, mAge time.Duration, jsAccess bool) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + Set(w, name, value, domain, mAge, jsAccess) + fmt.Fprint(w, "hello") + } } -func setEncryptedHandler(name, value, domain string, mAge time.Duration, secretKey string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - SetEncrypted(r, w, name, value, domain, mAge, secretKey) - fmt.Fprint(w, "hello") - }, - ) +func setEncryptedHandler(name, value, domain string, mAge time.Duration, secretKey string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + SetEncrypted(r, w, name, value, domain, mAge, secretKey) + fmt.Fprint(w, "hello") + } } func TestMain(m *testing.M) { @@ -189,14 +185,12 @@ func TestCookies(t *testing.T) { }) } -func deleteHandler(name, value, domain string, mAge time.Duration) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - Set(w, name, value, domain, mAge, false) - Delete(w, name, domain) - fmt.Fprint(w, "hello") - }, - ) +func deleteHandler(name, value, domain string, mAge time.Duration) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + Set(w, name, value, domain, mAge, false) + Delete(w, name, domain) + fmt.Fprint(w, "hello") + } } func TestDelete(t *testing.T) { diff --git a/cookie/example_test.go b/cookie/example_test.go index 4b31bc4d..af50e034 100644 --- a/cookie/example_test.go +++ b/cookie/example_test.go @@ -15,31 +15,29 @@ type shoppingCart struct { Price uint8 } -func shoppingCartHandler() http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - cookieName := "cart" - secretKey := "superSecret" - item := shoppingCart{ItemName: "shoe", Price: 89} - - b, err := json.Marshal(item) - if err != nil { - panic(err) - } - - cookie.SetEncrypted( - r, - w, - cookieName, - string(b), - "example.com", - 2*time.Hour, - secretKey, - ) - - fmt.Fprint(w, "thanks for shopping!") - }, - ) +func shoppingCartHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + cookieName := "cart" + secretKey := "superSecret" + item := shoppingCart{ItemName: "shoe", Price: 89} + + b, err := json.Marshal(item) + if err != nil { + panic(err) + } + + cookie.SetEncrypted( + r, + w, + cookieName, + string(b), + "example.com", + 2*time.Hour, + secretKey, + ) + + fmt.Fprint(w, "thanks for shopping!") + } } func ExampleSetEncrypted() { diff --git a/example/handlers.go b/example/handlers.go index 03f3cb90..4117c75c 100644 --- a/example/handlers.go +++ b/example/handlers.go @@ -50,7 +50,7 @@ func NewApp(d db, l *slog.Logger) app { // health handler showcases the use of: // - encryption/decryption. // - random id. -func (a app) health(secretKey string) http.Handler { +func (a app) health(secretKey string) http.HandlerFunc { var ( once sync.Once serverBoot time.Time = time.Now().UTC() @@ -59,19 +59,17 @@ func (a app) health(secretKey string) http.Handler { enc = cry.New(secretKey) ) - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - // intialize somethings only once for perfomance. - once.Do(func() { - serverStart = time.Now().UTC() - }) + return func(w http.ResponseWriter, r *http.Request) { + // intialize somethings only once for perfomance. + once.Do(func() { + serverStart = time.Now().UTC() + }) - encryptedSrvID := enc.EncryptEncode(serverID) + encryptedSrvID := enc.EncryptEncode(serverID) - res := fmt.Sprintf("serverBoot=%s, serverStart=%s, serverId=%s\n", serverBoot, serverStart, encryptedSrvID) - _, _ = io.WriteString(w, res) - }, - ) + res := fmt.Sprintf("serverBoot=%s, serverStart=%s, serverId=%s\n", serverBoot, serverStart, encryptedSrvID) + _, _ = io.WriteString(w, res) + } } // check handler showcases the use of: @@ -80,62 +78,60 @@ func (a app) health(secretKey string) http.Handler { // - xcontext.Detach. // - safe http client. // - error wrapping. -func (a app) check(msg string) http.Handler { +func (a app) check(msg string) http.HandlerFunc { cli := client.Safe(a.l) - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - cartID := "afakHda8eqL" - - age := mux.Param(r.Context(), "age") - sess.SetM(r, sess.M{ - "name": "John Doe", - "age": age, - "cart_id": cartID, - }) - - if sess.Get(r, "cart_id") != "" { - if sess.Get(r, "cart_id") != cartID { - http.Error(w, "wrong cartID", http.StatusBadRequest) - return - } - } + return func(w http.ResponseWriter, r *http.Request) { + cartID := "afakHda8eqL" - go func(ctx context.Context) { - ctx, cancel := context.WithTimeout(ctx, 2*time.Second) - defer cancel() + age := mux.Param(r.Context(), "age") + sess.SetM(r, sess.M{ + "name": "John Doe", + "age": age, + "cart_id": cartID, + }) - makeReq := func(url string) (code int, errp error) { - defer errors.Dwrap(&errp) + if sess.Get(r, "cart_id") != "" { + if sess.Get(r, "cart_id") != cartID { + http.Error(w, "wrong cartID", http.StatusBadRequest) + return + } + } - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) - if err != nil { - return 0, err - } - resp, err := cli.Do(req) - if err != nil { - return 0, err - } - defer func() { _ = resp.Body.Close() }() + go func(ctx context.Context) { + ctx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() - return resp.StatusCode, nil - } + makeReq := func(url string) (code int, errp error) { + defer errors.Dwrap(&errp) - l := log.WithID(ctx, a.l) - code, err := makeReq("https://example.com") + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { - l.Error("handler error", err) + return 0, err } - l.Info("req succeded", "code", code) - }( - // we need to detach context, - // since this goroutine can outlive the http request lifecycle. - xcontext.Detach(r.Context()), - ) - - _, _ = fmt.Fprintf(w, "hello %s. Age is %s", msg, age) - }, - ) + resp, err := cli.Do(req) + if err != nil { + return 0, err + } + defer func() { _ = resp.Body.Close() }() + + return resp.StatusCode, nil + } + + l := log.WithID(ctx, a.l) + code, err := makeReq("https://example.com") + if err != nil { + l.Error("handler error", err) + } + l.Info("req succeded", "code", code) + }( + // we need to detach context, + // since this goroutine can outlive the http request lifecycle. + xcontext.Detach(r.Context()), + ) + + _, _ = fmt.Fprintf(w, "hello %s. Age is %s", msg, age) + } } // login handler showcases the use of: @@ -143,7 +139,7 @@ func (a app) check(msg string) http.Handler { // - csp tokens. // - encrypted cookies // - hashing passwords. -func (a app) login(secretKey string) http.Handler { +func (a app) login(secretKey string) http.HandlerFunc { tmpl, err := template.New("myTpl").Parse(` @@ -199,83 +195,81 @@ func (a app) login(secretKey string) http.Handler { Name string } - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - reqL := log.WithID(r.Context(), a.l) - - if r.Method != http.MethodPost { - data := struct { - CsrfTokenName string - CsrfTokenValue string - CspNonceValue string - }{ - CsrfTokenName: middleware.CsrfTokenFormName, - CsrfTokenValue: middleware.GetCsrfToken(r.Context()), - CspNonceValue: middleware.GetCspNonce(r.Context()), - } - - if errE := tmpl.Execute(w, data); errE != nil { - http.Error(w, errE.Error(), http.StatusInternalServerError) - return - } - return + return func(w http.ResponseWriter, r *http.Request) { + reqL := log.WithID(r.Context(), a.l) + + if r.Method != http.MethodPost { + data := struct { + CsrfTokenName string + CsrfTokenValue string + CspNonceValue string + }{ + CsrfTokenName: middleware.CsrfTokenFormName, + CsrfTokenValue: middleware.GetCsrfToken(r.Context()), + CspNonceValue: middleware.GetCspNonce(r.Context()), } - if errP := r.ParseForm(); errP != nil { - http.Error(w, errP.Error(), http.StatusInternalServerError) + if errE := tmpl.Execute(w, data); errE != nil { + http.Error(w, errE.Error(), http.StatusInternalServerError) return } - - email := r.FormValue("email") - firstName := r.FormValue("firstName") - password := r.FormValue("password") - - u := &User{Email: email, Name: firstName} - - s, errM := json.Marshal(u) - if errM != nil { - http.Error(w, errM.Error(), http.StatusInternalServerError) + return + } + + if errP := r.ParseForm(); errP != nil { + http.Error(w, errP.Error(), http.StatusInternalServerError) + return + } + + email := r.FormValue("email") + firstName := r.FormValue("firstName") + password := r.FormValue("password") + + u := &User{Email: email, Name: firstName} + + s, errM := json.Marshal(u) + if errM != nil { + http.Error(w, errM.Error(), http.StatusInternalServerError) + return + } + + cookieName := "example_session_cookie" + c, errM := cookie.GetEncrypted(r, cookieName, secretKey) + reqL.Info("login handler log cookie", + "err", errM, + "cookie", c, + ) + + cookie.SetEncrypted( + r, + w, + cookieName, + string(s), + "localhost", + 23*24*time.Hour, + secretKey, + ) + + existingPasswdHash := a.db.Get("passwd") + if e := cry.Eql(password, existingPasswdHash); e != nil { + // passwd did not exist before. + hashedPasswd, errH := cry.Hash(password) + if errH != nil { + http.Error(w, errH.Error(), http.StatusInternalServerError) return } + a.db.Set("passwd", hashedPasswd) + } - cookieName := "example_session_cookie" - c, errM := cookie.GetEncrypted(r, cookieName, secretKey) - reqL.Info("login handler log cookie", - "err", errM, - "cookie", c, - ) - - cookie.SetEncrypted( - r, - w, - cookieName, - string(s), - "localhost", - 23*24*time.Hour, - secretKey, - ) - - existingPasswdHash := a.db.Get("passwd") - if e := cry.Eql(password, existingPasswdHash); e != nil { - // passwd did not exist before. - hashedPasswd, errH := cry.Hash(password) - if errH != nil { - http.Error(w, errH.Error(), http.StatusInternalServerError) - return - } - a.db.Set("passwd", hashedPasswd) - } - - _, _ = fmt.Fprintf(w, "you have submitted: %s", r.Form) - }, - ) + _, _ = fmt.Fprintf(w, "you have submitted: %s", r.Form) + } } // handleFileServer handler showcases the use of: // - middleware.ClientIP // - middleware.ClientFingerPrint // - logging -func (a app) handleFileServer() http.Handler { +func (a app) handleFileServer() http.HandlerFunc { // Do NOT let `http.FileServer` be able to serve your root directory. // Otherwise, your .git folder and other sensitive info(including http://localhost:65080/main.go) may be available // instead create a folder that only has your templates and server that. @@ -290,29 +284,25 @@ func (a app) handleFileServer() http.Handler { fs := http.FileServer(http.Dir(dir)) realHandler := http.StripPrefix("/staticAssets/", fs).ServeHTTP - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - reqL := log.WithID(r.Context(), a.l) - reqL.Info("handleFileServer", "clientIP", middleware.ClientIP(r), "clientFingerPrint", middleware.ClientFingerPrint(r)) + return func(w http.ResponseWriter, r *http.Request) { + reqL := log.WithID(r.Context(), a.l) + reqL.Info("handleFileServer", "clientIP", middleware.ClientIP(r), "clientFingerPrint", middleware.ClientFingerPrint(r)) - slog.NewLogLogger(reqL.Handler(), log.LevelImmediate). - Println("this is now a Go standard library logger") + slog.NewLogLogger(reqL.Handler(), log.LevelImmediate). + Println("this is now a Go standard library logger") - realHandler(w, r) - }, - ) + realHandler(w, r) + } } // panic handler showcases the use of: // - recoverer middleware. -func (a app) panic() http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - names := []string{"John", "Jane", "Kamau"} - _ = 93 - msg := "hey" - n := names[934] - _, _ = io.WriteString(w, fmt.Sprintf("%s %s", msg, n)) - }, - ) +func (a app) panic() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + names := []string{"John", "Jane", "Kamau"} + _ = 93 + msg := "hey" + n := names[934] + _, _ = io.WriteString(w, fmt.Sprintf("%s %s", msg, n)) + } } diff --git a/middleware/auth.go b/middleware/auth.go index b21dba74..5a0c8c7a 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -9,7 +9,7 @@ import ( const minPasswdSize = 16 // BasicAuth is a middleware that protects wrappedHandler using basic authentication. -func BasicAuth(wrappedHandler http.Handler, user, passwd string) http.Handler { +func BasicAuth(wrappedHandler http.Handler, user, passwd string) http.HandlerFunc { const realm = "enter username and password" if len(passwd) < minPasswdSize { @@ -23,25 +23,23 @@ func BasicAuth(wrappedHandler http.Handler, user, passwd string) http.Handler { http.Error(w, "Unauthorized", http.StatusUnauthorized) } - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - u, p, ok := r.BasicAuth() - if u == "" || p == "" || !ok { - e(w) - return - } - - if subtle.ConstantTimeCompare([]byte(u), []byte(user)) != 1 { - e(w) - return - } - - if subtle.ConstantTimeCompare([]byte(p), []byte(passwd)) != 1 { - e(w) - return - } - - wrappedHandler.ServeHTTP(w, r) - }, - ) + return func(w http.ResponseWriter, r *http.Request) { + u, p, ok := r.BasicAuth() + if u == "" || p == "" || !ok { + e(w) + return + } + + if subtle.ConstantTimeCompare([]byte(u), []byte(user)) != 1 { + e(w) + return + } + + if subtle.ConstantTimeCompare([]byte(p), []byte(passwd)) != 1 { + e(w) + return + } + + wrappedHandler.ServeHTTP(w, r) + } } diff --git a/middleware/auth_test.go b/middleware/auth_test.go index d75a8a23..d0f01e4d 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -12,12 +12,10 @@ import ( "go.akshayshah.org/attest" ) -func protectedHandler(msg string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, msg) - }, - ) +func protectedHandler(msg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, msg) + } } func TestBasicAuth(t *testing.T) { diff --git a/middleware/client_ip.go b/middleware/client_ip.go index 005ae819..c9fd81e0 100644 --- a/middleware/client_ip.go +++ b/middleware/client_ip.go @@ -68,26 +68,24 @@ func ClientIP(r *http.Request) string { // Fetching the "real" client is done in a best-effort basis and can be [grossly inaccurate & precarious]. // // [grossly inaccurate & precarious]: https://adam-p.ca/blog/2022/03/x-forwarded-for/ -func clientIP(wrappedHandler http.Handler, strategy ClientIPstrategy) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - var clientAddr string - switch v := strategy; v { - case DirectIpStrategy: - clientAddr = clientip.DirectAddress(r.RemoteAddr) - case LeftIpStrategy: - clientAddr = clientip.Leftmost(r.Header) - case RightIpStrategy: - clientAddr = clientip.Rightmost(r.Header) - case ProxyStrategy: - clientAddr = clientip.ProxyHeader(r.Header) - default: - // treat everything else as a `singleIP` strategy - clientAddr = clientip.SingleIPHeader(string(v), r.Header) - } +func clientIP(wrappedHandler http.Handler, strategy ClientIPstrategy) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var clientAddr string + switch v := strategy; v { + case DirectIpStrategy: + clientAddr = clientip.DirectAddress(r.RemoteAddr) + case LeftIpStrategy: + clientAddr = clientip.Leftmost(r.Header) + case RightIpStrategy: + clientAddr = clientip.Rightmost(r.Header) + case ProxyStrategy: + clientAddr = clientip.ProxyHeader(r.Header) + default: + // treat everything else as a `singleIP` strategy + clientAddr = clientip.SingleIPHeader(string(v), r.Header) + } - r = clientip.With(r, clientAddr) - wrappedHandler.ServeHTTP(w, r) - }, - ) + r = clientip.With(r, clientAddr) + wrappedHandler.ServeHTTP(w, r) + } } diff --git a/middleware/client_ip_test.go b/middleware/client_ip_test.go index 33b86064..5339b655 100644 --- a/middleware/client_ip_test.go +++ b/middleware/client_ip_test.go @@ -17,14 +17,12 @@ const ( proxyHeader = "PROXY" ) -func someClientIpHandler(msg string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - ip := ClientIP(r) - res := fmt.Sprintf("message: %s, ip: %s", msg, ip) - fmt.Fprint(w, res) - }, - ) +func someClientIpHandler(msg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ip := ClientIP(r) + res := fmt.Sprintf("message: %s, ip: %s", msg, ip) + fmt.Fprint(w, res) + } } func TestClientIP(t *testing.T) { diff --git a/middleware/cors.go b/middleware/cors.go index aa36c31e..2f1b9ea9 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -64,27 +64,25 @@ func cors( allowedOrigins []string, allowedMethods []string, allowedHeaders []string, -) http.Handler { +) http.HandlerFunc { allowedOrigins, allowedWildcardOrigins := getOrigins(allowedOrigins) allowedMethods = getMethods(allowedMethods) allowedHeaders = getHeaders(allowedHeaders) - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodOptions && r.Header.Get(acrmHeader) != "" { - // handle preflight request - handlePreflight(w, r, allowedOrigins, allowedWildcardOrigins, allowedMethods, allowedHeaders) - // Preflight requests are standalone and should stop the chain as some other - // middleware may not handle OPTIONS requests correctly. One typical example - // is authentication middleware ; OPTIONS requests won't carry authentication headers. - w.WriteHeader(http.StatusNoContent) - } else { - // handle actual request - handleActualRequest(w, r, allowedOrigins, allowedWildcardOrigins, allowedMethods) - wrappedHandler.ServeHTTP(w, r) - } - }, - ) + return func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodOptions && r.Header.Get(acrmHeader) != "" { + // handle preflight request + handlePreflight(w, r, allowedOrigins, allowedWildcardOrigins, allowedMethods, allowedHeaders) + // Preflight requests are standalone and should stop the chain as some other + // middleware may not handle OPTIONS requests correctly. One typical example + // is authentication middleware ; OPTIONS requests won't carry authentication headers. + w.WriteHeader(http.StatusNoContent) + } else { + // handle actual request + handleActualRequest(w, r, allowedOrigins, allowedWildcardOrigins, allowedMethods) + wrappedHandler.ServeHTTP(w, r) + } + } } func handlePreflight( diff --git a/middleware/cors_test.go b/middleware/cors_test.go index 79004a34..e61dc4e8 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -11,12 +11,10 @@ import ( "go.akshayshah.org/attest" ) -func someCorsHandler(msg string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, msg) - }, - ) +func someCorsHandler(msg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, msg) + } } func TestCorsPreflight(t *testing.T) { diff --git a/middleware/csrf.go b/middleware/csrf.go index b50f1dfd..5e116c79 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -67,151 +67,149 @@ const ( // csrf is a middleware that provides protection against Cross Site Request Forgeries. // // If a csrf token is not provided(or is not valid), when it ought to have been; this middleware will issue a http GET redirect to the same url. -func csrf(wrappedHandler http.Handler, secretKey, domain string) http.Handler { +func csrf(wrappedHandler http.Handler, secretKey, domain string) http.HandlerFunc { once.Do(func() { enc = cry.New(secretKey) }) msgToEncrypt := id.Random(16) - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - // - https://docs.djangoproject.com/en/4.0/ref/csrf/ - // - https://github.com/django/django/blob/4.0.5/django/middleware/csrf.py - // - https://github.com/gofiber/fiber/blob/v2.34.1/middleware/csrf/csrf.go + return func(w http.ResponseWriter, r *http.Request) { + // - https://docs.djangoproject.com/en/4.0/ref/csrf/ + // - https://github.com/django/django/blob/4.0.5/django/middleware/csrf.py + // - https://github.com/gofiber/fiber/blob/v2.34.1/middleware/csrf/csrf.go + + // 1. check http method. + // - if it is a 'safe' method like GET, try and get `actualToken` from request. + // - if it is not a 'safe' method, try and get `actualToken` from header/cookies/httpForm + // - take the found token and try to get it from memory store. + // - if not found in memory store, delete the cookie & return an error. + + ctx := r.Context() + + switch r.Method { + // safe methods under rfc7231: https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 + case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: + break + default: + // For POST requests, we insist on a CSRF cookie, and in this way we can avoid all CSRF attacks, including login CSRF. + actualToken := getToken(r) + + ct, _, err := mime.ParseMediaType(r.Header.Get(ctHeader)) + if err == nil && + ct != formUrlEncoded && + ct != multiformData && + r.Header.Get(clientCookieHeader) == "" && + r.Header.Get(authorizationHeader) == "" && + r.Header.Get(proxyAuthorizationHeader) == "" { + // For POST requests that; + // - are not form data. + // - have no cookies. + // - are not using http authentication. + // then it is okay to not validate csrf for them. + // This is especially useful for REST API endpoints. + // see: https://github.com/komuw/ong/issues/76 + break + } - // 1. check http method. - // - if it is a 'safe' method like GET, try and get `actualToken` from request. - // - if it is not a 'safe' method, try and get `actualToken` from header/cookies/httpForm - // - take the found token and try to get it from memory store. - // - if not found in memory store, delete the cookie & return an error. + tokVal, errN := enc.DecryptDecode(actualToken) + if errN != nil { + // We should redirect the request since it means that the server is not aware of such a token. + // It shoulbe be a temporary redirect to the same page but this time send a http GET request. + // + // To test using curl, use; + // curl -kL \ + // -H "Content-Type: application/x-www-form-urlencoded" \ + // -d "firstName=john&csrftoken=bogusToken" https://localhost:65081/login/ + // Do NOT use `-X POST`, see: https://stackoverflow.com/a/41890653/2768067 + // + cookie.Delete(w, csrfCookieName, domain) + w.Header().Set(ongMiddlewareErrorHeader, errCsrfTokenNotFound.Error()) + http.Redirect( + w, + r, + r.URL.String(), + // http 303(StatusSeeOther) is guaranteed by the spec to always use http GET. + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/303 + http.StatusSeeOther, + ) + return + } - ctx := r.Context() + res := strings.Split(tokVal, sep) + if len(res) != 2 { + cookie.Delete(w, csrfCookieName, domain) + w.Header().Set(ongMiddlewareErrorHeader, errCsrfTokenWrongFormat.Error()) + http.Redirect(w, r, r.URL.String(), http.StatusSeeOther) + return + } - switch r.Method { - // safe methods under rfc7231: https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 - case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: - break - default: - // For POST requests, we insist on a CSRF cookie, and in this way we can avoid all CSRF attacks, including login CSRF. - actualToken := getToken(r) - - ct, _, err := mime.ParseMediaType(r.Header.Get(ctHeader)) - if err == nil && - ct != formUrlEncoded && - ct != multiformData && - r.Header.Get(clientCookieHeader) == "" && - r.Header.Get(authorizationHeader) == "" && - r.Header.Get(proxyAuthorizationHeader) == "" { - // For POST requests that; - // - are not form data. - // - have no cookies. - // - are not using http authentication. - // then it is okay to not validate csrf for them. - // This is especially useful for REST API endpoints. - // see: https://github.com/komuw/ong/issues/76 - break - } - - tokVal, errN := enc.DecryptDecode(actualToken) - if errN != nil { - // We should redirect the request since it means that the server is not aware of such a token. - // It shoulbe be a temporary redirect to the same page but this time send a http GET request. - // - // To test using curl, use; - // curl -kL \ - // -H "Content-Type: application/x-www-form-urlencoded" \ - // -d "firstName=john&csrftoken=bogusToken" https://localhost:65081/login/ - // Do NOT use `-X POST`, see: https://stackoverflow.com/a/41890653/2768067 - // - cookie.Delete(w, csrfCookieName, domain) - w.Header().Set(ongMiddlewareErrorHeader, errCsrfTokenNotFound.Error()) - http.Redirect( - w, - r, - r.URL.String(), - // http 303(StatusSeeOther) is guaranteed by the spec to always use http GET. - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/303 - http.StatusSeeOther, - ) - return - } - - res := strings.Split(tokVal, sep) - if len(res) != 2 { - cookie.Delete(w, csrfCookieName, domain) - w.Header().Set(ongMiddlewareErrorHeader, errCsrfTokenWrongFormat.Error()) - http.Redirect(w, r, r.URL.String(), http.StatusSeeOther) - return - } - - expires, errP := strconv.ParseInt(res[1], 10, 64) - if errP != nil { - cookie.Delete(w, csrfCookieName, domain) - w.Header().Set(ongMiddlewareErrorHeader, errP.Error()) - http.Redirect(w, r, r.URL.String(), http.StatusSeeOther) - return - } - - diff := expires - time.Now().UTC().Unix() - if diff <= 0 { - cookie.Delete(w, csrfCookieName, domain) - w.Header().Set(ongMiddlewareErrorHeader, errCsrfTokenExpired.Error()) - http.Redirect(w, r, r.URL.String(), http.StatusSeeOther) - return - } + expires, errP := strconv.ParseInt(res[1], 10, 64) + if errP != nil { + cookie.Delete(w, csrfCookieName, domain) + w.Header().Set(ongMiddlewareErrorHeader, errP.Error()) + http.Redirect(w, r, r.URL.String(), http.StatusSeeOther) + return } - // 2. generate a new token. - /* - We need to try and protect against BreachAttack[1]. See[2] for a refresher on how it works. - The mitigations against the attack in order of effectiveness are: - (a) Disabling HTTP compression - (b) Separating secrets from user input - (c) Randomizing secrets per request - (d) Masking secrets (effectively randomizing by XORing with a random secret per request) - (e) Protecting vulnerable pages with CSRF - (f) Length hiding (by adding random number of bytes to the responses) - (g) Rate-limiting the requests - Most csrf implementation use (d). Here, we'll use (c) - The [encrypt] func uses a random nonce everytime it is called. - - 1. http://breachattack.com/ - 2. https://security.stackexchange.com/a/172646 - */ - expires := strconv.FormatInt( - time.Now().UTC().Add(tokenMaxAge).Unix(), - 10, - ) - tokenToIssue := enc.EncryptEncode( - // see: https://github.com/golang/net/blob/v0.8.0/xsrftoken/xsrf.go#L33-L46 - fmt.Sprintf("%s%s%s", msgToEncrypt, sep, expires), - ) - - // 3. create cookie - cookie.Set( - w, - csrfCookieName, - tokenToIssue, - domain, - tokenMaxAge, - true, // accessible to javascript - ) - - // 4. set cookie header - w.Header().Set( - CsrfHeader, - tokenToIssue, - ) - - // 5. update Vary header. - w.Header().Add(varyHeader, clientCookieHeader) - - // 6. store tokenToIssue in context - r = r.WithContext(context.WithValue(ctx, csrfCtxKey, tokenToIssue)) - - wrappedHandler.ServeHTTP(w, r) - }, - ) + diff := expires - time.Now().UTC().Unix() + if diff <= 0 { + cookie.Delete(w, csrfCookieName, domain) + w.Header().Set(ongMiddlewareErrorHeader, errCsrfTokenExpired.Error()) + http.Redirect(w, r, r.URL.String(), http.StatusSeeOther) + return + } + } + + // 2. generate a new token. + /* + We need to try and protect against BreachAttack[1]. See[2] for a refresher on how it works. + The mitigations against the attack in order of effectiveness are: + (a) Disabling HTTP compression + (b) Separating secrets from user input + (c) Randomizing secrets per request + (d) Masking secrets (effectively randomizing by XORing with a random secret per request) + (e) Protecting vulnerable pages with CSRF + (f) Length hiding (by adding random number of bytes to the responses) + (g) Rate-limiting the requests + Most csrf implementation use (d). Here, we'll use (c) + The [encrypt] func uses a random nonce everytime it is called. + + 1. http://breachattack.com/ + 2. https://security.stackexchange.com/a/172646 + */ + expires := strconv.FormatInt( + time.Now().UTC().Add(tokenMaxAge).Unix(), + 10, + ) + tokenToIssue := enc.EncryptEncode( + // see: https://github.com/golang/net/blob/v0.8.0/xsrftoken/xsrf.go#L33-L46 + fmt.Sprintf("%s%s%s", msgToEncrypt, sep, expires), + ) + + // 3. create cookie + cookie.Set( + w, + csrfCookieName, + tokenToIssue, + domain, + tokenMaxAge, + true, // accessible to javascript + ) + + // 4. set cookie header + w.Header().Set( + CsrfHeader, + tokenToIssue, + ) + + // 5. update Vary header. + w.Header().Add(varyHeader, clientCookieHeader) + + // 6. store tokenToIssue in context + r = r.WithContext(context.WithValue(ctx, csrfCtxKey, tokenToIssue)) + + wrappedHandler.ServeHTTP(w, r) + } } // GetCsrfToken returns the csrf token that was set for the http request in question. diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 5a258f88..8033f01b 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -115,13 +115,11 @@ func TestGetToken(t *testing.T) { const tokenHeader = "CUSTOM-CSRF-TOKEN-TEST-HEADER" -func someCsrfHandler(msg string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(tokenHeader, GetCsrfToken(r.Context())) - fmt.Fprint(w, msg) - }, - ) +func someCsrfHandler(msg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(tokenHeader, GetCsrfToken(r.Context())) + fmt.Fprint(w, msg) + } } func TestCsrf(t *testing.T) { diff --git a/middleware/example_test.go b/middleware/example_test.go index df2fb1a1..23772d62 100644 --- a/middleware/example_test.go +++ b/middleware/example_test.go @@ -13,26 +13,22 @@ import ( "golang.org/x/exp/slog" ) -func loginHandler() http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - cspNonce := middleware.GetCspNonce(r.Context()) - _ = cspNonce // use CSP nonce +func loginHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + cspNonce := middleware.GetCspNonce(r.Context()) + _ = cspNonce // use CSP nonce - _, _ = fmt.Fprint(w, "welcome to your favorite website.") - }, - ) + _, _ = fmt.Fprint(w, "welcome to your favorite website.") + } } -func welcomeHandler() http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - csrfToken := middleware.GetCsrfToken(r.Context()) - _ = csrfToken // use CSRF token +func welcomeHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + csrfToken := middleware.GetCsrfToken(r.Context()) + _ = csrfToken // use CSRF token - _, _ = fmt.Fprint(w, "welcome.") - }, - ) + _, _ = fmt.Fprint(w, "welcome.") + } } func Example_getCspNonce() { diff --git a/middleware/fingerprint.go b/middleware/fingerprint.go index ae2b5735..dfdca004 100644 --- a/middleware/fingerprint.go +++ b/middleware/fingerprint.go @@ -10,30 +10,28 @@ import ( // fingerprint is a middleware that adds the client's TLS fingerprint to the request context. // The fingerprint can be fetched using [ClientFingerPrint] -func fingerprint(wrappedHandler http.Handler) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - fHash := "" +func fingerprint(wrappedHandler http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + fHash := "" - if vCtx := ctx.Value(octx.FingerPrintCtxKey); vCtx != nil { - if s, ok := vCtx.(*finger.Print); ok { - if hash := s.Hash.Load(); hash != nil { - fHash = *hash - } + if vCtx := ctx.Value(octx.FingerPrintCtxKey); vCtx != nil { + if s, ok := vCtx.(*finger.Print); ok { + if hash := s.Hash.Load(); hash != nil { + fHash = *hash } } + } - ctx = context.WithValue( - ctx, - octx.FingerPrintCtxKey, - fHash, - ) - r = r.WithContext(ctx) + ctx = context.WithValue( + ctx, + octx.FingerPrintCtxKey, + fHash, + ) + r = r.WithContext(ctx) - wrappedHandler.ServeHTTP(w, r) - }, - ) + wrappedHandler.ServeHTTP(w, r) + } } // ClientFingerPrint returns the [TLS fingerprint] of the client. diff --git a/middleware/gzip.go b/middleware/gzip.go index 4c3da9c1..96e61c18 100644 --- a/middleware/gzip.go +++ b/middleware/gzip.go @@ -29,37 +29,35 @@ const ( ) // gzip is a middleware that transparently gzips the http response body, for clients that support it. -func gzip(wrappedHandler http.Handler) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - w.Header().Add(varyHeader, acceptEncodingHeader) - - if !shouldGzipReq(r) { - wrappedHandler.ServeHTTP(w, r) - return - } - - gzipWriter, _ := stdGzip.NewWriterLevel(w, stdGzip.BestSpeed) - grw := &gzipRW{ - ResponseWriter: w, - // Bytes written during ServeHTTP are redirected to this gzip writer - // before being written to the underlying response. - gw: gzipWriter, - } - defer func() { _ = grw.Close() }() // errcheck made me do this. - - // We do not handle range requests when compression is used, as the - // range specified applies to the compressed data, not to the uncompressed one. - // see: https://github.com/nytimes/gziphandler/issues/83 - r.Header.Del(rangeHeader) - - // todo: we could detect if `w` is a `http.CloseNotifier` and do something special here. - // see: https://github.com/klauspost/compress/blob/4a97174a615ed745c450077edf0e1f7e97aabd58/gzhttp/compress.go#L383-L385 - // However `http.CloseNotifier` has been deprecated sinc Go v1.11(year 2018) - - wrappedHandler.ServeHTTP(grw, r) - }, - ) +func gzip(wrappedHandler http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Add(varyHeader, acceptEncodingHeader) + + if !shouldGzipReq(r) { + wrappedHandler.ServeHTTP(w, r) + return + } + + gzipWriter, _ := stdGzip.NewWriterLevel(w, stdGzip.BestSpeed) + grw := &gzipRW{ + ResponseWriter: w, + // Bytes written during ServeHTTP are redirected to this gzip writer + // before being written to the underlying response. + gw: gzipWriter, + } + defer func() { _ = grw.Close() }() // errcheck made me do this. + + // We do not handle range requests when compression is used, as the + // range specified applies to the compressed data, not to the uncompressed one. + // see: https://github.com/nytimes/gziphandler/issues/83 + r.Header.Del(rangeHeader) + + // todo: we could detect if `w` is a `http.CloseNotifier` and do something special here. + // see: https://github.com/klauspost/compress/blob/4a97174a615ed745c450077edf0e1f7e97aabd58/gzhttp/compress.go#L383-L385 + // However `http.CloseNotifier` has been deprecated sinc Go v1.11(year 2018) + + wrappedHandler.ServeHTTP(grw, r) + } } // gzipRW provides an http.ResponseWriter interface, which gzips diff --git a/middleware/gzip_test.go b/middleware/gzip_test.go index 6c28f58b..93805d3f 100644 --- a/middleware/gzip_test.go +++ b/middleware/gzip_test.go @@ -19,35 +19,31 @@ import ( tmthrgd "github.com/tmthrgd/gziphandler" ) -func someGzipHandler(msg string) http.Handler { +func someGzipHandler(msg string) http.HandlerFunc { // bound stack growth. // see: https://github.com/komuw/ong/issues/54 fMsg := strings.Repeat(msg, 3) - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, fMsg) - }, - ) + return func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, fMsg) + } } -func handlerImplementingFlush(msg string) http.Handler { +func handlerImplementingFlush(msg string) http.HandlerFunc { iterations := 3 - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if f, ok := w.(http.Flusher); ok { - msg = "FlusherCalled::" + strings.Repeat(msg, iterations) - fmt.Fprint(w, msg) - - f.Flush() - } else { - msg = strings.Repeat(msg, iterations) - fmt.Fprint(w, msg) - } - }, - ) + return func(w http.ResponseWriter, r *http.Request) { + if f, ok := w.(http.Flusher); ok { + msg = "FlusherCalled::" + strings.Repeat(msg, iterations) + fmt.Fprint(w, msg) + + f.Flush() + } else { + msg = strings.Repeat(msg, iterations) + fmt.Fprint(w, msg) + } + } } -func login() http.Handler { +func login() http.HandlerFunc { tmpl, err := template.New("myTpl").Parse(` @@ -74,31 +70,29 @@ func login() http.Handler { panic(err) } - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - data := struct { - CsrfTokenName string - CsrfTokenValue string - CspNonceValue string - }{ - CsrfTokenName: CsrfTokenFormName, - CsrfTokenValue: GetCsrfToken(r.Context()), - CspNonceValue: GetCspNonce(r.Context()), - } - if err = tmpl.Execute(w, data); err != nil { - panic(err) - } - return + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + data := struct { + CsrfTokenName string + CsrfTokenValue string + CspNonceValue string + }{ + CsrfTokenName: CsrfTokenFormName, + CsrfTokenValue: GetCsrfToken(r.Context()), + CspNonceValue: GetCspNonce(r.Context()), } - - if err = r.ParseForm(); err != nil { + if err = tmpl.Execute(w, data); err != nil { panic(err) } + return + } + + if err = r.ParseForm(); err != nil { + panic(err) + } - _, _ = fmt.Fprintf(w, "you have submitted: %s", r.Form) - }, - ) + _, _ = fmt.Fprintf(w, "you have submitted: %s", r.Form) + } } func readBody(t *testing.T, res *http.Response) (strBody string) { @@ -308,17 +302,15 @@ BenchmarkNytimesGzip-8 4 315_386_476 ns/op 3_813_934 B/op 116 BenchmarkTmthrgdGzip-8 4 319_786_254 ns/op 3_527_012 B/op 116 allocs/op */ -func gzipBenchmarkHandler() http.Handler { +func gzipBenchmarkHandler() http.HandlerFunc { bin, err := os.ReadFile("testdata/benchmark.json") if err != nil { panic(err) } - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, bin) - }, - ) + return func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, bin) + } } var result int //nolint:gochecknoglobals diff --git a/middleware/loadshed.go b/middleware/loadshed.go index eca702e2..f6f4215e 100644 --- a/middleware/loadshed.go +++ b/middleware/loadshed.go @@ -44,53 +44,51 @@ const ( ) // loadShedder is a middleware that sheds load based on http response latencies. -func loadShedder(wrappedHandler http.Handler) http.Handler { +func loadShedder(wrappedHandler http.Handler) http.HandlerFunc { // lq should not be a global variable, we want it to be per handler. // This is because different handlers(URIs) could have different latencies and we want each to be loadshed independently. lq := newLatencyQueue() loadShedCheckStart := time.Now().UTC() - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - startReq := time.Now().UTC() - defer func() { - endReq := time.Now().UTC() - durReq := endReq.Sub(startReq) - lq.add(durReq) - - // we do not want to reduce size of `lq` before a period `> samplingPeriod` otherwise `lq.getP99()` will always return zero. - if endReq.Sub(loadShedCheckStart) > resizePeriod { - // lets reduce the size of latencyQueue - lq.reSize() - loadShedCheckStart = endReq - } - }() - - sendProbe := false - { - // Even if the server is overloaded, we want to send a percentage of the requests through. - // These requests act as a probe. If the server eventually recovers, - // these requests will re-populate latencyQueue(`lq`) with lower latencies and thus end the load-shed. - sendProbe = mathRand.Intn(100) == 1 // let 1% of requests through. NB: Intn(100) is `0-99` ie, 100 is not included. + return func(w http.ResponseWriter, r *http.Request) { + startReq := time.Now().UTC() + defer func() { + endReq := time.Now().UTC() + durReq := endReq.Sub(startReq) + lq.add(durReq) + + // we do not want to reduce size of `lq` before a period `> samplingPeriod` otherwise `lq.getP99()` will always return zero. + if endReq.Sub(loadShedCheckStart) > resizePeriod { + // lets reduce the size of latencyQueue + lq.reSize() + loadShedCheckStart = endReq } - - p99 := lq.getP99(minSampleSize) - if p99.Milliseconds() > breachLatency.Milliseconds() && !sendProbe { - // drop request - err := fmt.Errorf("ong/middleware: server is overloaded, retry after %s", retryAfter) - w.Header().Set(ongMiddlewareErrorHeader, fmt.Sprintf("%s. p99latency: %s. breachLatency: %s", err.Error(), p99, breachLatency)) - w.Header().Set(retryAfterHeader, fmt.Sprintf("%d", int(retryAfter.Seconds()))) // header should be in seconds(decimal-integer). - http.Error( - w, - err.Error(), - http.StatusServiceUnavailable, - ) - return - } - - wrappedHandler.ServeHTTP(w, r) - }, - ) + }() + + sendProbe := false + { + // Even if the server is overloaded, we want to send a percentage of the requests through. + // These requests act as a probe. If the server eventually recovers, + // these requests will re-populate latencyQueue(`lq`) with lower latencies and thus end the load-shed. + sendProbe = mathRand.Intn(100) == 1 // let 1% of requests through. NB: Intn(100) is `0-99` ie, 100 is not included. + } + + p99 := lq.getP99(minSampleSize) + if p99.Milliseconds() > breachLatency.Milliseconds() && !sendProbe { + // drop request + err := fmt.Errorf("ong/middleware: server is overloaded, retry after %s", retryAfter) + w.Header().Set(ongMiddlewareErrorHeader, fmt.Sprintf("%s. p99latency: %s. breachLatency: %s", err.Error(), p99, breachLatency)) + w.Header().Set(retryAfterHeader, fmt.Sprintf("%d", int(retryAfter.Seconds()))) // header should be in seconds(decimal-integer). + http.Error( + w, + err.Error(), + http.StatusServiceUnavailable, + ) + return + } + + wrappedHandler.ServeHTTP(w, r) + } } type latencyQueue struct { diff --git a/middleware/loadshed_test.go b/middleware/loadshed_test.go index 2d4b127a..9fb1af6f 100644 --- a/middleware/loadshed_test.go +++ b/middleware/loadshed_test.go @@ -16,18 +16,16 @@ import ( const loadShedderTestHeader = "LoadShedderTestHeader" -func someLoadShedderHandler(msg string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - lat := r.Header.Get(loadShedderTestHeader) - latency, err := strconv.Atoi(lat) - if err != nil { - panic(err) - } - time.Sleep(time.Duration(latency) * time.Millisecond) - fmt.Fprint(w, msg) - }, - ) +func someLoadShedderHandler(msg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + lat := r.Header.Get(loadShedderTestHeader) + latency, err := strconv.Atoi(lat) + if err != nil { + panic(err) + } + time.Sleep(time.Duration(latency) * time.Millisecond) + fmt.Fprint(w, msg) + } } func TestLoadShedder(t *testing.T) { @@ -242,14 +240,12 @@ func TestLatencyQueue(t *testing.T) { }) } -func loadShedderBenchmarkHandler() http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - latency := time.Duration(rand.Intn(100)+1) * time.Millisecond - time.Sleep(latency) - fmt.Fprint(w, "hey") - }, - ) +func loadShedderBenchmarkHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + latency := time.Duration(rand.Intn(100)+1) * time.Millisecond + time.Sleep(latency) + fmt.Fprint(w, "hey") + } } func BenchmarkLoadShedder(b *testing.B) { diff --git a/middleware/log.go b/middleware/log.go index f9bfce94..8e6b693c 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -15,15 +15,14 @@ import ( ) // logger is a middleware that logs http requests and responses using [log.Logger]. -func logger(wrappedHandler http.Handler, l *slog.Logger) http.Handler { +func logger(wrappedHandler http.Handler, l *slog.Logger) http.HandlerFunc { // We pass the logger as an argument so that the middleware can share the same logger as the app. // That way, if the app logs an error, the middleware logs are also flushed. // This makes debugging easier for developers. // // However, each request should get its own context. That's why we call `logger.WithCtx` for every request. - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { start := time.Now() lrw := &logRW{ ResponseWriter: w, @@ -72,9 +71,7 @@ func logger(wrappedHandler http.Handler, l *slog.Logger) http.Handler { }() wrappedHandler.ServeHTTP(lrw, r) - }, - ) -} + } // logRW provides an http.ResponseWriter interface, which logs requests/responses. type logRW struct { diff --git a/middleware/log_test.go b/middleware/log_test.go index af9653e0..e8213bdc 100644 --- a/middleware/log_test.go +++ b/middleware/log_test.go @@ -23,24 +23,22 @@ const ( someLatencyMS = 3 ) -func someLogHandler(successMsg string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - // sleep so that the log middleware has some useful duration metrics to report. - time.Sleep(someLatencyMS * time.Millisecond) - if r.Header.Get(someLogHandlerHeader) != "" { - http.Error( - w, - r.Header.Get(someLogHandlerHeader), - http.StatusInternalServerError, - ) - return - } else { - fmt.Fprint(w, successMsg) - return - } - }, - ) +func someLogHandler(successMsg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // sleep so that the log middleware has some useful duration metrics to report. + time.Sleep(someLatencyMS * time.Millisecond) + if r.Header.Get(someLogHandlerHeader) != "" { + http.Error( + w, + r.Header.Get(someLogHandlerHeader), + http.StatusInternalServerError, + ) + return + } else { + fmt.Fprint(w, successMsg) + return + } + } } func TestLogMiddleware(t *testing.T) { diff --git a/middleware/middleware.go b/middleware/middleware.go index e7b8d353..fc90b780 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -107,7 +107,7 @@ func WithOpts( func allDefaultMiddlewares( wrappedHandler http.Handler, o Opts, -) http.Handler { +) http.HandlerFunc { domain := o.domain httpsPort := o.httpsPort allowedOrigins := o.allowedOrigins @@ -200,14 +200,14 @@ func allDefaultMiddlewares( // All is a middleware that allows all http methods. // // See the package documentation for the additional functionality provided by this middleware. -func All(wrappedHandler http.Handler, o Opts) http.Handler { +func All(wrappedHandler http.Handler, o Opts) http.HandlerFunc { return allDefaultMiddlewares( all(wrappedHandler), o, ) } -func all(wrappedHandler http.Handler) http.Handler { +func all(wrappedHandler http.Handler) http.HandlerFunc { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { wrappedHandler.ServeHTTP(w, r) @@ -218,14 +218,14 @@ func all(wrappedHandler http.Handler) http.Handler { // Get is a middleware that only allows http GET requests and http OPTIONS requests. // // See the package documentation for the additional functionality provided by this middleware. -func Get(wrappedHandler http.Handler, o Opts) http.Handler { +func Get(wrappedHandler http.Handler, o Opts) http.HandlerFunc { return allDefaultMiddlewares( get(wrappedHandler), o, ) } -func get(wrappedHandler http.Handler) http.Handler { +func get(wrappedHandler http.Handler) http.HandlerFunc { msg := "http method: %s not allowed. only allows http GET" return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { @@ -250,14 +250,14 @@ func get(wrappedHandler http.Handler) http.Handler { // Post is a middleware that only allows http POST requests and http OPTIONS requests. // // See the package documentation for the additional functionality provided by this middleware. -func Post(wrappedHandler http.Handler, o Opts) http.Handler { +func Post(wrappedHandler http.Handler, o Opts) http.HandlerFunc { return allDefaultMiddlewares( post(wrappedHandler), o, ) } -func post(wrappedHandler http.Handler) http.Handler { +func post(wrappedHandler http.Handler) http.HandlerFunc { msg := "http method: %s not allowed. only allows http POST" return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { @@ -280,14 +280,14 @@ func post(wrappedHandler http.Handler) http.Handler { // Head is a middleware that only allows http HEAD requests and http OPTIONS requests. // // See the package documentation for the additional functionality provided by this middleware. -func Head(wrappedHandler http.Handler, o Opts) http.Handler { +func Head(wrappedHandler http.Handler, o Opts) http.HandlerFunc { return allDefaultMiddlewares( head(wrappedHandler), o, ) } -func head(wrappedHandler http.Handler) http.Handler { +func head(wrappedHandler http.Handler) http.HandlerFunc { msg := "http method: %s not allowed. only allows http HEAD" return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { @@ -310,14 +310,14 @@ func head(wrappedHandler http.Handler) http.Handler { // Put is a middleware that only allows http PUT requests and http OPTIONS requests. // // See the package documentation for the additional functionality provided by this middleware. -func Put(wrappedHandler http.Handler, o Opts) http.Handler { +func Put(wrappedHandler http.Handler, o Opts) http.HandlerFunc { return allDefaultMiddlewares( put(wrappedHandler), o, ) } -func put(wrappedHandler http.Handler) http.Handler { +func put(wrappedHandler http.Handler) http.HandlerFunc { msg := "http method: %s not allowed. only allows http PUT" return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { @@ -340,7 +340,7 @@ func put(wrappedHandler http.Handler) http.Handler { // Delete is a middleware that only allows http DELETE requests and http OPTIONS requests. // // See the package documentation for the additional functionality provided by this middleware. -func Delete(wrappedHandler http.Handler, o Opts) http.Handler { +func Delete(wrappedHandler http.Handler, o Opts) http.HandlerFunc { return allDefaultMiddlewares( deleteH(wrappedHandler), o, @@ -348,7 +348,7 @@ func Delete(wrappedHandler http.Handler, o Opts) http.Handler { } // this is not called `delete` since that is a Go builtin func for deleting from maps. -func deleteH(wrappedHandler http.Handler) http.Handler { +func deleteH(wrappedHandler http.Handler) http.HandlerFunc { msg := "http method: %s not allowed. only allows http DELETE" return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 8ecb4acb..46c5327c 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -20,23 +20,21 @@ import ( "go.uber.org/goleak" ) -func someMiddlewareTestHandler(msg string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodPost { - b, e := io.ReadAll(r.Body) - if e != nil { - panic(e) - } - if len(b) > 1 { - _, _ = w.Write(b) - return - } +func someMiddlewareTestHandler(msg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + b, e := io.ReadAll(r.Body) + if e != nil { + panic(e) } + if len(b) > 1 { + _, _ = w.Write(b) + return + } + } - fmt.Fprint(w, msg) - }, - ) + fmt.Fprint(w, msg) + } } func TestMain(m *testing.M) { @@ -357,15 +355,13 @@ func TestMiddlewareServer(t *testing.T) { }) } -func someBenchmarkAllMiddlewaresHandler() http.Handler { +func someBenchmarkAllMiddlewaresHandler() http.HandlerFunc { // bound stack growth. // see: https://github.com/komuw/ong/issues/54 msg := strings.Repeat("hello world", 2) - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, msg) - }, - ) + return func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, msg) + } } var resultBenchmarkAllMiddlewares int //nolint:gochecknoglobals diff --git a/middleware/ratelimiter.go b/middleware/ratelimiter.go index ad42e595..b1ca3f9d 100644 --- a/middleware/ratelimiter.go +++ b/middleware/ratelimiter.go @@ -25,36 +25,34 @@ import ( var rateLimiterSendRate = 100.00 //nolint:gochecknoglobals // rateLimiter is a middleware that limits requests by IP address. -func rateLimiter(wrappedHandler http.Handler) http.Handler { +func rateLimiter(wrappedHandler http.Handler) http.HandlerFunc { rl := newRl() const retryAfter = 15 * time.Minute - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - rl.reSize() - - host := ClientIP(r) - tb := rl.get(host, rateLimiterSendRate) - - if !tb.allow() { - err := fmt.Errorf("ong/middleware: rate limited, retry after %s", retryAfter) - w.Header().Set(ongMiddlewareErrorHeader, err.Error()) - w.Header().Set(retryAfterHeader, fmt.Sprintf("%d", int(retryAfter.Seconds()))) // header should be in seconds(decimal-integer). - http.Error( - w, - err.Error(), - http.StatusTooManyRequests, - ) - return - } - - // todo: maybe also limit max body size using something like `http.MaxBytesHandler` - // todo: also maybe add another limiter for IP subnet. - // see limitation: https://github.com/komuw/ong/issues/17#issuecomment-1114551281 - - wrappedHandler.ServeHTTP(w, r) - }, - ) + return func(w http.ResponseWriter, r *http.Request) { + rl.reSize() + + host := ClientIP(r) + tb := rl.get(host, rateLimiterSendRate) + + if !tb.allow() { + err := fmt.Errorf("ong/middleware: rate limited, retry after %s", retryAfter) + w.Header().Set(ongMiddlewareErrorHeader, err.Error()) + w.Header().Set(retryAfterHeader, fmt.Sprintf("%d", int(retryAfter.Seconds()))) // header should be in seconds(decimal-integer). + http.Error( + w, + err.Error(), + http.StatusTooManyRequests, + ) + return + } + + // todo: maybe also limit max body size using something like `http.MaxBytesHandler` + // todo: also maybe add another limiter for IP subnet. + // see limitation: https://github.com/komuw/ong/issues/17#issuecomment-1114551281 + + wrappedHandler.ServeHTTP(w, r) + } } // rl is a ratelimiter per IP address. diff --git a/middleware/ratelimiter_test.go b/middleware/ratelimiter_test.go index d15c80c7..588228c6 100644 --- a/middleware/ratelimiter_test.go +++ b/middleware/ratelimiter_test.go @@ -13,12 +13,10 @@ import ( "golang.org/x/exp/slices" ) -func someRateLimiterHandler(msg string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, msg) - }, - ) +func someRateLimiterHandler(msg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, msg) + } } func TestRateLimiter(t *testing.T) { diff --git a/middleware/recoverer.go b/middleware/recoverer.go index 43a51154..e4a2af3d 100644 --- a/middleware/recoverer.go +++ b/middleware/recoverer.go @@ -15,51 +15,49 @@ import ( // recoverer is a middleware that recovers from panics in wrappedHandler. // When/if a panic occurs, it logs the stack trace and returns an InternalServerError response. -func recoverer(wrappedHandler http.Handler, l *slog.Logger) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - defer func() { - errR := recover() - if errR != nil { - reqL := log.WithID(r.Context(), l) - - code := http.StatusInternalServerError - status := http.StatusText(code) - - msg := "http_server" - flds := []any{ - "error", fmt.Sprint(errR), - "clientIP", ClientIP(r), - "clientFingerPrint", ClientFingerPrint(r), - "method", r.Method, - "path", r.URL.Redacted(), - "code", code, - "status", status, - } - if ongError := w.Header().Get(ongMiddlewareErrorHeader); ongError != "" { - extra := []any{"ongError", ongError} - flds = append(flds, extra...) - } - w.Header().Del(ongMiddlewareErrorHeader) // remove header so that users dont see it. - - if e, ok := errR.(error); ok { - extra := []any{"err", errors.Wrap(e)} // wrap with ong/errors so that the log will have a stacktrace. - flds = append(flds, extra...) - reqL.Error(msg, flds...) - } else { - reqL.Error(msg, flds...) - } - - // respond. - http.Error( - w, - status, - code, - ) +func recoverer(wrappedHandler http.Handler, l *slog.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + defer func() { + errR := recover() + if errR != nil { + reqL := log.WithID(r.Context(), l) + + code := http.StatusInternalServerError + status := http.StatusText(code) + + msg := "http_server" + flds := []any{ + "error", fmt.Sprint(errR), + "clientIP", ClientIP(r), + "clientFingerPrint", ClientFingerPrint(r), + "method", r.Method, + "path", r.URL.Redacted(), + "code", code, + "status", status, + } + if ongError := w.Header().Get(ongMiddlewareErrorHeader); ongError != "" { + extra := []any{"ongError", ongError} + flds = append(flds, extra...) + } + w.Header().Del(ongMiddlewareErrorHeader) // remove header so that users dont see it. + + if e, ok := errR.(error); ok { + extra := []any{"err", errors.Wrap(e)} // wrap with ong/errors so that the log will have a stacktrace. + flds = append(flds, extra...) + reqL.Error(msg, flds...) + } else { + reqL.Error(msg, flds...) } - }() - wrappedHandler.ServeHTTP(w, r) - }, - ) + // respond. + http.Error( + w, + status, + code, + ) + } + }() + + wrappedHandler.ServeHTTP(w, r) + } } diff --git a/middleware/recoverer_test.go b/middleware/recoverer_test.go index b1f713b3..028c09ad 100644 --- a/middleware/recoverer_test.go +++ b/middleware/recoverer_test.go @@ -19,34 +19,30 @@ import ( "golang.org/x/exp/slog" ) -func handlerThatPanics(msg string, shouldPanic bool, err error) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - x := 3 + 9 - _ = x - if shouldPanic { - panic(msg) - } - if err != nil { - panic(err) - } - - fmt.Fprint(w, msg) - }, - ) +func handlerThatPanics(msg string, shouldPanic bool, err error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + x := 3 + 9 + _ = x + if shouldPanic { + panic(msg) + } + if err != nil { + panic(err) + } + + fmt.Fprint(w, msg) + } } -func anotherHandlerThatPanics() http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - _ = 90 - someSlice := []string{"zero", "one", "two"} - _ = "kilo" - _ = someSlice[16] // panic - - fmt.Fprint(w, "anotherHandlerThatPanics") - }, - ) +func anotherHandlerThatPanics() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + _ = 90 + someSlice := []string{"zero", "one", "two"} + _ = "kilo" + _ = someSlice[16] // panic + + fmt.Fprint(w, "anotherHandlerThatPanics") + } } func TestPanic(t *testing.T) { diff --git a/middleware/redirect.go b/middleware/redirect.go index 7c46f99f..1072ba57 100644 --- a/middleware/redirect.go +++ b/middleware/redirect.go @@ -12,51 +12,49 @@ import ( // // domain is the domain name of your website. // httpsPort is the tls port where http requests will be redirected to. -func httpsRedirector(wrappedHandler http.Handler, httpsPort uint16, domain string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - isTls := strings.EqualFold(r.URL.Scheme, "https") || r.TLS != nil - if !isTls { - url := r.URL - url.Scheme = "https" - url.Host = joinHostPort(domain, fmt.Sprint(httpsPort)) - path := url.String() +func httpsRedirector(wrappedHandler http.Handler, httpsPort uint16, domain string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + isTls := strings.EqualFold(r.URL.Scheme, "https") || r.TLS != nil + if !isTls { + url := r.URL + url.Scheme = "https" + url.Host = joinHostPort(domain, fmt.Sprint(httpsPort)) + path := url.String() - http.Redirect(w, r, path, http.StatusPermanentRedirect) - return - } - - // A Host header field must be sent in all HTTP/1.1 request messages. - // Thus we expect `r.Host[0]` to always have a value. - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host - isHostBareIP := unicode.IsDigit(rune(r.Host[0])) - if isHostBareIP { - /* - the request has tried to access us via an IP address, redirect them to our domain. + http.Redirect(w, r, path, http.StatusPermanentRedirect) + return + } - curl -vkIL 172.217.170.174 #google - HEAD / HTTP/1.1 - Host: 172.217.170.174 + // A Host header field must be sent in all HTTP/1.1 request messages. + // Thus we expect `r.Host[0]` to always have a value. + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host + isHostBareIP := unicode.IsDigit(rune(r.Host[0])) + if isHostBareIP { + /* + the request has tried to access us via an IP address, redirect them to our domain. - HTTP/1.1 301 Moved Permanently - Location: http://www.google.com/ - */ - url := r.URL - url.Scheme = "https" - _, port, err := net.SplitHostPort(r.Host) - if err != nil { - port = fmt.Sprint(httpsPort) - } - url.Host = joinHostPort(domain, port) - path := url.String() + curl -vkIL 172.217.170.174 #google + HEAD / HTTP/1.1 + Host: 172.217.170.174 - http.Redirect(w, r, path, http.StatusPermanentRedirect) - return + HTTP/1.1 301 Moved Permanently + Location: http://www.google.com/ + */ + url := r.URL + url.Scheme = "https" + _, port, err := net.SplitHostPort(r.Host) + if err != nil { + port = fmt.Sprint(httpsPort) } + url.Host = joinHostPort(domain, port) + path := url.String() - wrappedHandler.ServeHTTP(w, r) - }, - ) + http.Redirect(w, r, path, http.StatusPermanentRedirect) + return + } + + wrappedHandler.ServeHTTP(w, r) + } } // joinHostPort is like `net.JoinHostPort` except suited for this package. diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go index f3c849ac..0e89b1aa 100644 --- a/middleware/redirect_test.go +++ b/middleware/redirect_test.go @@ -13,21 +13,19 @@ import ( "go.akshayshah.org/attest" ) -func someHttpsRedirectorHandler(msg string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodPost { - p := make([]byte, 16) - _, err := r.Body.Read(p) - if err == nil || err == io.EOF { - _, _ = w.Write(p) - return - } +func someHttpsRedirectorHandler(msg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + p := make([]byte, 16) + _, err := r.Body.Read(p) + if err == nil || err == io.EOF { + _, _ = w.Write(p) + return } + } - fmt.Fprint(w, msg) - }, - ) + fmt.Fprint(w, msg) + } } const locationHeader = "Location" diff --git a/middleware/reload_protect.go b/middleware/reload_protect.go index e2597fa9..4c5b513d 100644 --- a/middleware/reload_protect.go +++ b/middleware/reload_protect.go @@ -16,7 +16,7 @@ const reloadProtectCookiePrefix = "ong_form_reload_protect" // reloadProtector is a middleware that attempts to provides protection against a form re-submission when a user reloads/refreshes an already submitted web page/form. // // If such a situation is detected; this middleware will issue a http GET redirect to the same url. -func reloadProtector(wrappedHandler http.Handler, domain string) http.Handler { +func reloadProtector(wrappedHandler http.Handler, domain string) http.HandlerFunc { safeMethods := []string{ // safe methods under rfc7231: https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 http.MethodGet, @@ -24,49 +24,47 @@ func reloadProtector(wrappedHandler http.Handler, domain string) http.Handler { http.MethodOptions, http.MethodTrace, } - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - // It is possible for one to send a form without having added the requiste form http header. - if !slices.Contains(safeMethods, r.Method) { - // This could be a http POST/DELETE/etc + return func(w http.ResponseWriter, r *http.Request) { + // It is possible for one to send a form without having added the requiste form http header. + if !slices.Contains(safeMethods, r.Method) { + // This could be a http POST/DELETE/etc - theCookie := fmt.Sprintf("%s-%s", - reloadProtectCookiePrefix, - strings.ReplaceAll(r.URL.EscapedPath(), "/", ""), - ) + theCookie := fmt.Sprintf("%s-%s", + reloadProtectCookiePrefix, + strings.ReplaceAll(r.URL.EscapedPath(), "/", ""), + ) - // todo: should we check if gotCookie.MaxAge > 0 - gotCookie, err := r.Cookie(theCookie) - if err == nil && gotCookie != nil { - // It means that the form had been submitted before. + // todo: should we check if gotCookie.MaxAge > 0 + gotCookie, err := r.Cookie(theCookie) + if err == nil && gotCookie != nil { + // It means that the form had been submitted before. - cookie.Delete( - w, - theCookie, - domain, - ) - http.Redirect( - w, - r, - r.URL.String(), - // http 303(StatusSeeOther) is guaranteed by the spec to always use http GET. - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/303 - http.StatusSeeOther, - ) - return - } else { - cookie.Set( - w, - theCookie, - "YES", - domain, - 1*time.Hour, - false, - ) - } + cookie.Delete( + w, + theCookie, + domain, + ) + http.Redirect( + w, + r, + r.URL.String(), + // http 303(StatusSeeOther) is guaranteed by the spec to always use http GET. + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/303 + http.StatusSeeOther, + ) + return + } else { + cookie.Set( + w, + theCookie, + "YES", + domain, + 1*time.Hour, + false, + ) } + } - wrappedHandler.ServeHTTP(w, r) - }, - ) + wrappedHandler.ServeHTTP(w, r) + } } diff --git a/middleware/reload_protect_test.go b/middleware/reload_protect_test.go index 2a359ca6..bef3918f 100644 --- a/middleware/reload_protect_test.go +++ b/middleware/reload_protect_test.go @@ -11,32 +11,30 @@ import ( "go.akshayshah.org/attest" ) -func someReloadProtectorHandler(msg, expectedFormName, expectedFormValue string) http.Handler { +func someReloadProtectorHandler(msg, expectedFormName, expectedFormValue string) http.HandlerFunc { // count is state that is affected by form submission. // eg, when a form is submitted; we create a new user. count := 0 - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodPost { - err := r.ParseForm() - if err != nil { - panic(err) - } - val := r.Form.Get(expectedFormName) - if val != expectedFormValue { - panic(fmt.Sprintf("expected = %v got = %v", expectedFormValue, val)) - } - - count = count + 1 - if count > 1 { - // form re-submission happened - panic("form re-submission happened") - } + return func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + err := r.ParseForm() + if err != nil { + panic(err) } + val := r.Form.Get(expectedFormName) + if val != expectedFormValue { + panic(fmt.Sprintf("expected = %v got = %v", expectedFormValue, val)) + } + + count = count + 1 + if count > 1 { + // form re-submission happened + panic("form re-submission happened") + } + } - fmt.Fprint(w, msg) - }, - ) + fmt.Fprint(w, msg) + } } func TestReloadProtector(t *testing.T) { diff --git a/middleware/security.go b/middleware/security.go index 742d6c00..84afa2a2 100644 --- a/middleware/security.go +++ b/middleware/security.go @@ -35,82 +35,80 @@ const ( // securityHeaders is a middleware that adds some important HTTP security headers and assigns them sensible default values. // // Some of the headers set are Permissions-Policy, Content-securityHeaders-Policy, X-Content-Type-Options, X-Frame-Options, Cross-Origin-Resource-Policy, Cross-Origin-Opener-Policy, Referrer-Policy & Strict-Transport-securityHeaders -func securityHeaders(wrappedHandler http.Handler, domain string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - w.Header().Set( - permissionsPolicyHeader, - // flocOptOut disables floc which is otherwise ON by default - // see: https://github.com/WICG/floc#opting-out-of-computation - "interest-cohort=()", - ) - - // The nonce should be generated per request & propagated to the html of the page. - // The nonce can be fetched in middlewares using the GetCspNonce func +func securityHeaders(wrappedHandler http.Handler, domain string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + w.Header().Set( + permissionsPolicyHeader, + // flocOptOut disables floc which is otherwise ON by default + // see: https://github.com/WICG/floc#opting-out-of-computation + "interest-cohort=()", + ) + + // The nonce should be generated per request & propagated to the html of the page. + // The nonce can be fetched in middlewares using the GetCspNonce func + // + // eg; + // + nonce := id.Random(cspBytesTokenLength) + r = r.WithContext(context.WithValue(ctx, cspCtxKey, nonce)) + w.Header().Set( + cspHeader, + // - https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP + // - https://web.dev/security-headers/ + // - https://stackoverflow.com/a/66955464/2768067 + // - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/script-src + // - https://web.dev/security-headers/#tt // - // eg; - // - nonce := id.Random(cspBytesTokenLength) - r = r.WithContext(context.WithValue(ctx, cspCtxKey, nonce)) + // content is only permitted from: + // - the document's origin(and subdomains) + // - images may load from anywhere + // - media is allowed from domain(and its subdomains) + // - executable scripts is only allowed from self(& subdomains). + // - DOM xss(eg setting innerHtml) is blocked by require-trusted-types. + getCsp(domain, nonce), + ) + + w.Header().Set( + xContentOptionsHeader, + "nosniff", + ) + + w.Header().Set( + xFrameHeader, + "DENY", + ) + + w.Header().Set( + corpHeader, + "same-site", + ) + + w.Header().Set( + coopHeader, + "same-origin", + ) + + w.Header().Set( + referrerHeader, + // - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy + "strict-origin-when-cross-origin", + ) + + if r.TLS != nil { w.Header().Set( - cspHeader, - // - https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP - // - https://web.dev/security-headers/ - // - https://stackoverflow.com/a/66955464/2768067 - // - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/script-src - // - https://web.dev/security-headers/#tt - // - // content is only permitted from: - // - the document's origin(and subdomains) - // - images may load from anywhere - // - media is allowed from domain(and its subdomains) - // - executable scripts is only allowed from self(& subdomains). - // - DOM xss(eg setting innerHtml) is blocked by require-trusted-types. - getCsp(domain, nonce), - ) - - w.Header().Set( - xContentOptionsHeader, - "nosniff", - ) - - w.Header().Set( - xFrameHeader, - "DENY", - ) - - w.Header().Set( - corpHeader, - "same-site", - ) - - w.Header().Set( - coopHeader, - "same-origin", - ) - - w.Header().Set( - referrerHeader, - // - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy - "strict-origin-when-cross-origin", + stsHeader, + // - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Strict-Transport-Security + // A max-age(in seconds) of 2yrs is recommended + getSts(15*24*time.Hour), // 15 days ) + } - if r.TLS != nil { - w.Header().Set( - stsHeader, - // - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Strict-Transport-Security - // A max-age(in seconds) of 2yrs is recommended - getSts(15*24*time.Hour), // 15 days - ) - } - - wrappedHandler.ServeHTTP(w, r) - }, - ) + wrappedHandler.ServeHTTP(w, r) + } } // GetCspNonce returns the Content-Security-Policy nonce that was set for the http request in question. diff --git a/middleware/security_test.go b/middleware/security_test.go index 0f42f6fb..6bad00a3 100644 --- a/middleware/security_test.go +++ b/middleware/security_test.go @@ -16,13 +16,11 @@ import ( const nonceHeader = "CUSTOM-CSP-NONCE-TEST-HEADER" // echoHandler echos back in the response, the msg that was passed in. -func echoHandler(msg string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(nonceHeader, GetCspNonce(r.Context())) - fmt.Fprint(w, msg) - }, - ) +func echoHandler(msg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(nonceHeader, GetCspNonce(r.Context())) + fmt.Fprint(w, msg) + } } func TestSecurity(t *testing.T) { diff --git a/middleware/session.go b/middleware/session.go index 9bff4596..57382050 100644 --- a/middleware/session.go +++ b/middleware/session.go @@ -21,18 +21,16 @@ const ( // It lets you store and retrieve arbitrary data on a per-site-visitor basis. // // This middleware works best when used together with the [sess] package. -func session(wrappedHandler http.Handler, secretKey, domain string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - // 1. Read from cookies and check for session cookie. - // 2. Get that cookie and save it to r.context - r = sess.Initialise(r, secretKey) - - srw := newSessRW(w, r, domain, secretKey) - - wrappedHandler.ServeHTTP(srw, r) - }, - ) +func session(wrappedHandler http.Handler, secretKey, domain string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // 1. Read from cookies and check for session cookie. + // 2. Get that cookie and save it to r.context + r = sess.Initialise(r, secretKey) + + srw := newSessRW(w, r, domain, secretKey) + + wrappedHandler.ServeHTTP(srw, r) + } } // sessRW provides an http.ResponseWriter interface, which provides http session functionality. diff --git a/middleware/session_test.go b/middleware/session_test.go index adc66773..b444e728 100644 --- a/middleware/session_test.go +++ b/middleware/session_test.go @@ -25,18 +25,16 @@ func bigMap() map[string]string { return y } -func someSessionHandler(msg, key, value string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - sess.Set(r, key, value) - sess.SetM(r, bigMap()) - fmt.Fprint(w, msg) - }, - ) +func someSessionHandler(msg, key, value string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + sess.Set(r, key, value) + sess.SetM(r, bigMap()) + fmt.Fprint(w, msg) + } } // See https://github.com/komuw/ong/issues/205 -func templateVarsHandler(t *testing.T, name string) http.Handler { +func templateVarsHandler(t *testing.T, name string) http.HandlerFunc { tmpl, err := template.New("myTpl").Parse(` @@ -49,19 +47,17 @@ func templateVarsHandler(t *testing.T, name string) http.Handler { t.Fatal(err) } - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - sess.Set(r, "name", name) - - data := struct { - Name string - }{Name: name} - if err = tmpl.Execute(w, data); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - }, - ) + return func(w http.ResponseWriter, r *http.Request) { + sess.Set(r, "name", name) + + data := struct { + Name string + }{Name: name} + if err = tmpl.Execute(w, data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } } func TestSession(t *testing.T) { diff --git a/middleware/trace.go b/middleware/trace.go index 7998d4d1..4b8808c3 100644 --- a/middleware/trace.go +++ b/middleware/trace.go @@ -13,37 +13,35 @@ import ( const logIDKey = string(octx.LogCtxKey) // trace is a middleware that adds logID to request and response. -func trace(wrappedHandler http.Handler, domain string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() +func trace(wrappedHandler http.Handler, domain string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() - // set cookie/headers/ctx for logID. - logID := getLogId(r) - ctx = context.WithValue( - ctx, - // using this custom key is important, instead of using `logIDKey` - octx.LogCtxKey, - logID, - ) - r = r.WithContext(ctx) - r.Header.Set(logIDKey, logID) - w.Header().Set(logIDKey, logID) - cookie.Set( - w, - logIDKey, - logID, - domain, - // Hopefully 15mins is enough. - // Google considers a session to be 30mins. - // https://support.google.com/analytics/answer/2731565?hl=en#time-based-expiration - 15*time.Minute, - false, - ) + // set cookie/headers/ctx for logID. + logID := getLogId(r) + ctx = context.WithValue( + ctx, + // using this custom key is important, instead of using `logIDKey` + octx.LogCtxKey, + logID, + ) + r = r.WithContext(ctx) + r.Header.Set(logIDKey, logID) + w.Header().Set(logIDKey, logID) + cookie.Set( + w, + logIDKey, + logID, + domain, + // Hopefully 15mins is enough. + // Google considers a session to be 30mins. + // https://support.google.com/analytics/answer/2731565?hl=en#time-based-expiration + 15*time.Minute, + false, + ) - wrappedHandler.ServeHTTP(w, r) - }, - ) + wrappedHandler.ServeHTTP(w, r) + } } // getLogId returns a logID from the request or autogenerated if not available from the request. diff --git a/middleware/trace_test.go b/middleware/trace_test.go index 402bda88..9c919b80 100644 --- a/middleware/trace_test.go +++ b/middleware/trace_test.go @@ -18,24 +18,22 @@ import ( const someTraceHandlerHeader = "someTraceHandlerHeader" -func someTraceHandler(successMsg string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - // sleep so that the trace middleware has some useful duration metrics to report. - time.Sleep(3 * time.Millisecond) - if r.Header.Get(someTraceHandlerHeader) != "" { - http.Error( - w, - r.Header.Get(someTraceHandlerHeader), - http.StatusInternalServerError, - ) - return - } else { - fmt.Fprint(w, successMsg) - return - } - }, - ) +func someTraceHandler(successMsg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // sleep so that the trace middleware has some useful duration metrics to report. + time.Sleep(3 * time.Millisecond) + if r.Header.Get(someTraceHandlerHeader) != "" { + http.Error( + w, + r.Header.Get(someTraceHandlerHeader), + http.StatusInternalServerError, + ) + return + } else { + fmt.Fprint(w, successMsg) + return + } + } } func TestTraceMiddleware(t *testing.T) { diff --git a/mux/example_test.go b/mux/example_test.go index d4540893..58a634cb 100644 --- a/mux/example_test.go +++ b/mux/example_test.go @@ -11,21 +11,17 @@ import ( "github.com/komuw/ong/mux" ) -func LoginHandler() http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - _, _ = fmt.Fprint(w, "welcome to your favorite website.") - }, - ) +func LoginHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprint(w, "welcome to your favorite website.") + } } -func BooksByAuthorHandler() http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - author := mux.Param(r.Context(), "author") - _, _ = fmt.Fprintf(w, "fetching books by author: %s", author) - }, - ) +func BooksByAuthorHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + author := mux.Param(r.Context(), "author") + _, _ = fmt.Fprintf(w, "fetching books by author: %s", author) + } } func ExampleMux() { diff --git a/mux/mux_test.go b/mux/mux_test.go index 18abc347..641eb4a6 100644 --- a/mux/mux_test.go +++ b/mux/mux_test.go @@ -22,29 +22,23 @@ func getSecretKey() string { return key } -func someMuxHandler(msg string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, msg) - }, - ) +func someMuxHandler(msg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, msg) + } } -func thisIsAnotherMuxHandler() http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "thisIsAnotherMuxHandler") - }, - ) +func thisIsAnotherMuxHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "thisIsAnotherMuxHandler") + } } -func checkAgeHandler() http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - age := Param(r.Context(), "age") - _, _ = fmt.Fprintf(w, "Age is %s", age) - }, - ) +func checkAgeHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + age := Param(r.Context(), "age") + _, _ = fmt.Fprintf(w, "Age is %s", age) + } } func TestMain(m *testing.M) { diff --git a/mux/route_test.go b/mux/route_test.go index d3ba591b..ae4379fb 100644 --- a/mux/route_test.go +++ b/mux/route_test.go @@ -302,20 +302,16 @@ func TestMultipleRoutesDifferentMethods(t *testing.T) { attest.Equal(t, match, "POST") } -func firstRoute(msg string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - _, _ = fmt.Fprint(w, msg) - }, - ) +func firstRoute(msg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprint(w, msg) + } } -func secondRoute(msg string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - _, _ = fmt.Fprint(w, msg) - }, - ) +func secondRoute(msg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprint(w, msg) + } } func TestConflicts(t *testing.T) { diff --git a/server/example_test.go b/server/example_test.go index 489ec35f..df3a5761 100644 --- a/server/example_test.go +++ b/server/example_test.go @@ -47,24 +47,20 @@ func ExampleRun() { } } -func hello(msg string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - cspNonce := middleware.GetCspNonce(r.Context()) - csrfToken := middleware.GetCsrfToken(r.Context()) - fmt.Printf("hello called cspNonce: %s, csrfToken: %s", cspNonce, csrfToken) +func hello(msg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + cspNonce := middleware.GetCspNonce(r.Context()) + csrfToken := middleware.GetCsrfToken(r.Context()) + fmt.Printf("hello called cspNonce: %s, csrfToken: %s", cspNonce, csrfToken) - // use msg, which is a dependency specific to this handler - fmt.Fprint(w, msg) - }, - ) + // use msg, which is a dependency specific to this handler + fmt.Fprint(w, msg) + } } -func check() http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - age := mux.Param(r.Context(), "age") - _, _ = fmt.Fprintf(w, "Age is %s", age) - }, - ) +func check() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + age := mux.Param(r.Context(), "age") + _, _ = fmt.Fprintf(w, "Age is %s", age) + } } diff --git a/server/server_test.go b/server/server_test.go index 9ae4a8dc..aafc08ee 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -136,21 +136,19 @@ func TestOpts(t *testing.T) { const tlsFingerPrintKey = "TlsFingerPrintKey" -func someServerTestHandler(msg string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(tlsFingerPrintKey, middleware.ClientFingerPrint(r)) - - if _, err := io.ReadAll(r.Body); err != nil { - // This is where the error produced by `http.MaxBytesHandler` is produced at. - // ie, its produced when a read is made. - fmt.Fprint(w, err.Error()) - return - } +func someServerTestHandler(msg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(tlsFingerPrintKey, middleware.ClientFingerPrint(r)) + + if _, err := io.ReadAll(r.Body); err != nil { + // This is where the error produced by `http.MaxBytesHandler` is produced at. + // ie, its produced when a read is made. + fmt.Fprint(w, err.Error()) + return + } - fmt.Fprint(w, msg) - }, - ) + fmt.Fprint(w, msg) + } } func TestServer(t *testing.T) { @@ -389,12 +387,10 @@ func TestServer(t *testing.T) { }) } -func benchmarkServerHandler(msg string) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, msg) - }, - ) +func benchmarkServerHandler(msg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, msg) + } } var result int //nolint:gochecknoglobals diff --git a/sess/example_test.go b/sess/example_test.go index 7e4c65c8..275a2dc9 100644 --- a/sess/example_test.go +++ b/sess/example_test.go @@ -12,19 +12,17 @@ import ( "github.com/komuw/ong/sess" ) -func loginHandler() http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - mySession := map[string]string{ - "name": "John Doe", - "favorite_color": "red", - "height": "5 feet 6 inches", - } - sess.SetM(r, mySession) - - fmt.Fprint(w, "welcome again.") - }, - ) +func loginHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + mySession := map[string]string{ + "name": "John Doe", + "favorite_color": "red", + "height": "5 feet 6 inches", + } + sess.SetM(r, mySession) + + fmt.Fprint(w, "welcome again.") + } } func ExampleSetM() { From 187bb2ee487eb95fe9d8f59276410c35c0dd8bde Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Wed, 7 Jun 2023 23:57:21 +0300 Subject: [PATCH 07/13] m --- middleware/middleware.go | 154 ++++++++++++++++------------------ middleware/middleware_test.go | 1 - 2 files changed, 71 insertions(+), 84 deletions(-) diff --git a/middleware/middleware.go b/middleware/middleware.go index fc90b780..8f0cc2e0 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -1,5 +1,5 @@ // Package middleware provides helpful functions that implement some common functionalities in http servers. -// A middleware is a function that returns a [http.Handler] +// A middleware is a function that takes in a [http.Handler] as one of its arguments and returns a [http.Handler] // // The middlewares [All], [Get], [Post], [Head], [Put] & [Delete] wrap other internal middleware. // The effect of this is that the aforementioned middleware, in addition to their specialised functionality, will: @@ -208,11 +208,9 @@ func All(wrappedHandler http.Handler, o Opts) http.HandlerFunc { } func all(wrappedHandler http.Handler) http.HandlerFunc { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - wrappedHandler.ServeHTTP(w, r) - }, - ) + return func(w http.ResponseWriter, r *http.Request) { + wrappedHandler.ServeHTTP(w, r) + } } // Get is a middleware that only allows http GET requests and http OPTIONS requests. @@ -227,24 +225,22 @@ func Get(wrappedHandler http.Handler, o Opts) http.HandlerFunc { func get(wrappedHandler http.Handler) http.HandlerFunc { msg := "http method: %s not allowed. only allows http GET" - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - // We do not need to allow `http.MethodOptions` here. - // This is coz, the cors middleware has already handled that for us and it comes before the Get middleware. - if r.Method != http.MethodGet { - errMsg := fmt.Sprintf(msg, r.Method) - w.Header().Set(ongMiddlewareErrorHeader, errMsg) - http.Error( - w, - errMsg, - http.StatusMethodNotAllowed, - ) - return - } + return func(w http.ResponseWriter, r *http.Request) { + // We do not need to allow `http.MethodOptions` here. + // This is coz, the cors middleware has already handled that for us and it comes before the Get middleware. + if r.Method != http.MethodGet { + errMsg := fmt.Sprintf(msg, r.Method) + w.Header().Set(ongMiddlewareErrorHeader, errMsg) + http.Error( + w, + errMsg, + http.StatusMethodNotAllowed, + ) + return + } - wrappedHandler.ServeHTTP(w, r) - }, - ) + wrappedHandler.ServeHTTP(w, r) + } } // Post is a middleware that only allows http POST requests and http OPTIONS requests. @@ -259,22 +255,20 @@ func Post(wrappedHandler http.Handler, o Opts) http.HandlerFunc { func post(wrappedHandler http.Handler) http.HandlerFunc { msg := "http method: %s not allowed. only allows http POST" - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - errMsg := fmt.Sprintf(msg, r.Method) - w.Header().Set(ongMiddlewareErrorHeader, errMsg) - http.Error( - w, - errMsg, - http.StatusMethodNotAllowed, - ) - return - } + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + errMsg := fmt.Sprintf(msg, r.Method) + w.Header().Set(ongMiddlewareErrorHeader, errMsg) + http.Error( + w, + errMsg, + http.StatusMethodNotAllowed, + ) + return + } - wrappedHandler.ServeHTTP(w, r) - }, - ) + wrappedHandler.ServeHTTP(w, r) + } } // Head is a middleware that only allows http HEAD requests and http OPTIONS requests. @@ -289,22 +283,20 @@ func Head(wrappedHandler http.Handler, o Opts) http.HandlerFunc { func head(wrappedHandler http.Handler) http.HandlerFunc { msg := "http method: %s not allowed. only allows http HEAD" - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodHead { - errMsg := fmt.Sprintf(msg, r.Method) - w.Header().Set(ongMiddlewareErrorHeader, errMsg) - http.Error( - w, - errMsg, - http.StatusMethodNotAllowed, - ) - return - } + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodHead { + errMsg := fmt.Sprintf(msg, r.Method) + w.Header().Set(ongMiddlewareErrorHeader, errMsg) + http.Error( + w, + errMsg, + http.StatusMethodNotAllowed, + ) + return + } - wrappedHandler.ServeHTTP(w, r) - }, - ) + wrappedHandler.ServeHTTP(w, r) + } } // Put is a middleware that only allows http PUT requests and http OPTIONS requests. @@ -319,22 +311,20 @@ func Put(wrappedHandler http.Handler, o Opts) http.HandlerFunc { func put(wrappedHandler http.Handler) http.HandlerFunc { msg := "http method: %s not allowed. only allows http PUT" - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPut { - errMsg := fmt.Sprintf(msg, r.Method) - w.Header().Set(ongMiddlewareErrorHeader, errMsg) - http.Error( - w, - errMsg, - http.StatusMethodNotAllowed, - ) - return - } + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + errMsg := fmt.Sprintf(msg, r.Method) + w.Header().Set(ongMiddlewareErrorHeader, errMsg) + http.Error( + w, + errMsg, + http.StatusMethodNotAllowed, + ) + return + } - wrappedHandler.ServeHTTP(w, r) - }, - ) + wrappedHandler.ServeHTTP(w, r) + } } // Delete is a middleware that only allows http DELETE requests and http OPTIONS requests. @@ -350,20 +340,18 @@ func Delete(wrappedHandler http.Handler, o Opts) http.HandlerFunc { // this is not called `delete` since that is a Go builtin func for deleting from maps. func deleteH(wrappedHandler http.Handler) http.HandlerFunc { msg := "http method: %s not allowed. only allows http DELETE" - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodDelete { - errMsg := fmt.Sprintf(msg, r.Method) - w.Header().Set(ongMiddlewareErrorHeader, errMsg) - http.Error( - w, - errMsg, - http.StatusMethodNotAllowed, - ) - return - } + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + errMsg := fmt.Sprintf(msg, r.Method) + w.Header().Set(ongMiddlewareErrorHeader, errMsg) + http.Error( + w, + errMsg, + http.StatusMethodNotAllowed, + ) + return + } - wrappedHandler.ServeHTTP(w, r) - }, - ) + wrappedHandler.ServeHTTP(w, r) + } } diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 46c5327c..aa78fc73 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -1,5 +1,4 @@ // Package middleware provides helpful functions that implement some common functionalities in http servers. -// A middleware is a func that returns a http.Handler package middleware import ( From 0c0eb48c62a4cc1d09753c60d1fd33fbb45e02f7 Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Wed, 7 Jun 2023 23:59:29 +0300 Subject: [PATCH 08/13] m --- middleware/gzip.go | 4 +- middleware/log.go | 91 ++++++++++++++++++----------------- middleware/middleware_test.go | 2 +- 3 files changed, 49 insertions(+), 48 deletions(-) diff --git a/middleware/gzip.go b/middleware/gzip.go index 96e61c18..ae76fd4f 100644 --- a/middleware/gzip.go +++ b/middleware/gzip.go @@ -79,9 +79,9 @@ var ( _ http.ResponseWriter = &gzipRW{} _ http.Flusher = &gzipRW{} _ http.Hijacker = &gzipRW{} - _ http.Pusher = &logRW{} + _ http.Pusher = &gzipRW{} _ io.WriteCloser = &gzipRW{} - _ io.ReaderFrom = &logRW{} + _ io.ReaderFrom = &gzipRW{} // _ http.CloseNotifier = &gzipRW{} // `http.CloseNotifier` has been deprecated sinc Go v1.11(year 2018) ) diff --git a/middleware/log.go b/middleware/log.go index 8e6b693c..e2f5e59b 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -22,56 +22,57 @@ func logger(wrappedHandler http.Handler, l *slog.Logger) http.HandlerFunc { // // However, each request should get its own context. That's why we call `logger.WithCtx` for every request. - return func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - lrw := &logRW{ - ResponseWriter: w, + return func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + lrw := &logRW{ + ResponseWriter: w, + } + defer func() { + msg := "http_server" + flds := []any{ + "clientIP", ClientIP(r), + "clientFingerPrint", ClientFingerPrint(r), + "method", r.Method, + "path", r.URL.Redacted(), + "code", lrw.code, + "status", http.StatusText(lrw.code), + "durationMS", time.Since(start).Milliseconds(), + } + if ongError := lrw.Header().Get(ongMiddlewareErrorHeader); ongError != "" { + extra := []any{"ongError", ongError} + flds = append(flds, extra...) } - defer func() { - msg := "http_server" - flds := []any{ - "clientIP", ClientIP(r), - "clientFingerPrint", ClientFingerPrint(r), - "method", r.Method, - "path", r.URL.Redacted(), - "code", lrw.code, - "status", http.StatusText(lrw.code), - "durationMS", time.Since(start).Milliseconds(), - } - if ongError := lrw.Header().Get(ongMiddlewareErrorHeader); ongError != "" { - extra := []any{"ongError", ongError} - flds = append(flds, extra...) - } - // Remove header so that users dont see it. - // - // Note that this may not actually work. - // According to: https://pkg.go.dev/net/http#ResponseWriter - // Changing the header map after a call to WriteHeader (or - // Write) has no effect unless the HTTP status code was of the - // 1xx class or the modified headers are trailers. - lrw.Header().Del(ongMiddlewareErrorHeader) - - // The logger should be in the defer block so that it uses the updated context containing the logID. - reqL := log.WithID(r.Context(), l) - - if lrw.code == http.StatusServiceUnavailable || lrw.code == http.StatusTooManyRequests && w.Header().Get(retryAfterHeader) != "" { - // We are either in load shedding or rate-limiting. - // Only log 10% of the errors. - shouldLog := mathRand.Intn(100) > 90 - if shouldLog { - reqL.Error(msg, flds...) - } - } else if lrw.code >= http.StatusBadRequest { - // both client and server errors. + // Remove header so that users dont see it. + // + // Note that this may not actually work. + // According to: https://pkg.go.dev/net/http#ResponseWriter + // Changing the header map after a call to WriteHeader (or + // Write) has no effect unless the HTTP status code was of the + // 1xx class or the modified headers are trailers. + lrw.Header().Del(ongMiddlewareErrorHeader) + + // The logger should be in the defer block so that it uses the updated context containing the logID. + reqL := log.WithID(r.Context(), l) + + if lrw.code == http.StatusServiceUnavailable || lrw.code == http.StatusTooManyRequests && w.Header().Get(retryAfterHeader) != "" { + // We are either in load shedding or rate-limiting. + // Only log 10% of the errors. + shouldLog := mathRand.Intn(100) > 90 + if shouldLog { reqL.Error(msg, flds...) - } else { - reqL.Info(msg, flds...) } - }() + } else if lrw.code >= http.StatusBadRequest { + // both client and server errors. + reqL.Error(msg, flds...) + } else { + reqL.Info(msg, flds...) + } + }() - wrappedHandler.ServeHTTP(lrw, r) - } + wrappedHandler.ServeHTTP(lrw, r) + } +} // logRW provides an http.ResponseWriter interface, which logs requests/responses. type logRW struct { diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index aa78fc73..60e4cd8d 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -56,7 +56,7 @@ func TestAllMiddleware(t *testing.T) { errMsg := "not allowed. only allows http" tests := []struct { name string - middleware func(wrappedHandler http.Handler, o Opts) http.Handler + middleware func(wrappedHandler http.Handler, o Opts) http.HandlerFunc httpMethod string expectedStatusCode int expectedMsg string From f1b564b47c94be8ae1f28b6c16f40412ff26e064 Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Thu, 8 Jun 2023 14:09:22 +0300 Subject: [PATCH 09/13] m --- mux/mux_test.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mux/mux_test.go b/mux/mux_test.go index 641eb4a6..75f76b88 100644 --- a/mux/mux_test.go +++ b/mux/mux_test.go @@ -207,8 +207,8 @@ func TestMux(t *testing.T) { rStr := fmt.Sprintf("%v", r) attest.Subsequence(t, rStr, uri2) attest.Subsequence(t, rStr, method) - attest.Subsequence(t, rStr, "ong/mux/mux_test.go:27") // location where `someMuxHandler` is declared. - attest.Subsequence(t, rStr, "ong/mux/mux_test.go:35") // location where `thisIsAnotherMuxHandler` is declared. + attest.Subsequence(t, rStr, "ong/mux/mux_test.go:26") // location where `someMuxHandler` is declared. + attest.Subsequence(t, rStr, "ong/mux/mux_test.go:32") // location where `thisIsAnotherMuxHandler` is declared. }() _ = New( @@ -261,28 +261,28 @@ func TestMux(t *testing.T) { "api", "/api/", MethodGet, - "ong/mux/mux_test.go:27", // location where `someMuxHandler` is declared. + "ong/mux/mux_test.go:26", // location where `someMuxHandler` is declared. }, { "success with prefix slash", "/api", "/api/", MethodGet, - "ong/mux/mux_test.go:27", // location where `someMuxHandler` is declared. + "ong/mux/mux_test.go:26", // location where `someMuxHandler` is declared. }, { "success with suffix slash", "api/", "/api/", MethodGet, - "ong/mux/mux_test.go:27", // location where `someMuxHandler` is declared. + "ong/mux/mux_test.go:26", // location where `someMuxHandler` is declared. }, { "success with all slashes", "/api/", "/api/", MethodGet, - "ong/mux/mux_test.go:27", // location where `someMuxHandler` is declared. + "ong/mux/mux_test.go:26", // location where `someMuxHandler` is declared. }, { "failure", @@ -296,14 +296,14 @@ func TestMux(t *testing.T) { "check/2625", "/check/:age/", MethodAll, - "ong/mux/mux_test.go:43", // location where `checkAgeHandler` is declared. + "ong/mux/mux_test.go:38", // location where `checkAgeHandler` is declared. }, { "url with domain name", "https://localhost/check/2625", "/check/:age/", MethodAll, - "ong/mux/mux_test.go:43", // location where `checkAgeHandler` is declared. + "ong/mux/mux_test.go:38", // location where `checkAgeHandler` is declared. }, } From 1b6c9197d6026e81488a5a2ae02f501eead02574 Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Thu, 8 Jun 2023 14:13:12 +0300 Subject: [PATCH 10/13] m --- middleware/middleware_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 60e4cd8d..c69aa066 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -1,4 +1,3 @@ -// Package middleware provides helpful functions that implement some common functionalities in http servers. package middleware import ( From fb362fb6edd823fdde5df0a4856eead7949f4c11 Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Thu, 8 Jun 2023 14:24:13 +0300 Subject: [PATCH 11/13] m --- client/client.go | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/client/client.go b/client/client.go index 6d054bc0..4975b1c2 100644 --- a/client/client.go +++ b/client/client.go @@ -20,6 +20,11 @@ import ( const ( logIDHeader = string(octx.LogCtxKey) errPrefix = "ong/client:" + // The wikipedia monitoring dashboards are public: https://grafana.wikimedia.org/?orgId=1 + // In there we can see that the p95 response times for http GET requests is ~700ms: https://grafana.wikimedia.org/d/RIA1lzDZk/application-servers-red?orgId=1 + // and the p95 response times for http POST requests is ~3seconds: + // Thus, we set the timeout to be twice that. + defaultTimeout = 3 * 2 * time.Second ) // Most of the code here is inspired by(or taken from): @@ -31,24 +36,28 @@ const ( // Safe creates a http client that has some good defaults & is safe from server-side request forgery (SSRF). // It also logs requests and responses using [log.Logger] -func Safe(l *slog.Logger) *http.Client { - return new(true, l) +// The timeout is optional. +func Safe(l *slog.Logger, timeout ...time.Duration) *http.Client { + t := defaultTimeout + if len(timeout) > 0 { + t = timeout[0] + } + return new(true, t, l) } // Unsafe creates a http client that has some good defaults & is NOT safe from server-side request forgery (SSRF). // It also logs requests and responses using [log.Logger] -func Unsafe(l *slog.Logger) *http.Client { - return new(false, l) +// The timeout is optional +func Unsafe(l *slog.Logger, timeout ...time.Duration) *http.Client { + t := defaultTimeout + if len(timeout) > 0 { + t = timeout[0] + } + return new(false, t, l) } // new creates a client. Use [Safe] or [Unsafe] instead. -func new(ssrfSafe bool, l *slog.Logger) *http.Client { - // The wikipedia monitoring dashboards are public: https://grafana.wikimedia.org/?orgId=1 - // In there we can see that the p95 response times for http GET requests is ~700ms: https://grafana.wikimedia.org/d/RIA1lzDZk/application-servers-red?orgId=1 - // and the p95 response times for http POST requests is ~3seconds: - // Thus, we set the timeout to be twice that. - timeout := 3 * 2 * time.Second - +func new(ssrfSafe bool, timeout time.Duration, l *slog.Logger) *http.Client { dialer := &net.Dialer{ // Using Dialer.ControlContext instead of Dialer.Control allows; // - propagation of logging contexts, metric context or other metadata down to the callback. @@ -81,7 +90,7 @@ func new(ssrfSafe bool, l *slog.Logger) *http.Client { MaxIdleConns: 100, IdleConnTimeout: 3 * timeout, TLSHandshakeTimeout: timeout, - ExpectContinueTimeout: 1 * time.Second, + ExpectContinueTimeout: (timeout / 5), } lr := &loggingRT{transport, l} From d3491747eeaa5fe961636f5edb3bbb0cf2c5bd4c Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Thu, 8 Jun 2023 14:25:38 +0300 Subject: [PATCH 12/13] m --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 77b57dd9..733b4f5c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Most recent version is listed first. # v0.0.49 - Add mux Resolve function: https://github.com/komuw/ong/pull/268 - Use http.Handler as the http middleware instead of http.HandlerFunc: https://github.com/komuw/ong/pull/269 +- Add optional http timeout: https://github.com/komuw/ong/pull/270 ## v0.0.48 - Change attest import path: https://github.com/komuw/ong/pull/265 From ab870fad5512899f76cc5a0e91ae8794bbf649c2 Mon Sep 17 00:00:00 2001 From: $MY_NAME Date: Thu, 8 Jun 2023 17:41:42 +0300 Subject: [PATCH 13/13] m --- mux/mux_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mux/mux_test.go b/mux/mux_test.go index 75f76b88..0b5b5668 100644 --- a/mux/mux_test.go +++ b/mux/mux_test.go @@ -285,7 +285,7 @@ func TestMux(t *testing.T) { "ong/mux/mux_test.go:26", // location where `someMuxHandler` is declared. }, { - "failure", + "bad", "/", "", "",