From 85fa63b9570357b8ce954b88ccfd2ba2dd437d15 Mon Sep 17 00:00:00 2001 From: Adam Scarr Date: Mon, 19 Feb 2018 21:07:27 +1100 Subject: [PATCH] Automatically add type conversions around wrapped types --- codegen/build.go | 10 +-- codegen/input_build.go | 4 +- codegen/object.go | 9 --- codegen/object_build.go | 4 +- codegen/type.go | 30 +++++++- codegen/util.go | 129 +++++++++++++++++++++------------ example/scalars/generated.go | 17 ++++- example/scalars/model.go | 4 + example/scalars/schema.graphql | 2 + example/starwars/model.go | 2 +- example/starwars/resolvers.go | 8 +- 11 files changed, 145 insertions(+), 74 deletions(-) diff --git a/codegen/build.go b/codegen/build.go index b07fe372b0..eb5fb11b17 100644 --- a/codegen/build.go +++ b/codegen/build.go @@ -35,9 +35,9 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (* b := &Build{ PackageName: filepath.Base(destDir), - Objects: buildObjects(namedTypes, schema, prog), + Objects: buildObjects(namedTypes, schema, prog, imports), Interfaces: buildInterfaces(namedTypes, schema), - Inputs: buildInputs(namedTypes, schema, prog), + Inputs: buildInputs(namedTypes, schema, prog, imports), Imports: imports, } @@ -56,19 +56,19 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (* // Poke a few magic methods into query q := b.Objects.ByName(b.QueryRoot.GQLType) q.Fields = append(q.Fields, Field{ - Type: &Type{namedTypes["__Schema"], []string{modPtr}}, + Type: &Type{namedTypes["__Schema"], []string{modPtr}, ""}, GQLName: "__schema", NoErr: true, GoMethodName: "ec.introspectSchema", Object: q, }) q.Fields = append(q.Fields, Field{ - Type: &Type{namedTypes["__Type"], []string{modPtr}}, + Type: &Type{namedTypes["__Type"], []string{modPtr}, ""}, GQLName: "__type", NoErr: true, GoMethodName: "ec.introspectType", Args: []FieldArgument{ - {GQLName: "name", Type: &Type{namedTypes["String"], []string{}}}, + {GQLName: "name", Type: &Type{namedTypes["String"], []string{}, ""}}, }, Object: q, }) diff --git a/codegen/input_build.go b/codegen/input_build.go index f42aa88058..685c5df855 100644 --- a/codegen/input_build.go +++ b/codegen/input_build.go @@ -11,7 +11,7 @@ import ( "golang.org/x/tools/go/loader" ) -func buildInputs(namedTypes NamedTypes, s *schema.Schema, prog *loader.Program) Objects { +func buildInputs(namedTypes NamedTypes, s *schema.Schema, prog *loader.Program, imports Imports) Objects { var inputs Objects for _, typ := range s.Types { @@ -25,7 +25,7 @@ func buildInputs(namedTypes NamedTypes, s *schema.Schema, prog *loader.Program) } if def != nil { input.Marshaler = buildInputMarshaler(typ, def) - bindObject(def.Type(), input) + bindObject(def.Type(), input, imports) } inputs = append(inputs, input) diff --git a/codegen/object.go b/codegen/object.go index cf001b95d8..518ff7a533 100644 --- a/codegen/object.go +++ b/codegen/object.go @@ -38,15 +38,6 @@ type FieldArgument struct { type Objects []*Object -func (o *Object) GetField(name string) *Field { - for i, field := range o.Fields { - if strings.EqualFold(field.GQLName, name) { - return &o.Fields[i] - } - } - return nil -} - func (o *Object) Implementors() string { satisfiedBy := strconv.Quote(o.GQLType) for _, s := range o.Satisfies { diff --git a/codegen/object_build.go b/codegen/object_build.go index b8212c7224..7ea1e3d0fd 100644 --- a/codegen/object_build.go +++ b/codegen/object_build.go @@ -10,7 +10,7 @@ import ( "golang.org/x/tools/go/loader" ) -func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Objects { +func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program, imports Imports) Objects { var objects Objects for _, typ := range s.Types { @@ -23,7 +23,7 @@ func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Obje fmt.Fprintf(os.Stderr, err.Error()) } if def != nil { - bindObject(def.Type(), obj) + bindObject(def.Type(), obj, imports) } objects = append(objects, obj) diff --git a/codegen/type.go b/codegen/type.go index 16ad100743..19fa5e870c 100644 --- a/codegen/type.go +++ b/codegen/type.go @@ -24,6 +24,7 @@ type Type struct { *NamedType Modifiers []string + CastType string // the type to cast to when unmarshalling } const ( @@ -46,6 +47,15 @@ func (t Type) Signature() string { return strings.Join(t.Modifiers, "") + t.FullName() } +func (t Type) FullSignature() string { + pkg := "" + if t.Package != "" { + pkg = t.Package + "." + } + + return strings.Join(t.Modifiers, "") + pkg + t.GoType +} + func (t Type) IsPtr() bool { return len(t.Modifiers) > 0 && t.Modifiers[0] == modPtr } @@ -59,18 +69,32 @@ func (t NamedType) IsMarshaled() bool { } func (t Type) Unmarshal(result, raw string) string { - if t.Marshaler != nil { - return result + ", err := " + t.Marshaler.pkgDot() + "Unmarshal" + t.Marshaler.GoType + "(" + raw + ")" + realResult := result + if t.CastType != "" { + result = "castTmp" } - return tpl(`var {{.result}} {{.type}} + ret := tpl(`var {{.result}} {{.type}} err := (&{{.result}}).UnmarshalGQL({{.raw}})`, map[string]interface{}{ "result": result, "raw": raw, "type": t.FullName(), }) + + if t.Marshaler != nil { + ret = result + ", err := " + t.Marshaler.pkgDot() + "Unmarshal" + t.Marshaler.GoType + "(" + raw + ")" + } + + if t.CastType != "" { + ret += "\n" + realResult + " := " + t.CastType + "(castTmp)" + } + return ret } func (t Type) Marshal(result, val string) string { + if t.CastType != "" { + val = t.GoType + "(" + val + ")" + } + if t.Marshaler != nil { return result + " = " + t.Marshaler.pkgDot() + "Marshal" + t.Marshaler.GoType + "(" + val + ")" } diff --git a/codegen/util.go b/codegen/util.go index fa53c7dcc0..baed7a5e0e 100644 --- a/codegen/util.go +++ b/codegen/util.go @@ -47,69 +47,104 @@ func isMethod(t types.Object) bool { return f.Type().(*types.Signature).Recv() != nil } -func bindObject(t types.Type, object *Object) bool { - switch t := t.(type) { - case *types.Named: - for i := 0; i < t.NumMethods(); i++ { - method := t.Method(i) - if !method.Exported() { - continue - } +func findMethod(typ *types.Named, name string) *types.Func { + for i := 0; i < typ.NumMethods(); i++ { + method := typ.Method(i) + if !method.Exported() { + continue + } + + if strings.EqualFold(method.Name(), name) { + return method + } + } + return nil +} + +func findField(typ *types.Struct, name string) *types.Var { + for i := 0; i < typ.NumFields(); i++ { + field := typ.Field(i) + if !field.Exported() { + continue + } - if methodField := object.GetField(method.Name()); methodField != nil { - methodField.GoMethodName = "it." + method.Name() - sig := method.Type().(*types.Signature) + if strings.EqualFold(field.Name(), name) { + return field + } + } + return nil +} - methodField.Type.Modifiers = modifiersFromGoType(sig.Results().At(0).Type()) +func bindObject(t types.Type, object *Object, imports Imports) { + namedType, ok := t.(*types.Named) + if !ok { + fmt.Fprintf(os.Stderr, "expected %s to be a named struct, instead found %s", object.FullName(), t.String()) + return + } - // check arg order matches code, not gql + underlying, ok := t.Underlying().(*types.Struct) + if !ok { + fmt.Fprintf(os.Stderr, "expected %s to be a named struct, instead found %s", object.FullName(), t.String()) + return + } - var newArgs []FieldArgument - l2: - for j := 0; j < sig.Params().Len(); j++ { - param := sig.Params().At(j) - for _, oldArg := range methodField.Args { - if strings.EqualFold(oldArg.GQLName, param.Name()) { - oldArg.Type.Modifiers = modifiersFromGoType(param.Type()) - newArgs = append(newArgs, oldArg) - continue l2 - } + for i := range object.Fields { + field := &object.Fields[i] + if method := findMethod(namedType, field.GQLName); method != nil { + sig := method.Type().(*types.Signature) + field.GoMethodName = "it." + method.Name() + field.Type.Modifiers = modifiersFromGoType(sig.Results().At(0).Type()) + + // check arg order matches code, not gql + var newArgs []FieldArgument + l2: + for j := 0; j < sig.Params().Len(); j++ { + param := sig.Params().At(j) + for _, oldArg := range field.Args { + if strings.EqualFold(oldArg.GQLName, param.Name()) { + oldArg.Type.Modifiers = modifiersFromGoType(param.Type()) + newArgs = append(newArgs, oldArg) + continue l2 } - fmt.Fprintln(os.Stderr, "cannot match argument "+param.Name()+" to any argument in "+t.String()) } - methodField.Args = newArgs + fmt.Fprintln(os.Stderr, "cannot match argument "+param.Name()+" to any argument in "+t.String()) + } + field.Args = newArgs - if sig.Results().Len() == 1 { - methodField.NoErr = true - } else if sig.Results().Len() != 2 { - fmt.Fprintf(os.Stderr, "weird number of results on %s. expected either (result), or (result, error)\n", method.Name()) - } + if sig.Results().Len() == 1 { + field.NoErr = true + } else if sig.Results().Len() != 2 { + fmt.Fprintf(os.Stderr, "weird number of results on %s. expected either (result), or (result, error)\n", method.Name()) } + continue } - bindObject(t.Underlying(), object) - return true + if structField := findField(underlying, field.GQLName); structField != nil { + field.Type.Modifiers = modifiersFromGoType(structField.Type()) + field.GoVarName = "it." + structField.Name() - case *types.Struct: - for i := 0; i < t.NumFields(); i++ { - field := t.Field(i) - // Todo: struct tags, name and - at least + switch field.Type.FullSignature() { + case structField.Type().String(): + // everything is fine - if !field.Exported() { - continue - } + case structField.Type().Underlying().String(): + pkg, typ := pkgAndType(structField.Type().String()) + imp := imports.findByPkg(pkg) + field.CastType = typ + if imp.Name != "" { + field.CastType = imp.Name + "." + typ + } - // Todo: check for type matches before binding too? - if objectField := object.GetField(field.Name()); objectField != nil { - objectField.GoVarName = "it." + field.Name() - objectField.Type.Modifiers = modifiersFromGoType(field.Type()) + default: + fmt.Fprintf(os.Stderr, "type mismatch on %s.%s, expected %s got %s\n", object.GQLType, field.GQLName, field.Type.FullSignature(), structField.Type()) } + continue } - t.Underlying() - return true - } - return false + if field.IsScalar { + fmt.Fprintf(os.Stderr, "unable to bind %s.%s to anything, %s has no suitable fields or methods\n", object.GQLType, field.GQLName, namedType.String()) + } + } } func modifiersFromGoType(t types.Type) []string { diff --git a/example/scalars/generated.go b/example/scalars/generated.go index 36f13bfc1b..5a810eeb75 100644 --- a/example/scalars/generated.go +++ b/example/scalars/generated.go @@ -218,6 +218,14 @@ func (ec *executionContext) _user(sel []query.Selection, it *User) graphql.Marsh res := it.Location out.Values[i] = res + case "isBanned": + badArgs := false + if badArgs { + continue + } + res := it.IsBanned + + out.Values[i] = graphql.MarshalBoolean(bool(res)) default: panic("unknown field " + strconv.Quote(field.Name)) } @@ -807,13 +815,20 @@ func UnmarshalSearchArgs(v interface{}) (SearchArgs, error) { return it, err } it.CreatedAfter = &val + case "isBanned": + castTmp, err := graphql.UnmarshalBoolean(v) + val := Banned(castTmp) + if err != nil { + return it, err + } + it.IsBanned = val } } return it, nil } -var parsedSchema = schema.MustParse("schema {\n query: Query\n}\n\ntype Query {\n user(id: ID!): User\n search(input: SearchArgs!): [User!]!\n}\n\ntype User {\n id: ID!\n name: String!\n created: Timestamp\n location: Point\n}\n\ninput SearchArgs {\n location: Point\n createdAfter: Timestamp\n}\n\nscalar Timestamp\nscalar Point\n") +var parsedSchema = schema.MustParse("schema {\n query: Query\n}\n\ntype Query {\n user(id: ID!): User\n search(input: SearchArgs!): [User!]!\n}\n\ntype User {\n id: ID!\n name: String!\n created: Timestamp\n location: Point\n isBanned: Boolean!\n}\n\ninput SearchArgs {\n location: Point\n createdAfter: Timestamp\n isBanned: Boolean\n}\n\nscalar Timestamp\nscalar Point\n") func (ec *executionContext) introspectSchema() *introspection.Schema { return introspection.WrapSchema(parsedSchema) diff --git a/example/scalars/model.go b/example/scalars/model.go index 2569932ebc..6b593f2614 100644 --- a/example/scalars/model.go +++ b/example/scalars/model.go @@ -11,11 +11,14 @@ import ( "github.com/vektah/gqlgen/graphql" ) +type Banned bool + type User struct { ID string Name string Location Point // custom scalar types Created time.Time // direct binding to builtin types with external Marshal/Unmarshal methods + IsBanned Banned // aliased primitive } // Point is serialized as a simple array, eg [1, 2] @@ -71,4 +74,5 @@ func UnmarshalTimestamp(v interface{}) (time.Time, error) { type SearchArgs struct { Location *Point CreatedAfter *time.Time + IsBanned Banned } diff --git a/example/scalars/schema.graphql b/example/scalars/schema.graphql index 27d44f5453..18ac9a60b0 100644 --- a/example/scalars/schema.graphql +++ b/example/scalars/schema.graphql @@ -12,11 +12,13 @@ type User { name: String! created: Timestamp location: Point + isBanned: Boolean! } input SearchArgs { location: Point createdAfter: Timestamp + isBanned: Boolean } scalar Timestamp diff --git a/example/starwars/model.go b/example/starwars/model.go index 80b8d759ca..cdf3eeae91 100644 --- a/example/starwars/model.go +++ b/example/starwars/model.go @@ -33,7 +33,7 @@ func (h *Human) Height(unit string) float64 { type Starship struct { ID string Name string - History [][2]int + History [][]int lengthMeters float64 } diff --git a/example/starwars/resolvers.go b/example/starwars/resolvers.go index a394bc7fd6..c8e808ef69 100644 --- a/example/starwars/resolvers.go +++ b/example/starwars/resolvers.go @@ -223,7 +223,7 @@ func NewResolver() *Resolver { "3000": { ID: "3000", Name: "Millennium Falcon", - History: [][2]int{ + History: [][]int{ {1, 2}, {4, 5}, {1, 2}, @@ -234,7 +234,7 @@ func NewResolver() *Resolver { "3001": { ID: "3001", Name: "X-Wing", - History: [][2]int{ + History: [][]int{ {6, 4}, {3, 2}, {2, 3}, @@ -245,7 +245,7 @@ func NewResolver() *Resolver { "3002": { ID: "3002", Name: "TIE Advanced x1", - History: [][2]int{ + History: [][]int{ {3, 2}, {7, 2}, {6, 4}, @@ -256,7 +256,7 @@ func NewResolver() *Resolver { "3003": { ID: "3003", Name: "Imperial shuttle", - History: [][2]int{ + History: [][]int{ {1, 7}, {3, 5}, {5, 3},