Skip to content

Commit

Permalink
Modify graphql.Upload to use io.ReadCloser. Change the way upload fil…
Browse files Browse the repository at this point in the history
…es are managed.
  • Loading branch information
hantonelli committed Apr 19, 2019
1 parent 0306783 commit f848415
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 103 deletions.
2 changes: 1 addition & 1 deletion docs/content/reference/scalars.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ scalar Upload
Maps a `Upload` GraphQL scalar to a `graphql.Upload` struct, defined as follows:
```
type Upload struct {
File multipart.File
File io.ReadCloser
Filename string
Size int64
}
Expand Down
55 changes: 30 additions & 25 deletions example/fileupload/fileupload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,13 @@ func TestFileUpload(t *testing.T) {

resp, err := client.Do(req)
require.Nil(t, err)
defer func() {
_ = resp.Body.Close()
}()
require.Equal(t, http.StatusOK, resp.StatusCode)
responseBody, err := ioutil.ReadAll(resp.Body)
require.Nil(t, err)
responseString := string(responseBody)
require.Equal(t, `{"data":{"singleUpload":{"id":1,"name":"a.txt","content":"test"}}}`, responseString)
err = resp.Body.Close()
require.Nil(t, err)
})

t.Run("valid single file upload with payload", func(t *testing.T) {
Expand Down Expand Up @@ -94,13 +93,12 @@ func TestFileUpload(t *testing.T) {

resp, err := client.Do(req)
require.Nil(t, err)
defer func() {
_ = resp.Body.Close()
}()
require.Equal(t, http.StatusOK, resp.StatusCode)
responseBody, err := ioutil.ReadAll(resp.Body)
require.Nil(t, err)
require.Equal(t, `{"data":{"singleUploadWithPayload":{"id":1,"name":"a.txt","content":"test"}}}`, string(responseBody))
err = resp.Body.Close()
require.Nil(t, err)
})

t.Run("valid file list upload", func(t *testing.T) {
Expand Down Expand Up @@ -145,13 +143,12 @@ func TestFileUpload(t *testing.T) {

resp, err := client.Do(req)
require.Nil(t, err)
defer func() {
_ = resp.Body.Close()
}()
require.Equal(t, http.StatusOK, resp.StatusCode)
responseBody, err := ioutil.ReadAll(resp.Body)
require.Nil(t, err)
require.Equal(t, `{"data":{"multipleUpload":[{"id":1,"name":"a.txt","content":"test1"},{"id":2,"name":"b.txt","content":"test2"}]}}`, string(responseBody))
err = resp.Body.Close()
require.Nil(t, err)
})

t.Run("valid file list upload with payload", func(t *testing.T) {
Expand Down Expand Up @@ -200,13 +197,12 @@ func TestFileUpload(t *testing.T) {

resp, err := client.Do(req)
require.Nil(t, err)
defer func() {
_ = resp.Body.Close()
}()
require.Equal(t, http.StatusOK, resp.StatusCode)
responseBody, err := ioutil.ReadAll(resp.Body)
require.Nil(t, err)
require.Equal(t, `{"data":{"multipleUploadWithPayload":[{"id":1,"name":"a.txt","content":"test1"},{"id":2,"name":"b.txt","content":"test2"}]}}`, string(responseBody))
err = resp.Body.Close()
require.Nil(t, err)
})

t.Run("valid file list upload with payload and file reuse", func(t *testing.T) {
Expand All @@ -220,7 +216,6 @@ func TestFileUpload(t *testing.T) {
require.NotNil(t, req[i].File)
require.NotNil(t, req[i].File.File)
ids = append(ids, req[i].ID)
req[i].File.File.Seek(0, 0)
content, err := ioutil.ReadAll(req[i].File.File)
require.Nil(t, err)
contents = append(contents, string(content))
Expand All @@ -235,8 +230,6 @@ func TestFileUpload(t *testing.T) {
return resp, nil
},
}
srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolver}), handler.UploadMaxMemory(2)))
defer srv.Close()

operations := `{ "query": "mutation($req: [UploadFile!]!) { multipleUploadWithPayload(req: $req) { id, name, content } }", "variables": { "req": [ { "id": 1, "file": null }, { "id": 2, "file": null } ] } }`
mapData := `{ "0": ["variables.req.0.file", "variables.req.1.file"] }`
Expand All @@ -247,17 +240,29 @@ func TestFileUpload(t *testing.T) {
content: "test1",
},
}
req := createUploadRequest(t, srv.URL, operations, mapData, files)

resp, err := client.Do(req)
require.Nil(t, err)
defer func() {
_ = resp.Body.Close()
}()
require.Equal(t, http.StatusOK, resp.StatusCode)
responseBody, err := ioutil.ReadAll(resp.Body)
require.Nil(t, err)
require.Equal(t, `{"data":{"multipleUploadWithPayload":[{"id":1,"name":"a.txt","content":"test1"},{"id":2,"name":"a.txt","content":"test1"}]}}`, string(responseBody))
test := func(uploadMaxMemory int64) {
memory := handler.UploadMaxMemory(uploadMaxMemory)
srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolver}), memory))
defer srv.Close()
req := createUploadRequest(t, srv.URL, operations, mapData, files)
resp, err := client.Do(req)
require.Nil(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
responseBody, err := ioutil.ReadAll(resp.Body)
require.Nil(t, err)
require.Equal(t, `{"data":{"multipleUploadWithPayload":[{"id":1,"name":"a.txt","content":"test1"},{"id":2,"name":"a.txt","content":"test1"}]}}`, string(responseBody))
err = resp.Body.Close()
require.Nil(t, err)
}

t.Run("payload smaller than UploadMaxMemory, stored in memory", func(t *testing.T){
test(5000)
})

t.Run("payload bigger than UploadMaxMemory, persisted to disk", func(t *testing.T){
test(2)
})
})
}

Expand Down
3 changes: 1 addition & 2 deletions graphql/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ package graphql
import (
"fmt"
"io"
"mime/multipart"
)

type Upload struct {
File multipart.File
File io.Reader
Filename string
Size int64
}
Expand Down
104 changes: 83 additions & 21 deletions handler/graphql.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package handler

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -369,7 +372,17 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case http.MethodPost:
contentType := strings.SplitN(r.Header.Get("Content-Type"), ";", 2)[0]
if contentType == "multipart/form-data" {
if err := processMultipart(w, r, &reqParams, gh.cfg.uploadMaxSize, gh.cfg.uploadMaxMemory); err != nil {
var closers []io.Closer
var tmpFiles []string
defer func() {
for i := len(closers) - 1; 0 <= i; i-- {
_ = closers[i].Close()
}
for _, tmpFile := range tmpFiles {
_ = os.Remove(tmpFile)
}
}()
if err := processMultipart(w, r, &reqParams, &closers, &tmpFiles, gh.cfg.uploadMaxSize, gh.cfg.uploadMaxMemory); err != nil {
sendErrorf(w, http.StatusBadRequest, "multipart body could not be decoded: "+err.Error())
return
}
Expand Down Expand Up @@ -534,7 +547,7 @@ func sendErrorf(w http.ResponseWriter, code int, format string, args ...interfac
sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
}

func processMultipart(w http.ResponseWriter, r *http.Request, request *params, uploadMaxSize, uploadMaxMemory int64) error {
func processMultipart(w http.ResponseWriter, r *http.Request, request *params, closers *[]io.Closer, tmpFiles *[]string, uploadMaxSize, uploadMaxMemory int64) error {
var err error
if r.ContentLength > uploadMaxSize {
return errors.New("failed to parse multipart form, request body too large")
Expand All @@ -546,6 +559,7 @@ func processMultipart(w http.ResponseWriter, r *http.Request, request *params, u
}
return errors.New("failed to parse multipart form")
}
*closers = append(*closers, r.Body)

if err = jsonDecode(strings.NewReader(r.Form.Get("operations")), &request); err != nil {
return errors.New("operations form field could not be decoded")
Expand All @@ -558,46 +572,94 @@ func processMultipart(w http.ResponseWriter, r *http.Request, request *params, u

var upload graphql.Upload
for key, paths := range uploadsMap {
err = func() error {
file, header, err := r.FormFile(key)
if err != nil {
return fmt.Errorf("failed to get key %s from form", key)
}
if len(paths) == 0 {
return fmt.Errorf("invalid empty operations paths list for key %s", key)
}
if len(paths) == 0 {
return fmt.Errorf("invalid empty operations paths list for key %s", key)
}
file, header, err := r.FormFile(key)
if err != nil {
return fmt.Errorf("failed to get key %s from form", key)
}
*closers = append(*closers, file)

if len(paths) == 1 {
upload = graphql.Upload{
File: file,
Size: header.Size,
Filename: header.Filename,
}
for _, path := range paths {
if !strings.HasPrefix(path, "variables.") {
return fmt.Errorf("invalid operations paths for key %s", key)
err = addUploadToOperations(request, upload, key, paths[0])
if err != nil {
return err
}
} else {
if r.ContentLength < uploadMaxMemory {
fileContent, err := ioutil.ReadAll(file)
if err != nil {
return fmt.Errorf("failed to read file for key %s", key)
}
for _, path := range paths {
upload = graphql.Upload{
File: ioutil.NopCloser(bytes.NewReader(fileContent)),
Size: header.Size,
Filename: header.Filename,
}
err = addUploadToOperations(request, upload, key, path)
if err != nil {
return err
}
}
} else {
tmpFile, err := ioutil.TempFile(os.TempDir(), "gqlgen-")
if err != nil {
return fmt.Errorf("failed to create temp file for key %s", key)
}
err = addUploadToOperations(request, upload, path)
tmpName := tmpFile.Name()
*tmpFiles = append(*tmpFiles, tmpName)
_, err = io.Copy(tmpFile, file)
if err != nil {
return err
if err := tmpFile.Close(); err != nil {
return fmt.Errorf("failed to copy to temp file and close temp file for key %s", key)
}
return fmt.Errorf("failed to copy to temp file for key %s", key)
}
if err := tmpFile.Close(); err != nil {
return fmt.Errorf("failed to close temp file for key %s", key)
}
for _, path := range paths {
pathTmpFile, err := os.Open(tmpName)
if err != nil {
return fmt.Errorf("failed to open temp file for key %s", key)
}
*closers = append(*closers, pathTmpFile)
upload = graphql.Upload{
File: pathTmpFile,
Size: header.Size,
Filename: header.Filename,
}
err = addUploadToOperations(request, upload, key, path)
if err != nil {
return err
}
}
}
return nil
}()
if err != nil {
return err
}
}
return nil
}

func addUploadToOperations(request *params, upload graphql.Upload, path string) error {
func addUploadToOperations(request *params, upload graphql.Upload, key, path string) error {
if !strings.HasPrefix(path, "variables.") {
return fmt.Errorf("invalid operations paths for key %s", key)
}

var ptr interface{} = request.Variables
parts := strings.Split(path, ".")

// skip the first part (variables) because we started there
for i, p := range parts[1:] {
last := i == len(parts)-2
if ptr == nil {
return fmt.Errorf("variables is missing, path: %s", path)
return fmt.Errorf("path is missing \"variables.\" prefix, key: %s, path: %s", key, path)
}
if index, parseNbrErr := strconv.Atoi(p); parseNbrErr == nil {
if last {
Expand Down
Loading

0 comments on commit f848415

Please sign in to comment.