Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not providing any token in requests results in wrong error message #149

Merged
merged 4 commits into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Please keep the list sorted.

adiabatic <[email protected]>
Florian D. Loch <[email protected]>
Google LLC (https://opensource.google.com)
jamesgroat <[email protected]>
Joshua Carp <[email protected]>
Expand Down
17 changes: 12 additions & 5 deletions csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,16 +274,23 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}

// If the token returned from the session store is nil for non-idempotent
// ("unsafe") methods, call the error handler.
if realToken == nil {
// Retrieve the combined token (pad + masked) token...
maskedToken, err := cs.requestToken(r)

FlorianLoch marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
r = envError(r, ErrBadToken)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}

if maskedToken == nil {
r = envError(r, ErrNoToken)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}

// Retrieve the combined token (pad + masked) token and unmask it.
requestToken := unmask(cs.requestToken(r))
// ... and unmask it.
requestToken := unmask(maskedToken)

// Compare the request token against the real token
if !compareTokens(requestToken, realToken) {
Expand Down
42 changes: 42 additions & 0 deletions csrf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,48 @@ func TestWithReferer(t *testing.T) {
}
}

// Requests without a token should fail with ErrNoToken.
func TestNoTokenProvided(t *testing.T) {
var finalErr error

s := http.NewServeMux()
p := Protect(testKey, ErrorHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
finalErr = FailureReason(r)
})))(s)

var token string
s.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token = Token(r)
}))

// Obtain a CSRF cookie via a GET request.
r, err := http.NewRequest("GET", "http://www.gorillatoolkit.org/", nil)
if err != nil {
t.Fatal(err)
}

rr := httptest.NewRecorder()
p.ServeHTTP(rr, r)

// POST the token back in the header.
r, err = http.NewRequest("POST", "http://www.gorillatoolkit.org/", nil)
if err != nil {
t.Fatal(err)
}

setCookie(rr, r)
// By accident we use the wrong header name for the token...
r.Header.Set("X-CSRF-nekot", token)
r.Header.Set("Referer", "http://www.gorillatoolkit.org/")

rr = httptest.NewRecorder()
p.ServeHTTP(rr, r)

if finalErr != nil && finalErr != ErrNoToken {
t.Fatalf("middleware failed to return correct error: got '%v' want '%v'", finalErr, ErrNoToken)
}
}

func setCookie(rr *httptest.ResponseRecorder, r *http.Request) {
r.Header.Set("Cookie", rr.Header().Get("Set-Cookie"))
}
11 changes: 8 additions & 3 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func unmask(issued []byte) []byte {

// requestToken returns the issued token (pad + masked token) from the HTTP POST
// body or HTTP header. It will return nil if the token fails to decode.
func (cs *csrf) requestToken(r *http.Request) []byte {
func (cs *csrf) requestToken(r *http.Request) ([]byte, error) {
// 1. Check the HTTP header first.
issued := r.Header.Get(cs.opts.RequestHeader)

Expand All @@ -123,14 +123,19 @@ func (cs *csrf) requestToken(r *http.Request) []byte {
}
}

// Return nil (equivalent to empty byte slice) if no token was found
if issued == "" {
return nil, nil
}

// Decode the "issued" (pad + masked) token sent in the request. Return a
// nil byte slice on a decoding error (this will fail upstream).
decoded, err := base64.StdEncoding.DecodeString(issued)
if err != nil {
return nil
return nil, err
}

return decoded
return decoded, nil
}

// generateRandomBytes returns securely generated random bytes.
Expand Down