Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make most middleware private #186

Merged
merged 13 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Most recent version is listed first.

## v0.0.25
- ong/client: Use roundTripper for logging: https://github.com/komuw/ong/pull/185
- Make most middleware private: https://github.com/komuw/ong/pull/186

## v0.0.24
- Set session cookie only if non-empty: https://github.com/komuw/ong/pull/170
Expand All @@ -15,7 +16,7 @@ Most recent version is listed first.
- ong/client: Add log id http header: https://github.com/komuw/ong/pull/166

## v0.0.22
- Panic middleware should include correct stack trace: https://github.com/komuw/ong/pull/164
- Panic/recoverer middleware should include correct stack trace: https://github.com/komuw/ong/pull/164
- Log client address without port: https://github.com/komuw/ong/pull/165

## v0.0.21
Expand Down
4 changes: 2 additions & 2 deletions middleware/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ const (
corsCacheDur = 2 * time.Hour
)

// Cors is a middleware to implement Cross-Origin Resource Sharing support.
// cors is a middleware to implement Cross-Origin Resource Sharing support.
//
// If allowedOrigins is nil, all origins are allowed. You can also use * to allow all.
// If allowedMethods is nil, "GET", "POST", "HEAD" are allowed. Use * to allow all.
// If allowedHeaders is nil, "Origin", "Accept", "Content-Type", "X-Requested-With" are allowed. Use * to allow all.
func Cors(
func cors(
wrappedHandler http.HandlerFunc,
allowedOrigins []string,
allowedMethods []string,
Expand Down
22 changes: 11 additions & 11 deletions middleware/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestCorsPreflight(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), nil, nil, nil)
wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
req.Header.Add(acrmHeader, "is-set") // preflight request header set
Expand All @@ -44,7 +44,7 @@ func TestCorsPreflight(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, []string{"*"})
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, []string{"*"})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
req.Header.Add(acrmHeader, http.MethodGet) // preflight request header set
Expand Down Expand Up @@ -75,7 +75,7 @@ func TestCorsPreflight(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), nil, nil, nil)
wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
// preflight request header NOT set
Expand Down Expand Up @@ -132,7 +132,7 @@ func TestCorsPreflight(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), tt.allowedOrigins, []string{"*"}, []string{"*"})
wrappedHandler := cors(someCorsHandler(msg), tt.allowedOrigins, []string{"*"}, []string{"*"})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
req.Header.Add(acrmHeader, "is-set") // preflight request header set
Expand Down Expand Up @@ -206,7 +206,7 @@ func TestCorsPreflight(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

wrappedHandler := Cors(someCorsHandler(msg), []string{"*"}, tt.allowedMethods, []string{"*"})
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, tt.allowedMethods, []string{"*"})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
req.Header.Add(originHeader, "http://some-origin.com")
Expand Down Expand Up @@ -270,7 +270,7 @@ func TestCorsPreflight(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, tt.allowedHeaders)
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, tt.allowedHeaders)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodOptions, "/someUri", nil)
req.Header.Add(acrmHeader, "is-set") // preflight request header set
Expand Down Expand Up @@ -302,7 +302,7 @@ func TestCorsActualRequest(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), nil, nil, nil)
wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
wrappedHandler.ServeHTTP(rec, req)
Expand All @@ -321,7 +321,7 @@ func TestCorsActualRequest(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, []string{"*"})
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, []string{"*"}, []string{"*"})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
req.Header.Add(originHeader, "http://example.com")
Expand Down Expand Up @@ -381,7 +381,7 @@ func TestCorsActualRequest(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), tt.allowedOrigins, []string{"*"}, []string{"*"})
wrappedHandler := cors(someCorsHandler(msg), tt.allowedOrigins, []string{"*"}, []string{"*"})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
req.Header.Add(originHeader, tt.origin)
Expand Down Expand Up @@ -445,7 +445,7 @@ func TestCorsActualRequest(t *testing.T) {
t.Parallel()

msg := "hello"
wrappedHandler := Cors(someCorsHandler(msg), []string{"*"}, tt.allowedMethods, []string{"*"})
wrappedHandler := cors(someCorsHandler(msg), []string{"*"}, tt.allowedMethods, []string{"*"})
rec := httptest.NewRecorder()
req := httptest.NewRequest(tt.method, "/someUri", nil)
req.Header.Add(originHeader, "http://some-origin.com")
Expand Down Expand Up @@ -474,7 +474,7 @@ func TestCorsActualRequest(t *testing.T) {
msg := "hello"
// for this concurrency test, we have to re-use the same wrappedHandler
// so that state is shared and thus we can see if there is any state which is not handled correctly.
wrappedHandler := Cors(someCorsHandler(msg), nil, nil, nil)
wrappedHandler := cors(someCorsHandler(msg), nil, nil, nil)

runhandler := func() {
rec := httptest.NewRecorder()
Expand Down
4 changes: 2 additions & 2 deletions middleware/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ const (
csrfBytesTokenLength = 32
)

// Csrf is a middleware that provides protection against Cross Site Request Forgeries.
// csrf is a middleware that provides protection against Cross Site Request Forgeries.
//
// If a csrf token is not provided(or is not valid), when it ought to have been; this middleware will issue a http GET redirect to the same url.
func Csrf(wrappedHandler http.HandlerFunc, secretKey, domain string) http.HandlerFunc {
func csrf(wrappedHandler http.HandlerFunc, secretKey, domain string) http.HandlerFunc {
once.Do(func() {
enc = cry.New(secretKey)
})
Expand Down
18 changes: 9 additions & 9 deletions middleware/csrf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func TestCsrf(t *testing.T) {

msg := "hello"
domain := "example.com"
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
Expand All @@ -148,7 +148,7 @@ func TestCsrf(t *testing.T) {

msg := "hello"
domain := "example.com"
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

reqCsrfTok := id.Random(csrfBytesTokenLength)
rec := httptest.NewRecorder()
Expand Down Expand Up @@ -179,7 +179,7 @@ func TestCsrf(t *testing.T) {

msg := "hello"
domain := "example.com"
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

reqCsrfTok := id.Random(csrfBytesTokenLength)
rec := httptest.NewRecorder()
Expand All @@ -203,7 +203,7 @@ func TestCsrf(t *testing.T) {

msg := "hello"
domain := "example.com"
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
Expand All @@ -225,7 +225,7 @@ func TestCsrf(t *testing.T) {

msg := "hello"
domain := "example.com"
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
Expand Down Expand Up @@ -266,7 +266,7 @@ func TestCsrf(t *testing.T) {

msg := "hello"
domain := "example.com"
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

reqCsrfTok := id.Random(csrfBytesTokenLength)
rec := httptest.NewRecorder()
Expand Down Expand Up @@ -295,7 +295,7 @@ func TestCsrf(t *testing.T) {

msg := "hello"
domain := "example.com"
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

key := getSecretKey()
enc2 := cry.New(key)
Expand Down Expand Up @@ -365,7 +365,7 @@ func TestCsrf(t *testing.T) {

msg := "hello"
domain := "example.com"
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

rec := httptest.NewRecorder()
postMsg := "my name is John"
Expand Down Expand Up @@ -394,7 +394,7 @@ func TestCsrf(t *testing.T) {
domain := "example.com"
// for this concurrency test, we have to re-use the same wrappedHandler
// so that state is shared and thus we can see if there is any state which is not handled correctly.
wrappedHandler := Csrf(someCsrfHandler(msg), getSecretKey(), domain)
wrappedHandler := csrf(someCsrfHandler(msg), getSecretKey(), domain)

key := getSecretKey()
enc2 := cry.New(key)
Expand Down
10 changes: 8 additions & 2 deletions middleware/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ func loginHandler() http.HandlerFunc {
}

func Example_getCspNonce() {
handler := middleware.SecurityHeaders(loginHandler(), "example.com")
handler := middleware.Get(
loginHandler(),
middleware.WithOpts("example.com", 443, "secretKey", log.New(os.Stdout, 100)),
)
_ = handler // use handler

// Output:
Expand All @@ -36,7 +39,10 @@ func welcomeHandler() http.HandlerFunc {
}

func Example_getCsrfToken() {
handler := middleware.Csrf(welcomeHandler(), "some-secret-key", "example.com")
handler := middleware.Get(
welcomeHandler(),
middleware.WithOpts("example.com", 443, "secretKey", log.New(os.Stdout, 100)),
)
_ = handler // use handler

// Output:
Expand Down
14 changes: 7 additions & 7 deletions middleware/gzip.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package middleware

import (
"bufio"
"compress/gzip"
stdGzip "compress/gzip"
"fmt"
"io"
"net"
Expand All @@ -28,8 +28,8 @@ const (
thisMiddlewareEncoding = "gzip"
)

// Gzip is a middleware that transparently gzips the http response body, for clients that support it.
func Gzip(wrappedHandler http.HandlerFunc) http.HandlerFunc {
// gzip is a middleware that transparently gzips the http response body, for clients that support it.
func gzip(wrappedHandler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Add(varyHeader, acceptEncodingHeader)

Expand All @@ -38,7 +38,7 @@ func Gzip(wrappedHandler http.HandlerFunc) http.HandlerFunc {
return
}

gzipWriter, _ := gzip.NewWriterLevel(w, gzip.BestSpeed)
gzipWriter, _ := stdGzip.NewWriterLevel(w, stdGzip.BestSpeed)
grw := &gzipRW{
ResponseWriter: w,
// Bytes written during ServeHTTP are redirected to this gzip writer
Expand All @@ -65,7 +65,7 @@ func Gzip(wrappedHandler http.HandlerFunc) http.HandlerFunc {
// writers, so don't forget to do that.
type gzipRW struct {
http.ResponseWriter
gw *gzip.Writer
gw *stdGzip.Writer

buf []byte // Holds the first part of the write before reaching the minSize or the end of the write.

Expand Down Expand Up @@ -176,7 +176,7 @@ func (grw *gzipRW) handleGzipped(ct string, lenB int) (int, error) {
return lenB, nil
}

// Close will close the gzip.Writer.
// Close will close the stdGzip.Writer.
func (grw *gzipRW) Close() error {
if !grw.handledZip {
// GZIP not triggered yet, write out regular response.
Expand All @@ -187,7 +187,7 @@ func (grw *gzipRW) Close() error {
return grw.gw.Close() // will also call gzip flush()
}

// Flush flushes the underlying *gzip.Writer and then the
// Flush flushes the underlying *stdGzip.Writer and then the
// underlying http.ResponseWriter if it is an http.Flusher.
// This makes gzipRW an http.Flusher.
func (grw *gzipRW) Flush() {
Expand Down
Loading