diff --git a/CHANGELOG.md b/CHANGELOG.md index 57524b04..d207d065 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,3 +31,4 @@ Most recent version is listed first. - use acme for certificates: https://github.com/komuw/ong/pull/69 - issues/73: bind on 0.0.0.0 or localhost conditionally: https://github.com/komuw/ong/pull/74 - redirect IP to domain: https://github.com/komuw/ong/pull/75 +- dont require csrf for POST requests that have no cookies and arent http auth: https://github.com/komuw/ong/pull/77 diff --git a/example/main.go b/example/main.go index 1a35fa3c..cb851ecb 100644 --- a/example/main.go +++ b/example/main.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "html/template" "net/http" "os" "sync" @@ -40,6 +41,12 @@ func main() { api.check(200), middleware.WithOpts("localhost"), ), + server.NewRoute( + "login", + server.MethodAll, + api.login(), + middleware.WithOpts("localhost"), + ), }) _, _ = server.CreateDevCertKey() @@ -116,7 +123,64 @@ func (s myAPI) check(code int) http.HandlerFunc { csrfToken := middleware.GetCsrfToken(r.Context()) s.l.Info(log.F{"msg": "check called", "cspNonce": cspNonce, "csrfToken": csrfToken}) + _, _ = fmt.Fprint(w, "hello from check/ endpoint") // use code, which is a dependency specific to this handler w.WriteHeader(code) } } + +func (s myAPI) login() http.HandlerFunc { + tmpl, err := template.New("myTpl").Parse(` + + + +

Welcome to awesome website.

+
+
+
+
+
+ +
+ +
+ + + + + +`) + if err != nil { + panic(err) + } + + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + data := struct { + CsrfTokenName string + CsrfTokenValue string + CspNonceValue string + }{ + CsrfTokenName: middleware.CsrfTokenFormName, + CsrfTokenValue: middleware.GetCsrfToken(r.Context()), + CspNonceValue: middleware.GetCspNonce(r.Context()), + } + if err = tmpl.Execute(w, data); err != nil { + panic(err) + } + return + } + + if err = r.ParseForm(); err != nil { + panic(err) + } + + fmt.Println("r.Form: ", r.Form) + for k, v := range r.Form { + fmt.Println("k, v: ", k, v) + } + _, _ = fmt.Fprintf(w, "you have submitted: %s", r.Form) + } +} diff --git a/middleware/csrf.go b/middleware/csrf.go index 79a3f6b0..70881794 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -3,6 +3,7 @@ package middleware import ( "context" "errors" + "mime" "net/http" "sync" "time" @@ -29,13 +30,19 @@ var ( type csrfContextKey string const ( - csrfCtxKey = csrfContextKey("csrfContextKey") - csrfDefaultToken = "" - csrfCookieName = "csrftoken" // named after what django uses. - csrfHeader = "X-Csrf-Token" // named after what fiber uses. - csrfCookieForm = csrfCookieName - clientCookieHeader = "Cookie" - varyHeader = "Vary" + // CsrfTokenFormName is the name of the html form name attribute for csrf token. + CsrfTokenFormName = "csrftoken" // named after what django uses. + csrfCtxKey = csrfContextKey("csrfContextKey") + csrfDefaultToken = "" + csrfCookieName = CsrfTokenFormName + csrfHeader = "X-Csrf-Token" // named after what fiber uses. + clientCookieHeader = "Cookie" + varyHeader = "Vary" + authorizationHeader = "Authorization" + proxyAuthorizationHeader = "Proxy-Authorization" + ctHeader = "Content-Type" + formUrlEncoded = "application/x-www-form-urlencoded" + multiformData = "multipart/form-data" // gorilla/csrf; 12hrs // django: 1yr?? @@ -76,6 +83,24 @@ func Csrf(wrappedHandler http.HandlerFunc, domain string) http.HandlerFunc { default: // For POST requests, we insist on a CSRF cookie, and in this way we can avoid all CSRF attacks, including login CSRF. actualToken = getToken(r) + + ct, _, err := mime.ParseMediaType(r.Header.Get(ctHeader)) + if err == nil && + ct != formUrlEncoded && + ct != multiformData && + r.Header.Get(clientCookieHeader) == "" && + r.Header.Get(authorizationHeader) == "" && + r.Header.Get(proxyAuthorizationHeader) == "" { + // For POST requests that; + // - are not form data. + // - have no cookies. + // - are not using http authentication. + // then it is okay to not validate csrf for them. + // This is especially useful for REST API endpoints. + // see: https://github.com/komuw/ong/issues/76 + break + } + if !csrfStore.exists(actualToken) { // we should fail the request since it means that the server is not aware of such a token. cookie.Delete(w, csrfCookieName, domain) @@ -179,10 +204,7 @@ func getToken(r *http.Request) (actualToken string) { } fromForm := func() string { - if err := r.ParseForm(); err != nil { - return "" - } - return r.Form.Get(csrfCookieForm) + return r.FormValue(CsrfTokenFormName) // calls ParseMultipartForm and ParseForm if necessary } fromHeader := func() string { diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 17285281..b8cc75b3 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "sync" "testing" @@ -186,7 +187,7 @@ func TestGetToken(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/someUri", nil) err := req.ParseForm() attest.Ok(t, err) - req.Form.Add(csrfCookieForm, want) + req.Form.Add(CsrfTokenFormName, want) got := getToken(req) attest.Equal(t, got, want[csrfStringTokenlength:]) }) @@ -209,7 +210,7 @@ func TestGetToken(t *testing.T) { req.Header.Set(csrfHeader, headerToken) err := req.ParseForm() attest.Ok(t, err) - req.Form.Add(csrfCookieForm, formToken) + req.Form.Add(CsrfTokenFormName, formToken) got := getToken(req) attest.Equal(t, got, cookieToken[csrfStringTokenlength:]) @@ -477,6 +478,32 @@ func TestCsrf(t *testing.T) { } }) + t.Run("POST requests with no cookies dont need csrf", func(t *testing.T) { + t.Parallel() + + msg := "hello" + domain := "example.com" + wrappedHandler := Csrf(someCsrfHandler(msg), domain) + + rec := httptest.NewRecorder() + postMsg := "my name is John" + body := strings.NewReader(postMsg) + req := httptest.NewRequest(http.MethodPost, "/someUri", body) + req.Header.Add(ctHeader, "application/json") + wrappedHandler.ServeHTTP(rec, req) + + res := rec.Result() + defer res.Body.Close() + + rb, err := io.ReadAll(res.Body) + attest.Ok(t, err) + + attest.Equal(t, res.StatusCode, http.StatusOK) + attest.Equal(t, string(rb), msg) + attest.NotZero(t, res.Header.Get(tokenHeader)) + attest.Equal(t, len(res.Cookies()), 1) + }) + // concurrency safe t.Run("POST requests with valid token from mutiple tabs", func(t *testing.T) { t.Parallel()