From 31478cf4f74564682f6b0160867661c25b0bbe78 Mon Sep 17 00:00:00 2001 From: Adam Scarr Date: Fri, 10 Aug 2018 15:47:19 +1000 Subject: [PATCH] Stop force resolver from picking up types from matching fields --- codegen/object.go | 2 +- codegen/testserver/generated.go | 60 ++++++++++++++++++++++++++++ codegen/testserver/generated_test.go | 9 +++++ codegen/testserver/gqlgen.yml | 4 ++ codegen/testserver/models.go | 5 +++ codegen/testserver/resolver.go | 9 +++++ codegen/testserver/schema.graphql | 4 ++ codegen/util.go | 8 +++- example/todo/todo_test.go | 8 ++-- 9 files changed, 103 insertions(+), 6 deletions(-) create mode 100644 codegen/testserver/models.go diff --git a/codegen/object.go b/codegen/object.go index c9cf52ddbe..262dccc91b 100644 --- a/codegen/object.go +++ b/codegen/object.go @@ -71,7 +71,7 @@ func (o *Object) HasResolvers() bool { } func (f *Field) IsResolver() bool { - return f.ForceResolver || f.GoFieldName == "" + return f.GoFieldName == "" } func (f *Field) IsMethod() bool { diff --git a/codegen/testserver/generated.go b/codegen/testserver/generated.go index 643ad729f0..a7097c1b27 100644 --- a/codegen/testserver/generated.go +++ b/codegen/testserver/generated.go @@ -30,11 +30,15 @@ type Config struct { } type ResolverRoot interface { + ForcedResolver() ForcedResolverResolver Query() QueryResolver } type DirectiveRoot struct { } +type ForcedResolverResolver interface { + Field(ctx context.Context, obj *ForcedResolver) (*Circle, error) +} type QueryResolver interface { InvalidIdentifier(ctx context.Context) (*invalid_packagename.InvalidIdentifier, error) Collision(ctx context.Context) (*introspection1.It, error) @@ -144,6 +148,58 @@ func (ec *executionContext) _Circle_area(ctx context.Context, field graphql.Coll return graphql.MarshalFloat(res) } +var forcedResolverImplementors = []string{"ForcedResolver"} + +// nolint: gocyclo, errcheck, gas, goconst +func (ec *executionContext) _ForcedResolver(ctx context.Context, sel ast.SelectionSet, obj *ForcedResolver) graphql.Marshaler { + fields := graphql.CollectFields(ctx, sel, forcedResolverImplementors) + + out := graphql.NewOrderedMap(len(fields)) + for i, field := range fields { + out.Keys[i] = field.Alias + + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("ForcedResolver") + case "field": + out.Values[i] = ec._ForcedResolver_field(ctx, field, obj) + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + + return out +} + +func (ec *executionContext) _ForcedResolver_field(ctx context.Context, field graphql.CollectedField, obj *ForcedResolver) graphql.Marshaler { + ctx = graphql.WithResolverContext(ctx, &graphql.ResolverContext{ + Object: "ForcedResolver", + Args: nil, + Field: field, + }) + return graphql.Defer(func() (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + userErr := ec.Recover(ctx, r) + ec.Error(ctx, userErr) + ret = graphql.Null + } + }() + + resTmp := ec.FieldMiddleware(ctx, func(ctx context.Context) (interface{}, error) { + return ec.resolvers.ForcedResolver().Field(ctx, obj) + }) + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*Circle) + if res == nil { + return graphql.Null + } + return ec._Circle(ctx, field.Selections, res) + }) +} + var innerObjectImplementors = []string{"InnerObject"} // nolint: gocyclo, errcheck, gas, goconst @@ -2362,5 +2418,9 @@ type Rectangle implements Shape { area: Float } union ShapeUnion = Circle | Rectangle + +type ForcedResolver { + field: Circle +} `}, ) diff --git a/codegen/testserver/generated_test.go b/codegen/testserver/generated_test.go index d055a065ae..b4bb895b74 100644 --- a/codegen/testserver/generated_test.go +++ b/codegen/testserver/generated_test.go @@ -7,7 +7,10 @@ import ( "net/http" "testing" + "reflect" + "github.com/99designs/gqlgen/handler" + "github.com/stretchr/testify/require" ) func TestCompiles(t *testing.T) { @@ -15,3 +18,9 @@ func TestCompiles(t *testing.T) { Resolvers: &Resolver{}, }))) } + +func TestForcedResolverFieldIsPointer(t *testing.T) { + field, ok := reflect.TypeOf((*ForcedResolverResolver)(nil)).Elem().MethodByName("Field") + require.True(t, ok) + require.Equal(t, "*testserver.Circle", field.Type.Out(0).String()) +} diff --git a/codegen/testserver/gqlgen.yml b/codegen/testserver/gqlgen.yml index d7d70f769d..455fecc455 100644 --- a/codegen/testserver/gqlgen.yml +++ b/codegen/testserver/gqlgen.yml @@ -25,3 +25,7 @@ models: model: "github.com/99designs/gqlgen/codegen/testserver.Circle" Rectangle: model: "github.com/99designs/gqlgen/codegen/testserver.Rectangle" + ForcedResolver: + model: "github.com/99designs/gqlgen/codegen/testserver.ForcedResolver" + fields: + field: { resolver: true } diff --git a/codegen/testserver/models.go b/codegen/testserver/models.go new file mode 100644 index 0000000000..4eb1a45112 --- /dev/null +++ b/codegen/testserver/models.go @@ -0,0 +1,5 @@ +package testserver + +type ForcedResolver struct { + Field Circle +} diff --git a/codegen/testserver/resolver.go b/codegen/testserver/resolver.go index eb52ef5a06..ace486f35b 100644 --- a/codegen/testserver/resolver.go +++ b/codegen/testserver/resolver.go @@ -11,10 +11,19 @@ import ( type Resolver struct{} +func (r *Resolver) ForcedResolver() ForcedResolverResolver { + return &forcedResolverResolver{r} +} func (r *Resolver) Query() QueryResolver { return &queryResolver{r} } +type forcedResolverResolver struct{ *Resolver } + +func (r *forcedResolverResolver) Field(ctx context.Context, obj *ForcedResolver) (*Circle, error) { + panic("not implemented") +} + type queryResolver struct{ *Resolver } func (r *queryResolver) InvalidIdentifier(ctx context.Context) (*invalid_packagename.InvalidIdentifier, error) { diff --git a/codegen/testserver/schema.graphql b/codegen/testserver/schema.graphql index 1ae35d45cb..4024fc4271 100644 --- a/codegen/testserver/schema.graphql +++ b/codegen/testserver/schema.graphql @@ -113,3 +113,7 @@ type Rectangle implements Shape { area: Float } union ShapeUnion = Circle | Rectangle + +type ForcedResolver { + field: Circle +} diff --git a/codegen/util.go b/codegen/util.go index f1732ad112..fae94adead 100644 --- a/codegen/util.go +++ b/codegen/util.go @@ -166,6 +166,10 @@ func bindObject(t types.Type, object *Object, imports *Imports) BindErrors { for i := range object.Fields { field := &object.Fields[i] + if field.ForceResolver { + continue + } + // first try binding to a method methodErr := bindMethod(imports, t, field) if methodErr == nil { @@ -261,7 +265,9 @@ nextArg: param := params.At(j) for _, oldArg := range field.Args { if strings.EqualFold(oldArg.GQLName, param.Name()) { - oldArg.Type.Modifiers = modifiersFromGoType(param.Type()) + if !field.ForceResolver { + oldArg.Type.Modifiers = modifiersFromGoType(param.Type()) + } newArgs = append(newArgs, oldArg) continue nextArg } diff --git a/example/todo/todo_test.go b/example/todo/todo_test.go index e7bbad936f..bd7dee60c9 100644 --- a/example/todo/todo_test.go +++ b/example/todo/todo_test.go @@ -223,10 +223,10 @@ func TestSkipAndIncludeDirectives(t *testing.T) { Expected bool } table := []TestCase{ - TestCase{Skip: true, Include: true, Expected: false}, - TestCase{Skip: true, Include: false, Expected: false}, - TestCase{Skip: false, Include: true, Expected: true}, - TestCase{Skip: false, Include: false, Expected: false}, + {Skip: true, Include: true, Expected: false}, + {Skip: true, Include: false, Expected: false}, + {Skip: false, Include: true, Expected: true}, + {Skip: false, Include: false, Expected: false}, } q := `query Test($skip: Boolean!, $include: Boolean!) { todo(id: 1) @skip(if: $skip) @include(if: $include) { __typename } }` for _, tc := range table {