From c29d0ed8bd33eaecca81c744dbe94f9a6904e938 Mon Sep 17 00:00:00 2001 From: Matias Anaya Date: Sun, 3 Nov 2019 19:21:43 +1100 Subject: [PATCH] Fixes #843: Bind to embedded struct method or field --- codegen/field.go | 148 ++++++++++----- codegen/field_test.go | 18 +- codegen/testserver/embedded.go | 30 +++ codegen/testserver/embedded.graphql | 12 ++ codegen/testserver/embedded_test.go | 46 +++++ codegen/testserver/generated.go | 281 ++++++++++++++++++++++++++++ codegen/testserver/resolver.go | 6 + codegen/testserver/stub.go | 8 + 8 files changed, 492 insertions(+), 57 deletions(-) create mode 100644 codegen/testserver/embedded.go create mode 100644 codegen/testserver/embedded.graphql create mode 100644 codegen/testserver/embedded_test.go diff --git a/codegen/field.go b/codegen/field.go index 9d9b60934fe..738a55413b0 100644 --- a/codegen/field.go +++ b/codegen/field.go @@ -184,75 +184,111 @@ func (b *builder) bindField(obj *Object, f *Field) (errret error) { } } -// findField attempts to match the name to a struct field with the following -// priorites: -// 1. Any method with a matching name -// 2. Any Fields with a struct tag (see config.StructTag) -// 3. Any fields with a matching name -// 4. Same logic again for embedded fields +// findBindTarget attempts to match the name to a struct field or method +// with the following priorites: +// 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found +// 2. Any method or field with a matching name. Errors if more than one match is found +// 3. Same logic again for embedded fields func (b *builder) findBindTarget(named *types.Named, name string) (types.Object, error) { + strukt, isStruct := named.Underlying().(*types.Struct) + if isStruct { + // NOTE: a struct tag will override both methods and fields + // Bind to struct tag + found, err := b.findBindStructTagTarget(strukt, name) + if found != nil || err != nil { + return found, err + } + } + + // Search for a method to bind to + var foundMethod types.Object for i := 0; i < named.NumMethods(); i++ { method := named.Method(i) - if !method.Exported() { + if !method.Exported() || !strings.EqualFold(method.Name(), name) { continue } - if !strings.EqualFold(method.Name(), name) { - continue + if foundMethod != nil { + return nil, errors.Errorf("found more than one matching method to bind for %s", name) } - return method, nil - } - - strukt, ok := named.Underlying().(*types.Struct) - if !ok { - return nil, fmt.Errorf("not a struct") + foundMethod = method } - return b.findBindStructTarget(strukt, name) -} - -func (b *builder) findBindStructTarget(strukt *types.Struct, name string) (types.Object, error) { - // struct tags have the highest priority - if b.Config.StructTag != "" { - var foundField *types.Var - for i := 0; i < strukt.NumFields(); i++ { - field := strukt.Field(i) - if !field.Exported() { - continue - } - tags := reflect.StructTag(strukt.Tag(i)) - if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) { - if foundField != nil { - return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val) - } - foundField = field - } + // Search for a field to bind to + if isStruct { + foundField, err := b.findBindFieldTarget(strukt, name) + if err != nil { + return nil, err } - if foundField != nil { + + switch { + case foundField == nil && foundMethod == nil: + // Search embeds + return b.findBindEmbedsTarget(strukt, name) + case foundField == nil && foundMethod != nil: + // Bind to method + return foundMethod, nil + case foundField != nil && foundMethod == nil: + // Bind to field return foundField, nil + case foundField != nil && foundMethod != nil: + // Error + return nil, errors.Errorf("found more than one way to bind for %s", name) } } - // Then matching field names + // Bind to method or don't bind at all + return foundMethod, nil +} + +func (b *builder) findBindStructTagTarget(strukt *types.Struct, name string) (types.Object, error) { + if b.Config.StructTag == "" { + return nil, nil + } + + var found types.Object for i := 0; i < strukt.NumFields(); i++ { field := strukt.Field(i) - if !field.Exported() { + if !field.Exported() || field.Embedded() { continue } - if equalFieldName(field.Name(), name) { // aqui! - return field, nil + tags := reflect.StructTag(strukt.Tag(i)) + if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) { + if found != nil { + return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val) + } + + found = field } } - // Then look in embedded structs + return found, nil +} + +func (b *builder) findBindFieldTarget(strukt *types.Struct, name string) (types.Object, error) { + var found types.Object for i := 0; i < strukt.NumFields(); i++ { field := strukt.Field(i) - if !field.Exported() { + if !field.Exported() || !equalFieldName(field.Name(), name) { continue } - if !field.Anonymous() { + if found != nil { + return nil, errors.Errorf("found more than one matching field to bind for %s", name) + } + + found = field + } + + return found, nil +} + +func (b *builder) findBindEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) { + var found types.Object + for i := 0; i < strukt.NumFields(); i++ { + field := strukt.Field(i) + if !field.Embedded() { continue } @@ -267,23 +303,39 @@ func (b *builder) findBindStructTarget(strukt *types.Struct, name string) (types if err != nil { return nil, err } + if f != nil && found != nil { + return nil, errors.Errorf("found more than one way to bind for %s", name) + } if f != nil { - return f, nil + found = f } case *types.Struct: - f, err := b.findBindStructTarget(fieldType, name) + f, err := b.findBindStructTagTarget(fieldType, name) + if err != nil { + return nil, err + } + if f != nil && found != nil { + return nil, errors.Errorf("found more than one way to bind for %s", name) + } + if f != nil { + found = f + continue + } + + f, err = b.findBindFieldTarget(fieldType, name) if err != nil { return nil, err } + if f != nil && found != nil { + return nil, errors.Errorf("found more than one way to bind for %s", name) + } if f != nil { - return f, nil + found = f } - default: - panic(fmt.Errorf("unknown embedded field type %T", field.Type())) } } - return nil, nil + return found, nil } func (f *Field) HasDirectives() bool { diff --git a/codegen/field_test.go b/codegen/field_test.go index 90dddb53780..13177049e3a 100644 --- a/codegen/field_test.go +++ b/codegen/field_test.go @@ -40,15 +40,15 @@ type Embed struct { scope, err := parseScope(input, "test") require.NoError(t, err) - std := scope.Lookup("Std").Type().Underlying().(*types.Struct) - anon := scope.Lookup("Anon").Type().Underlying().(*types.Struct) - tags := scope.Lookup("Tags").Type().Underlying().(*types.Struct) - amb := scope.Lookup("Amb").Type().Underlying().(*types.Struct) - embed := scope.Lookup("Embed").Type().Underlying().(*types.Struct) + std := scope.Lookup("Std").Type().(*types.Named) + anon := scope.Lookup("Anon").Type().(*types.Named) + tags := scope.Lookup("Tags").Type().(*types.Named) + amb := scope.Lookup("Amb").Type().(*types.Named) + embed := scope.Lookup("Embed").Type().(*types.Named) tests := []struct { Name string - Struct *types.Struct + Named *types.Named Field string Tag string Expected string @@ -65,13 +65,13 @@ type Embed struct { for _, tt := range tests { b := builder{Config: &config.Config{StructTag: tt.Tag}} - field, err := b.findBindStructTarget(tt.Struct, tt.Field) + target, err := b.findBindTarget(tt.Named, tt.Field) if tt.ShouldError { - require.Nil(t, field, tt.Name) + require.Nil(t, target, tt.Name) require.Error(t, err, tt.Name) } else { require.NoError(t, err, tt.Name) - require.Equal(t, tt.Expected, field.Name(), tt.Name) + require.Equal(t, tt.Expected, target.Name(), tt.Name) } } } diff --git a/codegen/testserver/embedded.go b/codegen/testserver/embedded.go new file mode 100644 index 00000000000..83d56962e93 --- /dev/null +++ b/codegen/testserver/embedded.go @@ -0,0 +1,30 @@ +package testserver + +// EmbeddedCase1 model +type EmbeddedCase1 struct { + Empty + *ExportedEmbeddedPointerAfterInterface +} + +// Empty interface +type Empty interface{} + +// ExportedEmbeddedPointerAfterInterface model +type ExportedEmbeddedPointerAfterInterface struct{} + +// ExportedEmbeddedPointerExportedMethod method +func (*ExportedEmbeddedPointerAfterInterface) ExportedEmbeddedPointerExportedMethod() string { + return "ExportedEmbeddedPointerExportedMethodResponse" +} + +// EmbeddedCase2 model +type EmbeddedCase2 struct { + *unexportedEmbeddedPointer +} + +type unexportedEmbeddedPointer struct{} + +// UnexportedEmbeddedPointerExportedMethod method +func (*unexportedEmbeddedPointer) UnexportedEmbeddedPointerExportedMethod() string { + return "UnexportedEmbeddedPointerExportedMethodResponse" +} diff --git a/codegen/testserver/embedded.graphql b/codegen/testserver/embedded.graphql new file mode 100644 index 00000000000..ea83c955576 --- /dev/null +++ b/codegen/testserver/embedded.graphql @@ -0,0 +1,12 @@ +extend type Query { + embeddedCase1: EmbeddedCase1 + embeddedCase2: EmbeddedCase2 +} + +type EmbeddedCase1 @goModel(model:"testserver.EmbeddedCase1") { + exportedEmbeddedPointerExportedMethod: String! +} + +type EmbeddedCase2 @goModel(model:"testserver.EmbeddedCase2") { + unexportedEmbeddedPointerExportedMethod: String! +} diff --git a/codegen/testserver/embedded_test.go b/codegen/testserver/embedded_test.go new file mode 100644 index 00000000000..ab6d4bdd863 --- /dev/null +++ b/codegen/testserver/embedded_test.go @@ -0,0 +1,46 @@ +package testserver + +import ( + "context" + "testing" + + "github.com/99designs/gqlgen/client" + "github.com/99designs/gqlgen/handler" + "github.com/stretchr/testify/require" +) + +func TestEmbedded(t *testing.T) { + resolver := &Stub{} + resolver.QueryResolver.EmbeddedCase1 = func(ctx context.Context) (*EmbeddedCase1, error) { + return &EmbeddedCase1{}, nil + } + resolver.QueryResolver.EmbeddedCase2 = func(ctx context.Context) (*EmbeddedCase2, error) { + return &EmbeddedCase2{&unexportedEmbeddedPointer{}}, nil + } + + c := client.New(handler.GraphQL( + NewExecutableSchema(Config{Resolvers: resolver}), + )) + + t.Run("embedded case 1", func(t *testing.T) { + var resp struct { + EmbeddedCase1 struct { + ExportedEmbeddedPointerExportedMethod string + } + } + err := c.Post(`query { embeddedCase1 { exportedEmbeddedPointerExportedMethod } }`, &resp) + require.NoError(t, err) + require.Equal(t, resp.EmbeddedCase1.ExportedEmbeddedPointerExportedMethod, "ExportedEmbeddedPointerExportedMethodResponse") + }) + + t.Run("embedded case 2", func(t *testing.T) { + var resp struct { + EmbeddedCase2 struct { + UnexportedEmbeddedPointerExportedMethod string + } + } + err := c.Post(`query { embeddedCase2 { unexportedEmbeddedPointerExportedMethod } }`, &resp) + require.NoError(t, err) + require.Equal(t, resp.EmbeddedCase2.UnexportedEmbeddedPointerExportedMethod, "UnexportedEmbeddedPointerExportedMethodResponse") + }) +} diff --git a/codegen/testserver/generated.go b/codegen/testserver/generated.go index 0891bd075fa..70c762446c1 100644 --- a/codegen/testserver/generated.go +++ b/codegen/testserver/generated.go @@ -109,6 +109,14 @@ type ComplexityRoot struct { Foo func(childComplexity int) int } + EmbeddedCase1 struct { + ExportedEmbeddedPointerExportedMethod func(childComplexity int) int + } + + EmbeddedCase2 struct { + UnexportedEmbeddedPointerExportedMethod func(childComplexity int) int + } + EmbeddedDefaultScalar struct { Value func(childComplexity int) int } @@ -224,6 +232,8 @@ type ComplexityRoot struct { DirectiveObject func(childComplexity int) int DirectiveObjectWithCustomGoModel func(childComplexity int) int DirectiveUnimplemented func(childComplexity int) int + EmbeddedCase1 func(childComplexity int) int + EmbeddedCase2 func(childComplexity int) int ErrorBubble func(childComplexity int) int Errors func(childComplexity int) int Fallback func(childComplexity int, arg FallbackToStringEncoding) int @@ -366,6 +376,8 @@ type QueryResolver interface { DirectiveField(ctx context.Context) (*string, error) DirectiveDouble(ctx context.Context) (*string, error) DirectiveUnimplemented(ctx context.Context) (*string, error) + EmbeddedCase1(ctx context.Context) (*EmbeddedCase1, error) + EmbeddedCase2(ctx context.Context) (*EmbeddedCase2, error) Shapes(ctx context.Context) ([]Shape, error) NoShape(ctx context.Context) (Shape, error) MapStringInterface(ctx context.Context, in map[string]interface{}) (map[string]interface{}, error) @@ -503,6 +515,20 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.ContentUser.Foo(childComplexity), true + case "EmbeddedCase1.exportedEmbeddedPointerExportedMethod": + if e.complexity.EmbeddedCase1.ExportedEmbeddedPointerExportedMethod == nil { + break + } + + return e.complexity.EmbeddedCase1.ExportedEmbeddedPointerExportedMethod(childComplexity), true + + case "EmbeddedCase2.unexportedEmbeddedPointerExportedMethod": + if e.complexity.EmbeddedCase2.UnexportedEmbeddedPointerExportedMethod == nil { + break + } + + return e.complexity.EmbeddedCase2.UnexportedEmbeddedPointerExportedMethod(childComplexity), true + case "EmbeddedDefaultScalar.value": if e.complexity.EmbeddedDefaultScalar.Value == nil { break @@ -926,6 +952,20 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.DirectiveUnimplemented(childComplexity), true + case "Query.embeddedCase1": + if e.complexity.Query.EmbeddedCase1 == nil { + break + } + + return e.complexity.Query.EmbeddedCase1(childComplexity), true + + case "Query.embeddedCase2": + if e.complexity.Query.EmbeddedCase2 == nil { + break + } + + return e.complexity.Query.EmbeddedCase2(childComplexity), true + case "Query.errorBubble": if e.complexity.Query.ErrorBubble == nil { break @@ -1518,6 +1558,19 @@ type ObjectDirectives { type ObjectDirectivesWithCustomGoModel { nullableText: String @toNull } +`}, + &ast.Source{Name: "embedded.graphql", Input: `extend type Query { + embeddedCase1: EmbeddedCase1 + embeddedCase2: EmbeddedCase2 +} + +type EmbeddedCase1 @goModel(model:"testserver.EmbeddedCase1") { + exportedEmbeddedPointerExportedMethod: String! +} + +type EmbeddedCase2 @goModel(model:"testserver.EmbeddedCase2") { + unexportedEmbeddedPointerExportedMethod: String! +} `}, &ast.Source{Name: "interfaces.graphql", Input: `extend type Query { shapes: [Shape] @@ -3154,6 +3207,74 @@ func (ec *executionContext) _Content_User_foo(ctx context.Context, field graphql return ec.marshalOString2ᚖstring(ctx, field.Selections, res) } +func (ec *executionContext) _EmbeddedCase1_exportedEmbeddedPointerExportedMethod(ctx context.Context, field graphql.CollectedField, obj *EmbeddedCase1) (ret graphql.Marshaler) { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + ec.Tracer.EndFieldExecution(ctx) + }() + rctx := &graphql.ResolverContext{ + Object: "EmbeddedCase1", + Field: field, + Args: nil, + IsMethod: true, + } + ctx = graphql.WithResolverContext(ctx, rctx) + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec._fieldMiddleware(ctx, obj, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.ExportedEmbeddedPointerExportedMethod(), nil + }) + + if resTmp == nil { + if !ec.HasError(rctx) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) _EmbeddedCase2_unexportedEmbeddedPointerExportedMethod(ctx context.Context, field graphql.CollectedField, obj *EmbeddedCase2) (ret graphql.Marshaler) { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + ec.Tracer.EndFieldExecution(ctx) + }() + rctx := &graphql.ResolverContext{ + Object: "EmbeddedCase2", + Field: field, + Args: nil, + IsMethod: true, + } + ctx = graphql.WithResolverContext(ctx, rctx) + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec._fieldMiddleware(ctx, obj, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.UnexportedEmbeddedPointerExportedMethod(), nil + }) + + if resTmp == nil { + if !ec.HasError(rctx) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return ec.marshalNString2string(ctx, field.Selections, res) +} + func (ec *executionContext) _EmbeddedDefaultScalar_value(ctx context.Context, field graphql.CollectedField, obj *EmbeddedDefaultScalar) (ret graphql.Marshaler) { ctx = ec.Tracer.StartFieldExecution(ctx, field) defer func() { @@ -5573,6 +5694,68 @@ func (ec *executionContext) _Query_directiveUnimplemented(ctx context.Context, f return ec.marshalOString2ᚖstring(ctx, field.Selections, res) } +func (ec *executionContext) _Query_embeddedCase1(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + ec.Tracer.EndFieldExecution(ctx) + }() + rctx := &graphql.ResolverContext{ + Object: "Query", + Field: field, + Args: nil, + IsMethod: true, + } + ctx = graphql.WithResolverContext(ctx, rctx) + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec._fieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().EmbeddedCase1(rctx) + }) + + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*EmbeddedCase1) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return ec.marshalOEmbeddedCase12ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐEmbeddedCase1(ctx, field.Selections, res) +} + +func (ec *executionContext) _Query_embeddedCase2(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + ec.Tracer.EndFieldExecution(ctx) + }() + rctx := &graphql.ResolverContext{ + Object: "Query", + Field: field, + Args: nil, + IsMethod: true, + } + ctx = graphql.WithResolverContext(ctx, rctx) + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec._fieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().EmbeddedCase2(rctx) + }) + + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*EmbeddedCase2) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return ec.marshalOEmbeddedCase22ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐEmbeddedCase2(ctx, field.Selections, res) +} + func (ec *executionContext) _Query_shapes(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { ctx = ec.Tracer.StartFieldExecution(ctx, field) defer func() { @@ -8992,6 +9175,60 @@ func (ec *executionContext) _Content_User(ctx context.Context, sel ast.Selection return out } +var embeddedCase1Implementors = []string{"EmbeddedCase1"} + +func (ec *executionContext) _EmbeddedCase1(ctx context.Context, sel ast.SelectionSet, obj *EmbeddedCase1) graphql.Marshaler { + fields := graphql.CollectFields(ec.RequestContext, sel, embeddedCase1Implementors) + + out := graphql.NewFieldSet(fields) + var invalids uint32 + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("EmbeddedCase1") + case "exportedEmbeddedPointerExportedMethod": + out.Values[i] = ec._EmbeddedCase1_exportedEmbeddedPointerExportedMethod(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalids > 0 { + return graphql.Null + } + return out +} + +var embeddedCase2Implementors = []string{"EmbeddedCase2"} + +func (ec *executionContext) _EmbeddedCase2(ctx context.Context, sel ast.SelectionSet, obj *EmbeddedCase2) graphql.Marshaler { + fields := graphql.CollectFields(ec.RequestContext, sel, embeddedCase2Implementors) + + out := graphql.NewFieldSet(fields) + var invalids uint32 + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("EmbeddedCase2") + case "unexportedEmbeddedPointerExportedMethod": + out.Values[i] = ec._EmbeddedCase2_unexportedEmbeddedPointerExportedMethod(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalids > 0 { + return graphql.Null + } + return out +} + var embeddedDefaultScalarImplementors = []string{"EmbeddedDefaultScalar"} func (ec *executionContext) _EmbeddedDefaultScalar(ctx context.Context, sel ast.SelectionSet, obj *EmbeddedDefaultScalar) graphql.Marshaler { @@ -10050,6 +10287,28 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr res = ec._Query_directiveUnimplemented(ctx, field) return res }) + case "embeddedCase1": + field := field + out.Concurrently(i, func() (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_embeddedCase1(ctx, field) + return res + }) + case "embeddedCase2": + field := field + out.Concurrently(i, func() (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_embeddedCase2(ctx, field) + return res + }) case "shapes": field := field out.Concurrently(i, func() (res graphql.Marshaler) { @@ -11695,6 +11954,28 @@ func (ec *executionContext) marshalODefaultScalarImplementation2ᚖstring(ctx co return ec.marshalODefaultScalarImplementation2string(ctx, sel, *v) } +func (ec *executionContext) marshalOEmbeddedCase12githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐEmbeddedCase1(ctx context.Context, sel ast.SelectionSet, v EmbeddedCase1) graphql.Marshaler { + return ec._EmbeddedCase1(ctx, sel, &v) +} + +func (ec *executionContext) marshalOEmbeddedCase12ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐEmbeddedCase1(ctx context.Context, sel ast.SelectionSet, v *EmbeddedCase1) graphql.Marshaler { + if v == nil { + return graphql.Null + } + return ec._EmbeddedCase1(ctx, sel, v) +} + +func (ec *executionContext) marshalOEmbeddedCase22githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐEmbeddedCase2(ctx context.Context, sel ast.SelectionSet, v EmbeddedCase2) graphql.Marshaler { + return ec._EmbeddedCase2(ctx, sel, &v) +} + +func (ec *executionContext) marshalOEmbeddedCase22ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐEmbeddedCase2(ctx context.Context, sel ast.SelectionSet, v *EmbeddedCase2) graphql.Marshaler { + if v == nil { + return graphql.Null + } + return ec._EmbeddedCase2(ctx, sel, v) +} + func (ec *executionContext) marshalOError2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐError(ctx context.Context, sel ast.SelectionSet, v Error) graphql.Marshaler { return ec._Error(ctx, sel, &v) } diff --git a/codegen/testserver/resolver.go b/codegen/testserver/resolver.go index 29df1b49451..a4edc9f6afc 100644 --- a/codegen/testserver/resolver.go +++ b/codegen/testserver/resolver.go @@ -179,6 +179,12 @@ func (r *queryResolver) DirectiveDouble(ctx context.Context) (*string, error) { func (r *queryResolver) DirectiveUnimplemented(ctx context.Context) (*string, error) { panic("not implemented") } +func (r *queryResolver) EmbeddedCase1(ctx context.Context) (*EmbeddedCase1, error) { + panic("not implemented") +} +func (r *queryResolver) EmbeddedCase2(ctx context.Context) (*EmbeddedCase2, error) { + panic("not implemented") +} func (r *queryResolver) Shapes(ctx context.Context) ([]Shape, error) { panic("not implemented") } diff --git a/codegen/testserver/stub.go b/codegen/testserver/stub.go index bcc7ad7f243..0d4d3ee26c7 100644 --- a/codegen/testserver/stub.go +++ b/codegen/testserver/stub.go @@ -63,6 +63,8 @@ type Stub struct { DirectiveField func(ctx context.Context) (*string, error) DirectiveDouble func(ctx context.Context) (*string, error) DirectiveUnimplemented func(ctx context.Context) (*string, error) + EmbeddedCase1 func(ctx context.Context) (*EmbeddedCase1, error) + EmbeddedCase2 func(ctx context.Context) (*EmbeddedCase2, error) Shapes func(ctx context.Context) ([]Shape, error) NoShape func(ctx context.Context) (Shape, error) MapStringInterface func(ctx context.Context, in map[string]interface{}) (map[string]interface{}, error) @@ -263,6 +265,12 @@ func (r *stubQuery) DirectiveDouble(ctx context.Context) (*string, error) { func (r *stubQuery) DirectiveUnimplemented(ctx context.Context) (*string, error) { return r.QueryResolver.DirectiveUnimplemented(ctx) } +func (r *stubQuery) EmbeddedCase1(ctx context.Context) (*EmbeddedCase1, error) { + return r.QueryResolver.EmbeddedCase1(ctx) +} +func (r *stubQuery) EmbeddedCase2(ctx context.Context) (*EmbeddedCase2, error) { + return r.QueryResolver.EmbeddedCase2(ctx) +} func (r *stubQuery) Shapes(ctx context.Context) ([]Shape, error) { return r.QueryResolver.Shapes(ctx) }