diff --git a/handler/graphql.go b/handler/graphql.go index 6d2a787a03..a22542225f 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "io/ioutil" + "mime" "net/http" "os" "strconv" @@ -16,7 +17,7 @@ import ( "github.com/99designs/gqlgen/complexity" "github.com/99designs/gqlgen/graphql" "github.com/gorilla/websocket" - "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru" "github.com/vektah/gqlparser/ast" "github.com/vektah/gqlparser/gqlerror" "github.com/vektah/gqlparser/parser" @@ -369,8 +370,20 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } case http.MethodPost: - contentType := strings.SplitN(r.Header.Get("Content-Type"), ";", 2)[0] - if contentType == "multipart/form-data" { + mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil { + sendErrorf(w, http.StatusBadRequest, "error parsing request Content-Type") + return + } + + switch mediaType { + case "application/json": + if err := jsonDecode(r.Body, &reqParams); err != nil { + sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error()) + return + } + + case "multipart/form-data": var closers []io.Closer var tmpFiles []string defer func() { @@ -385,11 +398,9 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { sendErrorf(w, http.StatusBadRequest, "multipart body could not be decoded: "+err.Error()) return } - } else { - if err := jsonDecode(r.Body, &reqParams); err != nil { - sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error()) - return - } + default: + sendErrorf(w, http.StatusBadRequest, "unsupported Content-Type: "+mediaType) + return } default: w.WriteHeader(http.StatusMethodNotAllowed) diff --git a/handler/graphql_test.go b/handler/graphql_test.go index 70946ab949..c97c5290ea 100644 --- a/handler/graphql_test.go +++ b/handler/graphql_test.go @@ -15,7 +15,6 @@ import ( "testing" "github.com/99designs/gqlgen/graphql" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vektah/gqlparser/ast" @@ -88,6 +87,49 @@ func TestHandlerPOST(t *testing.T) { assert.Equal(t, resp.Header().Get("Content-Type"), "application/json") assert.Equal(t, `{"errors":[{"message":"mutations are not supported"}],"data":null}`, resp.Body.String()) }) + + t.Run("validate content type", func(t *testing.T) { + doReq := func(handler http.Handler, method string, target string, body string, contentType string) *httptest.ResponseRecorder { + r := httptest.NewRequest(method, target, strings.NewReader(body)) + if contentType != "" { + r.Header.Set("Content-Type", contentType) + } + w := httptest.NewRecorder() + + handler.ServeHTTP(w, r) + return w + } + + validContentTypes := []string{ + "application/json", + "application/json; charset=utf-8", + } + + for _, contentType := range validContentTypes { + t.Run(fmt.Sprintf("allow for content type %s", contentType), func(t *testing.T) { + resp := doReq(h, "POST", "/graphql", `{"query":"{ me { name } }"}`, contentType) + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) + }) + } + + invalidContentTypes := []struct{ contentType, expectedError string }{ + {"", "error parsing request Content-Type"}, + {"text/plain", "unsupported Content-Type: text/plain"}, + + // These content types are currently not supported, but they are supported by other GraphQL servers, like express-graphql. + {"application/x-www-form-urlencoded", "unsupported Content-Type: application/x-www-form-urlencoded"}, + {"application/graphql", "unsupported Content-Type: application/graphql"}, + } + + for _, tc := range invalidContentTypes { + t.Run(fmt.Sprintf("reject for content type %s", tc.contentType), func(t *testing.T) { + resp := doReq(h, "POST", "/graphql", `{"query":"{ me { name } }"}`, tc.contentType) + assert.Equal(t, http.StatusBadRequest, resp.Code) + assert.Equal(t, fmt.Sprintf(`{"errors":[{"message":"%s"}],"data":null}`, tc.expectedError), resp.Body.String()) + }) + } + }) } func TestHandlerGET(t *testing.T) { @@ -640,6 +682,7 @@ func createUploadRequest(t *testing.T, operations, mapData string, files []file) func doRequest(handler http.Handler, method string, target string, body string) *httptest.ResponseRecorder { r := httptest.NewRequest(method, target, strings.NewReader(body)) + r.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() handler.ServeHTTP(w, r)