Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix vendored import paths in generated models #64

Merged
merged 2 commits into from
Mar 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions codegen/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func Models(schema *schema.Schema, userTypes map[string]string, destDir string)
panic(err)
}

bindTypes(imports, namedTypes, prog)
bindTypes(imports, namedTypes, destDir, prog)

models := buildModels(namedTypes, schema, prog)
return &ModelBuild{
Expand All @@ -59,7 +59,7 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (*
return nil, err
}

bindTypes(imports, namedTypes, prog)
imports = bindTypes(imports, namedTypes, destDir, prog)

objects := buildObjects(namedTypes, schema, prog, imports)
inputs := buildInputs(namedTypes, schema, prog, imports)
Expand Down
65 changes: 34 additions & 31 deletions codegen/import_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,52 +23,55 @@ func buildImports(types NamedTypes, destDir string) Imports {
}

for _, t := range types {
if t.Package == "" {
continue
}
imports, t.Import = imports.addPkg(types, destDir, t.Package)
}

if existing := imports.findByPkg(t.Package); existing != nil {
t.Import = existing
continue
}
return imports
}

localName := ""
if !strings.HasSuffix(destDir, t.Package) {
localName = filepath.Base(t.Package)
i := 1
imp := imports.findByName(localName)
for imp != nil && imp.Package != t.Package {
localName = filepath.Base(t.Package) + strconv.Itoa(i)
imp = imports.findByName(localName)
i++
if i > 10 {
panic("too many collisions")
}
}
}
func (s Imports) addPkg(types NamedTypes, destDir string, pkg string) (Imports, *Import) {
if pkg == "" {
return s, nil
}

if existing := s.findByPkg(pkg); existing != nil {
return s, existing
}

imp := &Import{
Name: localName,
Package: t.Package,
localName := ""
if !strings.HasSuffix(destDir, pkg) {
localName = filepath.Base(pkg)
i := 1
imp := s.findByName(localName)
for imp != nil && imp.Package != pkg {
localName = filepath.Base(pkg) + strconv.Itoa(i)
imp = s.findByName(localName)
i++
if i > 10 {
panic("too many collisions")
}
}
t.Import = imp
imports = append(imports, imp)
}

return imports
imp := &Import{
Name: localName,
Package: pkg,
}
s = append(s, imp)
return s, imp
}

func (i Imports) findByPkg(pkg string) *Import {
for _, imp := range i {
func (s Imports) findByPkg(pkg string) *Import {
for _, imp := range s {
if imp.Package == pkg {
return imp
}
}
return nil
}

func (i Imports) findByName(name string) *Import {
for _, imp := range i {
func (s Imports) findByName(name string) *Import {
for _, imp := range s {
if imp.Name == name {
return imp
}
Expand Down
12 changes: 9 additions & 3 deletions codegen/type_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func buildNamedTypes(s *schema.Schema, userTypes map[string]string) NamedTypes {
return types
}

func bindTypes(imports Imports, namedTypes NamedTypes, prog *loader.Program) {
func bindTypes(imports Imports, namedTypes NamedTypes, destDir string, prog *loader.Program) Imports {
for _, t := range namedTypes {
if t.Package == "" {
continue
Expand All @@ -44,9 +44,10 @@ func bindTypes(imports Imports, namedTypes NamedTypes, prog *loader.Program) {
t.Marshaler = &cpy

t.Package, t.GoType = pkgAndType(sig.Params().At(0).Type().String())
t.Import = imports.findByName(t.Package)
imports, t.Import = imports.addPkg(namedTypes, destDir, t.Package)
}
}
return imports
}

// namedTypeFromSchema objects for every graphql type, including primitives.
Expand All @@ -71,7 +72,12 @@ func pkgAndType(name string) (string, string) {
return "", name
}

return strings.Join(parts[:len(parts)-1], "."), parts[len(parts)-1]
return normalizeVendor(strings.Join(parts[:len(parts)-1], ".")), parts[len(parts)-1]
}

func normalizeVendor(pkg string) string {
parts := strings.Split(pkg, "/vendor/")
return parts[len(parts)-1]
}

func (n NamedTypes) getType(t common.Type) *Type {
Expand Down
4 changes: 2 additions & 2 deletions codegen/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ func bindObject(t types.Type, object *Object, imports Imports) {
field.GoVarName = structField.Name()

switch field.Type.FullSignature() {
case structField.Type().String():
case normalizeVendor(structField.Type().String()):
// everything is fine

case structField.Type().Underlying().String():
case normalizeVendor(structField.Type().Underlying().String()):
pkg, typ := pkgAndType(structField.Type().String())
imp := imports.findByPkg(pkg)
field.CastType = typ
Expand Down
62 changes: 50 additions & 12 deletions example/scalars/generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package scalars
import (
"bytes"
context "context"
external "external"
strconv "strconv"
time "time"

Expand All @@ -24,7 +25,7 @@ func MakeExecutableSchema(resolvers Resolvers, opts ...ExecutableOption) graphql
}

type Resolvers interface {
Query_user(ctx context.Context, id string) (*User, error)
Query_user(ctx context.Context, id external.ObjectID) (*User, error)
Query_search(ctx context.Context, input SearchArgs) ([]User, error)

User_primitiveResolver(ctx context.Context, obj *User) (string, error)
Expand Down Expand Up @@ -85,6 +86,43 @@ type executionContext struct {
recover graphql.RecoverFunc
}

var addressImplementors = []string{"Address"}

// nolint: gocyclo, errcheck, gas, goconst
func (ec *executionContext) _Address(sel []query.Selection, obj *Address) graphql.Marshaler {
fields := graphql.CollectFields(ec.doc, sel, addressImplementors, ec.variables)
out := graphql.NewOrderedMap(len(fields))
for i, field := range fields {
out.Keys[i] = field.Alias

switch field.Name {
case "__typename":
out.Values[i] = graphql.MarshalString("Address")
case "id":
out.Values[i] = ec._Address_id(field, obj)
case "location":
out.Values[i] = ec._Address_location(field, obj)
default:
panic("unknown field " + strconv.Quote(field.Name))
}
}

return out
}

func (ec *executionContext) _Address_id(field graphql.CollectedField, obj *Address) graphql.Marshaler {
res := obj.ID
return MarshalID(res)
}

func (ec *executionContext) _Address_location(field graphql.CollectedField, obj *Address) graphql.Marshaler {
res := obj.Location
if res == nil {
return graphql.Null
}
return *res
}

var queryImplementors = []string{"Query"}

// nolint: gocyclo, errcheck, gas, goconst
Expand Down Expand Up @@ -114,10 +152,10 @@ func (ec *executionContext) _Query(sel []query.Selection) graphql.Marshaler {
}

func (ec *executionContext) _Query_user(field graphql.CollectedField) graphql.Marshaler {
var arg0 string
var arg0 external.ObjectID
if tmp, ok := field.Args["id"]; ok {
var err error
arg0, err = graphql.UnmarshalID(tmp)
arg0, err = UnmarshalID(tmp)
if err != nil {
ec.Error(err)
return graphql.Null
Expand Down Expand Up @@ -226,14 +264,14 @@ func (ec *executionContext) _User(sel []query.Selection, obj *User) graphql.Mars
out.Values[i] = ec._User_name(field, obj)
case "created":
out.Values[i] = ec._User_created(field, obj)
case "location":
out.Values[i] = ec._User_location(field, obj)
case "isBanned":
out.Values[i] = ec._User_isBanned(field, obj)
case "primitiveResolver":
out.Values[i] = ec._User_primitiveResolver(field, obj)
case "customResolver":
out.Values[i] = ec._User_customResolver(field, obj)
case "address":
out.Values[i] = ec._User_address(field, obj)
default:
panic("unknown field " + strconv.Quote(field.Name))
}
Expand All @@ -244,7 +282,7 @@ func (ec *executionContext) _User(sel []query.Selection, obj *User) graphql.Mars

func (ec *executionContext) _User_id(field graphql.CollectedField, obj *User) graphql.Marshaler {
res := obj.ID
return graphql.MarshalID(res)
return MarshalID(res)
}

func (ec *executionContext) _User_name(field graphql.CollectedField, obj *User) graphql.Marshaler {
Expand All @@ -257,11 +295,6 @@ func (ec *executionContext) _User_created(field graphql.CollectedField, obj *Use
return MarshalTimestamp(res)
}

func (ec *executionContext) _User_location(field graphql.CollectedField, obj *User) graphql.Marshaler {
res := obj.Location
return res
}

func (ec *executionContext) _User_isBanned(field graphql.CollectedField, obj *User) graphql.Marshaler {
res := obj.IsBanned
return graphql.MarshalBoolean(bool(res))
Expand Down Expand Up @@ -303,6 +336,11 @@ func (ec *executionContext) _User_customResolver(field graphql.CollectedField, o
})
}

func (ec *executionContext) _User_address(field graphql.CollectedField, obj *User) graphql.Marshaler {
res := obj.Address
return ec._Address(field.Selections, &res)
}

var __DirectiveImplementors = []string{"__Directive"}

// nolint: gocyclo, errcheck, gas, goconst
Expand Down Expand Up @@ -837,7 +875,7 @@ func UnmarshalSearchArgs(v interface{}) (SearchArgs, error) {
return it, nil
}

var parsedSchema = schema.MustParse("type Query {\n user(id: ID!): User\n search(input: SearchArgs = {location: \"37,144\"}): [User!]!\n}\n\ntype User {\n id: ID!\n name: String!\n created: Timestamp\n location: Point\n isBanned: Boolean!\n primitiveResolver: String!\n customResolver: Point!\n}\n\ninput SearchArgs {\n location: Point\n createdAfter: Timestamp\n isBanned: Boolean\n}\n\nscalar Timestamp\nscalar Point\n")
var parsedSchema = schema.MustParse("type Query {\n user(id: ID!): User\n search(input: SearchArgs = {location: \"37,144\"}): [User!]!\n}\n\ntype User {\n id: ID!\n name: String!\n created: Timestamp\n isBanned: Boolean!\n primitiveResolver: String!\n customResolver: Point!\n address: Address\n}\n\ntype Address {\n id: ID!\n location: Point\n}\n\ninput SearchArgs {\n location: Point\n createdAfter: Timestamp\n isBanned: Boolean\n}\n\nscalar Timestamp\nscalar Point\n")

func (ec *executionContext) introspectSchema() *introspection.Schema {
return introspection.WrapSchema(parsedSchema)
Expand Down
23 changes: 21 additions & 2 deletions example/scalars/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,19 @@ import (
"strings"
"time"

"external"

"github.com/vektah/gqlgen/graphql"
)

type Banned bool

type User struct {
ID string
ID external.ObjectID
Name string
Location Point // custom scalar types
Created time.Time // direct binding to builtin types with external Marshal/Unmarshal methods
IsBanned Banned // aliased primitive
Address Address
}

// Point is serialized as a simple array, eg [1, 2]
Expand Down Expand Up @@ -71,6 +73,23 @@ func UnmarshalTimestamp(v interface{}) (time.Time, error) {
return time.Time{}, errors.New("time should be a unix timestamp")
}

// Lets redefine the base ID type to use an id from an external library
func MarshalID(id external.ObjectID) graphql.Marshaler {
return graphql.WriterFunc(func(w io.Writer) {
io.WriteString(w, strconv.Quote(fmt.Sprintf("=%d=", id)))
})
}

// And the same for the unmarshaler
func UnmarshalID(v interface{}) (external.ObjectID, error) {
str, ok := v.(string)
if !ok {
return 0, fmt.Errorf("ids must be strings")
}
i, err := strconv.Atoi(str[1 : len(str)-1])
return external.ObjectID(i), err
}

type SearchArgs struct {
Location *Point
CreatedAfter *time.Time
Expand Down
12 changes: 12 additions & 0 deletions example/scalars/models_gen.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// This file was generated by github.com/vektah/gqlgen, DO NOT EDIT

package scalars

import (
external "external"
)

type Address struct {
ID external.ObjectID
Location *Point
}
29 changes: 16 additions & 13 deletions example/scalars/resolvers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@ package scalars

import (
context "context"
"fmt"
"time"

"external"
)

type Resolver struct {
}

func (r *Resolver) Query_user(ctx context.Context, id string) (*User, error) {
func (r *Resolver) Query_user(ctx context.Context, id external.ObjectID) (*User, error) {
return &User{
ID: id,
Name: "Test User " + id,
Created: time.Now(),
Location: Point{1, 2},
ID: id,
Name: fmt.Sprintf("Test User %d", id),
Created: time.Now(),
Address: Address{ID: 1, Location: &Point{1, 2}},
}, nil
}

Expand All @@ -32,16 +35,16 @@ func (r *Resolver) Query_search(ctx context.Context, input SearchArgs) ([]User,

return []User{
{
ID: "1",
Name: "Test User 1",
Created: created,
Location: location,
ID: 1,
Name: "Test User 1",
Created: created,
Address: Address{ID: 2, Location: &location},
},
{
ID: "2",
Name: "Test User 2",
Created: created,
Location: location,
ID: 2,
Name: "Test User 2",
Created: created,
Address: Address{ID: 1, Location: &location},
},
}, nil
}
Expand Down
Loading