Skip to content

Commit

Permalink
Fix CSRF middleware not being able to extract token from `multipart/f…
Browse files Browse the repository at this point in the history
…orm-data` form (#2136, fixes #2135)
  • Loading branch information
aldas authored Mar 15, 2022
1 parent 5c38c3b commit 01d7d01
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
4 changes: 2 additions & 2 deletions middleware/extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ func valuesFromCookie(name string) ValuesExtractor {
// valuesFromForm returns a function that extracts values from the form field.
func valuesFromForm(name string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
if parseErr := c.Request().ParseForm(); parseErr != nil {
return nil, fmt.Errorf("valuesFromForm parse form failed: %w", parseErr)
if c.Request().Form == nil {
_ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does
}
values := c.Request().Form[name]
if len(values) == 0 {
Expand Down
39 changes: 29 additions & 10 deletions middleware/extractor_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package middleware

import (
"bytes"
"fmt"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -499,6 +501,25 @@ func TestValuesFromForm(t *testing.T) {
return req
}

exampleMultiPartFormRequest := func(mod func(w *multipart.Writer)) *http.Request {
var b bytes.Buffer
w := multipart.NewWriter(&b)
w.WriteField("name", "Jon Snow")
w.WriteField("emails[]", "[email protected]")
if mod != nil {
mod(w)
}

fw, _ := w.CreateFormFile("upload", "my.file")
fw.Write([]byte(`<div>hi</div>`))
w.Close()

req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(b.String()))
req.Header.Add(echo.HeaderContentType, w.FormDataContentType())

return req
}

var testCases = []struct {
name string
givenRequest *http.Request
Expand All @@ -520,6 +541,14 @@ func TestValuesFromForm(t *testing.T) {
whenName: "emails[]",
expectValues: []string{"[email protected]", "[email protected]"},
},
{
name: "ok, POST multipart/form, multiple value",
givenRequest: exampleMultiPartFormRequest(func(w *multipart.Writer) {
w.WriteField("emails[]", "[email protected]")
}),
whenName: "emails[]",
expectValues: []string{"[email protected]", "[email protected]"},
},
{
name: "ok, GET form, single value",
givenRequest: exampleGetFormRequest(nil),
Expand All @@ -540,16 +569,6 @@ func TestValuesFromForm(t *testing.T) {
whenName: "nope",
expectError: errFormExtractorValueMissing.Error(),
},
{
name: "nok, POST form, form parsing error",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Body = nil
return req
}(),
whenName: "name",
expectError: "valuesFromForm parse form failed: missing form body",
},
{
name: "ok, cut values over extractorLimit",
givenRequest: examplePostFormRequest(func(v *url.Values) {
Expand Down

0 comments on commit 01d7d01

Please sign in to comment.