diff --git a/internal/exec/exec.go b/internal/exec/exec.go index 44571873233..fb63c1910e3 100644 --- a/internal/exec/exec.go +++ b/internal/exec/exec.go @@ -21,7 +21,7 @@ type Exec struct { } func Make(s *schema.Schema, resolver interface{}) (*Exec, error) { - t := s.AllTypes[s.EntryPoints["query"]] + t := s.Types[s.EntryPoints["query"]] e, err := makeWithType(s, t, resolver) if err != nil { return nil, err @@ -97,7 +97,7 @@ func makeExec2(s *schema.Schema, t schema.Type, resolverType reflect.Type, typeR return nil, err } - typeAssertions, err := makeTypeAssertions(s, t.Name, t.ImplementedBy, resolverType, typeRefMap) + typeAssertions, err := makeTypeAssertions(s, t.Name, t.PossibleTypes, resolverType, typeRefMap) if err != nil { return nil, err } @@ -109,7 +109,7 @@ func makeExec2(s *schema.Schema, t schema.Type, resolverType reflect.Type, typeR }, nil case *schema.Union: - typeAssertions, err := makeTypeAssertions(s, t.Name, t.Types, resolverType, typeRefMap) + typeAssertions, err := makeTypeAssertions(s, t.Name, t.PossibleTypes, resolverType, typeRefMap) if err != nil { return nil, err } @@ -210,24 +210,20 @@ func makeFieldExecs(s *schema.Schema, typeName string, fields map[string]*schema return fieldExecs, nil } -func makeTypeAssertions(s *schema.Schema, typeName string, impls []string, resolverType reflect.Type, typeRefMap map[typeRefMapKey]*typeRef) (map[string]*typeAssertExec, error) { +func makeTypeAssertions(s *schema.Schema, typeName string, impls []*schema.Object, resolverType reflect.Type, typeRefMap map[typeRefMapKey]*typeRef) (map[string]*typeAssertExec, error) { typeAssertions := make(map[string]*typeAssertExec) for _, impl := range impls { - methodIndex := findMethod(resolverType, "to"+impl) + methodIndex := findMethod(resolverType, "to"+impl.Name) if methodIndex == -1 { - return nil, fmt.Errorf("%s does not resolve %q: missing method %q to convert to %q", resolverType, typeName, "to"+impl, impl) - } - refT, ok := s.AllTypes[impl] - if !ok { - return nil, fmt.Errorf("type %q not found", impl) + return nil, fmt.Errorf("%s does not resolve %q: missing method %q to convert to %q", resolverType, typeName, "to"+impl.Name, impl.Name) } a := &typeAssertExec{ methodIndex: methodIndex, } - if err := makeExec(&a.typeExec, s, refT, resolverType.Method(methodIndex).Type.Out(0), typeRefMap); err != nil { + if err := makeExec(&a.typeExec, s, impl, resolverType.Method(methodIndex).Type.Out(0), typeRefMap); err != nil { return nil, err } - typeAssertions[impl] = a + typeAssertions[impl.Name] = a } return typeAssertions, nil } diff --git a/internal/exec/introspection.go b/internal/exec/introspection.go index 07a2282df53..af6069051ca 100644 --- a/internal/exec/introspection.go +++ b/internal/exec/introspection.go @@ -26,7 +26,7 @@ func init() { { var err error - schemaExec, err = makeWithType(metaSchema, metaSchema.AllTypes["__Schema"], &schemaResolver{}) + schemaExec, err = makeWithType(metaSchema, metaSchema.Types["__Schema"], &schemaResolver{}) if err != nil { panic(err) } @@ -34,7 +34,7 @@ func init() { { var err error - typeExec, err = makeWithType(metaSchema, metaSchema.AllTypes["__Type"], &typeResolver{}) + typeExec, err = makeWithType(metaSchema, metaSchema.Types["__Type"], &typeResolver{}) if err != nil { panic(err) } @@ -46,7 +46,7 @@ func introspectSchema(r *request, selSet *query.SelectionSet) interface{} { } func introspectType(r *request, name string, selSet *query.SelectionSet) interface{} { - t, ok := r.schema.AllTypes[name] + t, ok := r.schema.Types[name] if !ok { return nil } @@ -144,14 +144,14 @@ func (r *schemaResolver) Types() []*typeResolver { var l []*typeResolver addTypes := func(s *schema.Schema, metaOnly bool) { var names []string - for name := range s.AllTypes { + for name := range s.Types { if !metaOnly || strings.HasPrefix(name, "__") { names = append(names, name) } } sort.Strings(names) for _, name := range names { - l = append(l, &typeResolver{s.AllTypes[name]}) + l = append(l, &typeResolver{s.Types[name]}) } } addTypes(r.schema, false) @@ -160,11 +160,11 @@ func (r *schemaResolver) Types() []*typeResolver { } func (r *schemaResolver) QueryType() *typeResolver { - return &typeResolver{typ: r.schema.AllTypes[r.schema.EntryPoints["query"]]} + return &typeResolver{typ: r.schema.Types[r.schema.EntryPoints["query"]]} } func (r *schemaResolver) MutationType() *typeResolver { - return &typeResolver{typ: r.schema.AllTypes[r.schema.EntryPoints["mutation"]]} + return &typeResolver{typ: r.schema.Types[r.schema.EntryPoints["mutation"]]} } func (r *schemaResolver) Directives() []*directiveResolver { diff --git a/internal/schema/schema.go b/internal/schema/schema.go index 2bbc34deb53..e35db182711 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -11,9 +11,10 @@ import ( type Schema struct { EntryPoints map[string]string - AllTypes map[string]Type - Objects map[string]*Object - Interfaces map[string]*Interface + Types map[string]Type + + objects []*Object + unions []*Union } type Type interface { @@ -26,21 +27,25 @@ type Scalar struct { type Object struct { Name string - Implements string + Interfaces []*Interface Fields map[string]*Field FieldOrder []string + + interfaceNames []string } type Interface struct { Name string - ImplementedBy []string + PossibleTypes []*Object Fields map[string]*Field FieldOrder []string } type Union struct { - Name string - Types []string + Name string + PossibleTypes []*Object + + typeNames []string } type Enum struct { @@ -99,25 +104,51 @@ func Parse(schemaString string) (s *Schema, err *errors.GraphQLError) { } sc.Init(strings.NewReader(schemaString)) + c := &context{} l := lexer.New(sc) err = l.CatchSyntaxError(func() { - c := &context{} s = parseSchema(l, c) - for _, ref := range c.typeRefs { - *ref.target = s.AllTypes[ref.name] - } }) if err != nil { return nil, err } - for _, obj := range s.Objects { - if obj.Implements != "" { - intf, ok := s.Interfaces[obj.Implements] + for _, ref := range c.typeRefs { + t, ok := s.Types[ref.name] + if !ok { + return nil, errors.Errorf("type %q not found", ref.name) + } + *ref.target = t + } + + for _, obj := range s.objects { + obj.Interfaces = make([]*Interface, len(obj.interfaceNames)) + for i, intfName := range obj.interfaceNames { + t, ok := s.Types[intfName] + if !ok { + return nil, errors.Errorf("interface %q not found", intfName) + } + intf, ok := t.(*Interface) if !ok { - return nil, errors.Errorf("interface %q not found", obj.Implements) + return nil, errors.Errorf("type %q is not an interface", intfName) } - intf.ImplementedBy = append(intf.ImplementedBy, obj.Name) + obj.Interfaces[i] = intf + intf.PossibleTypes = append(intf.PossibleTypes, obj) + } + } + + for _, union := range s.unions { + union.PossibleTypes = make([]*Object, len(union.typeNames)) + for i, name := range union.typeNames { + t, ok := s.Types[name] + if !ok { + return nil, errors.Errorf("object type %q not found", name) + } + obj, ok := t.(*Object) + if !ok { + return nil, errors.Errorf("type %q is not an object", name) + } + union.PossibleTypes[i] = obj } } @@ -127,15 +158,13 @@ func Parse(schemaString string) (s *Schema, err *errors.GraphQLError) { func parseSchema(l *lexer.Lexer, c *context) *Schema { s := &Schema{ EntryPoints: make(map[string]string), - AllTypes: map[string]Type{ + Types: map[string]Type{ "Int": &Scalar{Name: "Int"}, "Float": &Scalar{Name: "Float"}, "String": &Scalar{Name: "String"}, "Boolean": &Scalar{Name: "Boolean"}, "ID": &Scalar{Name: "ID"}, }, - Objects: make(map[string]*Object), - Interfaces: make(map[string]*Interface), } for l.Peek() != scanner.EOF { @@ -151,21 +180,21 @@ func parseSchema(l *lexer.Lexer, c *context) *Schema { l.ConsumeToken('}') case "type": obj := parseObjectDecl(l, c) - s.AllTypes[obj.Name] = obj - s.Objects[obj.Name] = obj + s.Types[obj.Name] = obj + s.objects = append(s.objects, obj) case "interface": intf := parseInterfaceDecl(l, c) - s.AllTypes[intf.Name] = intf - s.Interfaces[intf.Name] = intf + s.Types[intf.Name] = intf case "union": union := parseUnionDecl(l, c) - s.AllTypes[union.Name] = union + s.Types[union.Name] = union + s.unions = append(s.unions, union) case "enum": enum := parseEnumDecl(l, c) - s.AllTypes[enum.Name] = enum + s.Types[enum.Name] = enum case "input": input := parseInputDecl(l, c) - s.AllTypes[input.Name] = input + s.Types[input.Name] = input default: l.SyntaxError(fmt.Sprintf(`unexpected %q, expecting "schema", "type", "enum", "interface", "union" or "input"`, x)) } @@ -179,7 +208,12 @@ func parseObjectDecl(l *lexer.Lexer, c *context) *Object { o.Name = l.ConsumeIdent() if l.Peek() == scanner.Ident { l.ConsumeKeyword("implements") - o.Implements = l.ConsumeIdent() + for { + o.interfaceNames = append(o.interfaceNames, l.ConsumeIdent()) + if l.Peek() == '{' { + break + } + } } l.ConsumeToken('{') o.Fields, o.FieldOrder = parseFields(l, c) @@ -200,10 +234,10 @@ func parseUnionDecl(l *lexer.Lexer, c *context) *Union { union := &Union{} union.Name = l.ConsumeIdent() l.ConsumeToken('=') - union.Types = []string{l.ConsumeIdent()} + union.typeNames = []string{l.ConsumeIdent()} for l.Peek() == '|' { l.ConsumeToken('|') - union.Types = append(union.Types, l.ConsumeIdent()) + union.typeNames = append(union.typeNames, l.ConsumeIdent()) } return union }