From 7f9c22eb6772a423c81c3fe3d1624327e5a2dc2f Mon Sep 17 00:00:00 2001 From: Komu Wairagu Date: Wed, 15 Jun 2022 20:01:00 +0300 Subject: [PATCH] issues/28: add csrf middleware (#32) fixes: https://github.com/komuw/goweb/issues/28 --- CHANGELOG.md | 1 + middleware/csrf.go | 213 ++++++++++++++++ middleware/csrf_test.go | 479 ++++++++++++++++++++++++++++++++++++ middleware/panic_test.go | 6 + middleware/security.go | 14 +- middleware/security_test.go | 4 +- server/server_test.go | 10 + 7 files changed, 718 insertions(+), 9 deletions(-) create mode 100644 middleware/csrf.go create mode 100644 middleware/csrf_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 811cee34..8f0d00b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,3 +9,4 @@ Most recent version is listed first. - harmonize timeouts: https://github.com/komuw/goweb/pull/25 - add panic middleware: https://github.com/komuw/goweb/pull/26 - cookies: https://github.com/komuw/goweb/pull/27 +- csrf middleware: https://github.com/komuw/goweb/pull/32 diff --git a/middleware/csrf.go b/middleware/csrf.go new file mode 100644 index 00000000..a243f07f --- /dev/null +++ b/middleware/csrf.go @@ -0,0 +1,213 @@ +package middleware + +import ( + "context" + "errors" + "net/http" + "sync" + "time" + + "github.com/komuw/goweb/cookie" + "github.com/rs/xid" +) + +var ( + errCsrfTokenNotFound = errors.New("csrf token not found/recognized") + // csrfStore needs to be a global var so that different handlers that are decorated with the Csrf middleware can use same store. + // Image if you had `Csrf(loginHandler, domain)` & `Csrf(cartCheckoutHandler, domain)`, if they didn't share a global store, + // a customer navigating from login to checkout would get a errCsrfTokenNotFound error; which is not what we want. + csrfStore = newStore() //nolint:gochecknoglobals +) + +type csrfContextKey string + +const ( + csrfCtxKey = csrfContextKey("csrfContextKey") + csrfDefaultToken = "" + csrfCookieName = "csrftoken" // named after what django uses. + csrfHeader = "X-Csrf-Token" // named after what fiber uses. + csrfCookieForm = csrfCookieName + clientCookieHeader = "Cookie" + varyHeader = "Vary" + + // The memory store is reset(for memory efficiency) if either resetDuration OR maxRequestsToReset occurs. + tokenMaxAge = 1 * time.Hour // same max-age as what fiber uses. django seems to use one year. + resetDuration = tokenMaxAge + (7 * time.Minute) +) + +// Csrf is a middleware that provides protection against Cross Site Request Forgeries. +// If maxRequestsToReset <= 0, it is set to a high default value. +func Csrf(wrappedHandler http.HandlerFunc, domain string, maxRequestsToReset int32) http.HandlerFunc { + start := time.Now() + requestsServed := int32(0) + if maxRequestsToReset <= 0 { + maxRequestsToReset = 10_000_000 + } + + return func(w http.ResponseWriter, r *http.Request) { + // - https://docs.djangoproject.com/en/4.0/ref/csrf/ + // - https://github.com/django/django/blob/4.0.5/django/middleware/csrf.py + // - https://github.com/gofiber/fiber/blob/v2.34.1/middleware/csrf/csrf.go + + // 1. check http method. + // - if it is a 'safe' method like GET, try and get csrfToken from request. + // - if it is not a 'safe' method, try and get csrfToken from header/cookies/httpForm + // - take the found token and try to get it from memory store. + // - if not found in memory store, delete the cookie & return an error. + + ctx := r.Context() + + csrfToken := "" + switch r.Method { + // safe methods under rfc7231: https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 + case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: + csrfToken = getToken(r) + default: + // For POST requests, we insist on a CSRF cookie, and in this way we can avoid all CSRF attacks, including login CSRF. + csrfToken = getToken(r) + + if csrfToken == "" || !csrfStore.exists(csrfToken) { + // we should fail the request since it means that the server is not aware of such a token. + cookie.Delete(w, csrfCookieName, domain) + http.Error( + w, + errCsrfTokenNotFound.Error(), + http.StatusBadRequest, + ) + return + } + } + + // 2. If csrfToken is still an empty string. generate it. + if csrfToken == "" { + csrfToken = xid.New().String() + } + + // 3. create cookie + cookie.Set( + w, + csrfCookieName, + csrfToken, + domain, + tokenMaxAge, + true, + ) + + // 4. set cookie header + w.Header().Set( + csrfHeader, + csrfToken, + ) + + // 5. update Vary header. + w.Header().Add(varyHeader, clientCookieHeader) + + // 6. store csrfToken in context + r = r.WithContext(context.WithValue(ctx, csrfCtxKey, csrfToken)) + + // 7. save csrfToken in memory store. + csrfStore.set(csrfToken) + + // 8. reset memory to decrease its size. + requestsServed = requestsServed + 1 + now := time.Now() + diff := now.Sub(start) + if (diff > resetDuration) || (requestsServed > maxRequestsToReset) { + csrfStore.reset() + start = now + requestsServed = 0 + } + + wrappedHandler(w, r) + } +} + +// GetCsrfToken returns the csrf token was set for that particular request. +// +// usage: +// func myHandler(w http.ResponseWriter, r *http.Request) { +// csrfToken := middleware.GetCsrfToken(r.Context()) +// _ = csrfToken +// } +func GetCsrfToken(c context.Context) string { + v := c.Value(csrfCtxKey) + if v != nil { + s, ok := v.(string) + if ok { + return s + } + } + return csrfDefaultToken +} + +// getToken tries to fetch a csrf token from the incoming request r. +// It tries to fetch from cookies, headers, http-forms in that order. +func getToken(r *http.Request) string { + fromCookie := func() string { + c, err := r.Cookie(csrfCookieName) + if err != nil { + return "" + } + return c.Value + } + + fromHeader := func() string { + return r.Header.Get(csrfHeader) + } + + fromForm := func() string { + if err := r.ParseForm(); err != nil { + return "" + } + return r.Form.Get(csrfCookieForm) + } + + tok := fromCookie() + if tok == "" { + tok = fromHeader() + } + if tok == "" { + tok = fromForm() + } + + return tok +} + +// store persists csrf tokens server-side, in-memory. +type store struct { + mu sync.RWMutex // protects m + m map[string]struct{} +} + +func newStore() *store { + return &store{ + m: map[string]struct{}{}, + } +} + +func (s *store) exists(csrfToken string) bool { + s.mu.RLock() + _, ok := s.m[csrfToken] + s.mu.RUnlock() + return ok +} + +func (s *store) set(csrfToken string) { + s.mu.Lock() + s.m[csrfToken] = struct{}{} + s.mu.Unlock() +} + +func (s *store) reset() { + s.mu.Lock() + s.m = map[string]struct{}{} + s.mu.Unlock() +} + +// used in tests +func (s *store) _len() int { + s.mu.RLock() + l := len(s.m) + s.mu.RUnlock() + return l +} diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go new file mode 100644 index 00000000..15e2359c --- /dev/null +++ b/middleware/csrf_test.go @@ -0,0 +1,479 @@ +package middleware + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "github.com/akshayjshah/attest" + "github.com/rs/xid" +) + +func TestStore(t *testing.T) { + t.Parallel() + + t.Run("concurrency safe", func(t *testing.T) { + t.Parallel() + + store := newStore() + + tokens := []string{ + "a", "aa", "aaa", "aaron", "ab", "abandoned", "abc", "aberdeen", "abilities", "ability", "able", "aboriginal", "abortion", + "about", "above", "abraham", "abroad", "abs", "absence", "absent", "absolute", "absolutely", "absorption", "abstract", + "abstracts", "abu", "abuse", "ac", "academic", "academics", "academy", "acc", "accent", "accept", "acceptable", "acceptance", + "accepted", "accepting", "accepts", "access", "accessed", "accessibility", "accessible", "accessing", "accessories", + "accessory", "accident", "accidents", "accommodate", "accommodation", "accommodations", "accompanied", "accompanying", + "accomplish", "accomplished", "accordance", "according", "accordingly", "account", "accountability", "accounting", "accounts", + "accreditation", "accredited", "accuracy", "accurate", "accurately", "accused", "acdbentity", "ace", + } + + for _, tok := range tokens { + go func(t string) { + store.set(t) + }(tok) + } + + for _, tok := range tokens { + go func(t string) { + store.exists(t) + }(tok) + } + + for _, tok := range tokens { + go func(t string) { + store.reset() + }(tok) + } + + wg := &sync.WaitGroup{} + for _, tok := range tokens { + wg.Add(1) + go func(t string) { + store.set(t) + wg.Done() + }(tok) + } + wg.Wait() + }) + + t.Run("set", func(t *testing.T) { + t.Parallel() + + store := newStore() + + tokens := []string{ + "a", "aa", "aaa", "aaron", "ab", "abandoned", "abc", "aberdeen", "abilities", "ability", "able", "aboriginal", "abortion", + "about", "above", "abraham", "abroad", "abs", "absence", "absent", "absolute", "absolutely", "absorption", "abstract", + "abstracts", "abu", "abuse", "ac", "academic", "academics", "academy", "acc", "accent", "accept", "acceptable", "acceptance", + "accepted", "accepting", "accepts", "access", "accessed", "accessibility", "accessible", "accessing", "accessories", + "accessory", "accident", "accidents", "accommodate", "accommodation", "accommodations", "accompanied", "accompanying", + "accomplish", "accomplished", "accordance", "according", "accordingly", "account", "accountability", "accounting", "accounts", + "accreditation", "accredited", "accuracy", "accurate", "accurately", "accused", "acdbentity", "ace", + } + wg := &sync.WaitGroup{} + for _, tok := range tokens { + wg.Add(1) + go func(t string) { + store.set(t) + wg.Done() + }(tok) + } + wg.Wait() + + attest.Equal(t, store._len(), len(tokens)) + }) + + t.Run("reset", func(t *testing.T) { + t.Parallel() + + store := newStore() + + tokens := []string{"aaron", "abandoned", "according", "accreditation", "accurately", "accused"} + wg := &sync.WaitGroup{} + for _, tok := range tokens { + wg.Add(1) + go func(t string) { + store.set(t) + wg.Done() + }(tok) + } + wg.Wait() + + attest.Equal(t, store._len(), len(tokens)) + + store.reset() + attest.Equal(t, store._len(), 0) + }) +} + +func TestGetToken(t *testing.T) { + t.Parallel() + + t.Run("empty request", func(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/someUri", nil) + tok := getToken(r) + attest.Zero(t, tok) + }) + + t.Run("from cookie", func(t *testing.T) { + t.Parallel() + + want := xid.New().String() + r := httptest.NewRequest(http.MethodGet, "/someUri", nil) + r.AddCookie(&http.Cookie{ + Name: csrfCookieName, + Value: want, + Path: "/", + HttpOnly: false, // If true, makes cookie inaccessible to JS. Should be false for csrf cookies. + Secure: true, // https only. + SameSite: http.SameSiteStrictMode, + }) + got := getToken(r) + attest.Equal(t, got, want) + }) + + t.Run("from header", func(t *testing.T) { + t.Parallel() + + want := xid.New().String() + r := httptest.NewRequest(http.MethodGet, "/someUri", nil) + r.Header.Set(csrfHeader, want) + got := getToken(r) + attest.Equal(t, got, want) + }) + + t.Run("from form", func(t *testing.T) { + t.Parallel() + + want := xid.New().String() + r := httptest.NewRequest(http.MethodGet, "/someUri", nil) + err := r.ParseForm() + attest.Ok(t, err) + r.Form.Add(csrfCookieForm, want) + got := getToken(r) + attest.Equal(t, got, want) + }) + + t.Run("cookie takes precedence", func(t *testing.T) { + t.Parallel() + + cookieToken := xid.New().String() + headerToken := xid.New().String() + formToken := xid.New().String() + r := httptest.NewRequest(http.MethodGet, "/someUri", nil) + r.AddCookie(&http.Cookie{ + Name: csrfCookieName, + Value: cookieToken, + Path: "/", + HttpOnly: false, // If true, makes cookie inaccessible to JS. Should be false for csrf cookies. + Secure: true, // https only. + SameSite: http.SameSiteStrictMode, + }) + r.Header.Set(csrfHeader, headerToken) + err := r.ParseForm() + attest.Ok(t, err) + r.Form.Add(csrfCookieForm, formToken) + + got := getToken(r) + attest.Equal(t, got, cookieToken) + }) +} + +const tokenHeader = "CUSTOM-CSRF-TOKEN-TEST-HEADER" + +func someCsrfHandler(msg string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(tokenHeader, GetCsrfToken(r.Context())) + fmt.Fprint(w, msg) + } +} + +func TestCsrf(t *testing.T) { + t.Parallel() + + t.Run("middleware succeds", func(t *testing.T) { + t.Parallel() + + msg := "hello" + domain := "example.com" + wrappedHandler := Csrf(someCsrfHandler(msg), domain, -1) + + rec := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/someUri", nil) + wrappedHandler.ServeHTTP(rec, r) + + res := rec.Result() + defer res.Body.Close() + + rb, err := io.ReadAll(res.Body) + attest.Ok(t, err) + + attest.Equal(t, res.StatusCode, http.StatusOK) + attest.Equal(t, string(rb), msg) + }) + + t.Run("fetch token from GET requests", func(t *testing.T) { + t.Parallel() + + msg := "hello" + domain := "example.com" + wrappedHandler := Csrf(someCsrfHandler(msg), domain, -1) + + reqCsrfTok := xid.New().String() + rec := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/someUri", nil) + r.AddCookie(&http.Cookie{ + Name: csrfCookieName, + Value: reqCsrfTok, + Path: "/", + HttpOnly: false, // If true, makes cookie inaccessible to JS. Should be false for csrf cookies. + Secure: true, // https only. + SameSite: http.SameSiteStrictMode, + }) + wrappedHandler.ServeHTTP(rec, r) + + res := rec.Result() + defer res.Body.Close() + + rb, err := io.ReadAll(res.Body) + attest.Ok(t, err) + + attest.Equal(t, res.StatusCode, http.StatusOK) + attest.Equal(t, string(rb), msg) + attest.Equal(t, res.Header.Get(tokenHeader), reqCsrfTok) + }) + + t.Run("fetch token from HEAD requests", func(t *testing.T) { + t.Parallel() + + msg := "hello" + domain := "example.com" + wrappedHandler := Csrf(someCsrfHandler(msg), domain, -1) + + reqCsrfTok := xid.New().String() + rec := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodHead, "/someUri", nil) + r.Header.Set(csrfHeader, reqCsrfTok) + wrappedHandler.ServeHTTP(rec, r) + + res := rec.Result() + defer res.Body.Close() + + rb, err := io.ReadAll(res.Body) + attest.Ok(t, err) + + attest.Equal(t, res.StatusCode, http.StatusOK) + attest.Equal(t, string(rb), msg) + attest.Equal(t, res.Header.Get(tokenHeader), reqCsrfTok) + }) + + t.Run("can generate csrf tokens", func(t *testing.T) { + t.Parallel() + + msg := "hello" + domain := "example.com" + wrappedHandler := Csrf(someCsrfHandler(msg), domain, -1) + + rec := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/someUri", nil) + wrappedHandler.ServeHTTP(rec, r) + + res := rec.Result() + defer res.Body.Close() + + rb, err := io.ReadAll(res.Body) + attest.Ok(t, err) + + attest.Equal(t, res.StatusCode, http.StatusOK) + attest.Equal(t, string(rb), msg) + attest.NotZero(t, res.Header.Get(tokenHeader)) + }) + + t.Run("token is set in all required places", func(t *testing.T) { + t.Parallel() + + msg := "hello" + domain := "example.com" + wrappedHandler := Csrf(someCsrfHandler(msg), domain, -1) + + rec := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/someUri", nil) + wrappedHandler.ServeHTTP(rec, r) + + res := rec.Result() + defer res.Body.Close() + + rb, err := io.ReadAll(res.Body) + attest.Ok(t, err) + + attest.Equal(t, res.StatusCode, http.StatusOK) + attest.Equal(t, string(rb), msg) + + // assert that: + // (a) csrf cookie is set. + // (b) cookie header is set. + // (c) vary header is updated. + // (d) r.context is updated. + // (e) memory store is updated. + + // (a) + attest.Equal(t, len(res.Cookies()), 1) + attest.Equal(t, res.Cookies()[0].Name, csrfCookieName) + attest.Equal(t, res.Cookies()[0].Value, res.Header.Get(tokenHeader)) + + // (b) + attest.Equal(t, res.Header.Get(csrfHeader), res.Header.Get(tokenHeader)) + + // (c) + attest.Equal(t, res.Header.Get(varyHeader), clientCookieHeader) + + // (d) + attest.NotZero(t, res.Header.Get(tokenHeader)) + + // (e) + attest.True(t, csrfStore.exists(res.Header.Get(tokenHeader))) + attest.True(t, csrfStore._len() > 0) + }) + + t.Run("POST requests must have valid token", func(t *testing.T) { + t.Parallel() + + msg := "hello" + domain := "example.com" + wrappedHandler := Csrf(someCsrfHandler(msg), domain, -1) + + reqCsrfTok := xid.New().String() + rec := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/someUri", nil) + r.AddCookie(&http.Cookie{ + Name: csrfCookieName, + Value: reqCsrfTok, + Path: "/", + HttpOnly: false, // If true, makes cookie inaccessible to JS. Should be false for csrf cookies. + Secure: true, // https only. + SameSite: http.SameSiteStrictMode, + }) + wrappedHandler.ServeHTTP(rec, r) + + res := rec.Result() + defer res.Body.Close() + + rb, err := io.ReadAll(res.Body) + attest.Ok(t, err) + + attest.Equal(t, res.StatusCode, http.StatusBadRequest) + attest.Equal(t, string(rb), errCsrfTokenNotFound.Error()+"\n") + attest.Zero(t, res.Header.Get(tokenHeader)) + attest.Equal(t, len(res.Cookies()), 0) + }) + + t.Run("POST requests with valid token", func(t *testing.T) { + t.Parallel() + + msg := "hello" + domain := "example.com" + wrappedHandler := Csrf(someCsrfHandler(msg), domain, -1) + + reqCsrfTok := xid.New().String() + + { + // make GET request + rec := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/someUri", nil) + r.Header.Set(csrfHeader, reqCsrfTok) + wrappedHandler.ServeHTTP(rec, r) + res := rec.Result() + defer res.Body.Close() + rb, err := io.ReadAll(res.Body) + attest.Ok(t, err) + + attest.Equal(t, res.StatusCode, http.StatusOK) + attest.Equal(t, string(rb), msg) + } + + { + // make POST request using same csrf token + rec := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/someUri", nil) + r.AddCookie(&http.Cookie{ + Name: csrfCookieName, + Value: reqCsrfTok, + Path: "/", + HttpOnly: false, // If true, makes cookie inaccessible to JS. Should be false for csrf cookies. + Secure: true, // https only. + SameSite: http.SameSiteStrictMode, + }) + wrappedHandler.ServeHTTP(rec, r) + + res := rec.Result() + defer res.Body.Close() + + rb, err := io.ReadAll(res.Body) + attest.Ok(t, err) + attest.Equal(t, res.StatusCode, http.StatusOK) + attest.Equal(t, string(rb), msg) + + // assert that: + // (a) csrf cookie is set. + // (b) cookie header is set. + // (c) vary header is updated. + // (d) r.context is updated. + // (e) memory store is updated. + + // (a) + attest.Equal(t, len(res.Cookies()), 1) + attest.Equal(t, res.Cookies()[0].Name, csrfCookieName) + attest.Equal(t, res.Cookies()[0].Value, res.Header.Get(tokenHeader)) + + // (b) + attest.Equal(t, res.Header.Get(csrfHeader), res.Header.Get(tokenHeader)) + + // (c) + attest.Equal(t, res.Header.Get(varyHeader), clientCookieHeader) + + // (d) + attest.NotZero(t, res.Header.Get(tokenHeader)) + + // (e) + attest.True(t, csrfStore.exists(res.Header.Get(tokenHeader))) + attest.True(t, csrfStore._len() > 0) + } + }) + + t.Run("memory store reset", func(t *testing.T) { + t.Parallel() + + msg := "hello" + domain := "example.com" + maxRequestsToReset := int32(50) + wrappedHandler := Csrf(someCsrfHandler(msg), domain, maxRequestsToReset) + + for i := 0; i < (int(maxRequestsToReset) + 1); i++ { + reqCsrfTok := xid.New().String() + rec := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/someUri", nil) + r.Header.Set(csrfHeader, reqCsrfTok) + wrappedHandler.ServeHTTP(rec, r) + + res := rec.Result() + defer res.Body.Close() + + rb, err := io.ReadAll(res.Body) + attest.Ok(t, err) + + attest.Equal(t, res.StatusCode, http.StatusOK) + attest.Equal(t, string(rb), msg) + attest.Equal(t, res.Header.Get(tokenHeader), reqCsrfTok) + if i < int(maxRequestsToReset) { + attest.True(t, csrfStore._len() > 0) + } else { + attest.Equal(t, csrfStore._len(), 0) + } + } + }) +} diff --git a/middleware/panic_test.go b/middleware/panic_test.go index 62c6c31a..b1f058c0 100644 --- a/middleware/panic_test.go +++ b/middleware/panic_test.go @@ -21,7 +21,11 @@ func handlerThatPanics(msg string, shouldPanic bool) http.HandlerFunc { } func TestPanic(t *testing.T) { + t.Parallel() + t.Run("catches panics", func(t *testing.T) { + t.Parallel() + msg := "hello" wrappedHandler := Panic(handlerThatPanics(msg, true)) @@ -36,6 +40,8 @@ func TestPanic(t *testing.T) { }) t.Run("ok if no panic", func(t *testing.T) { + t.Parallel() + msg := "hello" wrappedHandler := Panic(handlerThatPanics(msg, false)) diff --git a/middleware/security.go b/middleware/security.go index 1a1dcea3..f766e665 100644 --- a/middleware/security.go +++ b/middleware/security.go @@ -12,8 +12,8 @@ import ( type cspContextKey string const ( - ck = cspContextKey("cspContextKey") - defaultNonce = "" + cspCtxKey = cspContextKey("cspContextKey") + cspDefaultNonce = "" // allow or block the use of browser features(eg accelerometer, camera, autoplay etc). permissionsPolicyHeader = "Permissions-Policy" @@ -54,7 +54,7 @@ func Security(wrappedHandler http.HandlerFunc, domain string) http.HandlerFunc { // var inline = 1; // nonce := xid.New().String() - r = r.WithContext(context.WithValue(ctx, ck, nonce)) + r = r.WithContext(context.WithValue(ctx, cspCtxKey, nonce)) w.Header().Set( cspHeader, // - https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP @@ -115,18 +115,18 @@ func Security(wrappedHandler http.HandlerFunc, domain string) http.HandlerFunc { // // usage: // func myHandler(w http.ResponseWriter, r *http.Request) { -// nonce := middleware.GetCspNonce(r.Context()) -// _ = nonce +// cspNonce := middleware.GetCspNonce(r.Context()) +// _ = cspNonce // } func GetCspNonce(c context.Context) string { - v := c.Value(ck) + v := c.Value(cspCtxKey) if v != nil { s, ok := v.(string) if ok { return s } } - return defaultNonce + return cspDefaultNonce } func getCsp(domain, nonce string) string { diff --git a/middleware/security_test.go b/middleware/security_test.go index 18ea5e85..d27b3850 100644 --- a/middleware/security_test.go +++ b/middleware/security_test.go @@ -12,7 +12,7 @@ import ( "github.com/akshayjshah/attest" ) -const nonceHeader = "CUSTOM-CSP-NONCE" +const nonceHeader = "CUSTOM-CSP-NONCE-TEST-HEADER" // echoHandler echos back in the response, the msg that was passed in. func echoHandler(msg string) http.HandlerFunc { @@ -98,6 +98,6 @@ func TestGetCspNonce(t *testing.T) { got := res.Header.Get(nonceHeader) attest.NotZero(t, got) - attest.True(t, got != defaultNonce) + attest.True(t, got != cspDefaultNonce) }) } diff --git a/server/server_test.go b/server/server_test.go index fd438804..1555b64f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -50,7 +50,11 @@ func TestDrainDuration(t *testing.T) { } func TestOpts(t *testing.T) { + t.Parallel() + t.Run("sensible defaults", func(t *testing.T) { + t.Parallel() + got := DefaultOpts() want := opts{ port: "8080", @@ -66,6 +70,8 @@ func TestOpts(t *testing.T) { }) t.Run("sensible defaults", func(t *testing.T) { + t.Parallel() + got := WithOpts("80", "localhost") want := opts{ port: "80", @@ -105,7 +111,11 @@ func TestOpts(t *testing.T) { // } // func TestRun(t *testing.T) { +// t.Parallel() +// // t.Run("success", func(t *testing.T) { +// t.Parallel() +// // eh := &myEH{router: http.NewServeMux()} // err := Run(eh, WithOpts("0", "localhost")) // attest.Ok(t, err)