From c9fdb5162fedc74570917d5a96f709e218bd57fc Mon Sep 17 00:00:00 2001 From: Komu Wairagu Date: Thu, 8 Dec 2022 13:10:59 +0300 Subject: [PATCH] make most middleware private (#186) What: - make most middleware private Why: - Ideally we should only use the `middleware.Get`, `middleware.Post`, etc middlewares --- CHANGELOG.md | 3 +- middleware/cors.go | 4 +- middleware/cors_test.go | 22 +++++----- middleware/csrf.go | 4 +- middleware/csrf_test.go | 18 ++++---- middleware/example_test.go | 10 ++++- middleware/gzip.go | 14 +++---- middleware/gzip_test.go | 24 +++++------ middleware/loadshed.go | 4 +- middleware/loadshed_test.go | 6 +-- middleware/log.go | 7 ++-- middleware/log_test.go | 10 ++--- middleware/middleware.go | 68 +++++++++++++++---------------- middleware/panic.go | 8 ++-- middleware/panic_test.go | 10 ++--- middleware/ratelimiter.go | 4 +- middleware/ratelimiter_test.go | 8 ++-- middleware/redirect.go | 4 +- middleware/redirect_test.go | 18 ++++---- middleware/reload_protect.go | 4 +- middleware/reload_protect_test.go | 6 +-- middleware/security.go | 6 +-- middleware/security_test.go | 8 ++-- middleware/session.go | 4 +- middleware/session_test.go | 6 +-- sess/example_test.go | 12 +++--- sess/sess.go | 10 ++--- 27 files changed, 156 insertions(+), 146 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b1fa18f9..f2ff934e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ Most recent version is listed first. ## v0.0.25 - ong/client: Use roundTripper for logging: https://github.com/komuw/ong/pull/185 +- Make most middleware private: https://github.com/komuw/ong/pull/186 ## v0.0.24 - Set session cookie only if non-empty: https://github.com/komuw/ong/pull/170 @@ -15,7 +16,7 @@ Most recent version is listed first. - ong/client: Add log id http header: https://github.com/komuw/ong/pull/166 ## v0.0.22 -- Panic middleware should include correct stack trace: https://github.com/komuw/ong/pull/164 +- Panic/recoverer middleware should include correct stack trace: https://github.com/komuw/ong/pull/164 - Log client address without port: https://github.com/komuw/ong/pull/165 ## v0.0.21 diff --git a/middleware/cors.go b/middleware/cors.go index 4b16fb7c..05ed5ac8 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -54,12 +54,12 @@ const ( corsCacheDur = 2 * time.Hour ) -// Cors is a middleware to implement Cross-Origin Resource Sharing support. +// cors is a middleware to implement Cross-Origin Resource Sharing support. // // If allowedOrigins is nil, all origins are allowed. You can also use * to allow all. // If allowedMethods is nil, "GET", "POST", "HEAD" are allowed. Use * to allow all. // If allowedHeaders is nil, "Origin", "Accept", "Content-Type", "X-Requested-With" are allowed. Use * to allow all. -func Cors( +func cors( wrappedHandler http.HandlerFunc, allowedOrigins []string, allowedMethods []string, diff --git a/middleware/cors_test.go b/middleware/cors_test.go index 297c91b7..ed25ecac 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -24,7 +24,7 @@ func TestCorsPreflight(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := Cors(someCorsHandler(msg), nil, nil, nil) + wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodOptions, "/someUri", nil) req.Header.Add(acrmHeader, "is-set") // preflight request header set @@ -44,7 +44,7 @@ func TestCorsPreflight(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := Cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, []string{"*"}) + wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, []string{"*"}) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodOptions, "/someUri", nil) req.Header.Add(acrmHeader, http.MethodGet) // preflight request header set @@ -75,7 +75,7 @@ func TestCorsPreflight(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := Cors(someCorsHandler(msg), nil, nil, nil) + wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodOptions, "/someUri", nil) // preflight request header NOT set @@ -132,7 +132,7 @@ func TestCorsPreflight(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := Cors(someCorsHandler(msg), tt.allowedOrigins, []string{"*"}, []string{"*"}) + wrappedHandler := cors(someCorsHandler(msg), tt.allowedOrigins, []string{"*"}, []string{"*"}) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodOptions, "/someUri", nil) req.Header.Add(acrmHeader, "is-set") // preflight request header set @@ -206,7 +206,7 @@ func TestCorsPreflight(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - wrappedHandler := Cors(someCorsHandler(msg), []string{"*"}, tt.allowedMethods, []string{"*"}) + wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, tt.allowedMethods, []string{"*"}) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodOptions, "/someUri", nil) req.Header.Add(originHeader, "http://some-origin.com") @@ -270,7 +270,7 @@ func TestCorsPreflight(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := Cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, tt.allowedHeaders) + wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, tt.allowedHeaders) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodOptions, "/someUri", nil) req.Header.Add(acrmHeader, "is-set") // preflight request header set @@ -302,7 +302,7 @@ func TestCorsActualRequest(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := Cors(someCorsHandler(msg), nil, nil, nil) + wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) wrappedHandler.ServeHTTP(rec, req) @@ -321,7 +321,7 @@ func TestCorsActualRequest(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := Cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, []string{"*"}) + wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, []string{"*"}) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) req.Header.Add(originHeader, "http://example.com") @@ -381,7 +381,7 @@ func TestCorsActualRequest(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := Cors(someCorsHandler(msg), tt.allowedOrigins, []string{"*"}, []string{"*"}) + wrappedHandler := cors(someCorsHandler(msg), tt.allowedOrigins, []string{"*"}, []string{"*"}) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) req.Header.Add(originHeader, tt.origin) @@ -445,7 +445,7 @@ func TestCorsActualRequest(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := Cors(someCorsHandler(msg), []string{"*"}, tt.allowedMethods, []string{"*"}) + wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, tt.allowedMethods, []string{"*"}) rec := httptest.NewRecorder() req := httptest.NewRequest(tt.method, "/someUri", nil) req.Header.Add(originHeader, "http://some-origin.com") @@ -474,7 +474,7 @@ func TestCorsActualRequest(t *testing.T) { msg := "hello" // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := Cors(someCorsHandler(msg), nil, nil, nil) + wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil) runhandler := func() { rec := httptest.NewRecorder() diff --git a/middleware/csrf.go b/middleware/csrf.go index ffb9ecb3..813690ba 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -54,10 +54,10 @@ const ( csrfBytesTokenLength = 32 ) -// Csrf is a middleware that provides protection against Cross Site Request Forgeries. +// csrf is a middleware that provides protection against Cross Site Request Forgeries. // // If a csrf token is not provided(or is not valid), when it ought to have been; this middleware will issue a http GET redirect to the same url. -func Csrf(wrappedHandler http.HandlerFunc, secretKey, domain string) http.HandlerFunc { +func csrf(wrappedHandler http.HandlerFunc, secretKey, domain string) http.HandlerFunc { once.Do(func() { enc = cry.New(secretKey) }) diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 50b51218..ff0ef95a 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -127,7 +127,7 @@ func TestCsrf(t *testing.T) { msg := "hello" domain := "example.com" - wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain) + wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -148,7 +148,7 @@ func TestCsrf(t *testing.T) { msg := "hello" domain := "example.com" - wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain) + wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain) reqCsrfTok := id.Random(csrfBytesTokenLength) rec := httptest.NewRecorder() @@ -179,7 +179,7 @@ func TestCsrf(t *testing.T) { msg := "hello" domain := "example.com" - wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain) + wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain) reqCsrfTok := id.Random(csrfBytesTokenLength) rec := httptest.NewRecorder() @@ -203,7 +203,7 @@ func TestCsrf(t *testing.T) { msg := "hello" domain := "example.com" - wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain) + wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -225,7 +225,7 @@ func TestCsrf(t *testing.T) { msg := "hello" domain := "example.com" - wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain) + wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -266,7 +266,7 @@ func TestCsrf(t *testing.T) { msg := "hello" domain := "example.com" - wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain) + wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain) reqCsrfTok := id.Random(csrfBytesTokenLength) rec := httptest.NewRecorder() @@ -295,7 +295,7 @@ func TestCsrf(t *testing.T) { msg := "hello" domain := "example.com" - wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain) + wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain) key := getSecretKey() enc2 := cry.New(key) @@ -365,7 +365,7 @@ func TestCsrf(t *testing.T) { msg := "hello" domain := "example.com" - wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain) + wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain) rec := httptest.NewRecorder() postMsg := "my name is John" @@ -394,7 +394,7 @@ func TestCsrf(t *testing.T) { domain := "example.com" // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain) + wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain) key := getSecretKey() enc2 := cry.New(key) diff --git a/middleware/example_test.go b/middleware/example_test.go index 31b1f720..2e0a0c3f 100644 --- a/middleware/example_test.go +++ b/middleware/example_test.go @@ -20,7 +20,10 @@ func loginHandler() http.HandlerFunc { } func Example_getCspNonce() { - handler := middleware.SecurityHeaders(loginHandler(), "example.com") + handler := middleware.Get( + loginHandler(), + middleware.WithOpts("example.com", 443, "secretKey", log.New(os.Stdout, 100)), + ) _ = handler // use handler // Output: @@ -36,7 +39,10 @@ func welcomeHandler() http.HandlerFunc { } func Example_getCsrfToken() { - handler := middleware.Csrf(welcomeHandler(), "some-secret-key", "example.com") + handler := middleware.Get( + welcomeHandler(), + middleware.WithOpts("example.com", 443, "secretKey", log.New(os.Stdout, 100)), + ) _ = handler // use handler // Output: diff --git a/middleware/gzip.go b/middleware/gzip.go index 80a53b00..f56a3591 100644 --- a/middleware/gzip.go +++ b/middleware/gzip.go @@ -2,7 +2,7 @@ package middleware import ( "bufio" - "compress/gzip" + stdGzip "compress/gzip" "fmt" "io" "net" @@ -28,8 +28,8 @@ const ( thisMiddlewareEncoding = "gzip" ) -// Gzip is a middleware that transparently gzips the http response body, for clients that support it. -func Gzip(wrappedHandler http.HandlerFunc) http.HandlerFunc { +// gzip is a middleware that transparently gzips the http response body, for clients that support it. +func gzip(wrappedHandler http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { w.Header().Add(varyHeader, acceptEncodingHeader) @@ -38,7 +38,7 @@ func Gzip(wrappedHandler http.HandlerFunc) http.HandlerFunc { return } - gzipWriter, _ := gzip.NewWriterLevel(w, gzip.BestSpeed) + gzipWriter, _ := stdGzip.NewWriterLevel(w, stdGzip.BestSpeed) grw := &gzipRW{ ResponseWriter: w, // Bytes written during ServeHTTP are redirected to this gzip writer @@ -65,7 +65,7 @@ func Gzip(wrappedHandler http.HandlerFunc) http.HandlerFunc { // writers, so don't forget to do that. type gzipRW struct { http.ResponseWriter - gw *gzip.Writer + gw *stdGzip.Writer buf []byte // Holds the first part of the write before reaching the minSize or the end of the write. @@ -176,7 +176,7 @@ func (grw *gzipRW) handleGzipped(ct string, lenB int) (int, error) { return lenB, nil } -// Close will close the gzip.Writer. +// Close will close the stdGzip.Writer. func (grw *gzipRW) Close() error { if !grw.handledZip { // GZIP not triggered yet, write out regular response. @@ -187,7 +187,7 @@ func (grw *gzipRW) Close() error { return grw.gw.Close() // will also call gzip flush() } -// Flush flushes the underlying *gzip.Writer and then the +// Flush flushes the underlying *stdGzip.Writer and then the // underlying http.ResponseWriter if it is an http.Flusher. // This makes gzipRW an http.Flusher. func (grw *gzipRW) Flush() { diff --git a/middleware/gzip_test.go b/middleware/gzip_test.go index 4ebf9d87..e7080029 100644 --- a/middleware/gzip_test.go +++ b/middleware/gzip_test.go @@ -1,7 +1,7 @@ package middleware import ( - "compress/gzip" + stdGzip "compress/gzip" "fmt" "html/template" "io" @@ -103,7 +103,7 @@ func readBody(t *testing.T, res *http.Response) (strBody string) { if res.Header.Get(contentEncodingHeader) == "gzip" { // the body is gzipped. - reader, err := gzip.NewReader(body) + reader, err := stdGzip.NewReader(body) attest.Ok(t, err) defer reader.Close() rb, err := io.ReadAll(reader) @@ -126,7 +126,7 @@ func TestGzip(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := Gzip(someGzipHandler(msg)) + wrappedHandler := gzip(someGzipHandler(msg)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodHead, "/someUri", nil) @@ -144,7 +144,7 @@ func TestGzip(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := Gzip(someGzipHandler(msg)) + wrappedHandler := gzip(someGzipHandler(msg)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) req.Header.Add(acceptEncodingHeader, "br;q=1.0, gzip;q=0.8, *;q=0.1") @@ -162,7 +162,7 @@ func TestGzip(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := Gzip(handlerImplementingFlush(msg)) + wrappedHandler := gzip(handlerImplementingFlush(msg)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) req.Header.Add(acceptEncodingHeader, "br;q=1.0, gzip;q=0.8, *;q=0.1") @@ -182,7 +182,7 @@ func TestGzip(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := Gzip(handlerImplementingFlush(msg)) + wrappedHandler := gzip(handlerImplementingFlush(msg)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) req.Header.Add(acceptEncodingHeader, "br;q=1.0, gzip;q=0.8, *;q=0.1") @@ -202,7 +202,7 @@ func TestGzip(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := Gzip(someGzipHandler(msg)) + wrappedHandler := gzip(someGzipHandler(msg)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) req.Header.Add(acceptEncodingHeader, "br;q=1.0, compress;q=0.8, *;q=0.1") @@ -219,7 +219,7 @@ func TestGzip(t *testing.T) { t.Run("issues/81", func(t *testing.T) { t.Parallel() - wrappedHandler := Gzip(login()) + wrappedHandler := gzip(login()) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -237,7 +237,7 @@ func TestGzip(t *testing.T) { t.Run("issues/81", func(t *testing.T) { t.Parallel() - wrappedHandler := Gzip(login()) + wrappedHandler := gzip(login()) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -258,7 +258,7 @@ func TestGzip(t *testing.T) { msg := "hello" // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := Gzip(someGzipHandler(msg)) + wrappedHandler := gzip(someGzipHandler(msg)) runhandler := func() { rec := httptest.NewRecorder() @@ -317,7 +317,7 @@ var result int //nolint:gochecknoglobals func BenchmarkOngGzip(b *testing.B) { var r int - wrappedHandler := Gzip(gzipBenchmarkHandler()) + wrappedHandler := gzip(gzipBenchmarkHandler()) b.ReportAllocs() b.ResetTimer() @@ -409,7 +409,7 @@ func BenchmarkTmthrgdGzip(b *testing.B) { func BenchmarkNoGzip(b *testing.B) { var r int - wrappedHandler := Gzip(gzipBenchmarkHandler()) + wrappedHandler := gzip(gzipBenchmarkHandler()) b.ReportAllocs() b.ResetTimer() diff --git a/middleware/loadshed.go b/middleware/loadshed.go index 8565f0b7..ca2e1c57 100644 --- a/middleware/loadshed.go +++ b/middleware/loadshed.go @@ -40,8 +40,8 @@ const ( resizePeriod = samplingPeriod + (3 * time.Minute) ) -// LoadShedder is a middleware that sheds load based on http response latencies. -func LoadShedder(wrappedHandler http.HandlerFunc) http.HandlerFunc { +// loadShedder is a middleware that sheds load based on http response latencies. +func loadShedder(wrappedHandler http.HandlerFunc) http.HandlerFunc { mathRand.Seed(time.Now().UTC().UnixNano()) // lq should not be a global variable, we want it to be per handler. // This is because different handlers(URIs) could have different latencies and we want each to be loadshed independently. diff --git a/middleware/loadshed_test.go b/middleware/loadshed_test.go index 7250303b..c020e841 100644 --- a/middleware/loadshed_test.go +++ b/middleware/loadshed_test.go @@ -35,7 +35,7 @@ func TestLoadShedder(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := LoadShedder(someLoadShedderHandler(msg)) + wrappedHandler := loadShedder(someLoadShedderHandler(msg)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -58,7 +58,7 @@ func TestLoadShedder(t *testing.T) { msg := "hello" // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := LoadShedder(someLoadShedderHandler(msg)) + wrappedHandler := loadShedder(someLoadShedderHandler(msg)) runhandler := func() { rec := httptest.NewRecorder() @@ -194,7 +194,7 @@ func loadShedderBenchmarkHandler() http.HandlerFunc { func BenchmarkLoadShedder(b *testing.B) { var r int - wrappedHandler := LoadShedder(loadShedderBenchmarkHandler()) + wrappedHandler := loadShedder(loadShedderBenchmarkHandler()) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) diff --git a/middleware/log.go b/middleware/log.go index ea97cf60..34e20f89 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -17,8 +17,8 @@ import ( const logIDKey = string(log.CtxKey) -// Log is a middleware that logs http requests and responses using [log.Logger]. -func Log(wrappedHandler http.HandlerFunc, domain string, l log.Logger) http.HandlerFunc { +// logger is a middleware that logs http requests and responses using [log.Logger]. +func logger(wrappedHandler http.HandlerFunc, domain string, l log.Logger) http.HandlerFunc { // We pass the logger as an argument so that the middleware can share the same logger as the app. // That way, if the app logs an error, the middleware logs are also flushed. // This makes debugging easier for developers. @@ -26,6 +26,7 @@ func Log(wrappedHandler http.HandlerFunc, domain string, l log.Logger) http.Hand // However, each request should get its own context. That's why we call `logger.WithCtx` for every request. mathRand.Seed(time.Now().UTC().UnixNano()) + pid := os.Getpid() return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -73,7 +74,7 @@ func Log(wrappedHandler http.HandlerFunc, domain string, l log.Logger) http.Hand "code": lrw.code, "status": http.StatusText(lrw.code), "durationMS": time.Since(start).Milliseconds(), - "pid": os.Getpid(), + "pid": pid, } if ongError := lrw.Header().Get(ongMiddlewareErrorHeader); ongError != "" { flds["ongError"] = ongError diff --git a/middleware/log_test.go b/middleware/log_test.go index f8fc7fb3..1c615c6f 100644 --- a/middleware/log_test.go +++ b/middleware/log_test.go @@ -49,7 +49,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} successMsg := "hello" domain := "example.com" - wrappedHandler := Log(someLogHandler(successMsg), domain, getLogger(logOutput)) + wrappedHandler := logger(someLogHandler(successMsg), domain, getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodHead, "/someUri", nil) @@ -79,7 +79,7 @@ func TestLogMiddleware(t *testing.T) { errorMsg := "someLogHandler failed" successMsg := "hello" domain := "example.com" - wrappedHandler := Log(someLogHandler(successMsg), domain, getLogger(logOutput)) + wrappedHandler := logger(someLogHandler(successMsg), domain, getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodHead, "/someUri", nil) @@ -120,7 +120,7 @@ func TestLogMiddleware(t *testing.T) { successMsg := "hello" errorMsg := "someLogHandler failed" domain := "example.com" - wrappedHandler := Log(someLogHandler(successMsg), domain, getLogger(logOutput)) + wrappedHandler := logger(someLogHandler(successMsg), domain, getLogger(logOutput)) { // first request that succeds @@ -218,7 +218,7 @@ func TestLogMiddleware(t *testing.T) { logOutput := &bytes.Buffer{} successMsg := "hello" domain := "example.com" - wrappedHandler := Log(someLogHandler(successMsg), domain, getLogger(logOutput)) + wrappedHandler := logger(someLogHandler(successMsg), domain, getLogger(logOutput)) someLogID := "hey-some-log-id:" + id.New() @@ -256,7 +256,7 @@ func TestLogMiddleware(t *testing.T) { domain := "example.com" // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := Log(someLogHandler(successMsg), domain, getLogger(logOutput)) + wrappedHandler := logger(someLogHandler(successMsg), domain, getLogger(logOutput)) runhandler := func() { rec := httptest.NewRecorder() diff --git a/middleware/middleware.go b/middleware/middleware.go index 2fa0103f..d34a7a5e 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -11,10 +11,10 @@ import ( // ongMiddlewareErrorHeader is a http header that is set by Ong // whenever any of it's middlewares return an error. -// The Log & Panic middleware will log the value of this header if it is set. +// The logger & recoverer middleware will log the value of this header if it is set. // // An example, is when the Get middleware fails because it has been called with the wrong http method. -// Or when the Csrf middleware fails because a csrf token was not found for POST/DELETE/etc requests. +// Or when the csrf middleware fails because a csrf token was not found for POST/DELETE/etc requests. const ongMiddlewareErrorHeader = "Ong-Middleware-Error" // Opts are the various parameters(optionals) that can be used to configure middlewares. @@ -34,7 +34,7 @@ type Opts struct { // // domain is the domain name of your website. // httpsPort is the tls port where http requests will be redirected to. -// allowedOrigins, allowedMethods, & allowedHeaders are used by the [Cors] middleware. +// allowedOrigins, allowedMethods, & allowedHeaders are used by the [cors] middleware. // // The secretKey should be kept secret and should not be shared. // If it becomes compromised, generate a new one and restart your application using the new one. @@ -78,36 +78,36 @@ func allDefaultMiddlewares( allowedMethods := o.allowedOrigins allowedHeaders := o.allowedHeaders secretKey := o.secretKey - logger := o.l + l := o.l // The way the middlewares are layered is: - // 1. Panic on the outer since we want it to watch all other middlewares. - // 2. Log since we would like to get logs as early in the lifecycle as possible. - // 3. RateLimiter since we want bad traffic to be filtered early. - // 4. LoadShedder for the same reason. - // 5. HttpsRedirector since it can be cpu intensive, thus should be behind the ratelimiter & loadshedder. - // 6. SecurityHeaders since we want some minimum level of security. - // 7. Cors since we might get pre-flight requests and we don't want those to go through all the middlewares for performance reasons. - // 8. Csrf since this one is a bit more involved perf-wise. + // 1. recoverer on the outer since we want it to watch all other middlewares. + // 2. logger since we would like to get logs as early in the lifecycle as possible. + // 3. rateLimiter since we want bad traffic to be filtered early. + // 4. loadShedder for the same reason. + // 5. httpsRedirector since it can be cpu intensive, thus should be behind the ratelimiter & loadshedder. + // 6. securityHeaders since we want some minimum level of security. + // 7. cors since we might get pre-flight requests and we don't want those to go through all the middlewares for performance reasons. + // 8. csrf since this one is a bit more involved perf-wise. // 9. Gzip since it is very involved perf-wise. - // 10. ReloadProtector, ideally I feel like it should come earlier but I'm yet to figure out where. - // 11. Session since we want sessions to saved as soon as possible. + // 10. reloadProtector, ideally I feel like it should come earlier but I'm yet to figure out where. + // 11. session since we want sessions to saved as soon as possible. // - // user -> Panic -> Log -> RateLimiter -> LoadShedder -> HttpsRedirector -> SecurityHeaders -> Cors -> Csrf -> Gzip -> ReloadProtector -> Session -> actual-handler + // user -> recoverer -> logger -> rateLimiter -> loadShedder -> httpsRedirector -> securityHeaders -> cors -> csrf -> Gzip -> reloadProtector -> session -> actual-handler // We have disabled Gzip for now, since it is about 2.5times slower than no-gzip for a 50MB sample response. // see: https://github.com/komuw/ong/issues/85 - return Panic( - Log( - RateLimiter( - LoadShedder( - HttpsRedirector( - SecurityHeaders( - Cors( - Csrf( - ReloadProtector( - Session( + return recoverer( + logger( + rateLimiter( + loadShedder( + httpsRedirector( + securityHeaders( + cors( + csrf( + reloadProtector( + session( wrappedHandler, secretKey, domain, @@ -129,15 +129,15 @@ func allDefaultMiddlewares( ), ), domain, - logger, + l, ), - logger, + l, ) } // All is a middleware that allows all http methods. // -// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware. +// It is composed of the [recoverer], [logger], [rateLimiter], [loadShedder], [httpsRedirector], [securityHeaders], [cors], [csrf], [reloadProtector] & [session] middleware. // As such, it provides the features and functionalities of all those middlewares. func All(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { return allDefaultMiddlewares( @@ -154,7 +154,7 @@ func all(wrappedHandler http.HandlerFunc) http.HandlerFunc { // Get is a middleware that only allows http GET requests and http OPTIONS requests. // -// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware. +// It is composed of the [recoverer], [logger], [rateLimiter], [loadShedder], [httpsRedirector], [securityHeaders], [cors], [csrf], [reloadProtector] & [session] middleware. // As such, it provides the features and functionalities of all those middlewares. func Get(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { return allDefaultMiddlewares( @@ -167,7 +167,7 @@ func get(wrappedHandler http.HandlerFunc) http.HandlerFunc { msg := "http method: %s not allowed. only allows http GET" return func(w http.ResponseWriter, r *http.Request) { // We do not need to allow `http.MethodOptions` here. - // This is coz, the Cors middleware has already handled that for us and it comes before the Get middleware. + // This is coz, the cors middleware has already handled that for us and it comes before the Get middleware. if r.Method != http.MethodGet { errMsg := fmt.Sprintf(msg, r.Method) w.Header().Set(ongMiddlewareErrorHeader, errMsg) @@ -185,7 +185,7 @@ func get(wrappedHandler http.HandlerFunc) http.HandlerFunc { // Post is a middleware that only allows http POST requests and http OPTIONS requests. // -// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware. +// It is composed of the [recoverer], [logger], [rateLimiter], [loadShedder], [httpsRedirector], [securityHeaders], [cors], [csrf], [reloadProtector] & [session] middleware. // As such, it provides the features and functionalities of all those middlewares. func Post(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { return allDefaultMiddlewares( @@ -214,7 +214,7 @@ func post(wrappedHandler http.HandlerFunc) http.HandlerFunc { // Head is a middleware that only allows http HEAD requests and http OPTIONS requests. // -// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware. +// It is composed of the [recoverer], [logger], [rateLimiter], [loadShedder], [httpsRedirector], [securityHeaders], [cors], [csrf], [reloadProtector] & [session] middleware. // As such, it provides the features and functionalities of all those middlewares. func Head(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { return allDefaultMiddlewares( @@ -243,7 +243,7 @@ func head(wrappedHandler http.HandlerFunc) http.HandlerFunc { // Put is a middleware that only allows http PUT requests and http OPTIONS requests. // -// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware. +// It is composed of the [recoverer], [logger], [rateLimiter], [loadShedder], [httpsRedirector], [securityHeaders], [cors], [csrf], [reloadProtector] & [session] middleware. // As such, it provides the features and functionalities of all those middlewares. func Put(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { return allDefaultMiddlewares( @@ -272,7 +272,7 @@ func put(wrappedHandler http.HandlerFunc) http.HandlerFunc { // Delete is a middleware that only allows http DELETE requests and http OPTIONS requests. // -// It is composed of the [Panic], [Log], [RateLimiter], [LoadShedder], [HttpsRedirector], [SecurityHeaders], [Cors], [Csrf], [ReloadProtector] & [Session] middleware. +// It is composed of the [recoverer], [logger], [rateLimiter], [loadShedder], [httpsRedirector], [securityHeaders], [cors], [csrf], [reloadProtector] & [session] middleware. // As such, it provides the features and functionalities of all those middlewares. func Delete(wrappedHandler http.HandlerFunc, o Opts) http.HandlerFunc { return allDefaultMiddlewares( diff --git a/middleware/panic.go b/middleware/panic.go index a22de5bf..e6dda719 100644 --- a/middleware/panic.go +++ b/middleware/panic.go @@ -13,9 +13,11 @@ import ( // Most of the code here is insipired(or taken from) by: // (a) https://github.com/eliben/code-for-blog whose license(Unlicense) can be found here: https://github.com/eliben/code-for-blog/blob/464a32f686d7646ba3fc612c19dbb550ec8a05b1/LICENSE -// Panic is a middleware that recovers from panics in wrappedHandler. +// recoverer is a middleware that recovers from panics in wrappedHandler. // When/if a panic occurs, it logs the stack trace and returns an InternalServerError response. -func Panic(wrappedHandler http.HandlerFunc, l log.Logger) http.HandlerFunc { +func recoverer(wrappedHandler http.HandlerFunc, l log.Logger) http.HandlerFunc { + pid := os.Getpid() + return func(w http.ResponseWriter, r *http.Request) { defer func() { errR := recover() @@ -42,7 +44,7 @@ func Panic(wrappedHandler http.HandlerFunc, l log.Logger) http.HandlerFunc { "path": r.URL.Redacted(), "code": code, "status": status, - "pid": os.Getpid(), + "pid": pid, } if ongError := w.Header().Get(ongMiddlewareErrorHeader); ongError != "" { flds["ongError"] = ongError diff --git a/middleware/panic_test.go b/middleware/panic_test.go index 88d54da0..54a2c645 100644 --- a/middleware/panic_test.go +++ b/middleware/panic_test.go @@ -55,7 +55,7 @@ func TestPanic(t *testing.T) { logOutput := &bytes.Buffer{} msg := "hello" - wrappedHandler := Panic(handlerThatPanics(msg, false, nil), getLogger(logOutput)) + wrappedHandler := recoverer(handlerThatPanics(msg, false, nil), getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -72,7 +72,7 @@ func TestPanic(t *testing.T) { logOutput := &bytes.Buffer{} msg := "hello" - wrappedHandler := Panic(handlerThatPanics(msg, true, nil), getLogger(logOutput)) + wrappedHandler := recoverer(handlerThatPanics(msg, true, nil), getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -102,7 +102,7 @@ func TestPanic(t *testing.T) { msg := "hello" errMsg := "99 problems" err := errors.New(errMsg) - wrappedHandler := Panic(handlerThatPanics(msg, false, err), getLogger(logOutput)) + wrappedHandler := recoverer(handlerThatPanics(msg, false, err), getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -130,7 +130,7 @@ func TestPanic(t *testing.T) { t.Parallel() logOutput := &bytes.Buffer{} - wrappedHandler := Panic(anotherHandlerThatPanics(), getLogger(logOutput)) + wrappedHandler := recoverer(anotherHandlerThatPanics(), getLogger(logOutput)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -151,7 +151,7 @@ func TestPanic(t *testing.T) { err := errors.New(msg) // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := Panic(handlerThatPanics(msg, false, err), getLogger(logOutput)) + wrappedHandler := recoverer(handlerThatPanics(msg, false, err), getLogger(logOutput)) runhandler := func() { rec := httptest.NewRecorder() diff --git a/middleware/ratelimiter.go b/middleware/ratelimiter.go index c99977df..bda54a2a 100644 --- a/middleware/ratelimiter.go +++ b/middleware/ratelimiter.go @@ -25,8 +25,8 @@ import ( // rateLimiterSendRate is the rate limit in requests/sec. var rateLimiterSendRate = 100.00 //nolint:gochecknoglobals -// RateLimiter is a middleware that limits requests by IP address. -func RateLimiter(wrappedHandler http.HandlerFunc) http.HandlerFunc { +// rateLimiter is a middleware that limits requests by IP address. +func rateLimiter(wrappedHandler http.HandlerFunc) http.HandlerFunc { rl := newRl() const retryAfter = 15 * time.Minute diff --git a/middleware/ratelimiter_test.go b/middleware/ratelimiter_test.go index 2fcde3f6..d249d549 100644 --- a/middleware/ratelimiter_test.go +++ b/middleware/ratelimiter_test.go @@ -26,7 +26,7 @@ func TestRateLimiter(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := RateLimiter(someRateLimiterHandler(msg)) + wrappedHandler := rateLimiter(someRateLimiterHandler(msg)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -46,7 +46,7 @@ func TestRateLimiter(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := RateLimiter(someRateLimiterHandler(msg)) + wrappedHandler := rateLimiter(someRateLimiterHandler(msg)) msgsDelivered := []int{} start := time.Now().UTC() @@ -82,7 +82,7 @@ func TestRateLimiter(t *testing.T) { t.Parallel() msg := "hello" - wrappedHandler := RateLimiter(someRateLimiterHandler(msg)) + wrappedHandler := rateLimiter(someRateLimiterHandler(msg)) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -105,7 +105,7 @@ func TestRateLimiter(t *testing.T) { msg := "hello" // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := RateLimiter(someRateLimiterHandler(msg)) + wrappedHandler := rateLimiter(someRateLimiterHandler(msg)) runhandler := func() { rec := httptest.NewRecorder() diff --git a/middleware/redirect.go b/middleware/redirect.go index 32c9ce42..e22a0bcb 100644 --- a/middleware/redirect.go +++ b/middleware/redirect.go @@ -8,11 +8,11 @@ import ( "unicode" ) -// HttpsRedirector is a middleware that redirects http requests to https. +// httpsRedirector is a middleware that redirects http requests to https. // // domain is the domain name of your website. // httpsPort is the tls port where http requests will be redirected to. -func HttpsRedirector(wrappedHandler http.HandlerFunc, httpsPort uint16, domain string) http.HandlerFunc { +func httpsRedirector(wrappedHandler http.HandlerFunc, httpsPort uint16, domain string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { isTls := strings.EqualFold(r.URL.Scheme, "https") || r.TLS != nil if !isTls { diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go index 53e7ddca..c669145a 100644 --- a/middleware/redirect_test.go +++ b/middleware/redirect_test.go @@ -44,7 +44,7 @@ func TestHttpsRedirector(t *testing.T) { msg := "hello world" port := uint16(443) - wrappedHandler := HttpsRedirector(someHttpsRedirectorHandler(msg), port, "localhost") + wrappedHandler := httpsRedirector(someHttpsRedirectorHandler(msg), port, "localhost") rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) wrappedHandler.ServeHTTP(rec, req) @@ -61,7 +61,7 @@ func TestHttpsRedirector(t *testing.T) { msg := "hello you" port := uint16(443) - wrappedHandler := HttpsRedirector(someHttpsRedirectorHandler(msg), port, "localhost") + wrappedHandler := httpsRedirector(someHttpsRedirectorHandler(msg), port, "localhost") rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/someUri", nil) wrappedHandler.ServeHTTP(rec, req) @@ -78,7 +78,7 @@ func TestHttpsRedirector(t *testing.T) { msg := "hello world" port := uint16(443) - wrappedHandler := HttpsRedirector(someHttpsRedirectorHandler(msg), port, "localhost") + wrappedHandler := httpsRedirector(someHttpsRedirectorHandler(msg), port, "localhost") for _, uri := range []string{ "/someUri", @@ -114,7 +114,7 @@ func TestHttpsRedirector(t *testing.T) { msg := "hello world" port := uint16(443) - wrappedHandler := HttpsRedirector(someHttpsRedirectorHandler(msg), port, "localhost") + wrappedHandler := httpsRedirector(someHttpsRedirectorHandler(msg), port, "localhost") ts := httptest.NewTLSServer( wrappedHandler, ) @@ -136,11 +136,11 @@ func TestHttpsRedirector(t *testing.T) { t.Parallel() // this test also asserts that a http POST is not converted to a http GET - // as might happen if `HttpsRedirector` was using `http.StatusMovedPermanently` + // as might happen if `httpsRedirector` was using `http.StatusMovedPermanently` msg := "hello world" port := uint16(443) - wrappedHandler := HttpsRedirector(someHttpsRedirectorHandler(msg), port, "localhost") + wrappedHandler := httpsRedirector(someHttpsRedirectorHandler(msg), port, "localhost") ts := httptest.NewTLSServer( wrappedHandler, ) @@ -173,7 +173,7 @@ func TestHttpsRedirector(t *testing.T) { uint16(88), uint16(65535), } { - wrappedHandler := HttpsRedirector(someHttpsRedirectorHandler(msg), p, domain) + wrappedHandler := httpsRedirector(someHttpsRedirectorHandler(msg), p, domain) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, uri, nil) wrappedHandler.ServeHTTP(rec, req) @@ -198,7 +198,7 @@ func TestHttpsRedirector(t *testing.T) { msg := "hello world" port := uint16(443) domain := "localhost" - wrappedHandler := HttpsRedirector(someHttpsRedirectorHandler(msg), port, domain) + wrappedHandler := httpsRedirector(someHttpsRedirectorHandler(msg), port, domain) ts := httptest.NewTLSServer( wrappedHandler, ) @@ -244,7 +244,7 @@ func TestHttpsRedirector(t *testing.T) { port := uint16(443) // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := HttpsRedirector(someHttpsRedirectorHandler(msg), port, "localhost") + wrappedHandler := httpsRedirector(someHttpsRedirectorHandler(msg), port, "localhost") runhandler := func() { rec := httptest.NewRecorder() diff --git a/middleware/reload_protect.go b/middleware/reload_protect.go index d025e4bb..dab0f7d9 100644 --- a/middleware/reload_protect.go +++ b/middleware/reload_protect.go @@ -12,10 +12,10 @@ import ( const reloadProtectCookiePrefix = "ong_form_reload_protect" -// ReloadProtector is a middleware that attempts to provides protection against a form re-submission when a user reloads/refreshes an already submitted web page/form. +// reloadProtector is a middleware that attempts to provides protection against a form re-submission when a user reloads/refreshes an already submitted web page/form. // // If such a situation is detected; this middleware will issue a http GET redirect to the same url. -func ReloadProtector(wrappedHandler http.HandlerFunc, domain string) http.HandlerFunc { +func reloadProtector(wrappedHandler http.HandlerFunc, domain string) http.HandlerFunc { safeMethods := []string{ // safe methods under rfc7231: https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 http.MethodGet, diff --git a/middleware/reload_protect_test.go b/middleware/reload_protect_test.go index b2863033..d455dcb8 100644 --- a/middleware/reload_protect_test.go +++ b/middleware/reload_protect_test.go @@ -47,7 +47,7 @@ func TestReloadProtector(t *testing.T) { domain := "localhost" expectedFormName := "user_name" expectedFormValue := "John Doe" - wrappedHandler := ReloadProtector(someReloadProtectorHandler(msg, expectedFormName, expectedFormValue), domain) + wrappedHandler := reloadProtector(someReloadProtectorHandler(msg, expectedFormName, expectedFormValue), domain) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -70,7 +70,7 @@ func TestReloadProtector(t *testing.T) { domain := "localhost" expectedFormName := "user_name" expectedFormValue := "John Doe" - wrappedHandler := ReloadProtector(someReloadProtectorHandler(msg, expectedFormName, expectedFormValue), domain) + wrappedHandler := reloadProtector(someReloadProtectorHandler(msg, expectedFormName, expectedFormValue), domain) req := httptest.NewRequest(http.MethodPost, "/someUri", nil) err := req.ParseForm() @@ -122,7 +122,7 @@ func TestReloadProtector(t *testing.T) { domain := "localhost" expectedFormName := "user_name" expectedFormValue := "John Doe" - wrappedHandler := ReloadProtector(someReloadProtectorHandler(msg, expectedFormName, expectedFormValue), domain) + wrappedHandler := reloadProtector(someReloadProtectorHandler(msg, expectedFormName, expectedFormValue), domain) runhandler := func() { rec := httptest.NewRecorder() diff --git a/middleware/security.go b/middleware/security.go index 80d07d8f..f0ceee49 100644 --- a/middleware/security.go +++ b/middleware/security.go @@ -32,10 +32,10 @@ const ( cspBytesTokenLength = csrfBytesTokenLength ) -// SecurityHeaders is a middleware that adds some important HTTP security headers and assigns them sensible default values. +// securityHeaders is a middleware that adds some important HTTP security headers and assigns them sensible default values. // -// Some of the headers set are Permissions-Policy, Content-SecurityHeaders-Policy, X-Content-Type-Options, X-Frame-Options, Cross-Origin-Resource-Policy, Cross-Origin-Opener-Policy, Referrer-Policy & Strict-Transport-SecurityHeaders -func SecurityHeaders(wrappedHandler http.HandlerFunc, domain string) http.HandlerFunc { +// Some of the headers set are Permissions-Policy, Content-securityHeaders-Policy, X-Content-Type-Options, X-Frame-Options, Cross-Origin-Resource-Policy, Cross-Origin-Opener-Policy, Referrer-Policy & Strict-Transport-securityHeaders +func securityHeaders(wrappedHandler http.HandlerFunc, domain string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() diff --git a/middleware/security_test.go b/middleware/security_test.go index c6b10912..b95c6a43 100644 --- a/middleware/security_test.go +++ b/middleware/security_test.go @@ -31,7 +31,7 @@ func TestSecurity(t *testing.T) { msg := "hello" domain := "example.com" - wrappedHandler := SecurityHeaders(echoHandler(msg), domain) + wrappedHandler := securityHeaders(echoHandler(msg), domain) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -52,7 +52,7 @@ func TestSecurity(t *testing.T) { msg := "hello" domain := "example.com" - wrappedHandler := SecurityHeaders(echoHandler(msg), domain) + wrappedHandler := securityHeaders(echoHandler(msg), domain) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -86,7 +86,7 @@ func TestSecurity(t *testing.T) { domain := "example.com" // for this concurrency test, we have to re-use the same wrappedHandler // so that state is shared and thus we can see if there is any state which is not handled correctly. - wrappedHandler := SecurityHeaders(echoHandler(msg), domain) + wrappedHandler := securityHeaders(echoHandler(msg), domain) runhandler := func() { rec := httptest.NewRecorder() @@ -124,7 +124,7 @@ func TestGetCspNonce(t *testing.T) { msg := "hello" domain := "example.com" - wrappedHandler := SecurityHeaders(echoHandler(msg), domain) + wrappedHandler := securityHeaders(echoHandler(msg), domain) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) diff --git a/middleware/session.go b/middleware/session.go index de470d03..342954cf 100644 --- a/middleware/session.go +++ b/middleware/session.go @@ -17,11 +17,11 @@ const ( sessionMaxAge = 14 * time.Hour ) -// Session is a middleware that implements http sessions. +// session is a middleware that implements http sessions. // It lets you store and retrieve arbitrary data on a per-site-visitor basis. // // This middleware works best when used together with the [sess] package. -func Session(wrappedHandler http.HandlerFunc, secretKey, domain string) http.HandlerFunc { +func session(wrappedHandler http.HandlerFunc, secretKey, domain string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // 1. Read from cookies and check for session cookie. // 2. Get that cookie and save it to r.context diff --git a/middleware/session_test.go b/middleware/session_test.go index a09416e5..e95bae70 100644 --- a/middleware/session_test.go +++ b/middleware/session_test.go @@ -42,7 +42,7 @@ func TestSession(t *testing.T) { domain := "localhost" key := "name" value := "John Doe" - wrappedHandler := Session(someSessionHandler(msg, key, value), secretKey, domain) + wrappedHandler := session(someSessionHandler(msg, key, value), secretKey, domain) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/someUri", nil) @@ -66,7 +66,7 @@ func TestSession(t *testing.T) { domain := "localhost" key := "name" value := "John Doe" - wrappedHandler := Session(someSessionHandler(msg, key, value), secretKey, domain) + wrappedHandler := session(someSessionHandler(msg, key, value), secretKey, domain) ts := httptest.NewServer( wrappedHandler, @@ -115,7 +115,7 @@ func TestSession(t *testing.T) { domain := "localhost" key := "name" value := "John Doe" - wrappedHandler := Session(someSessionHandler(msg, key, value), secretKey, domain) + wrappedHandler := session(someSessionHandler(msg, key, value), secretKey, domain) runhandler := func() { rec := httptest.NewRecorder() diff --git a/sess/example_test.go b/sess/example_test.go index 45b447ea..e8fde196 100644 --- a/sess/example_test.go +++ b/sess/example_test.go @@ -4,16 +4,13 @@ import ( "fmt" "net/http" "net/http/httptest" + "os" + "github.com/komuw/ong/log" "github.com/komuw/ong/middleware" "github.com/komuw/ong/sess" ) -const ( - secretKey = "some-secretKey" - domain = "example.com" -) - func loginHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { mySession := map[string]string{ @@ -30,7 +27,10 @@ func loginHandler() http.HandlerFunc { func ExampleSetM() { rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/login", nil) - handler := middleware.Session(loginHandler(), secretKey, domain) + handler := middleware.Get( + loginHandler(), + middleware.WithOpts("example.com", 443, "secretKey", log.New(os.Stdout, 100)), + ) handler.ServeHTTP(rec, req) res := rec.Result() diff --git a/sess/sess.go b/sess/sess.go index eb363b26..71861f04 100644 --- a/sess/sess.go +++ b/sess/sess.go @@ -1,5 +1,5 @@ // Package sess provides an implementation of http sessions that is backed by tamper-proof & encrypted cookies. -// This package should ideally be used together with the [github.com/komuw/ong/middleware.Session] middleware. +// This package should ideally be used together with the ong [github.com/komuw/ong/middleware] middlewares. package sess import ( @@ -26,8 +26,8 @@ const ( // Initialise returns a new http.Request (based on r) that has sessions properly setup. // -// You do not need to call this function, if you are also using the [github.com/komuw/ong/middleware.Session] middleware. -// That middleware does so automatically for you. +// You do not need to call this function, if you are also using the ong [github.com/komuw/ong/middleware] middleware. +// Those middleware does so automatically for you. func Initialise(r *http.Request, secretKey string) *http.Request { ctx := r.Context() var sessVal M // should be per request. @@ -108,8 +108,8 @@ func GetM(r *http.Request) map[string]string { // Save writes(to http cookies) any key-value pairs that have already been added to the current http session. // -// You do not need to call this function, if you are also using the [github.com/komuw/ong/middleware.Session] middleware. -// That middleware does so automatically for you. +// You do not need to call this function, if you are also using the ong [github.com/komuw/ong/middleware] middleware. +// Those middleware does so automatically for you. func Save( r *http.Request, w http.ResponseWriter,