diff --git a/graphql.go b/graphql.go index 95cd118d75..eb3ac28991 100644 --- a/graphql.go +++ b/graphql.go @@ -52,17 +52,22 @@ func exec(s *Schema, t schema.Type, sel *query.SelectionSet, resolver reflect.Va case *schema.Object: res := make(map[string]interface{}) for _, f := range sel.Selections { + sf := t.Fields[f.Name] m := resolver.Method(findMethod(resolver.Type(), f.Name)) var in []reflect.Value - if len(f.Arguments) != 0 { - args := reflect.New(m.Type().In(0).Elem()) - for name, value := range f.Arguments { - f := args.Elem().FieldByNameFunc(func(n string) bool { return strings.EqualFold(n, name) }) - f.Set(reflect.ValueOf(value.Value)) + if len(sf.Parameters) != 0 { + args := reflect.New(m.Type().In(0)) + for name, param := range sf.Parameters { + value, ok := f.Arguments[name] + if !ok { + value = &query.Value{Value: param.Default} + } + rf := args.Elem().FieldByNameFunc(func(n string) bool { return strings.EqualFold(n, name) }) + rf.Set(reflect.ValueOf(value.Value)) } - in = []reflect.Value{args} + in = []reflect.Value{args.Elem()} } - res[f.Name] = exec(s, t.Fields[f.Name].Type, f.Sel, m.Call(in)[0]) + res[f.Name] = exec(s, sf.Type, f.Sel, m.Call(in)[0]) } return res } diff --git a/graphql_test.go b/graphql_test.go index 15790047b5..bab7f4ebd9 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -279,11 +279,12 @@ func init() { type starWarsResolver struct{} -func (r *starWarsResolver) Hero() characterResolver { +func (r *starWarsResolver) Hero(args struct{ Episode string }) characterResolver { + // TODO episode return &droidResolver{droidData["2001"]} } -func (r *starWarsResolver) Human(args *struct{ ID string }) *humanResolver { +func (r *starWarsResolver) Human(args struct{ ID string }) *humanResolver { h := humanData[args.ID] if h == nil { return nil @@ -309,8 +310,15 @@ func (r *humanResolver) Name() string { return r.h.Name } -func (r *humanResolver) Height() float64 { - return r.h.Height +func (r *humanResolver) Height(args struct{ Unit string }) float64 { + switch args.Unit { + case "METER": + return r.h.Height + case "FOOT": + return r.h.Height * 3.28084 + default: + panic("invalid unit") + } } func (r *humanResolver) Friends() []characterResolver { @@ -373,7 +381,7 @@ var tests = []struct { `, }, { - name: "StarWars1", + name: "StarWarsBasic", schema: starWarsSchema, resolver: &starWarsResolver{}, query: ` @@ -408,7 +416,7 @@ var tests = []struct { `, }, { - name: "StarWars2", + name: "StarWarsArguments1", schema: starWarsSchema, resolver: &starWarsResolver{}, query: ` @@ -428,6 +436,27 @@ var tests = []struct { } `, }, + { + name: "StarWarsArguments2", + schema: starWarsSchema, + resolver: &starWarsResolver{}, + query: ` + { + human(id: "1000") { + name + height(unit: FOOT) + } + } + `, + result: ` + { + "human": { + "name": "Luke Skywalker", + "height": 5.6430448 + } + } + `, + }, } func TestAll(t *testing.T) { diff --git a/internal/query/query.go b/internal/query/query.go index d07f586aca..b0d2e8ceab 100644 --- a/internal/query/query.go +++ b/internal/query/query.go @@ -82,7 +82,28 @@ func parseArgument(l *lexer.Lexer) (string, *Value) { return name, value } +type ValueType int + +const ( + Int ValueType = iota + Float + String + Boolean + Enum +) + func parseValue(l *lexer.Lexer) *Value { - value := l.ConsumeString() - return &Value{Value: value} + switch l.Peek() { + case scanner.String: + return &Value{ + Value: l.ConsumeString(), + } + case scanner.Ident: + return &Value{ + Value: l.ConsumeIdent(), + } + default: + l.SyntaxError("invalid value") + panic("unreachable") + } } diff --git a/internal/schema/schema.go b/internal/schema/schema.go index ac71ba5452..b9c554eb83 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -33,10 +33,16 @@ type Object struct { type Field struct { Name string - Parameters map[string]string + Parameters map[string]*Parameter Type Type } +type Parameter struct { + Name string + Type string + Default string +} + func Parse(schemaString string, filename string) (res *Schema, errRes error) { sc := &scanner.Scanner{ Mode: scanner.ScanIdents | scanner.ScanFloats | scanner.ScanStrings, @@ -144,18 +150,18 @@ func parseInputDecl(l *lexer.Lexer) { func parseField(l *lexer.Lexer) *Field { f := &Field{ - Parameters: make(map[string]string), + Parameters: make(map[string]*Parameter), } f.Name = l.ConsumeIdent() if l.Peek() == '(' { l.ConsumeToken('(') if l.Peek() != ')' { - name, typ := parseParameter(l) - f.Parameters[name] = typ + p := parseParameter(l) + f.Parameters[p.Name] = p for l.Peek() != ')' { l.ConsumeToken(',') - name, typ := parseParameter(l) - f.Parameters[name] = typ + p := parseParameter(l) + f.Parameters[p.Name] = p } } l.ConsumeToken(')') @@ -168,18 +174,19 @@ func parseField(l *lexer.Lexer) *Field { return f } -func parseParameter(l *lexer.Lexer) (string, string) { - name := l.ConsumeIdent() +func parseParameter(l *lexer.Lexer) *Parameter { + p := &Parameter{} + p.Name = l.ConsumeIdent() l.ConsumeToken(':') - typ := l.ConsumeIdent() + p.Type = l.ConsumeIdent() if l.Peek() == '!' { l.ConsumeToken('!') // TODO } if l.Peek() == '=' { l.ConsumeToken('=') - l.ConsumeIdent() // TODO + p.Default = l.ConsumeIdent() } - return name, typ + return p } func parseType(l *lexer.Lexer) Type {