Skip to content

Commit

Permalink
support for custom scalars (fixes #3)
Browse files Browse the repository at this point in the history
  • Loading branch information
neelance committed Oct 31, 2016
1 parent bdfd5ce commit 2ab9d76
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 15 deletions.
22 changes: 15 additions & 7 deletions graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@ import (
opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"

"reflect"

"github.com/neelance/graphql-go/errors"
"github.com/neelance/graphql-go/internal/exec"
"github.com/neelance/graphql-go/internal/query"
"github.com/neelance/graphql-go/internal/schema"
)

func ParseSchema(schemaString string, resolver interface{}) (*Schema, error) {
b, err := New().Parse(schemaString)
if err != nil {
b := New()
if err := b.Parse(schemaString); err != nil {
return nil, err
}
return b.ApplyResolver(resolver)
Expand All @@ -33,11 +35,12 @@ func New() *SchemaBuilder {
}
}

func (b *SchemaBuilder) Parse(schemaString string) (*SchemaBuilder, error) {
if err := b.schema.Parse(schemaString); err != nil {
return nil, err
}
return b, nil
func (b *SchemaBuilder) Parse(schemaString string) error {
return b.schema.Parse(schemaString)
}

func (b *SchemaBuilder) AddCustomScalar(name string, scalar *ScalarConfig) {
exec.AddCustomScalar(b.schema, name, scalar.ReflectType, scalar.CoerceInput)
}

func (b *SchemaBuilder) ApplyResolver(resolver interface{}) (*Schema, error) {
Expand Down Expand Up @@ -104,3 +107,8 @@ func SchemaToJSON(schemaString string) ([]byte, error) {

return json.Marshal(result)
}

type ScalarConfig struct {
ReflectType reflect.Type
CoerceInput func(input interface{}) (interface{}, error)
}
51 changes: 50 additions & 1 deletion graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"testing"
"time"

"github.com/neelance/graphql-go/example/starwars"
)
Expand Down Expand Up @@ -34,8 +35,15 @@ func (r *theNumberResolver) ChangeTheNumber(args *struct{ NewNumber int32 }) *th
return r
}

type timeResolver struct{}

func (r *timeResolver) AddHour(args *struct{ Time time.Time }) time.Time {
return args.Time.Add(time.Hour)
}

var tests = []struct {
name string
setup func(b *SchemaBuilder)
schema string
variables map[string]interface{}
resolver interface{}
Expand Down Expand Up @@ -1070,12 +1078,53 @@ var tests = []struct {
}
`,
},

{
name: "Time",
setup: func(b *SchemaBuilder) {
b.AddCustomScalar("Time", Time)
},
schema: `
schema {
query: Query
}
type Query {
addHour(time: Time = "2001-02-03T04:05:06Z"): Time!
}
`,
resolver: &timeResolver{},
query: `
query($t: Time!) {
a: addHour(time: $t)
b: addHour
}
`,
variables: map[string]interface{}{
"t": time.Date(2000, 2, 3, 4, 5, 6, 0, time.UTC),
},
result: `
{
"a": "2000-02-03T05:05:06Z",
"b": "2001-02-03T05:05:06Z"
}
`,
},
}

func TestAll(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
schema, err := ParseSchema(test.schema, test.resolver)
b := New()
if test.setup != nil {
test.setup(b)
}

if err := b.Parse(test.schema); err != nil {
t.Fatal(err)
}

schema, err := b.ApplyResolver(test.resolver)
if err != nil {
t.Fatal(err)
}
Expand Down
12 changes: 8 additions & 4 deletions internal/exec/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,11 @@ func makeInputObjectExec(s *schema.Schema, obj *common.InputMap, typ reflect.Typ
}

if f.Default != nil {
defaultStruct.FieldByIndex(fe.fieldIndex).Set(fe.exec.eval(f.Default.Eval(nil)))
defaultValue, err := coerceValue(f.Type, f.Default.Eval(nil))
if err != nil {
return nil, err
}
defaultStruct.FieldByIndex(fe.fieldIndex).Set(fe.exec.eval(defaultValue))
}

fields = append(fields, fe)
Expand Down Expand Up @@ -449,7 +453,7 @@ func coerceInputObject(io *common.InputMap, variables map[string]interface{}) (m
coerced[iv.Name] = iv.Default.Eval(nil)
continue
}
c, err := coerceInputValue(iv, value)
c, err := coerceValue(iv.Type, value)
if err != nil {
return nil, err
}
Expand All @@ -458,8 +462,8 @@ func coerceInputObject(io *common.InputMap, variables map[string]interface{}) (m
return coerced, nil
}

func coerceInputValue(iv *common.InputValue, value interface{}) (interface{}, error) {
t, _ := unwrapNonNull(iv.Type)
func coerceValue(typ common.Type, value interface{}) (interface{}, error) {
t, _ := unwrapNonNull(typ)
switch t := t.(type) {
case *scalar:
return t.coerceInput(value)
Expand Down
8 changes: 8 additions & 0 deletions internal/exec/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,11 @@ func AddBuiltinScalars(s *schema.Schema) {
s.Types[scalar.name] = scalar
}
}

func AddCustomScalar(s *schema.Schema, name string, reflectType reflect.Type, coerceInput func(input interface{}) (interface{}, error)) {
s.Types[name] = &scalar{
name: name,
reflectType: reflectType,
coerceInput: coerceInput,
}
}
10 changes: 7 additions & 3 deletions internal/query/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,20 @@ func (Field) isSelection() {}
func (FragmentSpread) isSelection() {}
func (InlineFragment) isSelection() {}

func Parse(queryString string, resolver common.Resolver) (doc *Document, err *errors.QueryError) {
func Parse(queryString string, resolver common.Resolver) (*Document, *errors.QueryError) {
sc := &scanner.Scanner{
Mode: scanner.ScanIdents | scanner.ScanInts | scanner.ScanFloats | scanner.ScanStrings,
}
sc.Init(strings.NewReader(queryString))

l := lexer.New(sc)
err = l.CatchSyntaxError(func() {
var doc *Document
err := l.CatchSyntaxError(func() {
doc = parseDocument(l)
})
if err != nil {
return nil, err
}

for _, op := range doc.Operations {
for _, v := range op.Vars.Fields {
Expand All @@ -95,7 +99,7 @@ func Parse(queryString string, resolver common.Resolver) (doc *Document, err *er
}
}

return
return doc, nil
}

func parseDocument(l *lexer.Lexer) *Document {
Expand Down
24 changes: 24 additions & 0 deletions time.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package graphql

import (
"fmt"
"reflect"
"time"
)

var Time = &ScalarConfig{
ReflectType: reflect.TypeOf(time.Time{}),
CoerceInput: func(input interface{}) (interface{}, error) {
switch input := input.(type) {
case time.Time:
return input, nil
case string:
t, err := time.Parse(time.RFC3339, input)
return t, err
case int:
return time.Unix(int64(input), 0), nil
default:
return nil, fmt.Errorf("could not convert %v (%t) to time", input, input)
}
},
}

0 comments on commit 2ab9d76

Please sign in to comment.