diff --git a/codegen/config/binder.go b/codegen/config/binder.go index 26c9cdb7318..00cb5f62fd9 100644 --- a/codegen/config/binder.go +++ b/codegen/config/binder.go @@ -206,7 +206,7 @@ func (t *TypeReference) IsPtr() bool { } func (t *TypeReference) IsNilable() bool { - return isNilable(t.GO) + return IsNilable(t.GO) } func (t *TypeReference) IsSlice() bool { @@ -403,14 +403,14 @@ func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type { _, isInterface = named.Underlying().(*types.Interface) } - if !isInterface && !isNilable(base) && !t.NonNull { + if !isInterface && !IsNilable(base) && !t.NonNull { return types.NewPointer(base) } return base } -func isNilable(t types.Type) bool { +func IsNilable(t types.Type) bool { if namedType, isNamed := t.(*types.Named); isNamed { t = namedType.Underlying() } diff --git a/codegen/data.go b/codegen/data.go index 2f0ec3fb4e0..513723ebb53 100644 --- a/codegen/data.go +++ b/codegen/data.go @@ -80,7 +80,10 @@ func BuildData(cfg *config.Config) (*Data, error) { s.Inputs = append(s.Inputs, input) case ast.Union, ast.Interface: - s.Interfaces[schemaType.Name] = b.buildInterface(schemaType) + s.Interfaces[schemaType.Name], err = b.buildInterface(schemaType) + if err != nil { + return nil, errors.Wrap(err, "unable to bind to interface") + } } } diff --git a/codegen/interface.go b/codegen/interface.go index f59e8ed0715..0ccbf87cda1 100644 --- a/codegen/interface.go +++ b/codegen/interface.go @@ -1,9 +1,13 @@ package codegen import ( + "fmt" "go/types" + "github.com/pkg/errors" "github.com/vektah/gqlparser/ast" + + "github.com/99designs/gqlgen/codegen/config" ) type Interface struct { @@ -16,11 +20,11 @@ type Interface struct { type InterfaceImplementor struct { *ast.Definition - Interface *Interface - Type types.Type + Type types.Type + TakeRef bool } -func (b *builder) buildInterface(typ *ast.Definition) *Interface { +func (b *builder) buildInterface(typ *ast.Definition) (*Interface, error) { obj, err := b.Binder.DefaultUserObject(typ.Name) if err != nil { panic(err) @@ -32,32 +36,53 @@ func (b *builder) buildInterface(typ *ast.Definition) *Interface { InTypemap: b.Config.Models.UserDefined(typ.Name), } + interfaceType, err := findGoInterface(i.Type) + if interfaceType == nil || err != nil { + return nil, fmt.Errorf("%s is not an interface", i.Type) + } + for _, implementor := range b.Schema.GetPossibleTypes(typ) { obj, err := b.Binder.DefaultUserObject(implementor.Name) if err != nil { - panic(err) + return nil, fmt.Errorf("%s has no backing go type", implementor.Name) } - i.Implementors = append(i.Implementors, InterfaceImplementor{ - Definition: implementor, - Type: obj, - Interface: i, - }) - } + implementorType, err := findGoNamedType(obj) + if err != nil { + return nil, errors.Wrapf(err, "can not find backing go type %s", obj.String()) + } else if implementorType == nil { + return nil, fmt.Errorf("can not find backing go type %s", obj.String()) + } - return i -} + anyValid := false -func (i *InterfaceImplementor) ValueReceiver() bool { - interfaceType, err := findGoInterface(i.Interface.Type) - if interfaceType == nil || err != nil { - return true - } + // first check if the value receiver can be nil, eg can we type switch on case Thing: + if types.Implements(implementorType, interfaceType) { + i.Implementors = append(i.Implementors, InterfaceImplementor{ + Definition: implementor, + Type: obj, + TakeRef: !types.IsInterface(obj), + }) + anyValid = true + } + + // then check if the pointer receiver can be nil, eg can we type switch on case *Thing: + if types.Implements(types.NewPointer(implementorType), interfaceType) { + i.Implementors = append(i.Implementors, InterfaceImplementor{ + Definition: implementor, + Type: types.NewPointer(obj), + }) + anyValid = true + } - implementorType, err := findGoNamedType(i.Type) - if implementorType == nil || err != nil { - return true + if !anyValid { + return nil, fmt.Errorf("%s does not satisfy the interface %s", implementorType.String(), i.Type.String()) + } } - return types.Implements(implementorType, interfaceType) + return i, nil +} + +func (i *InterfaceImplementor) CanBeNil() bool { + return config.IsNilable(i.Type) } diff --git a/codegen/interface.gotpl b/codegen/interface.gotpl index bfb42b25d64..e9d560c8f64 100644 --- a/codegen/interface.gotpl +++ b/codegen/interface.gotpl @@ -5,15 +5,13 @@ func (ec *executionContext) _{{$interface.Name}}(ctx context.Context, sel ast.Se case nil: return graphql.Null {{- range $implementor := $interface.Implementors }} - {{- if $implementor.ValueReceiver }} - case {{$implementor.Type | ref}}: - return ec._{{$implementor.Name}}(ctx, sel, &obj) - {{- end}} - case *{{$implementor.Type | ref}}: - if obj == nil { - return graphql.Null - } - return ec._{{$implementor.Name}}(ctx, sel, obj) + case {{$implementor.Type | ref}}: + {{- if $implementor.CanBeNil }} + if obj == nil { + return graphql.Null + } + {{- end }} + return ec._{{$implementor.Name}}(ctx, sel, {{ if $implementor.TakeRef }}&{{ end }}obj) {{- end }} default: panic(fmt.Errorf("unexpected type %T", obj)) diff --git a/codegen/testserver/generated.go b/codegen/testserver/generated.go index bdddfe6bd44..53ac24f87bf 100644 --- a/codegen/testserver/generated.go +++ b/codegen/testserver/generated.go @@ -116,6 +116,11 @@ type ComplexityRoot struct { Name func(childComplexity int) int } + ConcreteNodeInterface struct { + Child func(childComplexity int) int + ID func(childComplexity int) int + } + ContentPost struct { Foo func(childComplexity int) int } @@ -607,6 +612,20 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.ConcreteNodeA.Name(childComplexity), true + case "ConcreteNodeInterface.child": + if e.complexity.ConcreteNodeInterface.Child == nil { + break + } + + return e.complexity.ConcreteNodeInterface.Child(childComplexity), true + + case "ConcreteNodeInterface.id": + if e.complexity.ConcreteNodeInterface.ID == nil { + break + } + + return e.complexity.ConcreteNodeInterface.ID(childComplexity), true + case "Content_Post.foo": if e.complexity.ContentPost.Foo == nil { break @@ -1823,6 +1842,12 @@ type ConcreteNodeA implements Node { child: Node! name: String! } + +""" Implements the Node interface with another interface """ +type ConcreteNodeInterface implements Node { + id: ID! + child: Node! +} `, BuiltIn: false}, &ast.Source{Name: "issue896.graphql", Input: `# This example should build stable output. If the file content starts # alternating nondeterministically between two outputs, then see @@ -3653,6 +3678,68 @@ func (ec *executionContext) _ConcreteNodeA_name(ctx context.Context, field graph return ec.marshalNString2string(ctx, field.Selections, res) } +func (ec *executionContext) _ConcreteNodeInterface_id(ctx context.Context, field graphql.CollectedField, obj ConcreteNodeInterface) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "ConcreteNodeInterface", + Field: field, + Args: nil, + IsMethod: true, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp := ec._fieldMiddleware(ctx, obj, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.ID(), nil + }) + + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNID2string(ctx, field.Selections, res) +} + +func (ec *executionContext) _ConcreteNodeInterface_child(ctx context.Context, field graphql.CollectedField, obj ConcreteNodeInterface) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "ConcreteNodeInterface", + Field: field, + Args: nil, + IsMethod: true, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp := ec._fieldMiddleware(ctx, obj, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Child() + }) + + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(Node) + fc.Result = res + return ec.marshalNNode2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐNode(ctx, field.Selections, res) +} + func (ec *executionContext) _Content_Post_foo(ctx context.Context, field graphql.CollectedField, obj *ContentPost) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -9399,6 +9486,11 @@ func (ec *executionContext) _Node(ctx context.Context, sel ast.SelectionSet, obj return graphql.Null } return ec._ConcreteNodeA(ctx, sel, obj) + case ConcreteNodeInterface: + if obj == nil { + return graphql.Null + } + return ec._ConcreteNodeInterface(ctx, sel, obj) default: panic(fmt.Errorf("unexpected type %T", obj)) } @@ -9789,6 +9881,38 @@ func (ec *executionContext) _ConcreteNodeA(ctx context.Context, sel ast.Selectio return out } +var concreteNodeInterfaceImplementors = []string{"ConcreteNodeInterface", "Node"} + +func (ec *executionContext) _ConcreteNodeInterface(ctx context.Context, sel ast.SelectionSet, obj ConcreteNodeInterface) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, concreteNodeInterfaceImplementors) + + out := graphql.NewFieldSet(fields) + var invalids uint32 + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("ConcreteNodeInterface") + case "id": + out.Values[i] = ec._ConcreteNodeInterface_id(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + case "child": + out.Values[i] = ec._ConcreteNodeInterface_child(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalids > 0 { + return graphql.Null + } + return out +} + var content_PostImplementors = []string{"Content_Post", "Content_Child"} func (ec *executionContext) _Content_Post(ctx context.Context, sel ast.SelectionSet, obj *ContentPost) graphql.Marshaler { diff --git a/codegen/testserver/interfaces.go b/codegen/testserver/interfaces.go index 41eb7e24008..87243965a2f 100644 --- a/codegen/testserver/interfaces.go +++ b/codegen/testserver/interfaces.go @@ -47,6 +47,25 @@ func (n *ConcreteNodeA) Child() (Node, error) { return n.child, nil } +// Implements the Node interface with another interface +type ConcreteNodeInterface interface { + Node + ID() string +} + +type ConcreteNodeInterfaceImplementor struct{} + +func (c ConcreteNodeInterfaceImplementor) ID() string { + return "CNII" +} + +func (c ConcreteNodeInterfaceImplementor) Child() (Node, error) { + return &ConcreteNodeA{ + ID: "Child", + Name: "child", + }, nil +} + type BackedByInterface interface { ThisShouldBind() string ThisShouldBindWithError() (string, error) diff --git a/codegen/testserver/interfaces.graphql b/codegen/testserver/interfaces.graphql index 08eb42e8e8c..8db93578173 100644 --- a/codegen/testserver/interfaces.graphql +++ b/codegen/testserver/interfaces.graphql @@ -54,3 +54,9 @@ type ConcreteNodeA implements Node { child: Node! name: String! } + +""" Implements the Node interface with another interface """ +type ConcreteNodeInterface implements Node { + id: ID! + child: Node! +} diff --git a/codegen/testserver/interfaces_test.go b/codegen/testserver/interfaces_test.go index 1e44bbfc5ad..b8ae9129b1a 100644 --- a/codegen/testserver/interfaces_test.go +++ b/codegen/testserver/interfaces_test.go @@ -176,4 +176,25 @@ func TestInterfaces(t *testing.T) { err := c.Post(`{ notAnInterface { id, thisShouldBind, thisShouldBindWithError } }`, &resp) require.EqualError(t, err, `[{"message":"boom","path":["notAnInterface","thisShouldBindWithError"]}]`) }) + + t.Run("interfaces can implement other interfaces", func(t *testing.T) { + resolvers := &Stub{} + resolvers.QueryResolver.Node = func(ctx context.Context) (node Node, err error) { + return ConcreteNodeInterfaceImplementor{}, nil + } + + c := client.New(handler.NewDefaultServer(NewExecutableSchema(Config{Resolvers: resolvers}))) + + var resp struct { + Node struct { + ID string + Child struct { + ID string + } + } + } + c.MustPost(`{ node { id, child { id } } }`, &resp) + require.Equal(t, "CNII", resp.Node.ID) + require.Equal(t, "Child", resp.Node.Child.ID) + }) }