Skip to content

Commit

Permalink
Merge pull request #23 from gorilla/pass-field-name
Browse files Browse the repository at this point in the history
[feature] csrf.TemplateField now uses custom FieldNames
  • Loading branch information
elithrar committed Nov 30, 2015
2 parents d5a1c5d + 12a0c0e commit 5af6691
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 3 deletions.
3 changes: 3 additions & 0 deletions csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ const tokenLength = 32
// Context/session keys & prefixes
const (
tokenKey string = "gorilla.csrf.Token"
formKey string = "gorilla.csrf.Form"
errorKey string = "gorilla.csrf.Error"
cookieName string = "_gorilla_csrf"
errorPrefix string = "gorilla/csrf: "
Expand Down Expand Up @@ -198,6 +199,8 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// Save the masked token to the request context
context.Set(r, tokenKey, mask(realToken, r))
// Save the field name to the request context
context.Set(r, formKey, cs.opts.FieldName)

// HTTP methods not defined as idempotent ("safe") under RFC7231 require
// inspection.
Expand Down
11 changes: 8 additions & 3 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,15 @@ func FailureReason(r *http.Request) error {
// <input type="hidden" name="gorilla.csrf.Token" value="<token>">
//
func TemplateField(r *http.Request) template.HTML {
fragment := fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
fieldName, Token(r))
name, ok := context.GetOk(r, formKey)
if ok {
fragment := fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
name, Token(r))

return template.HTML(fragment)
return template.HTML(fragment)
}

return template.HTML("")
}

// mask returns a unique-per-request token to mitigate the BREACH attack
Expand Down
40 changes: 40 additions & 0 deletions helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"mime/multipart"
"net/http"
Expand Down Expand Up @@ -214,3 +215,42 @@ func TestGenerateRandomBytes(t *testing.T) {
t.Fatalf("generateRandomBytes did not report a short read: only read %d bytes", len(b))
}
}

func TestTemplateField(t *testing.T) {
s := http.NewServeMux()

// Make the token & template field available outside of the handler.
var token string
var templateField string
s.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token = Token(r)
templateField = string(TemplateField(r))
t := template.Must((template.New("base").Parse(testTemplate)))
t.Execute(w, map[string]interface{}{
TemplateTag: TemplateField(r),
})
}))

testFieldName := "custom_field_name"
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}

rr := httptest.NewRecorder()
p := Protect(testKey, FieldName(testFieldName))(s)
p.ServeHTTP(rr, r)

expectedField := fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
testFieldName, token)

if rr.Code != http.StatusOK {
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
rr.Code, http.StatusOK)
}

if templateField != expectedField {
t.Fatalf("custom FieldName was not set correctly: got %v want %v",
templateField, expectedField)
}
}

0 comments on commit 5af6691

Please sign in to comment.