diff --git a/example/todo/todo_test.go b/example/todo/todo_test.go index 95972268401..ab607e02b4d 100644 --- a/example/todo/todo_test.go +++ b/example/todo/todo_test.go @@ -4,11 +4,10 @@ import ( "net/http/httptest" "testing" - "github.com/vektah/gqlgen/client" - "github.com/vektah/gqlgen/neelance/introspection" - "github.com/stretchr/testify/require" + "github.com/vektah/gqlgen/client" "github.com/vektah/gqlgen/handler" + "github.com/vektah/gqlgen/neelance/introspection" ) func TestTodo(t *testing.T) { diff --git a/graphql/string.go b/graphql/string.go index d2bcea0b96f..c549d99a289 100644 --- a/graphql/string.go +++ b/graphql/string.go @@ -6,9 +6,41 @@ import ( "strconv" ) +const alphabet = "0123456789ABCDEF" + func MarshalString(s string) Marshaler { return WriterFunc(func(w io.Writer) { - io.WriteString(w, strconv.Quote(s)) + start := 0 + io.WriteString(w, `"`) + + for i := 0; i < len(s); i++ { + c := s[i] + + if c < 0x20 || c == '\\' || c == '"' { + io.WriteString(w, s[start:i]) + + switch c { + case '\t': + io.WriteString(w, `\t`) + case '\r': + io.WriteString(w, `\r`) + case '\n': + io.WriteString(w, `\n`) + case '\\': + io.WriteString(w, `\\`) + case '"': + io.WriteString(w, `\"`) + default: + io.WriteString(w, `\u00`) + w.Write([]byte{alphabet[c>>4], alphabet[c&0xf]}) + } + + start = i + 1 + } + } + + io.WriteString(w, s[start:]) + io.WriteString(w, `"`) }) } func UnmarshalString(v interface{}) (string, error) { diff --git a/graphql/string_test.go b/graphql/string_test.go new file mode 100644 index 00000000000..915fadac472 --- /dev/null +++ b/graphql/string_test.go @@ -0,0 +1,27 @@ +package graphql + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestString(t *testing.T) { + assert.Equal(t, `"hello"`, doStrMarshal("hello")) + assert.Equal(t, `"he\tllo"`, doStrMarshal("he\tllo")) + assert.Equal(t, `"he\tllo"`, doStrMarshal("he llo")) + assert.Equal(t, `"he\nllo"`, doStrMarshal("he\nllo")) + assert.Equal(t, `"he\r\nllo"`, doStrMarshal("he\r\nllo")) + assert.Equal(t, `"he\\llo"`, doStrMarshal(`he\llo`)) + assert.Equal(t, `"quotes\"nested\"in\"quotes\""`, doStrMarshal(`quotes"nested"in"quotes"`)) + assert.Equal(t, `"\u0000"`, doStrMarshal("\u0000")) + assert.Equal(t, `"\u0000"`, doStrMarshal("\u0000")) + assert.Equal(t, "\"\U000fe4ed\"", doStrMarshal("\U000fe4ed")) +} + +func doStrMarshal(s string) string { + var buf bytes.Buffer + MarshalString(s).MarshalGQL(&buf) + return buf.String() +} diff --git a/handler/graphql.go b/handler/graphql.go index 98e1b90ecf4..fd53fa4ce26 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -175,11 +175,12 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc return } - ctx := graphql.WithRequestContext(r.Context(), cfg.newRequestContext(doc, reqParams.Query, reqParams.Variables)) + reqCtx := cfg.newRequestContext(doc, reqParams.Query, reqParams.Variables) + ctx := graphql.WithRequestContext(r.Context(), reqCtx) defer func() { if err := recover(); err != nil { - userErr := cfg.recover(ctx, err) + userErr := reqCtx.Recover(ctx, err) sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error()) } }() diff --git a/handler/websocket.go b/handler/websocket.go index 7430619d87e..2775416a385 100644 --- a/handler/websocket.go +++ b/handler/websocket.go @@ -154,7 +154,8 @@ func (c *wsConnection) subscribe(message *operationMessage) bool { return true } - ctx := graphql.WithRequestContext(c.ctx, c.cfg.newRequestContext(doc, reqParams.Query, reqParams.Variables)) + reqCtx := c.cfg.newRequestContext(doc, reqParams.Query, reqParams.Variables) + ctx := graphql.WithRequestContext(c.ctx, reqCtx) if op.Type != query.Subscription { var result *graphql.Response @@ -176,7 +177,7 @@ func (c *wsConnection) subscribe(message *operationMessage) bool { go func() { defer func() { if r := recover(); r != nil { - userErr := c.cfg.recover(ctx, r) + userErr := reqCtx.Recover(ctx, r) c.sendError(message.ID, &errors.QueryError{Message: userErr.Error()}) } }() diff --git a/test/generated.go b/test/generated.go index 7967281bbf2..1e114fa8d03 100644 --- a/test/generated.go +++ b/test/generated.go @@ -25,6 +25,7 @@ type Resolvers interface { Query_path(ctx context.Context) ([]Element, error) Query_date(ctx context.Context, filter models.DateFilter) (bool, error) Query_viewer(ctx context.Context) (*Viewer, error) + Query_jsonEncoding(ctx context.Context) (string, error) } type executableSchema struct { @@ -173,6 +174,8 @@ func (ec *executionContext) _Query(ctx context.Context, sel []query.Selection) g out.Values[i] = ec._Query_date(ctx, field) case "viewer": out.Values[i] = ec._Query_viewer(ctx, field) + case "jsonEncoding": + out.Values[i] = ec._Query_jsonEncoding(ctx, field) case "__schema": out.Values[i] = ec._Query___schema(ctx, field) case "__type": @@ -298,6 +301,36 @@ func (ec *executionContext) _Query_viewer(ctx context.Context, field graphql.Col }) } +func (ec *executionContext) _Query_jsonEncoding(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { + ctx = graphql.WithResolverContext(ctx, &graphql.ResolverContext{ + Object: "Query", + Args: nil, + Field: field, + }) + return graphql.Defer(func() (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + userErr := ec.Recover(ctx, r) + ec.Error(ctx, userErr) + ret = graphql.Null + } + }() + + resTmp, err := ec.ResolverMiddleware(ctx, func(ctx context.Context) (interface{}, error) { + return ec.resolvers.Query_jsonEncoding(ctx) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(string) + return graphql.MarshalString(res) + }) +} + func (ec *executionContext) _Query___schema(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { rctx := graphql.GetResolverContext(ctx) rctx.Object = "Query" @@ -1224,6 +1257,7 @@ type Query { path: [Element] date(filter: DateFilter!): Boolean! viewer: Viewer + jsonEncoding: String! } // this is a comment with a ` + "`" + `backtick` + "`" + ` diff --git a/test/resolvers_test.go b/test/resolvers_test.go index e732ed48bc8..ef3cb9c81bc 100644 --- a/test/resolvers_test.go +++ b/test/resolvers_test.go @@ -82,11 +82,28 @@ func TestInputDefaults(t *testing.T) { require.True(t, called) } +func TestJsonEncoding(t *testing.T) { + srv := httptest.NewServer(handler.GraphQL(MakeExecutableSchema(&testResolvers{}))) + c := client.New(srv.URL) + + var resp struct { + JsonEncoding string + } + + err := c.Post(`{ jsonEncoding }`, &resp) + require.NoError(t, err) + require.Equal(t, "\U000fe4ed", resp.JsonEncoding) +} + type testResolvers struct { err error queryDate func(ctx context.Context, filter models.DateFilter) (bool, error) } +func (r *testResolvers) Query_jsonEncoding(ctx context.Context) (string, error) { + return "\U000fe4ed", nil +} + func (r *testResolvers) Query_viewer(ctx context.Context) (*Viewer, error) { return &Viewer{ User: &remote_api.User{"Bob"}, diff --git a/test/schema.graphql b/test/schema.graphql index 9a2abd836ad..6a1918d8d5f 100644 --- a/test/schema.graphql +++ b/test/schema.graphql @@ -29,6 +29,7 @@ type Query { path: [Element] date(filter: DateFilter!): Boolean! viewer: Viewer + jsonEncoding: String! } // this is a comment with a `backtick`