diff --git a/CHANGELOG.md b/CHANGELOG.md
index 57524b04..d207d065 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -31,3 +31,4 @@ Most recent version is listed first.
- use acme for certificates: https://github.com/komuw/ong/pull/69
- issues/73: bind on 0.0.0.0 or localhost conditionally: https://github.com/komuw/ong/pull/74
- redirect IP to domain: https://github.com/komuw/ong/pull/75
+- dont require csrf for POST requests that have no cookies and arent http auth: https://github.com/komuw/ong/pull/77
diff --git a/example/main.go b/example/main.go
index 1a35fa3c..cb851ecb 100644
--- a/example/main.go
+++ b/example/main.go
@@ -3,6 +3,7 @@ package main
import (
"context"
"fmt"
+ "html/template"
"net/http"
"os"
"sync"
@@ -40,6 +41,12 @@ func main() {
api.check(200),
middleware.WithOpts("localhost"),
),
+ server.NewRoute(
+ "login",
+ server.MethodAll,
+ api.login(),
+ middleware.WithOpts("localhost"),
+ ),
})
_, _ = server.CreateDevCertKey()
@@ -116,7 +123,64 @@ func (s myAPI) check(code int) http.HandlerFunc {
csrfToken := middleware.GetCsrfToken(r.Context())
s.l.Info(log.F{"msg": "check called", "cspNonce": cspNonce, "csrfToken": csrfToken})
+ _, _ = fmt.Fprint(w, "hello from check/ endpoint")
// use code, which is a dependency specific to this handler
w.WriteHeader(code)
}
}
+
+func (s myAPI) login() http.HandlerFunc {
+ tmpl, err := template.New("myTpl").Parse(`
+
+
+
+ Welcome to awesome website.
+
+
+
+
+
+
+`)
+ if err != nil {
+ panic(err)
+ }
+
+ return func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ data := struct {
+ CsrfTokenName string
+ CsrfTokenValue string
+ CspNonceValue string
+ }{
+ CsrfTokenName: middleware.CsrfTokenFormName,
+ CsrfTokenValue: middleware.GetCsrfToken(r.Context()),
+ CspNonceValue: middleware.GetCspNonce(r.Context()),
+ }
+ if err = tmpl.Execute(w, data); err != nil {
+ panic(err)
+ }
+ return
+ }
+
+ if err = r.ParseForm(); err != nil {
+ panic(err)
+ }
+
+ fmt.Println("r.Form: ", r.Form)
+ for k, v := range r.Form {
+ fmt.Println("k, v: ", k, v)
+ }
+ _, _ = fmt.Fprintf(w, "you have submitted: %s", r.Form)
+ }
+}
diff --git a/middleware/csrf.go b/middleware/csrf.go
index 79a3f6b0..70881794 100644
--- a/middleware/csrf.go
+++ b/middleware/csrf.go
@@ -3,6 +3,7 @@ package middleware
import (
"context"
"errors"
+ "mime"
"net/http"
"sync"
"time"
@@ -29,13 +30,19 @@ var (
type csrfContextKey string
const (
- csrfCtxKey = csrfContextKey("csrfContextKey")
- csrfDefaultToken = ""
- csrfCookieName = "csrftoken" // named after what django uses.
- csrfHeader = "X-Csrf-Token" // named after what fiber uses.
- csrfCookieForm = csrfCookieName
- clientCookieHeader = "Cookie"
- varyHeader = "Vary"
+ // CsrfTokenFormName is the name of the html form name attribute for csrf token.
+ CsrfTokenFormName = "csrftoken" // named after what django uses.
+ csrfCtxKey = csrfContextKey("csrfContextKey")
+ csrfDefaultToken = ""
+ csrfCookieName = CsrfTokenFormName
+ csrfHeader = "X-Csrf-Token" // named after what fiber uses.
+ clientCookieHeader = "Cookie"
+ varyHeader = "Vary"
+ authorizationHeader = "Authorization"
+ proxyAuthorizationHeader = "Proxy-Authorization"
+ ctHeader = "Content-Type"
+ formUrlEncoded = "application/x-www-form-urlencoded"
+ multiformData = "multipart/form-data"
// gorilla/csrf; 12hrs
// django: 1yr??
@@ -76,6 +83,24 @@ func Csrf(wrappedHandler http.HandlerFunc, domain string) http.HandlerFunc {
default:
// For POST requests, we insist on a CSRF cookie, and in this way we can avoid all CSRF attacks, including login CSRF.
actualToken = getToken(r)
+
+ ct, _, err := mime.ParseMediaType(r.Header.Get(ctHeader))
+ if err == nil &&
+ ct != formUrlEncoded &&
+ ct != multiformData &&
+ r.Header.Get(clientCookieHeader) == "" &&
+ r.Header.Get(authorizationHeader) == "" &&
+ r.Header.Get(proxyAuthorizationHeader) == "" {
+ // For POST requests that;
+ // - are not form data.
+ // - have no cookies.
+ // - are not using http authentication.
+ // then it is okay to not validate csrf for them.
+ // This is especially useful for REST API endpoints.
+ // see: https://github.com/komuw/ong/issues/76
+ break
+ }
+
if !csrfStore.exists(actualToken) {
// we should fail the request since it means that the server is not aware of such a token.
cookie.Delete(w, csrfCookieName, domain)
@@ -179,10 +204,7 @@ func getToken(r *http.Request) (actualToken string) {
}
fromForm := func() string {
- if err := r.ParseForm(); err != nil {
- return ""
- }
- return r.Form.Get(csrfCookieForm)
+ return r.FormValue(CsrfTokenFormName) // calls ParseMultipartForm and ParseForm if necessary
}
fromHeader := func() string {
diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go
index 17285281..b8cc75b3 100644
--- a/middleware/csrf_test.go
+++ b/middleware/csrf_test.go
@@ -5,6 +5,7 @@ import (
"io"
"net/http"
"net/http/httptest"
+ "strings"
"sync"
"testing"
@@ -186,7 +187,7 @@ func TestGetToken(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
err := req.ParseForm()
attest.Ok(t, err)
- req.Form.Add(csrfCookieForm, want)
+ req.Form.Add(CsrfTokenFormName, want)
got := getToken(req)
attest.Equal(t, got, want[csrfStringTokenlength:])
})
@@ -209,7 +210,7 @@ func TestGetToken(t *testing.T) {
req.Header.Set(csrfHeader, headerToken)
err := req.ParseForm()
attest.Ok(t, err)
- req.Form.Add(csrfCookieForm, formToken)
+ req.Form.Add(CsrfTokenFormName, formToken)
got := getToken(req)
attest.Equal(t, got, cookieToken[csrfStringTokenlength:])
@@ -477,6 +478,32 @@ func TestCsrf(t *testing.T) {
}
})
+ t.Run("POST requests with no cookies dont need csrf", func(t *testing.T) {
+ t.Parallel()
+
+ msg := "hello"
+ domain := "example.com"
+ wrappedHandler := Csrf(someCsrfHandler(msg), domain)
+
+ rec := httptest.NewRecorder()
+ postMsg := "my name is John"
+ body := strings.NewReader(postMsg)
+ req := httptest.NewRequest(http.MethodPost, "/someUri", body)
+ req.Header.Add(ctHeader, "application/json")
+ wrappedHandler.ServeHTTP(rec, req)
+
+ res := rec.Result()
+ defer res.Body.Close()
+
+ rb, err := io.ReadAll(res.Body)
+ attest.Ok(t, err)
+
+ attest.Equal(t, res.StatusCode, http.StatusOK)
+ attest.Equal(t, string(rb), msg)
+ attest.NotZero(t, res.Header.Get(tokenHeader))
+ attest.Equal(t, len(res.Cookies()), 1)
+ })
+
// concurrency safe
t.Run("POST requests with valid token from mutiple tabs", func(t *testing.T) {
t.Parallel()