diff --git a/.gitignore b/.gitignore index 7b3bcd13..cf07dd5e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ /internal/validation/testdata/graphql-js /internal/validation/testdata/node_modules /vendor +.DS_Store +.idea/ +.vscode/ diff --git a/config/config.go b/config/config.go new file mode 100644 index 00000000..05bc9f8f --- /dev/null +++ b/config/config.go @@ -0,0 +1,11 @@ +package config + +type Config struct { + UseResolverMethods bool +} + +func Default() *Config { + return &Config{ + UseResolverMethods: true, + } +} diff --git a/example/starwars/server/server.go b/example/starwars/server/server.go index cfd8d08d..c0b87d70 100644 --- a/example/starwars/server/server.go +++ b/example/starwars/server/server.go @@ -12,7 +12,7 @@ import ( var schema *graphql.Schema func init() { - schema = graphql.MustParseSchema(starwars.Schema, &starwars.Resolver{}) + schema = graphql.MustParseSchema(starwars.Schema, &starwars.Resolver{}, nil) } func main() { diff --git a/graphql.go b/graphql.go index 90779bb5..544c3eea 100644 --- a/graphql.go +++ b/graphql.go @@ -2,9 +2,8 @@ package graphql import ( "context" - "fmt" - "encoding/json" + "fmt" "github.com/graph-gophers/graphql-go/errors" "github.com/graph-gophers/graphql-go/internal/common" diff --git a/graphql_test.go b/graphql_test.go index 1c1bd082..8a0f70b5 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -61,7 +61,7 @@ func (r *timeResolver) AddHour(args struct{ Time graphql.Time }) graphql.Time { return graphql.Time{Time: args.Time.Add(time.Hour)} } -var starwarsSchema = graphql.MustParseSchema(starwars.Schema, &starwars.Resolver{}) +var starwarsSchema = graphql.MustParseSchema(starwars.Schema, &starwars.Resolver{}, nil) func TestHelloWorld(t *testing.T) { gqltesting.RunTests(t, []*gqltesting.Test{ @@ -74,7 +74,7 @@ func TestHelloWorld(t *testing.T) { type Query { hello: String! } - `, &helloWorldResolver1{}), + `, &helloWorldResolver1{}, nil), Query: ` { hello @@ -96,7 +96,7 @@ func TestHelloWorld(t *testing.T) { type Query { hello: String! } - `, &helloWorldResolver2{}), + `, &helloWorldResolver2{}, nil), Query: ` { hello @@ -122,7 +122,7 @@ func TestHelloSnake(t *testing.T) { type Query { hello_html: String! } - `, &helloSnakeResolver1{}), + `, &helloSnakeResolver1{}, nil), Query: ` { hello_html @@ -144,7 +144,7 @@ func TestHelloSnake(t *testing.T) { type Query { hello_html: String! } - `, &helloSnakeResolver2{}), + `, &helloSnakeResolver2{}, nil), Query: ` { hello_html @@ -170,7 +170,7 @@ func TestHelloSnakeArguments(t *testing.T) { type Query { say_hello(full_name: String!): String! } - `, &helloSnakeResolver1{}), + `, &helloSnakeResolver1{}, nil), Query: ` { say_hello(full_name: "Rob Pike") @@ -192,7 +192,7 @@ func TestHelloSnakeArguments(t *testing.T) { type Query { say_hello(full_name: String!): String! } - `, &helloSnakeResolver2{}), + `, &helloSnakeResolver2{}, nil), Query: ` { say_hello(full_name: "Rob Pike") @@ -606,7 +606,7 @@ func TestDeprecatedDirective(t *testing.T) { b: Int! @deprecated c: Int! @deprecated(reason: "We don't like it") } - `, &testDeprecatedDirectiveResolver{}), + `, &testDeprecatedDirectiveResolver{}, nil), Query: ` { __type(name: "Query") { @@ -650,7 +650,7 @@ func TestDeprecatedDirective(t *testing.T) { B @deprecated C @deprecated(reason: "We don't like it") } - `, &testDeprecatedDirectiveResolver{}), + `, &testDeprecatedDirectiveResolver{}, nil), Query: ` { __type(name: "Test") { @@ -1441,7 +1441,7 @@ func TestMutationOrder(t *testing.T) { type Mutation { changeTheNumber(newNumber: Int!): Query } - `, &theNumberResolver{}), + `, &theNumberResolver{}, nil), Query: ` mutation { first: changeTheNumber(newNumber: 1) { @@ -1485,7 +1485,7 @@ func TestTime(t *testing.T) { } scalar Time - `, &timeResolver{}), + `, &timeResolver{}, nil), Query: ` query($t: Time!) { a: addHour(time: $t) @@ -1520,7 +1520,7 @@ func TestUnexportedMethod(t *testing.T) { type Mutation { changeTheNumber(newNumber: Int!): Int! } - `, &resolverWithUnexportedMethod{}) + `, &resolverWithUnexportedMethod{}, nil) if err == nil { t.Error("error expected") } @@ -1541,7 +1541,7 @@ func TestUnexportedField(t *testing.T) { type Mutation { changeTheNumber(newNumber: Int!): Int! } - `, &resolverWithUnexportedField{}) + `, &resolverWithUnexportedField{}, nil) if err == nil { t.Error("error expected") } @@ -1648,7 +1648,7 @@ func TestInput(t *testing.T) { Option1 Option2 } - `, &inputResolver{}) + `, &inputResolver{}, nil) gqltesting.RunTests(t, []*gqltesting.Test{ { Schema: coercionSchema, diff --git a/internal/exec/exec.go b/internal/exec/exec.go index f7149619..ced1c52e 100644 --- a/internal/exec/exec.go +++ b/internal/exec/exec.go @@ -173,22 +173,28 @@ func execFieldSelection(ctx context.Context, r *Request, f *fieldToExec, path *p return errors.Errorf("%s", err) // don't execute any more resolvers if context got cancelled } - var in []reflect.Value - if f.field.HasContext { - in = append(in, reflect.ValueOf(traceCtx)) - } - if f.field.ArgsPacker != nil { - in = append(in, f.field.PackedArgs) - } - callOut := f.resolver.Method(f.field.MethodIndex).Call(in) - result = callOut[0] - if f.field.HasError && !callOut[1].IsNil() { - resolverErr := callOut[1].Interface().(error) - err := errors.Errorf("%s", resolverErr) - err.Path = path.toSlice() - err.ResolverError = resolverErr - return err + if f.field.MethodIndex != -1 { + var in []reflect.Value + if f.field.HasContext { + in = append(in, reflect.ValueOf(traceCtx)) + } + if f.field.ArgsPacker != nil { + in = append(in, f.field.PackedArgs) + } + + callOut := f.resolver.Method(f.field.MethodIndex).Call(in) + result = callOut[0] + if f.field.HasError && !callOut[1].IsNil() { + resolverErr := callOut[1].Interface().(error) + err := errors.Errorf("%s", resolverErr) + err.Path = path.toSlice() + err.ResolverError = resolverErr + return err + } + } else { + result = f.resolver.Field(f.field.FieldIndex) } + return nil }() @@ -201,7 +207,6 @@ func execFieldSelection(ctx context.Context, r *Request, f *fieldToExec, path *p f.out.WriteString("null") // TODO handle non-nil return } - r.execSelectionSet(traceCtx, f.sels, f.field.Type, path, result, f.out) } diff --git a/internal/exec/resolvable/resolvable.go b/internal/exec/resolvable/resolvable.go index b7c1d93d..27f68c4e 100644 --- a/internal/exec/resolvable/resolvable.go +++ b/internal/exec/resolvable/resolvable.go @@ -6,9 +6,11 @@ import ( "reflect" "strings" + "github.com/graph-gophers/graphql-go/config" "github.com/graph-gophers/graphql-go/internal/common" "github.com/graph-gophers/graphql-go/internal/exec/packer" "github.com/graph-gophers/graphql-go/internal/schema" + "github.com/graph-gophers/graphql-go/introspection" ) type Schema struct { @@ -32,6 +34,7 @@ type Field struct { schema.Field TypeName string MethodIndex int + FieldIndex int HasContext bool HasError bool ArgsPacker *packer.StructPacker @@ -54,8 +57,15 @@ func (*Object) isResolvable() {} func (*List) isResolvable() {} func (*Scalar) isResolvable() {} +// TODO figure out a better way to handle passed config +// use this to save passed config in order to use it in functions or later +// this approach avoids updating signature of many functions +var conf *config.Config + func ApplyResolver(s *schema.Schema, resolver interface{}) (*Schema, error) { + b := newBuilder(s) + conf = s.Config var query, mutation Resolvable @@ -181,13 +191,14 @@ func makeScalarExec(t *schema.Scalar, resolverType reflect.Type) (Resolvable, er implementsType := false switch r := reflect.New(resolverType).Interface().(type) { case *int32: - implementsType = (t.Name == "Int") + implementsType = t.Name == "Int" case *float64: - implementsType = (t.Name == "Float") + implementsType = t.Name == "Float" case *string: - implementsType = (t.Name == "String") + // allow ID of type string + implementsType = t.Name == "String" || t.Name == "ID" case *bool: - implementsType = (t.Name == "Boolean") + implementsType = t.Name == "Boolean" case packer.Unmarshaler: implementsType = r.ImplementsGraphQLType(t.Name) } @@ -198,6 +209,7 @@ func makeScalarExec(t *schema.Scalar, resolverType reflect.Type) (Resolvable, er } func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, possibleTypes []*schema.Object, nonNull bool, resolverType reflect.Type) (*Object, error) { + if !nonNull { if resolverType.Kind() != reflect.Ptr && resolverType.Kind() != reflect.Interface { return nil, fmt.Errorf("%s is not a pointer or interface", resolverType) @@ -208,8 +220,24 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p Fields := make(map[string]*Field) for _, f := range fields { - methodIndex := findMethod(resolverType, f.Name) - if methodIndex == -1 { + methodIndex := -1 + fieldIndex := -1 + rt := unwrapPtr(resolverType) + + /** + * 1) Use resolver type's method for + * 1.1) __Type and __Schema requests + * 1.2) or when field has arguments + * 1.3) or it is configured to use method + * 2) Otherwise use resolver type's field + */ + if isResolverSchemaOrType(rt) == true || len(f.Args) > 0 || conf.UseResolverMethods == true { + methodIndex = findMethod(resolverType, f.Name) + } else { + fieldIndex = findField(rt, f.Name) + } + + if methodIndex == -1 && fieldIndex == -1 { hint := "" if findMethod(reflect.PtrTo(resolverType), f.Name) != -1 { hint = " (hint: the method exists on the pointer type)" @@ -217,8 +245,14 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p return nil, fmt.Errorf("%s does not resolve %q: missing method for field %q%s", resolverType, typeName, f.Name, hint) } - m := resolverType.Method(methodIndex) - fe, err := b.makeFieldExec(typeName, f, m, methodIndex, methodHasReceiver) + var m reflect.Method + var sf reflect.StructField + if methodIndex != -1 { + m = resolverType.Method(methodIndex) + } else { + sf = rt.Field(fieldIndex) + } + fe, err := b.makeFieldExec(typeName, f, m, sf, methodIndex, fieldIndex, methodHasReceiver) if err != nil { return nil, fmt.Errorf("%s\n\treturned by (%s).%s", err, resolverType, m.Name) } @@ -253,45 +287,52 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p var contextType = reflect.TypeOf((*context.Context)(nil)).Elem() var errorType = reflect.TypeOf((*error)(nil)).Elem() -func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect.Method, methodIndex int, methodHasReceiver bool) (*Field, error) { - in := make([]reflect.Type, m.Type.NumIn()) - for i := range in { - in[i] = m.Type.In(i) - } - if methodHasReceiver { - in = in[1:] // first parameter is receiver - } - - hasContext := len(in) > 0 && in[0] == contextType - if hasContext { - in = in[1:] - } +func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect.Method, sf reflect.StructField, + methodIndex, fieldIndex int, methodHasReceiver bool) (*Field, error) { var argsPacker *packer.StructPacker - if len(f.Args) > 0 { - if len(in) == 0 { - return nil, fmt.Errorf("must have parameter for field arguments") + var hasError bool + var hasContext bool + + if methodIndex != -1 { + in := make([]reflect.Type, m.Type.NumIn()) + for i := range in { + in[i] = m.Type.In(i) } - var err error - argsPacker, err = b.packerBuilder.MakeStructPacker(f.Args, in[0]) - if err != nil { - return nil, err + if methodHasReceiver { + in = in[1:] // first parameter is receiver } - in = in[1:] - } - if len(in) > 0 { - return nil, fmt.Errorf("too many parameters") - } + hasContext = len(in) > 0 && in[0] == contextType + if hasContext { + in = in[1:] + } - if m.Type.NumOut() > 2 { - return nil, fmt.Errorf("too many return values") - } + if len(f.Args) > 0 { + if len(in) == 0 { + return nil, fmt.Errorf("must have parameter for field arguments") + } + var err error + argsPacker, err = b.packerBuilder.MakeStructPacker(f.Args, in[0]) + if err != nil { + return nil, err + } + in = in[1:] + } + + if len(in) > 0 { + return nil, fmt.Errorf("too many parameters") + } - hasError := m.Type.NumOut() == 2 - if hasError { - if m.Type.Out(1) != errorType { - return nil, fmt.Errorf(`must have "error" as its second return value`) + if m.Type.NumOut() > 2 { + return nil, fmt.Errorf("too many return values") + } + + hasError = m.Type.NumOut() == 2 + if hasError { + if m.Type.Out(1) != errorType { + return nil, fmt.Errorf(`must have "error" as its second return value`) + } } } @@ -299,20 +340,41 @@ func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect. Field: *f, TypeName: typeName, MethodIndex: methodIndex, + FieldIndex: fieldIndex, HasContext: hasContext, ArgsPacker: argsPacker, HasError: hasError, TraceLabel: fmt.Sprintf("GraphQL field: %s.%s", typeName, f.Name), } - if err := b.assignExec(&fe.ValueExec, f.Type, m.Type.Out(0)); err != nil { + + var out reflect.Type + if methodIndex != -1 { + out = m.Type.Out(0) + } else { + out = sf.Type + } + if err := b.assignExec(&fe.ValueExec, f.Type, out); err != nil { return nil, err } + return fe, nil } +// find method with the same name or one which has `Resolver` suffix func findMethod(t reflect.Type, name string) int { + resName := name + "Resolver" for i := 0; i < t.NumMethod(); i++ { - if strings.EqualFold(stripUnderscore(name), stripUnderscore(t.Method(i).Name)) { + if strings.EqualFold(stripUnderscore(name), stripUnderscore(t.Method(i).Name)) || + strings.EqualFold(stripUnderscore(resName), stripUnderscore(t.Method(i).Name)) { + return i + } + } + return -1 +} + +func findField(t reflect.Type, name string) int { + for i := 0; i < t.NumField(); i++ { + if strings.EqualFold(stripUnderscore(name), stripUnderscore(t.Field(i).Name)) { return i } } @@ -329,3 +391,33 @@ func unwrapNonNull(t common.Type) (common.Type, bool) { func stripUnderscore(s string) string { return strings.Replace(s, "_", "", -1) } + +func unwrapPtr(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + return t.Elem() + } + return t +} + +// Determine whether given resolver type is from request __Type & __Schema +func isResolverSchemaOrType(rt reflect.Type) bool { + if rt == reflect.TypeOf(introspection.Schema{}) { + return true + } + if rt == reflect.TypeOf(introspection.Type{}) { + return true + } + if rt == reflect.TypeOf(introspection.Field{}) { + return true + } + if rt == reflect.TypeOf(introspection.InputValue{}) { + return true + } + if rt == reflect.TypeOf(introspection.EnumValue{}) { + return true + } + if rt == reflect.TypeOf(introspection.Directive{}) { + return true + } + return false +} diff --git a/internal/schema/meta.go b/internal/schema/meta.go index b48bf7ac..7d1e2ed1 100644 --- a/internal/schema/meta.go +++ b/internal/schema/meta.go @@ -1,10 +1,12 @@ package schema +import "github.com/graph-gophers/graphql-go/config" + var Meta *Schema func init() { Meta = &Schema{} // bootstrap - Meta = New() + Meta = New(config.Default()) if err := Meta.Parse(metaSrc); err != nil { panic(err) } diff --git a/internal/schema/schema.go b/internal/schema/schema.go index 62d6de12..c71562a2 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -4,6 +4,7 @@ import ( "fmt" "text/scanner" + "github.com/graph-gophers/graphql-go/config" "github.com/graph-gophers/graphql-go/errors" "github.com/graph-gophers/graphql-go/internal/common" ) @@ -45,6 +46,8 @@ type Schema struct { objects []*Object unions []*Union enums []*Enum + + Config *config.Config } // Resolve a named type in the schema by its name. @@ -230,11 +233,12 @@ type Field struct { } // New initializes an instance of Schema. -func New() *Schema { +func New(config *config.Config) *Schema { s := &Schema{ entryPointNames: make(map[string]string), Types: make(map[string]NamedType), Directives: make(map[string]*DirectiveDecl), + Config: config, } for n, t := range Meta.Types { s.Types[n] = t diff --git a/relay/relay_test.go b/relay/relay_test.go index 329a9a58..2fb84d9d 100644 --- a/relay/relay_test.go +++ b/relay/relay_test.go @@ -10,7 +10,7 @@ import ( "github.com/graph-gophers/graphql-go/relay" ) -var starwarsSchema = graphql.MustParseSchema(starwars.Schema, &starwars.Resolver{}) +var starwarsSchema = graphql.MustParseSchema(starwars.Schema, &starwars.Resolver{}, nil) func TestServeHTTP(t *testing.T) { w := httptest.NewRecorder()