diff --git a/handler/graphql.go b/handler/graphql.go index 5ead497c12..bbb680a655 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -21,10 +21,6 @@ import ( "github.com/vektah/gqlparser/validator" ) -const ( - defaultMaxMemory = 32 << 20 // 32 MB -) - type params struct { Query string `json:"query"` OperationName string `json:"operationName"` @@ -43,6 +39,7 @@ type Config struct { complexityLimitFunc graphql.ComplexityLimitFunc disableIntrospection bool connectionKeepAlivePingInterval time.Duration + uploadMaxMemory int64 } func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext { @@ -249,6 +246,15 @@ func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) { tw.tracer1.EndOperationExecution(ctx) } +// UploadMaxMemory sets the total of maxMemory bytes used to parse a request body +// as multipart/form-data in memory, with the remainder stored on disk in +// temporary files. +func UploadMaxMemory(maxMemory int64) Option { + return func(cfg *Config) { + cfg.uploadMaxMemory = maxMemory + } +} + // CacheSize sets the maximum size of the query cache. // If size is less than or equal to 0, the cache is disabled. func CacheSize(size int) Option { @@ -270,9 +276,15 @@ func WebsocketKeepAliveDuration(duration time.Duration) Option { const DefaultCacheSize = 1000 const DefaultConnectionKeepAlivePingInterval = 25 * time.Second +// DefaultUploadMaxMemory sets the total of maxMemory bytes used to parse a request body +// as multipart/form-data in memory, with the remainder stored on disk in +// temporary files. The default value is 32 MB. +const DefaultUploadMaxMemory = 32 << 20 + func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc { cfg := &Config{ cacheSize: DefaultCacheSize, + uploadMaxMemory: DefaultUploadMaxMemory, connectionKeepAlivePingInterval: DefaultConnectionKeepAlivePingInterval, upgrader: websocket.Upgrader{ ReadBufferSize: 1024, @@ -343,7 +355,7 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { case http.MethodPost: contentType := strings.SplitN(r.Header.Get("Content-Type"), ";", 2)[0] if contentType == "multipart/form-data" { - if err := processMultipart(r, &reqParams); err != nil { + if err := processMultipart(r, &reqParams, gh.cfg.uploadMaxMemory); err != nil { sendErrorf(w, http.StatusBadRequest, "multipart body could not be decoded: "+err.Error()) return } @@ -508,9 +520,9 @@ func sendErrorf(w http.ResponseWriter, code int, format string, args ...interfac sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)}) } -func processMultipart(r *http.Request, request *params) error { +func processMultipart(r *http.Request, request *params, uploadMaxMemory int64) error { // Parse multipart form - if err := r.ParseMultipartForm(defaultMaxMemory); err != nil { + if err := r.ParseMultipartForm(uploadMaxMemory); err != nil { return errors.New("failed to parse multipart form") } diff --git a/handler/graphql_test.go b/handler/graphql_test.go index 254a2ede59..dfac2d90bf 100644 --- a/handler/graphql_test.go +++ b/handler/graphql_test.go @@ -321,7 +321,7 @@ func TestProcessMultipart(t *testing.T) { Body: ioutil.NopCloser(new(bytes.Buffer)), } var reqParams params - err := processMultipart(req, &reqParams) + err := processMultipart(req, &reqParams, DefaultFileMaxMemory) require.NotNil(t, err) errMsg := err.Error() require.Equal(t, errMsg, "failed to parse multipart form") @@ -332,7 +332,7 @@ func TestProcessMultipart(t *testing.T) { req := createUploadRequest(t, operations, validMap, validFiles) var reqParams params - err := processMultipart(req, &reqParams) + err := processMultipart(req, &reqParams, DefaultFileMaxMemory) require.NotNil(t, err) require.Equal(t, err.Error(), "operations form field could not be decoded") }) @@ -342,7 +342,7 @@ func TestProcessMultipart(t *testing.T) { req := createUploadRequest(t, validOperations, mapData, validFiles) var reqParams params - err := processMultipart(req, &reqParams) + err := processMultipart(req, &reqParams, DefaultFileMaxMemory) require.NotNil(t, err) require.Equal(t, err.Error(), "map form field could not be decoded") }) @@ -352,7 +352,7 @@ func TestProcessMultipart(t *testing.T) { req := createUploadRequest(t, validOperations, validMap, files) var reqParams params - err := processMultipart(req, &reqParams) + err := processMultipart(req, &reqParams, DefaultFileMaxMemory) require.NotNil(t, err) require.Equal(t, err.Error(), "failed to get key 0 from form") }) @@ -362,7 +362,7 @@ func TestProcessMultipart(t *testing.T) { req := createUploadRequest(t, validOperations, mapData, validFiles) var reqParams params - err := processMultipart(req, &reqParams) + err := processMultipart(req, &reqParams, DefaultFileMaxMemory) require.NotNil(t, err) require.Equal(t, err.Error(), "invalid value for key 0") }) @@ -372,7 +372,7 @@ func TestProcessMultipart(t *testing.T) { req := createUploadRequest(t, validOperations, mapData, validFiles) var reqParams params - err := processMultipart(req, &reqParams) + err := processMultipart(req, &reqParams, DefaultFileMaxMemory) require.NotNil(t, err) require.Equal(t, err.Error(), "invalid value for key 0") }) @@ -381,7 +381,7 @@ func TestProcessMultipart(t *testing.T) { req := createUploadRequest(t, validOperations, validMap, validFiles) var reqParams params - err := processMultipart(req, &reqParams) + err := processMultipart(req, &reqParams, DefaultFileMaxMemory) require.Nil(t, err) require.Equal(t, "mutation ($file: Upload!) { singleUpload(file: $file) { id } }", reqParams.Query) require.Equal(t, "", reqParams.OperationName)