From abf85a104ab1b06c6c181c7dae74d84b3d88628c Mon Sep 17 00:00:00 2001 From: Adam Scarr Date: Sat, 23 Jun 2018 12:57:07 +1000 Subject: [PATCH] get package name from package source --- codegen/build.go | 26 +++++---- codegen/import.go | 11 ++-- codegen/import_build.go | 101 +++++++++++++++++++++------------- codegen/input_build.go | 2 +- codegen/object_build.go | 2 +- codegen/type.go | 4 +- codegen/type_build.go | 5 +- codegen/util.go | 8 +-- example/scalars/generated.go | 2 +- example/scalars/models_gen.go | 2 +- test/generated.go | 2 +- 11 files changed, 99 insertions(+), 66 deletions(-) diff --git a/codegen/build.go b/codegen/build.go index e1d9f2581a7..872ce5f9ead 100644 --- a/codegen/build.go +++ b/codegen/build.go @@ -15,7 +15,7 @@ type Build struct { Objects Objects Inputs Objects Interfaces []*Interface - Imports Imports + Imports []*Import QueryRoot *Object MutationRoot *Object SubscriptionRoot *Object @@ -24,7 +24,7 @@ type Build struct { type ModelBuild struct { PackageName string - Imports Imports + Imports []*Import Models []Model Enums []Enum } @@ -33,11 +33,11 @@ type ModelBuild struct { func (cfg *Config) models() (*ModelBuild, error) { namedTypes := cfg.buildNamedTypes() - imports := buildImports(namedTypes, cfg.modelDir) - prog, err := cfg.loadProgram(imports, true) + prog, err := cfg.loadProgram(namedTypes, true) if err != nil { return nil, errors.Wrap(err, "loading failed") } + imports := buildImports(namedTypes, cfg.modelDir) cfg.bindTypes(imports, namedTypes, cfg.modelDir, prog) @@ -49,7 +49,7 @@ func (cfg *Config) models() (*ModelBuild, error) { PackageName: cfg.ModelPackageName, Models: models, Enums: cfg.buildEnums(namedTypes), - Imports: buildImports(namedTypes, cfg.modelDir), + Imports: imports.imports, }, nil } @@ -57,13 +57,13 @@ func (cfg *Config) models() (*ModelBuild, error) { func (cfg *Config) bind() (*Build, error) { namedTypes := cfg.buildNamedTypes() - imports := buildImports(namedTypes, cfg.execDir) - prog, err := cfg.loadProgram(imports, false) + prog, err := cfg.loadProgram(namedTypes, false) if err != nil { return nil, errors.Wrap(err, "loading failed") } - imports = cfg.bindTypes(imports, namedTypes, cfg.execDir, prog) + imports := buildImports(namedTypes, cfg.execDir) + cfg.bindTypes(imports, namedTypes, cfg.execDir, prog) objects, err := cfg.buildObjects(namedTypes, prog, imports) if err != nil { @@ -80,7 +80,7 @@ func (cfg *Config) bind() (*Build, error) { Objects: objects, Interfaces: cfg.buildInterfaces(namedTypes, prog), Inputs: inputs, - Imports: imports, + Imports: imports.imports, } if qr, ok := cfg.schema.EntryPoints["query"]; ok { @@ -122,7 +122,7 @@ func (cfg *Config) bind() (*Build, error) { return b, nil } -func (cfg *Config) loadProgram(imports Imports, allowErrors bool) (*loader.Program, error) { +func (cfg *Config) loadProgram(namedTypes NamedTypes, allowErrors bool) (*loader.Program, error) { conf := loader.Config{} if allowErrors { conf = loader.Config{ @@ -132,7 +132,11 @@ func (cfg *Config) loadProgram(imports Imports, allowErrors bool) (*loader.Progr }, } } - for _, imp := range imports { + for _, imp := range ambientImports { + conf.Import(imp) + } + + for _, imp := range namedTypes { if imp.Package != "" { conf.Import(imp.Package) } diff --git a/codegen/import.go b/codegen/import.go index ed1ab0780c5..53f1888c2d4 100644 --- a/codegen/import.go +++ b/codegen/import.go @@ -5,12 +5,15 @@ import ( ) type Import struct { - Name string - Package string + Alias string + Path string } -type Imports []*Import +type Imports struct { + imports []*Import + destDir string +} func (i *Import) Write() string { - return i.Name + " " + strconv.Quote(i.Package) + return i.Alias + " " + strconv.Quote(i.Path) } diff --git a/codegen/import_build.go b/codegen/import_build.go index f917a188dba..f81f6fd5e56 100644 --- a/codegen/import_build.go +++ b/codegen/import_build.go @@ -1,33 +1,47 @@ package codegen import ( + "go/build" "path/filepath" "regexp" + "sort" "strconv" "strings" ) -func buildImports(types NamedTypes, destDir string) Imports { +// These imports are referenced by the generated code, and are assumed to have the +// default alias. So lets make sure they get added first, and any later collisions get +// renamed. +var ambientImports = []string{ + "context", + "fmt", + "io", + "strconv", + "time", + "sync", + "github.com/vektah/gqlgen/neelance/introspection", + "github.com/vektah/gqlgen/neelance/errors", + "github.com/vektah/gqlgen/neelance/query", + "github.com/vektah/gqlgen/neelance/schema", + "github.com/vektah/gqlgen/neelance/validation", + "github.com/vektah/gqlgen/graphql", +} + +func buildImports(types NamedTypes, destDir string) *Imports { imports := Imports{ - {"context", "context"}, - {"fmt", "fmt"}, - {"io", "io"}, - {"strconv", "strconv"}, - {"time", "time"}, - {"sync", "sync"}, - {"introspection", "github.com/vektah/gqlgen/neelance/introspection"}, - {"errors", "github.com/vektah/gqlgen/neelance/errors"}, - {"query", "github.com/vektah/gqlgen/neelance/query"}, - {"schema", "github.com/vektah/gqlgen/neelance/schema"}, - {"validation", "github.com/vektah/gqlgen/neelance/validation"}, - {"graphql", "github.com/vektah/gqlgen/graphql"}, + destDir: destDir, } + for _, ambient := range ambientImports { + imports.add(ambient) + } + + // Imports from top level user types for _, t := range types { - imports, t.Import = imports.addPkg(types, destDir, t.Package) + t.Import = imports.add(t.Package) } - return imports + return &imports } var invalidPackageNameChar = regexp.MustCompile(`[^\w]`) @@ -36,23 +50,32 @@ func sanitizePackageName(pkg string) string { return invalidPackageNameChar.ReplaceAllLiteralString(filepath.Base(pkg), "_") } -func (s Imports) addPkg(types NamedTypes, destDir string, pkg string) (Imports, *Import) { - if pkg == "" { - return s, nil +func (s *Imports) add(path string) *Import { + if path == "" { + return nil } - if existing := s.findByPkg(pkg); existing != nil { - return s, existing + if existing := s.findByPath(path); existing != nil { + return existing } - localName := "" - if !strings.HasSuffix(destDir, pkg) { - localName = sanitizePackageName(filepath.Base(pkg)) + pkg, err := build.Default.Import(path, s.destDir, 0) + if err != nil { + panic(err) + } + + alias := "" + if !strings.HasSuffix(s.destDir, path) { + if pkg == nil { + panic(path + " was not loaded") + } + + alias = pkg.Name i := 1 - imp := s.findByName(localName) - for imp != nil && imp.Package != pkg { - localName = sanitizePackageName(filepath.Base(pkg)) + strconv.Itoa(i) - imp = s.findByName(localName) + imp := s.findByAlias(alias) + for imp != nil && imp.Path != path { + alias = pkg.Name + strconv.Itoa(i) + imp = s.findByAlias(alias) i++ if i > 10 { panic("too many collisions") @@ -61,25 +84,29 @@ func (s Imports) addPkg(types NamedTypes, destDir string, pkg string) (Imports, } imp := &Import{ - Name: localName, - Package: pkg, + Alias: alias, + Path: path, } - s = append(s, imp) - return s, imp + s.imports = append(s.imports, imp) + sort.Slice(s.imports, func(i, j int) bool { + return s.imports[i].Alias > s.imports[j].Alias + }) + + return imp } -func (s Imports) findByPkg(pkg string) *Import { - for _, imp := range s { - if imp.Package == pkg { +func (s Imports) findByPath(importPath string) *Import { + for _, imp := range s.imports { + if imp.Path == importPath { return imp } } return nil } -func (s Imports) findByName(name string) *Import { - for _, imp := range s { - if imp.Name == name { +func (s Imports) findByAlias(alias string) *Import { + for _, imp := range s.imports { + if imp.Alias == alias { return imp } } diff --git a/codegen/input_build.go b/codegen/input_build.go index be1ebbd8eae..81da1ffe52d 100644 --- a/codegen/input_build.go +++ b/codegen/input_build.go @@ -10,7 +10,7 @@ import ( "golang.org/x/tools/go/loader" ) -func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, imports Imports) (Objects, error) { +func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, imports *Imports) (Objects, error) { var inputs Objects for _, typ := range cfg.schema.Types { diff --git a/codegen/object_build.go b/codegen/object_build.go index 20fd867f7d9..0809f7d26cb 100644 --- a/codegen/object_build.go +++ b/codegen/object_build.go @@ -9,7 +9,7 @@ import ( "golang.org/x/tools/go/loader" ) -func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports Imports) (Objects, error) { +func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports *Imports) (Objects, error) { var objects Objects for _, typ := range cfg.schema.Types { diff --git a/codegen/type.go b/codegen/type.go index 0138e8e7acc..31fcefc1317 100644 --- a/codegen/type.go +++ b/codegen/type.go @@ -40,10 +40,10 @@ func (t Ref) FullName() string { } func (t Ref) PkgDot() string { - if t.Import == nil || t.Import.Name == "" { + if t.Import == nil || t.Import.Alias == "" { return "" } - return t.Import.Name + "." + return t.Import.Alias + "." } func (t Type) Signature() string { diff --git a/codegen/type_build.go b/codegen/type_build.go index 6ecdbfdf453..4e55bedf65a 100644 --- a/codegen/type_build.go +++ b/codegen/type_build.go @@ -31,7 +31,7 @@ func (cfg *Config) buildNamedTypes() NamedTypes { return types } -func (cfg *Config) bindTypes(imports Imports, namedTypes NamedTypes, destDir string, prog *loader.Program) Imports { +func (cfg *Config) bindTypes(imports *Imports, namedTypes NamedTypes, destDir string, prog *loader.Program) { for _, t := range namedTypes { if t.Package == "" { continue @@ -45,10 +45,9 @@ func (cfg *Config) bindTypes(imports Imports, namedTypes NamedTypes, destDir str t.Marshaler = &cpy t.Package, t.GoType = pkgAndType(sig.Params().At(0).Type().String()) - imports, t.Import = imports.addPkg(namedTypes, destDir, t.Package) + t.Import = imports.add(t.Package) } } - return imports } // namedTypeFromSchema objects for every graphql type, including primitives. diff --git a/codegen/util.go b/codegen/util.go index b647069171b..553920fa786 100644 --- a/codegen/util.go +++ b/codegen/util.go @@ -131,7 +131,7 @@ func findField(typ *types.Struct, name string) *types.Var { return nil } -func bindObject(t types.Type, object *Object, imports Imports) error { +func bindObject(t types.Type, object *Object, imports *Imports) error { namedType, ok := t.(*types.Named) if !ok { return errors.Errorf("expected %s to be a named struct, instead found %s", object.FullName(), t.String()) @@ -184,10 +184,10 @@ func bindObject(t types.Type, object *Object, imports Imports) error { case normalizeVendor(structField.Type().Underlying().String()): pkg, typ := pkgAndType(structField.Type().String()) - imp := imports.findByPkg(pkg) + imp := imports.findByPath(pkg) field.CastType = typ - if imp.Name != "" { - field.CastType = imp.Name + "." + typ + if imp.Alias != "" { + field.CastType = imp.Alias + "." + typ } default: diff --git a/example/scalars/generated.go b/example/scalars/generated.go index e176b3e04db..a3c79811ac9 100644 --- a/example/scalars/generated.go +++ b/example/scalars/generated.go @@ -5,10 +5,10 @@ package scalars import ( "bytes" context "context" - external "external" strconv "strconv" time "time" + external "github.com/vektah/gqlgen/example/scalars/vendor/external" graphql "github.com/vektah/gqlgen/graphql" introspection "github.com/vektah/gqlgen/neelance/introspection" query "github.com/vektah/gqlgen/neelance/query" diff --git a/example/scalars/models_gen.go b/example/scalars/models_gen.go index 4a13829c15c..07705d9274d 100644 --- a/example/scalars/models_gen.go +++ b/example/scalars/models_gen.go @@ -3,7 +3,7 @@ package scalars import ( - external "external" + external "github.com/vektah/gqlgen/example/scalars/vendor/external" ) type Address struct { diff --git a/test/generated.go b/test/generated.go index 2336ac610c7..e3eabeb372c 100644 --- a/test/generated.go +++ b/test/generated.go @@ -5,7 +5,6 @@ package test import ( "bytes" context "context" - remote_api "remote_api" strconv "strconv" graphql "github.com/vektah/gqlgen/graphql" @@ -13,6 +12,7 @@ import ( query "github.com/vektah/gqlgen/neelance/query" schema "github.com/vektah/gqlgen/neelance/schema" models "github.com/vektah/gqlgen/test/models" + remote_api "github.com/vektah/gqlgen/test/vendor/remote_api" ) func MakeExecutableSchema(resolvers Resolvers) graphql.ExecutableSchema {