diff --git a/plugin/federation/entity.go b/plugin/federation/entity.go index 0d7fbed6f4d..04a3c033b06 100644 --- a/plugin/federation/entity.go +++ b/plugin/federation/entity.go @@ -1,7 +1,10 @@ package federation import ( + "go/types" + "github.com/99designs/gqlgen/codegen/config" + "github.com/99designs/gqlgen/codegen/templates" "github.com/99designs/gqlgen/plugin/federation/fieldset" "github.com/vektah/gqlparser/v2/ast" ) @@ -17,9 +20,14 @@ type Entity struct { } type EntityResolver struct { - ResolverName string // The resolver name, such as FindUserByID - KeyFields []*KeyField // The fields declared in @key. - InputType string // The Go generated input type for multi entity resolvers + ResolverName string // The resolver name, such as FindUserByID + KeyFields []*KeyField // The fields declared in @key. + InputType types.Type // The Go generated input type for multi entity resolvers + InputTypeName string +} + +func (e *EntityResolver) LookupInputType() string { + return templates.CurrentImports.LookupType(e.InputType) } type KeyField struct { diff --git a/plugin/federation/federation.go b/plugin/federation/federation.go index d0d3cac8a20..f38b618fcba 100644 --- a/plugin/federation/federation.go +++ b/plugin/federation/federation.go @@ -134,12 +134,12 @@ func (f *federation) InjectSourceLate(schema *ast.Schema) *ast.Source { if entityResolverInputDefinitions != "" { entityResolverInputDefinitions += "\n\n" } - entityResolverInputDefinitions += "input " + r.InputType + " {\n" + entityResolverInputDefinitions += "input " + r.InputTypeName + " {\n" for _, keyField := range r.KeyFields { entityResolverInputDefinitions += fmt.Sprintf("\t%s: %s\n", keyField.Field.ToGo(), keyField.Definition.Type.String()) } entityResolverInputDefinitions += "}" - resolvers += fmt.Sprintf("\t%s(reps: [%s!]!): [%s]\n", r.ResolverName, r.InputType, e.Name) + resolvers += fmt.Sprintf("\t%s(reps: [%s!]!): [%s]\n", r.ResolverName, r.InputTypeName, e.Name) } else { resolverArgs := "" for _, keyField := range r.KeyFields { @@ -234,6 +234,23 @@ func (f *federation) GenerateCode(data *codegen.Data) error { } } + // fill in types for resolver inputs + // + for _, entity := range f.Entities { + if !entity.Multi { + continue + } + + for _, resolver := range entity.Resolvers { + obj := data.Inputs.ByName(resolver.InputTypeName) + if obj == nil { + return fmt.Errorf("input object %s not found", resolver.InputTypeName) + } + + resolver.InputType = obj.Type + } + } + return templates.Render(templates.Options{ PackageName: data.Config.Federation.Package, Filename: data.Config.Federation.Filename, @@ -327,9 +344,9 @@ func (f *federation) setEntities(schema *ast.Schema) { } e.Resolvers = append(e.Resolvers, &EntityResolver{ - ResolverName: resolverName, - KeyFields: keyFields, - InputType: resolverFieldsToGo + "Input", + ResolverName: resolverName, + KeyFields: keyFields, + InputTypeName: resolverFieldsToGo + "Input", }) } diff --git a/plugin/federation/federation.gotpl b/plugin/federation/federation.gotpl index 4a30b6c9787..7cf84287eb6 100644 --- a/plugin/federation/federation.gotpl +++ b/plugin/federation/federation.gotpl @@ -133,7 +133,7 @@ func (ec *executionContext) __resolve_entities(ctx context.Context, representati {{ if and .Resolvers .Multi -}} case "{{.Def.Name}}": {{range $i, $_ := .Resolvers -}} - _reps := make([]*{{.InputType}}, len(reps)) + _reps := make([]*{{.LookupInputType}}, len(reps)) for i, rep := range reps { {{ range $i, $keyField := .KeyFields -}} @@ -143,7 +143,7 @@ func (ec *executionContext) __resolve_entities(ctx context.Context, representati } {{end}} - _reps[i] = &{{.InputType}} { + _reps[i] = &{{.LookupInputType}} { {{ range $i, $keyField := .KeyFields -}} {{$keyField.Field.ToGo}}: id{{$i}}, {{end}}