diff --git a/codegen/config/binder.go b/codegen/config/binder.go index 514ccc6742..a4f84fed80 100644 --- a/codegen/config/binder.go +++ b/codegen/config/binder.go @@ -151,7 +151,6 @@ func (b *Binder) PointerTo(ref *TypeReference) *TypeReference { newRef := &TypeReference{ GO: types.NewPointer(ref.GO), GQL: ref.GQL, - CastType: ref.CastType, Definition: ref.Definition, Unmarshaler: ref.Unmarshaler, Marshaler: ref.Marshaler, @@ -168,7 +167,6 @@ type TypeReference struct { GQL *ast.Type GO types.Type Target types.Type - CastType types.Type // Before calling marshalling functions cast from/to this base type Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler @@ -180,7 +178,6 @@ func (ref *TypeReference) Elem() *TypeReference { GO: p.Elem(), Target: ref.Target, GQL: ref.GQL, - CastType: ref.CastType, Definition: ref.Definition, Unmarshaler: ref.Unmarshaler, Marshaler: ref.Marshaler, @@ -193,7 +190,6 @@ func (ref *TypeReference) Elem() *TypeReference { GO: ref.GO.(*types.Slice).Elem(), Target: ref.Target, GQL: ref.GQL.Elem, - CastType: ref.CastType, Definition: ref.Definition, Unmarshaler: ref.Unmarshaler, Marshaler: ref.Marshaler, @@ -349,27 +345,16 @@ func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret return nil, err } - if fun, isFunc := obj.(*types.Func); isFunc { + fun, isFunc := obj.(*types.Func) + switch { + case isFunc: ref.GO = fun.Type().(*types.Signature).Params().At(0).Type() ref.Marshaler = fun ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil) - } else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") { + case hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL"): ref.GO = obj.Type() ref.IsMarshaler = true - } else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String { - // Special case for named types wrapping strings. Used by default enum implementations. - - ref.GO = obj.Type() - ref.CastType = underlying - - underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil) - if err != nil { - return nil, err - } - - ref.Marshaler = underlyingRef.Marshaler - ref.Unmarshaler = underlyingRef.Unmarshaler - } else { + default: ref.GO = obj.Type() } @@ -446,19 +431,3 @@ func hasMethod(it types.Type, name string) bool { } return false } - -func basicUnderlying(it types.Type) *types.Basic { - if ptr, isPtr := it.(*types.Pointer); isPtr { - it = ptr.Elem() - } - namedType, ok := it.(*types.Named) - if !ok { - return nil - } - - if basic, ok := namedType.Underlying().(*types.Basic); ok { - return basic - } - - return nil -} diff --git a/codegen/testserver/generated.go b/codegen/testserver/generated.go index fbba338ffe..65a117a243 100644 --- a/codegen/testserver/generated.go +++ b/codegen/testserver/generated.go @@ -277,7 +277,6 @@ type ComplexityRoot struct { EnumInInput func(childComplexity int, input *InputWithEnumValue) int ErrorBubble func(childComplexity int) int Errors func(childComplexity int) int - Fallback func(childComplexity int, arg FallbackToStringEncoding) int InputNullableSlice func(childComplexity int, arg []string) int InputSlice func(childComplexity int, arg []string) int InvalidIdentifier func(childComplexity int) int @@ -462,7 +461,6 @@ type QueryResolver interface { DefaultScalar(ctx context.Context, arg string) (string, error) Slices(ctx context.Context) (*Slices, error) ScalarSlice(ctx context.Context) ([]byte, error) - Fallback(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) OptionalUnion(ctx context.Context) (TestUnion, error) ValidType(ctx context.Context) (*ValidType, error) WrappedStruct(ctx context.Context) (*WrappedStruct, error) @@ -1203,18 +1201,6 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.Errors(childComplexity), true - case "Query.fallback": - if e.complexity.Query.Fallback == nil { - break - } - - args, err := ec.field_Query_fallback_args(context.TODO(), rawArgs) - if err != nil { - return 0, false - } - - return e.complexity.Query.Fallback(childComplexity, args["arg"].(FallbackToStringEncoding)), true - case "Query.inputNullableSlice": if e.complexity.Query.InputNullableSlice == nil { break @@ -2200,16 +2186,6 @@ type Slices { } scalar Bytes -`, BuiltIn: false}, - {Name: "typefallback.graphql", Input: `extend type Query { - fallback(arg: FallbackToStringEncoding!): FallbackToStringEncoding! -} - -enum FallbackToStringEncoding { - A - B - C -} `, BuiltIn: false}, {Name: "useptr.graphql", Input: `type A { id: ID! @@ -2712,21 +2688,6 @@ func (ec *executionContext) field_Query_enumInInput_args(ctx context.Context, ra return args, nil } -func (ec *executionContext) field_Query_fallback_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { - var err error - args := map[string]interface{}{} - var arg0 FallbackToStringEncoding - if tmp, ok := rawArgs["arg"]; ok { - ctx := graphql.WithFieldInputContext(ctx, graphql.NewFieldInputWithField("arg")) - arg0, err = ec.unmarshalNFallbackToStringEncoding2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐFallbackToStringEncoding(ctx, tmp) - if err != nil { - return nil, err - } - } - args["arg"] = arg0 - return args, nil -} - func (ec *executionContext) field_Query_inputNullableSlice_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -7310,44 +7271,6 @@ func (ec *executionContext) _Query_scalarSlice(ctx context.Context, field graphq return ec.marshalNBytes2ᚕbyte(ctx, field.Selections, res) } -func (ec *executionContext) _Query_fallback(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { - defer func() { - if r := recover(); r != nil { - ec.Error(ctx, ec.Recover(ctx, r)) - ret = graphql.Null - } - }() - fc := &graphql.FieldContext{ - Object: "Query", - Field: field, - Args: nil, - IsMethod: true, - } - - ctx = graphql.WithFieldContext(ctx, fc) - rawArgs := field.ArgumentMap(ec.Variables) - args, err := ec.field_Query_fallback_args(ctx, rawArgs) - if err != nil { - ec.Error(ctx, err) - return graphql.Null - } - fc.Args = args - resTmp := ec._fieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { - ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Fallback(rctx, args["arg"].(FallbackToStringEncoding)) - }) - - if resTmp == nil { - if !graphql.HasFieldError(ctx, fc) { - ec.Errorf(ctx, "must not be null") - } - return graphql.Null - } - res := resTmp.(FallbackToStringEncoding) - fc.Result = res - return ec.marshalNFallbackToStringEncoding2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐFallbackToStringEncoding(ctx, field.Selections, res) -} - func (ec *executionContext) _Query_optionalUnion(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -12110,20 +12033,6 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr } return res }) - case "fallback": - field := field - out.Concurrently(i, func() (res graphql.Marshaler) { - defer func() { - if r := recover(); r != nil { - ec.Error(ctx, ec.Recover(ctx, r)) - } - }() - res = ec._Query_fallback(ctx, field) - if res == graphql.Null { - atomic.AddUint32(&invalids, 1) - } - return res - }) case "optionalUnion": field := field out.Concurrently(i, func() (res graphql.Marshaler) { @@ -12953,22 +12862,6 @@ func (ec *executionContext) marshalNError2ᚖgithubᚗcomᚋ99designsᚋgqlgen return ec._Error(ctx, sel, v) } -func (ec *executionContext) unmarshalNFallbackToStringEncoding2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐFallbackToStringEncoding(ctx context.Context, v interface{}) (FallbackToStringEncoding, error) { - tmp, err := graphql.UnmarshalString(v) - return FallbackToStringEncoding(tmp), graphql.WrapErrorWithInputPath(ctx, err) -} - -func (ec *executionContext) marshalNFallbackToStringEncoding2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐFallbackToStringEncoding(ctx context.Context, sel ast.SelectionSet, v FallbackToStringEncoding) graphql.Marshaler { - in := v - res := graphql.MarshalString(string(in)) - if res == graphql.Null { - if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { - ec.Errorf(ctx, "must not be null") - } - } - return res -} - func (ec *executionContext) unmarshalNID2int(ctx context.Context, v interface{}) (int, error) { res, err := graphql.UnmarshalIntID(v) return res, graphql.WrapErrorWithInputPath(ctx, err) @@ -13452,19 +13345,13 @@ func (ec *executionContext) marshalNWrappedMap2githubᚗcomᚋ99designsᚋgqlgen } func (ec *executionContext) unmarshalNWrappedScalar2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedScalar(ctx context.Context, v interface{}) (WrappedScalar, error) { - tmp, err := graphql.UnmarshalString(v) - return WrappedScalar(tmp), graphql.WrapErrorWithInputPath(ctx, err) + var res WrappedScalar + err := res.UnmarshalGQL(v) + return res, graphql.WrapErrorWithInputPath(ctx, err) } func (ec *executionContext) marshalNWrappedScalar2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedScalar(ctx context.Context, sel ast.SelectionSet, v WrappedScalar) graphql.Marshaler { - in := v - res := graphql.MarshalString(string(in)) - if res == graphql.Null { - if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { - ec.Errorf(ctx, "must not be null") - } - } - return res + return v } func (ec *executionContext) marshalNWrappedSlice2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐWrappedSlice(ctx context.Context, sel ast.SelectionSet, v WrappedSlice) graphql.Marshaler { diff --git a/codegen/testserver/resolver.go b/codegen/testserver/resolver.go index 9b0e2e584e..d96bcaca0c 100644 --- a/codegen/testserver/resolver.go +++ b/codegen/testserver/resolver.go @@ -263,10 +263,6 @@ func (r *queryResolver) ScalarSlice(ctx context.Context) ([]byte, error) { panic("not implemented") } -func (r *queryResolver) Fallback(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) { - panic("not implemented") -} - func (r *queryResolver) OptionalUnion(ctx context.Context) (TestUnion, error) { panic("not implemented") } diff --git a/codegen/testserver/stub.go b/codegen/testserver/stub.go index b8580350bb..db82863779 100644 --- a/codegen/testserver/stub.go +++ b/codegen/testserver/stub.go @@ -92,7 +92,6 @@ type Stub struct { DefaultScalar func(ctx context.Context, arg string) (string, error) Slices func(ctx context.Context) (*Slices, error) ScalarSlice func(ctx context.Context) ([]byte, error) - Fallback func(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) OptionalUnion func(ctx context.Context) (TestUnion, error) ValidType func(ctx context.Context) (*ValidType, error) WrappedStruct func(ctx context.Context) (*WrappedStruct, error) @@ -381,9 +380,6 @@ func (r *stubQuery) Slices(ctx context.Context) (*Slices, error) { func (r *stubQuery) ScalarSlice(ctx context.Context) ([]byte, error) { return r.QueryResolver.ScalarSlice(ctx) } -func (r *stubQuery) Fallback(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) { - return r.QueryResolver.Fallback(ctx, arg) -} func (r *stubQuery) OptionalUnion(ctx context.Context) (TestUnion, error) { return r.QueryResolver.OptionalUnion(ctx) } diff --git a/codegen/testserver/typefallback.graphql b/codegen/testserver/typefallback.graphql deleted file mode 100644 index e1ff1a59d7..0000000000 --- a/codegen/testserver/typefallback.graphql +++ /dev/null @@ -1,9 +0,0 @@ -extend type Query { - fallback(arg: FallbackToStringEncoding!): FallbackToStringEncoding! -} - -enum FallbackToStringEncoding { - A - B - C -} diff --git a/codegen/testserver/typefallback_test.go b/codegen/testserver/typefallback_test.go deleted file mode 100644 index 8ebd091e9e..0000000000 --- a/codegen/testserver/typefallback_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package testserver - -import ( - "context" - "testing" - - "github.com/99designs/gqlgen/client" - "github.com/99designs/gqlgen/graphql/handler" - "github.com/stretchr/testify/require" -) - -func TestTypeFallback(t *testing.T) { - resolvers := &Stub{} - - c := client.New(handler.NewDefaultServer(NewExecutableSchema(Config{Resolvers: resolvers}))) - - resolvers.QueryResolver.Fallback = func(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) { - return arg, nil - } - - t.Run("fallback to string passthrough", func(t *testing.T) { - var resp struct { - Fallback string - } - c.MustPost(`query { fallback(arg: A) }`, &resp) - require.Equal(t, "A", resp.Fallback) - }) -} diff --git a/codegen/testserver/wrapped_type.go b/codegen/testserver/wrapped_type.go index d3aa63a79b..bd7ea31006 100644 --- a/codegen/testserver/wrapped_type.go +++ b/codegen/testserver/wrapped_type.go @@ -1,8 +1,28 @@ package testserver -import "github.com/99designs/gqlgen/codegen/testserver/otherpkg" +import ( + "fmt" + "io" + "strconv" + + "github.com/99designs/gqlgen/codegen/testserver/otherpkg" + "github.com/99designs/gqlgen/graphql" +) type WrappedScalar otherpkg.Scalar type WrappedStruct otherpkg.Struct type WrappedMap otherpkg.Map type WrappedSlice otherpkg.Slice + +func (e *WrappedScalar) UnmarshalGQL(v interface{}) error { + s, err := graphql.UnmarshalString(v) + if err != nil { + return err + } + *e = WrappedScalar(s) + return nil +} + +func (e WrappedScalar) MarshalGQL(w io.Writer) { + fmt.Fprint(w, strconv.Quote(string(e))) +} diff --git a/codegen/type.gotpl b/codegen/type.gotpl index f1639ea0db..e09107c3cc 100644 --- a/codegen/type.gotpl +++ b/codegen/type.gotpl @@ -25,10 +25,7 @@ return res, nil {{- else }} {{- if $type.Unmarshaler }} - {{- if $type.CastType }} - tmp, err := {{ $type.Unmarshaler | call }}(v) - return {{ $type.GO | ref }}(tmp), graphql.WrapErrorWithInputPath(ctx, err) - {{- else if and $type.IsTargetNilable (not $type.IsNilable) }} + {{- if and $type.IsTargetNilable (not $type.IsNilable) }} tmp, err := {{ $type.Unmarshaler | call }}(v) res := *tmp return res, graphql.WrapErrorWithInputPath(ctx, err) @@ -137,7 +134,7 @@ in := v {{- end }} {{- if $type.GQL.NonNull }} - res := {{ $type.Marshaler | call }}({{- if $type.CastType }}{{ $type.CastType | ref }}(in){{else}}in{{- end }}) + res := {{ $type.Marshaler | call }}(in) if res == graphql.Null { if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { ec.Errorf(ctx, "must not be null") @@ -145,7 +142,7 @@ } return res {{- else }} - return {{ $type.Marshaler | call }}({{- if $type.CastType }}{{ $type.CastType | ref }}(in){{else}}in{{- end }}) + return {{ $type.Marshaler | call }}(in) {{- end }} {{- else }} return ec._{{$type.Definition.Name}}(ctx, sel, {{ if not $type.IsNilable}}&{{end}} v)