diff --git a/graphql.go b/graphql.go index 9704798c489..c3b3601fdf8 100644 --- a/graphql.go +++ b/graphql.go @@ -1,23 +1,27 @@ package graphql import ( + "encoding/json" "errors" - "fmt" "reflect" "strings" "text/scanner" ) type Schema struct { - types map[string]*object + types map[string]*object + resolver reflect.Value } type typ interface { - exec() interface{} + exec(schema *Schema, sel *selectionSet, resolver reflect.Value) interface{} } type scalar struct { - resolver reflect.Value +} + +type typeName struct { + name string } type object struct { @@ -41,104 +45,130 @@ func NewSchema(schema string, filename string, resolver interface{}) (res *Schem } }() - return parseFile(sc, reflect.ValueOf(resolver)), nil + s := parseSchema(newLexer(sc)) + s.resolver = reflect.ValueOf(resolver) + // TODO type check resolver + return s, nil } -func parseFile(sc *scanner.Scanner, r reflect.Value) *Schema { - types := make(map[string]*object) - - scanToken(sc, scanner.Ident) - switch sc.TokenText() { - case "type": - name, obj := parseTypeDecl(sc, r) - types[name] = obj - default: - syntaxError(sc, `"type"`) +func parseSchema(l *lexer) *Schema { + s := &Schema{ + types: make(map[string]*object), } - return &Schema{ - types: types, + for l.peek() != scanner.EOF { + switch l.consumeIdent() { + case "type": + name, obj := parseTypeDecl(l) + s.types[name] = obj + default: + l.syntaxError(`"type"`) + } } -} - -func parseTypeDecl(sc *scanner.Scanner, r reflect.Value) (string, *object) { - typeName := scanIdent(sc) - scanToken(sc, '{') - - fields := make(map[string]typ) - fieldName := scanIdent(sc) - m := r.MethodByName(strings.ToUpper(fieldName[:1]) + fieldName[1:]) - scanToken(sc, ':') - fields[fieldName] = parseType(sc, m) + return s +} - scanToken(sc, '}') +func parseTypeDecl(l *lexer) (string, *object) { + typeName := l.consumeIdent() + l.consumeToken('{') - return typeName, &object{ - fields: fields, + o := &object{ + fields: make(map[string]typ), + } + for l.peek() != '}' { + fieldName := l.consumeIdent() + l.consumeToken(':') + o.fields[fieldName] = parseType(l) } + l.consumeToken('}') + + return typeName, o } -func parseType(sc *scanner.Scanner, r reflect.Value) typ { +func parseType(l *lexer) typ { // TODO check args // TODO check return type - scanToken(sc, scanner.Ident) - return &scalar{ - resolver: r, + name := l.consumeIdent() + if name == "String" { + return &scalar{} } -} - -func scanIdent(sc *scanner.Scanner) string { - scanToken(sc, scanner.Ident) - return sc.TokenText() -} - -func scanToken(sc *scanner.Scanner, expected rune) { - if got := sc.Scan(); got != expected { - syntaxError(sc, scanner.TokenString(expected)) + return &typeName{ + name: name, } } -func syntaxError(sc *scanner.Scanner, expected string) { - panic(parseError(fmt.Sprintf("%s:%d: syntax error: unexpected %q, expecting %s", sc.Filename, sc.Line, sc.TokenText(), expected))) -} - -func (s *Schema) Exec(query string) (interface{}, error) { +func (s *Schema) Exec(query string) (res []byte, errRes error) { sc := &scanner.Scanner{} sc.Init(strings.NewReader(query)) - res := s.types["Query"].exec(parseSelectionSet(sc)) - return res, nil + defer func() { + if err := recover(); err != nil { + if err, ok := err.(parseError); ok { + errRes = errors.New(string(err)) + return + } + panic(err) + } + }() + + rawRes := s.types["Query"].exec(s, parseSelectionSet(newLexer(sc)), s.resolver) + return json.Marshal(rawRes) } type selectionSet struct { selections []*field } -func parseSelectionSet(sc *scanner.Scanner) *selectionSet { - scanToken(sc, '{') - f := parseField(sc) - scanToken(sc, '}') - return &selectionSet{ - selections: []*field{f}, +func parseSelectionSet(l *lexer) *selectionSet { + sel := &selectionSet{} + l.consumeToken('{') + for l.peek() != '}' { + sel.selections = append(sel.selections, parseField(l)) } + l.consumeToken('}') + return sel } type field struct { name string + sel *selectionSet } -func parseField(sc *scanner.Scanner) *field { - name := scanIdent(sc) - return &field{ - name: name, +func parseField(l *lexer) *field { + f := &field{} + f.name = l.consumeIdent() + if l.peek() == '{' { + f.sel = parseSelectionSet(l) } + return f } -func (o *object) exec(sel *selectionSet) interface{} { - return o.fields[sel.selections[0].name].exec() +func (o *object) exec(schema *Schema, sel *selectionSet, resolver reflect.Value) interface{} { + res := make(map[string]interface{}) + for _, f := range sel.selections { + m := findMethod(resolver.Type(), f.name) + res[f.name] = o.fields[f.name].exec(schema, f.sel, resolver.Method(m).Call(nil)[0]) + } + return res +} + +func findMethod(t reflect.Type, name string) int { + for i := 0; i < t.NumMethod(); i++ { + if strings.EqualFold(name, t.Method(i).Name) { + return i + } + } + return -1 +} + +func (s *scalar) exec(schema *Schema, sel *selectionSet, resolver reflect.Value) interface{} { + if !resolver.IsValid() { + return "bad" + } + return resolver.Interface() } -func (s *scalar) exec() interface{} { - return s.resolver.Call(nil)[0].Interface() +func (s *typeName) exec(schema *Schema, sel *selectionSet, resolver reflect.Value) interface{} { + return schema.types[s.name].exec(schema, sel, resolver) } diff --git a/graphql_test.go b/graphql_test.go index 0e3872c2c29..a880e0ce883 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -1,29 +1,120 @@ package graphql -import "testing" +import ( + "bytes" + "encoding/json" + "testing" +) -type testStringResolver struct{} +type helloWorldResolver struct{} -func (r *testStringResolver) Hello() string { +func (r *helloWorldResolver) Hello() string { return "Hello world!" } -func TestString(t *testing.T) { - schema, err := NewSchema(` - type Query { - hello: String - } - `, "test", &testStringResolver{}) - if err != nil { - t.Fatal(err) - } +type starWarsResolver struct{} - got, err := schema.Exec(`{ hello }`) - if err != nil { - t.Fatal(err) - } +func (r *starWarsResolver) Hero() *userResolver { + return &userResolver{id: "2001", name: "R2-D2"} +} + +type userResolver struct { + id string + name string +} + +func (r *userResolver) ID() string { + return r.id +} + +func (r *userResolver) Name() string { + return r.name +} + +var tests = []struct { + name string + schema string + resolver interface{} + query string + result string +}{ + { + name: "HelloWorld", + schema: ` + type Query { + hello: String + } + `, + resolver: &helloWorldResolver{}, + query: ` + { + hello + } + `, + result: ` + { + "hello": "Hello world!" + } + `, + }, + { + name: "User", + schema: ` + type Query { + hero: User + } + + type User { + id: String + name: String + } + `, + resolver: &starWarsResolver{}, + query: ` + { + hero { + id + name + } + } + `, + result: ` + { + "hero": { + "id": "2001", + "name": "R2-D2" + } + } + `, + }, +} - if want := "Hello world!"; got != want { - t.Errorf("want %#v, got %#v", want, got) +func TestAll(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + schema, err := NewSchema(test.schema, test.name, test.resolver) + if err != nil { + t.Fatal(err) + } + + got, err := schema.Exec(test.query) + if err != nil { + t.Fatal(err) + } + + want := formatJSON([]byte(test.result)) + if !bytes.Equal(got, want) { + t.Logf("want: %s", want) + t.Logf("got: %s", got) + t.Fail() + } + }) } } + +func formatJSON(data []byte) []byte { + var v interface{} + json.Unmarshal(data, &v) + b, _ := json.Marshal(v) + return b +} diff --git a/lexer.go b/lexer.go new file mode 100644 index 00000000000..f7e0f1af73d --- /dev/null +++ b/lexer.go @@ -0,0 +1,42 @@ +package graphql + +import ( + "fmt" + "text/scanner" +) + +type lexer struct { + sc *scanner.Scanner + next rune +} + +func newLexer(sc *scanner.Scanner) *lexer { + l := &lexer{sc: sc} + l.consume() + return l +} + +func (l *lexer) peek() rune { + return l.next +} + +func (l *lexer) consume() { + l.next = l.sc.Scan() +} + +func (l *lexer) consumeIdent() string { + text := l.sc.TokenText() + l.consumeToken(scanner.Ident) + return text +} + +func (l *lexer) consumeToken(expected rune) { + if l.next != expected { + l.syntaxError(scanner.TokenString(expected)) + } + l.consume() +} + +func (l *lexer) syntaxError(expected string) { + panic(parseError(fmt.Sprintf("%s:%d: syntax error: unexpected %q, expecting %s", l.sc.Filename, l.sc.Line, l.sc.TokenText(), expected))) +}