From 2284a3eb7c2d41458c60f2ca603220d4e7ffa80e Mon Sep 17 00:00:00 2001 From: Mathew Byrne Date: Mon, 18 Mar 2019 17:27:41 +1100 Subject: [PATCH 1/2] Improve IsSlice logic to check GQL def Currently TypeReference.IsSlice only looks at the Go type to decide. This should also take into account the GraphQL type as well, to cover cases such as a scalar mapping to []byte --- codegen/config/binder.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/codegen/config/binder.go b/codegen/config/binder.go index 5e8c1cf6969..f39563874b4 100644 --- a/codegen/config/binder.go +++ b/codegen/config/binder.go @@ -193,9 +193,9 @@ func (ref *TypeReference) Elem() *TypeReference { } } - if s, isSlice := ref.GO.(*types.Slice); isSlice { + if ref.IsSlice() { return &TypeReference{ - GO: s.Elem(), + GO: ref.GO.(*types.Slice).Elem(), GQL: ref.GQL.Elem, CastType: ref.CastType, Definition: ref.Definition, @@ -221,7 +221,7 @@ func (t *TypeReference) IsNilable() bool { func (t *TypeReference) IsSlice() bool { _, isSlice := t.GO.(*types.Slice) - return isSlice + return t.GQL.Elem != nil && isSlice } func (t *TypeReference) IsNamed() bool { From 515f22547466f30865a313e46d9ae499468d0cf1 Mon Sep 17 00:00:00 2001 From: Mathew Byrne Date: Mon, 18 Mar 2019 17:37:18 +1100 Subject: [PATCH 2/2] Add test case for custom scalar to slice --- codegen/testserver/bytes.go | 27 ++++++++++++++ codegen/testserver/generated.go | 61 +++++++++++++++++++++++++++++++ codegen/testserver/gqlgen.yml | 2 + codegen/testserver/resolver.go | 3 ++ codegen/testserver/slices.graphql | 3 ++ codegen/testserver/slices_test.go | 12 ++++++ codegen/testserver/stub.go | 4 ++ 7 files changed, 112 insertions(+) create mode 100644 codegen/testserver/bytes.go diff --git a/codegen/testserver/bytes.go b/codegen/testserver/bytes.go new file mode 100644 index 00000000000..42a4541f9d4 --- /dev/null +++ b/codegen/testserver/bytes.go @@ -0,0 +1,27 @@ +package testserver + +import ( + "fmt" + "io" + + "github.com/99designs/gqlgen/graphql" +) + +func MarshalBytes(b []byte) graphql.Marshaler { + return graphql.WriterFunc(func(w io.Writer) { + _, _ = fmt.Fprintf(w, "%q", string(b)) + }) +} + +func UnmarshalBytes(v interface{}) ([]byte, error) { + switch v := v.(type) { + case string: + return []byte(v), nil + case *string: + return []byte(*v), nil + case []byte: + return v, nil + default: + return nil, fmt.Errorf("%T is not []byte", v) + } +} diff --git a/codegen/testserver/generated.go b/codegen/testserver/generated.go index 7b9c87a1f44..ea9638c0e6f 100644 --- a/codegen/testserver/generated.go +++ b/codegen/testserver/generated.go @@ -172,6 +172,7 @@ type ComplexityRoot struct { Overlapping func(childComplexity int) int Panics func(childComplexity int) int Recursive func(childComplexity int, input *RecursiveInputSlice) int + ScalarSlice func(childComplexity int) int ShapeUnion func(childComplexity int) int Shapes func(childComplexity int) int Slices func(childComplexity int) int @@ -270,6 +271,7 @@ type QueryResolver interface { Panics(ctx context.Context) (*Panics, error) DefaultScalar(ctx context.Context, arg string) (string, error) Slices(ctx context.Context) (*Slices, error) + ScalarSlice(ctx context.Context) ([]byte, error) Fallback(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) OptionalUnion(ctx context.Context) (TestUnion, error) ValidType(ctx context.Context) (*ValidType, error) @@ -778,6 +780,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.Recursive(childComplexity, args["input"].(*RecursiveInputSlice)), true + case "Query.ScalarSlice": + if e.complexity.Query.ScalarSlice == nil { + break + } + + return e.complexity.Query.ScalarSlice(childComplexity), true + case "Query.ShapeUnion": if e.complexity.Query.ShapeUnion == nil { break @@ -1312,6 +1321,7 @@ scalar Time `}, &ast.Source{Name: "slices.graphql", Input: `extend type Query { slices: Slices + scalarSlice: Bytes! } type Slices { @@ -1320,6 +1330,8 @@ type Slices { test3: [String]! test4: [String!]! } + +scalar Bytes `}, &ast.Source{Name: "typefallback.graphql", Input: `extend type Query { fallback(arg: FallbackToStringEncoding!): FallbackToStringEncoding! @@ -3765,6 +3777,33 @@ func (ec *executionContext) _Query_slices(ctx context.Context, field graphql.Col return ec.marshalOSlices2ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐSlices(ctx, field.Selections, res) } +func (ec *executionContext) _Query_scalarSlice(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { ec.Tracer.EndFieldExecution(ctx) }() + rctx := &graphql.ResolverContext{ + Object: "Query", + Field: field, + Args: nil, + IsMethod: true, + } + ctx = graphql.WithResolverContext(ctx, rctx) + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec.FieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().ScalarSlice(rctx) + }) + if resTmp == nil { + if !ec.HasError(rctx) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.([]byte) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return ec.marshalNBytes2ᚕbyte(ctx, field.Selections, res) +} + func (ec *executionContext) _Query_fallback(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { ctx = ec.Tracer.StartFieldExecution(ctx, field) defer func() { ec.Tracer.EndFieldExecution(ctx) }() @@ -6620,6 +6659,20 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr res = ec._Query_slices(ctx, field) return res }) + case "scalarSlice": + field := field + out.Concurrently(i, func() (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_scalarSlice(ctx, field) + if res == graphql.Null { + invalid = true + } + return res + }) case "fallback": field := field out.Concurrently(i, func() (res graphql.Marshaler) { @@ -7208,6 +7261,14 @@ func (ec *executionContext) marshalNBoolean2bool(ctx context.Context, sel ast.Se return graphql.MarshalBoolean(v) } +func (ec *executionContext) unmarshalNBytes2ᚕbyte(ctx context.Context, v interface{}) ([]byte, error) { + return UnmarshalBytes(v) +} + +func (ec *executionContext) marshalNBytes2ᚕbyte(ctx context.Context, sel ast.SelectionSet, v []byte) graphql.Marshaler { + return MarshalBytes(v) +} + func (ec *executionContext) unmarshalNDefaultScalarImplementation2string(ctx context.Context, v interface{}) (string, error) { return graphql.UnmarshalString(v) } diff --git a/codegen/testserver/gqlgen.yml b/codegen/testserver/gqlgen.yml index b3e47496542..83555211939 100644 --- a/codegen/testserver/gqlgen.yml +++ b/codegen/testserver/gqlgen.yml @@ -68,3 +68,5 @@ models: oldFoo: { fieldName: foo, resolver: true } FallbackToStringEncoding: model: "github.com/99designs/gqlgen/codegen/testserver.FallbackToStringEncoding" + Bytes: + model: "github.com/99designs/gqlgen/codegen/testserver.Bytes" diff --git a/codegen/testserver/resolver.go b/codegen/testserver/resolver.go index b00e7ee0b69..e752629b677 100644 --- a/codegen/testserver/resolver.go +++ b/codegen/testserver/resolver.go @@ -140,6 +140,9 @@ func (r *queryResolver) DefaultScalar(ctx context.Context, arg string) (string, func (r *queryResolver) Slices(ctx context.Context) (*Slices, error) { panic("not implemented") } +func (r *queryResolver) ScalarSlice(ctx context.Context) ([]byte, error) { + panic("not implemented") +} func (r *queryResolver) Fallback(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) { panic("not implemented") } diff --git a/codegen/testserver/slices.graphql b/codegen/testserver/slices.graphql index 7f4ca489f88..671235d0b37 100644 --- a/codegen/testserver/slices.graphql +++ b/codegen/testserver/slices.graphql @@ -1,5 +1,6 @@ extend type Query { slices: Slices + scalarSlice: Bytes! } type Slices { @@ -8,3 +9,5 @@ type Slices { test3: [String]! test4: [String!]! } + +scalar Bytes diff --git a/codegen/testserver/slices_test.go b/codegen/testserver/slices_test.go index 2516884705c..e67dd6fdb44 100644 --- a/codegen/testserver/slices_test.go +++ b/codegen/testserver/slices_test.go @@ -30,4 +30,16 @@ func TestSlices(t *testing.T) { require.NotNil(t, resp.Slices.Test3) require.NotNil(t, resp.Slices.Test4) }) + + t.Run("custom scalars to slices work", func(t *testing.T) { + resolvers.QueryResolver.ScalarSlice = func(ctx context.Context) ([]byte, error) { + return []byte("testing"), nil + } + + var resp struct { + ScalarSlice string + } + c.MustPost(`query { scalarSlice }`, &resp) + require.Equal(t, "testing", resp.ScalarSlice) + }) } diff --git a/codegen/testserver/stub.go b/codegen/testserver/stub.go index 36a100b96fc..eaebd12d323 100644 --- a/codegen/testserver/stub.go +++ b/codegen/testserver/stub.go @@ -50,6 +50,7 @@ type Stub struct { Panics func(ctx context.Context) (*Panics, error) DefaultScalar func(ctx context.Context, arg string) (string, error) Slices func(ctx context.Context) (*Slices, error) + ScalarSlice func(ctx context.Context) ([]byte, error) Fallback func(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) OptionalUnion func(ctx context.Context) (TestUnion, error) ValidType func(ctx context.Context) (*ValidType, error) @@ -192,6 +193,9 @@ func (r *stubQuery) DefaultScalar(ctx context.Context, arg string) (string, erro func (r *stubQuery) Slices(ctx context.Context) (*Slices, error) { return r.QueryResolver.Slices(ctx) } +func (r *stubQuery) ScalarSlice(ctx context.Context) ([]byte, error) { + return r.QueryResolver.ScalarSlice(ctx) +} func (r *stubQuery) Fallback(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) { return r.QueryResolver.Fallback(ctx, arg) }