Skip to content

Commit

Permalink
make most middleware private (#186)
Browse files Browse the repository at this point in the history
What:
- make most middleware private

Why:
- Ideally we should only use the `middleware.Get`, `middleware.Post`, etc middlewares
  • Loading branch information
komuw authored Dec 8, 2022
1 parent 8a3a2b0 commit c9fdb51
Show file tree
Hide file tree
Showing 27 changed files with 156 additions and 146 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Most recent version is listed first.

## v0.0.25
- ong/client: Use roundTripper for logging: https://github.com/komuw/ong/pull/185
- Make most middleware private: https://github.com/komuw/ong/pull/186

## v0.0.24
- Set session cookie only if non-empty: https://github.com/komuw/ong/pull/170
Expand All @@ -15,7 +16,7 @@ Most recent version is listed first.
- ong/client: Add log id http header: https://github.com/komuw/ong/pull/166

## v0.0.22
- Panic middleware should include correct stack trace: https://github.com/komuw/ong/pull/164
- Panic/recoverer middleware should include correct stack trace: https://github.com/komuw/ong/pull/164
- Log client address without port: https://github.com/komuw/ong/pull/165

## v0.0.21
Expand Down
4 changes: 2 additions & 2 deletions middleware/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ const (
corsCacheDur = 2 * time.Hour
)

// Cors is a middleware to implement Cross-Origin Resource Sharing support.
// cors is a middleware to implement Cross-Origin Resource Sharing support.
//
// If allowedOrigins is nil, all origins are allowed. You can also use * to allow all.
// If allowedMethods is nil, "GET", "POST", "HEAD" are allowed. Use * to allow all.
// If allowedHeaders is nil, "Origin", "Accept", "Content-Type", "X-Requested-With" are allowed. Use * to allow all.
func Cors(
func cors(
wrappedHandler http.HandlerFunc,
allowedOrigins []string,
allowedMethods []string,
Expand Down
22 changes: 11 additions & 11 deletions middleware/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestCorsPreflight(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), nil, nil, nil)
wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
req.Header.Add(acrmHeader, "is-set") // preflight request header set
Expand All @@ -44,7 +44,7 @@ func TestCorsPreflight(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, []string{"*"})
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, []string{"*"})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
req.Header.Add(acrmHeader, http.MethodGet) // preflight request header set
Expand Down Expand Up @@ -75,7 +75,7 @@ func TestCorsPreflight(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), nil, nil, nil)
wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
// preflight request header NOT set
Expand Down Expand Up @@ -132,7 +132,7 @@ func TestCorsPreflight(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), tt.allowedOrigins, []string{"*"}, []string{"*"})
wrappedHandler := cors(someCorsHandler(msg), tt.allowedOrigins, []string{"*"}, []string{"*"})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
req.Header.Add(acrmHeader, "is-set") // preflight request header set
Expand Down Expand Up @@ -206,7 +206,7 @@ func TestCorsPreflight(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

wrappedHandler := Cors(someCorsHandler(msg), []string{"*"}, tt.allowedMethods, []string{"*"})
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, tt.allowedMethods, []string{"*"})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
req.Header.Add(originHeader, "http://some-origin.com")
Expand Down Expand Up @@ -270,7 +270,7 @@ func TestCorsPreflight(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, tt.allowedHeaders)
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, tt.allowedHeaders)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
req.Header.Add(acrmHeader, "is-set") // preflight request header set
Expand Down Expand Up @@ -302,7 +302,7 @@ func TestCorsActualRequest(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), nil, nil, nil)
wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
wrappedHandler.ServeHTTP(rec, req)
Expand All @@ -321,7 +321,7 @@ func TestCorsActualRequest(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, []string{"*"})
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, []string{"*"})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
req.Header.Add(originHeader, "http://example.com")
Expand Down Expand Up @@ -381,7 +381,7 @@ func TestCorsActualRequest(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), tt.allowedOrigins, []string{"*"}, []string{"*"})
wrappedHandler := cors(someCorsHandler(msg), tt.allowedOrigins, []string{"*"}, []string{"*"})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
req.Header.Add(originHeader, tt.origin)
Expand Down Expand Up @@ -445,7 +445,7 @@ func TestCorsActualRequest(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), []string{"*"}, tt.allowedMethods, []string{"*"})
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, tt.allowedMethods, []string{"*"})
rec := httptest.NewRecorder()
req := httptest.NewRequest(tt.method, "/someUri", nil)
req.Header.Add(originHeader, "http://some-origin.com")
Expand Down Expand Up @@ -474,7 +474,7 @@ func TestCorsActualRequest(t *testing.T) {
msg := "hello"
// for this concurrency test, we have to re-use the same wrappedHandler
// so that state is shared and thus we can see if there is any state which is not handled correctly.
wrappedHandler := Cors(someCorsHandler(msg), nil, nil, nil)
wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil)

