From 46c40b748d55bcecd236c1ca827cef00808a5363 Mon Sep 17 00:00:00 2001 From: Adam Date: Wed, 8 May 2019 20:59:52 +1000 Subject: [PATCH] Fix interface casing fixes #694 --- codegen/testserver/generated.go | 145 ++++++++++++++++++++++++++ codegen/testserver/models-gen.go | 16 +++ codegen/testserver/validtypes.graphql | 10 ++ plugin/modelgen/models.gotpl | 2 +- 4 files changed, 172 insertions(+), 1 deletion(-) diff --git a/codegen/testserver/generated.go b/codegen/testserver/generated.go index d10e371039f..a5d1f79e5b8 100644 --- a/codegen/testserver/generated.go +++ b/codegen/testserver/generated.go @@ -85,6 +85,14 @@ type ComplexityRoot struct { Radius func(childComplexity int) int } + ContentPost struct { + Foo func(childComplexity int) int + } + + ContentUser struct { + Foo func(childComplexity int) int + } + EmbeddedDefaultScalar struct { Value func(childComplexity int) int } @@ -376,6 +384,20 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Circle.Radius(childComplexity), true + case "Content_Post.foo": + if e.complexity.ContentPost.Foo == nil { + break + } + + return e.complexity.ContentPost.Foo(childComplexity), true + + case "Content_User.foo": + if e.complexity.ContentUser.Foo == nil { + break + } + + return e.complexity.ContentUser.Foo(childComplexity), true + case "EmbeddedDefaultScalar.value": if e.complexity.EmbeddedDefaultScalar.Value == nil { break @@ -1425,6 +1447,16 @@ input ValidInput { _: String! } +# see https://github.com/99designs/gqlgen/issues/694 +type Content_User { + foo: String +} + +type Content_Post { + foo: String +} + +union Content_Child = Content_User | Content_Post `}, &ast.Source{Name: "weird_type_cases.graphql", Input: `# regression test for https://github.com/99designs/gqlgen/issues/583 @@ -2331,6 +2363,54 @@ func (ec *executionContext) _Circle_area(ctx context.Context, field graphql.Coll return ec.marshalOFloat2float64(ctx, field.Selections, res) } +func (ec *executionContext) _Content_Post_foo(ctx context.Context, field graphql.CollectedField, obj *ContentPost) graphql.Marshaler { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { ec.Tracer.EndFieldExecution(ctx) }() + rctx := &graphql.ResolverContext{ + Object: "Content_Post", + Field: field, + Args: nil, + IsMethod: false, + } + ctx = graphql.WithResolverContext(ctx, rctx) + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec.FieldMiddleware(ctx, obj, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Foo, nil + }) + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return ec.marshalOString2áš–string(ctx, field.Selections, res) +} + +func (ec *executionContext) _Content_User_foo(ctx context.Context, field graphql.CollectedField, obj *ContentUser) graphql.Marshaler { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { ec.Tracer.EndFieldExecution(ctx) }() + rctx := &graphql.ResolverContext{ + Object: "Content_User", + Field: field, + Args: nil, + IsMethod: false, + } + ctx = graphql.WithResolverContext(ctx, rctx) + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec.FieldMiddleware(ctx, obj, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Foo, nil + }) + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return ec.marshalOString2áš–string(ctx, field.Selections, res) +} + func (ec *executionContext) _EmbeddedDefaultScalar_value(ctx context.Context, field graphql.CollectedField, obj *EmbeddedDefaultScalar) graphql.Marshaler { ctx = ec.Tracer.StartFieldExecution(ctx, field) defer func() { ec.Tracer.EndFieldExecution(ctx) }() @@ -5654,6 +5734,23 @@ func (ec *executionContext) unmarshalInputValidInput(ctx context.Context, v inte // region ************************** interface.gotpl *************************** +func (ec *executionContext) _Content_Child(ctx context.Context, sel ast.SelectionSet, obj *ContentChild) graphql.Marshaler { + switch obj := (*obj).(type) { + case nil: + return graphql.Null + case ContentUser: + return ec._Content_User(ctx, sel, &obj) + case *ContentUser: + return ec._Content_User(ctx, sel, obj) + case ContentPost: + return ec._Content_Post(ctx, sel, &obj) + case *ContentPost: + return ec._Content_Post(ctx, sel, obj) + default: + panic(fmt.Errorf("unexpected type %T", obj)) + } +} + func (ec *executionContext) _Shape(ctx context.Context, sel ast.SelectionSet, obj *Shape) graphql.Marshaler { switch obj := (*obj).(type) { case nil: @@ -5882,6 +5979,54 @@ func (ec *executionContext) _Circle(ctx context.Context, sel ast.SelectionSet, o return out } +var content_PostImplementors = []string{"Content_Post", "Content_Child"} + +func (ec *executionContext) _Content_Post(ctx context.Context, sel ast.SelectionSet, obj *ContentPost) graphql.Marshaler { + fields := graphql.CollectFields(ec.RequestContext, sel, content_PostImplementors) + + out := graphql.NewFieldSet(fields) + invalid := false + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("Content_Post") + case "foo": + out.Values[i] = ec._Content_Post_foo(ctx, field, obj) + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalid { + return graphql.Null + } + return out +} + +var content_UserImplementors = []string{"Content_User", "Content_Child"} + +func (ec *executionContext) _Content_User(ctx context.Context, sel ast.SelectionSet, obj *ContentUser) graphql.Marshaler { + fields := graphql.CollectFields(ec.RequestContext, sel, content_UserImplementors) + + out := graphql.NewFieldSet(fields) + invalid := false + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("Content_User") + case "foo": + out.Values[i] = ec._Content_User_foo(ctx, field, obj) + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalid { + return graphql.Null + } + return out +} + var embeddedDefaultScalarImplementors = []string{"EmbeddedDefaultScalar"} func (ec *executionContext) _EmbeddedDefaultScalar(ctx context.Context, sel ast.SelectionSet, obj *EmbeddedDefaultScalar) graphql.Marshaler { diff --git a/codegen/testserver/models-gen.go b/codegen/testserver/models-gen.go index c1331c1dba3..f21aa31b3a8 100644 --- a/codegen/testserver/models-gen.go +++ b/codegen/testserver/models-gen.go @@ -9,6 +9,10 @@ import ( "time" ) +type ContentChild interface { + IsContentChild() +} + type TestUnion interface { IsTestUnion() } @@ -33,6 +37,18 @@ type B struct { func (B) IsTestUnion() {} +type ContentPost struct { + Foo *string `json:"foo"` +} + +func (ContentPost) IsContentChild() {} + +type ContentUser struct { + Foo *string `json:"foo"` +} + +func (ContentUser) IsContentChild() {} + type EmbeddedDefaultScalar struct { Value *string `json:"value"` } diff --git a/codegen/testserver/validtypes.graphql b/codegen/testserver/validtypes.graphql index f344a9f2de5..5ed068c5f47 100644 --- a/codegen/testserver/validtypes.graphql +++ b/codegen/testserver/validtypes.graphql @@ -66,3 +66,13 @@ input ValidInput { _: String! } +# see https://github.com/99designs/gqlgen/issues/694 +type Content_User { + foo: String +} + +type Content_Post { + foo: String +} + +union Content_Child = Content_User | Content_Post diff --git a/plugin/modelgen/models.gotpl b/plugin/modelgen/models.gotpl index d06cf050f7d..6df200ee064 100644 --- a/plugin/modelgen/models.gotpl +++ b/plugin/modelgen/models.gotpl @@ -31,7 +31,7 @@ } {{- range $iface := .Implements }} - func ({{ $model.Name|go }}) Is{{ $iface }}() {} + func ({{ $model.Name|go }}) Is{{ $iface|go }}() {} {{- end }} {{- end}}