runhandler := func() {
rec := httptest.NewRecorder()
Expand Down
4 changes: 2 additions & 2 deletions middleware/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ const (
csrfBytesTokenLength = 32
)

// Csrf is a middleware that provides protection against Cross Site Request Forgeries.
// csrf is a middleware that provides protection against Cross Site Request Forgeries.
//
// If a csrf token is not provided(or is not valid), when it ought to have been; this middleware will issue a http GET redirect to the same url.
func Csrf(wrappedHandler http.HandlerFunc, secretKey, domain string) http.HandlerFunc {
func csrf(wrappedHandler http.HandlerFunc, secretKey, domain string) http.HandlerFunc {
once.Do(func() {
enc = cry.New(secretKey)
})
Expand Down
18 changes: 9 additions & 9 deletions middleware/csrf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func TestCsrf(t *testing.T) {

msg := "hello"
domain := "example.com"
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
Expand All @@ -148,7 +148,7 @@ func TestCsrf(t *testing.T) {

msg := "hello"
domain := "example.com"
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

reqCsrfTok := id.Random(csrfBytesTokenLength)
rec := httptest.NewRecorder()
Expand Down Expand Up @@ -179,7 +179,7 @@ func TestCsrf(t *testing.T) {

msg := "hello"
domain := "example.com"
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

reqCsrfTok := id.Random(csrfBytesTokenLength)
rec := httptest.NewRecorder()
Expand All @@ -203,7 +203,7 @@ func TestCsrf(t *testing.T) {

msg := "hello"
domain := "example.com"
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
Expand All @@ -225,7 +225,7 @@ func TestCsrf(t *testing.T) {

msg := "hello"
domain := "example.com"
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
Expand Down Expand Up @@ -266,7 +266,7 @@ func TestCsrf(t *testing.T) {

msg := "hello"
domain := "example.com"
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

reqCsrfTok := id.Random(csrfBytesTokenLength)
rec := httptest.NewRecorder()
Expand Down Expand Up @@ -295,7 +295,7 @@ func TestCsrf(t *testing.T) {

msg := "hello"
domain := "example.com"
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

key := getSecretKey()
enc2 := cry.New(key)
Expand Down Expand Up @@ -365,7 +365,7 @@ func TestCsrf(t *testing.T) {

msg := "hello"
domain := "example.com"
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

rec := httptest.NewRecorder()
postMsg := "my name is John"
Expand Down Expand Up @@ -394,7 +394,7 @@ func TestCsrf(t *testing.T) {
domain := "example.com"
// for this concurrency test, we have to re-use the same wrappedHandler
// so that state is shared and thus we can see if there is any state which is not handled correctly.
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

key := getSecretKey()
enc2 := cry.New(key)
Expand Down
10 changes: 8 additions & 2 deletions middleware/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ func loginHandler() http.HandlerFunc {
}

func Example_getCspNonce() {
handler := middleware.SecurityHeaders(loginHandler(), "example.com")
handler := middleware.Get(
loginHandler(),
middleware.WithOpts("example.com", 443, "secretKey", log.New(os.Stdout, 100)),
)
_ = handler // use handler

// Output:
Expand All @@ -36,7 +39,10 @@ func welcomeHandler() http.HandlerFunc {
}

func Example_getCsrfToken() {
handler := middleware.Csrf(welcomeHandler(), "some-secret-key", "example.com")
handler := middleware.Get(
welcomeHandler(),
middleware.WithOpts("example.com", 443, "secretKey", log.New(os.Stdout, 100)),
)
_ = handler // use handler

// Output:
Expand Down
14 changes: 7 additions & 7 deletions middleware/gzip.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package middleware

import (
"bufio"
"compress/gzip"
stdGzip "compress/gzip"
"fmt"
"io"
"net"
Expand All @@ -28,8 +28,8 @@ const (
thisMiddlewareEncoding = "gzip"
)

// Gzip is a middleware that transparently gzips the http response body, for clients that support it.
func Gzip(wrappedHandler http.HandlerFunc) http.HandlerFunc {
// gzip is a middleware that transparently gzips the http response body, for clients that support it.
func gzip(wrappedHandler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Add(varyHeader, acceptEncodingHeader)

Expand All @@ -38,7 +38,7 @@ func Gzip(wrappedHandler http.HandlerFunc) http.HandlerFunc {
return
}

gzipWriter, _ := gzip.NewWriterLevel(w, gzip.BestSpeed)
gzipWriter, _ := stdGzip.NewWriterLevel(w, stdGzip.BestSpeed)
grw := &gzipRW{
ResponseWriter: w,
// Bytes written during ServeHTTP are redirected to this gzip writer
Expand All @@ -65,7 +65,7 @@ func Gzip(wrappedHandler http.HandlerFunc) http.HandlerFunc {
// writers, so don't forget to do that.
type gzipRW struct {
http.ResponseWriter
gw *gzip.Writer
gw *stdGzip.Writer

buf []byte // Holds the first part of the write before reaching the minSize or the end of the write.

Expand Down Expand Up @@ -176,7 +176,7 @@ func (grw *gzipRW) handleGzipped(ct string, lenB int) (int, error) {
return lenB, nil
}

// Close will close the gzip.Writer.
// Close will close the stdGzip.Writer.
func (grw *gzipRW) Close() error {
if !grw.handledZip {
// GZIP not triggered yet, write out regular response.
Expand All @@ -187,7 +187,7 @@ func (grw *gzipRW) Close() error {
return grw.gw.Close() // will also call gzip flush()
}

// Flush flushes the underlying *gzip.Writer and then the
// Flush flushes the underlying *stdGzip.Writer and then the
// underlying http.ResponseWriter if it is an http.Flusher.
// This makes gzipRW an http.Flusher.
func (grw *gzipRW) Flush() {
Expand Down
Loading

0 comments on commit c9fdb51

Please sign in to comment.