diff --git a/codegen/build.go b/codegen/build.go index e1d9f2581a..9319c4448c 100644 --- a/codegen/build.go +++ b/codegen/build.go @@ -15,7 +15,7 @@ type Build struct { Objects Objects Inputs Objects Interfaces []*Interface - Imports Imports + Imports []*Import QueryRoot *Object MutationRoot *Object SubscriptionRoot *Object @@ -24,7 +24,7 @@ type Build struct { type ModelBuild struct { PackageName string - Imports Imports + Imports []*Import Models []Model Enums []Enum } @@ -33,11 +33,11 @@ type ModelBuild struct { func (cfg *Config) models() (*ModelBuild, error) { namedTypes := cfg.buildNamedTypes() - imports := buildImports(namedTypes, cfg.modelDir) - prog, err := cfg.loadProgram(imports, true) + prog, err := cfg.loadProgram(namedTypes, true) if err != nil { return nil, errors.Wrap(err, "loading failed") } + imports := buildImports(namedTypes, cfg.modelDir) cfg.bindTypes(imports, namedTypes, cfg.modelDir, prog) @@ -49,7 +49,7 @@ func (cfg *Config) models() (*ModelBuild, error) { PackageName: cfg.ModelPackageName, Models: models, Enums: cfg.buildEnums(namedTypes), - Imports: buildImports(namedTypes, cfg.modelDir), + Imports: imports.finalize(), }, nil } @@ -57,13 +57,13 @@ func (cfg *Config) models() (*ModelBuild, error) { func (cfg *Config) bind() (*Build, error) { namedTypes := cfg.buildNamedTypes() - imports := buildImports(namedTypes, cfg.execDir) - prog, err := cfg.loadProgram(imports, false) + prog, err := cfg.loadProgram(namedTypes, false) if err != nil { return nil, errors.Wrap(err, "loading failed") } - imports = cfg.bindTypes(imports, namedTypes, cfg.execDir, prog) + imports := buildImports(namedTypes, cfg.execDir) + cfg.bindTypes(imports, namedTypes, cfg.execDir, prog) objects, err := cfg.buildObjects(namedTypes, prog, imports) if err != nil { @@ -80,7 +80,7 @@ func (cfg *Config) bind() (*Build, error) { Objects: objects, Interfaces: cfg.buildInterfaces(namedTypes, prog), Inputs: inputs, - Imports: imports, + Imports: imports.finalize(), } if qr, ok := cfg.schema.EntryPoints["query"]; ok { @@ -122,7 +122,7 @@ func (cfg *Config) bind() (*Build, error) { return b, nil } -func (cfg *Config) loadProgram(imports Imports, allowErrors bool) (*loader.Program, error) { +func (cfg *Config) loadProgram(namedTypes NamedTypes, allowErrors bool) (*loader.Program, error) { conf := loader.Config{} if allowErrors { conf = loader.Config{ @@ -132,7 +132,11 @@ func (cfg *Config) loadProgram(imports Imports, allowErrors bool) (*loader.Progr }, } } - for _, imp := range imports { + for _, imp := range ambientImports { + conf.Import(imp) + } + + for _, imp := range namedTypes { if imp.Package != "" { conf.Import(imp.Package) } diff --git a/codegen/codegen.go b/codegen/codegen.go index 0fd36ef5dc..e36f3226ce 100644 --- a/codegen/codegen.go +++ b/codegen/codegen.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "os" "path/filepath" + "regexp" "strings" "syscall" @@ -136,6 +137,12 @@ func (cfg *Config) normalize() error { return cfg.schema.Parse(cfg.SchemaStr) } +var invalidPackageNameChar = regexp.MustCompile(`[^\w]`) + +func sanitizePackageName(pkg string) string { + return invalidPackageNameChar.ReplaceAllLiteralString(filepath.Base(pkg), "_") +} + func abs(path string) string { absPath, err := filepath.Abs(path) if err != nil { diff --git a/codegen/import.go b/codegen/import.go index ed1ab0780c..123d310f6d 100644 --- a/codegen/import.go +++ b/codegen/import.go @@ -5,12 +5,16 @@ import ( ) type Import struct { - Name string - Package string + Name string + Alias string + Path string } -type Imports []*Import +type Imports struct { + imports []*Import + destDir string +} func (i *Import) Write() string { - return i.Name + " " + strconv.Quote(i.Package) + return i.Alias + " " + strconv.Quote(i.Path) } diff --git a/codegen/import_build.go b/codegen/import_build.go index f917a188db..24646b74a2 100644 --- a/codegen/import_build.go +++ b/codegen/import_build.go @@ -1,85 +1,114 @@ package codegen import ( - "path/filepath" - "regexp" + "fmt" + "go/build" + "sort" "strconv" "strings" ) -func buildImports(types NamedTypes, destDir string) Imports { +// These imports are referenced by the generated code, and are assumed to have the +// default alias. So lets make sure they get added first, and any later collisions get +// renamed. +var ambientImports = []string{ + "context", + "fmt", + "io", + "strconv", + "time", + "sync", + "github.com/vektah/gqlgen/neelance/introspection", + "github.com/vektah/gqlgen/neelance/errors", + "github.com/vektah/gqlgen/neelance/query", + "github.com/vektah/gqlgen/neelance/schema", + "github.com/vektah/gqlgen/neelance/validation", + "github.com/vektah/gqlgen/graphql", +} + +func buildImports(types NamedTypes, destDir string) *Imports { imports := Imports{ - {"context", "context"}, - {"fmt", "fmt"}, - {"io", "io"}, - {"strconv", "strconv"}, - {"time", "time"}, - {"sync", "sync"}, - {"introspection", "github.com/vektah/gqlgen/neelance/introspection"}, - {"errors", "github.com/vektah/gqlgen/neelance/errors"}, - {"query", "github.com/vektah/gqlgen/neelance/query"}, - {"schema", "github.com/vektah/gqlgen/neelance/schema"}, - {"validation", "github.com/vektah/gqlgen/neelance/validation"}, - {"graphql", "github.com/vektah/gqlgen/graphql"}, + destDir: destDir, } + for _, ambient := range ambientImports { + imports.add(ambient) + } + + // Imports from top level user types for _, t := range types { - imports, t.Import = imports.addPkg(types, destDir, t.Package) + t.Import = imports.add(t.Package) } - return imports + return &imports } -var invalidPackageNameChar = regexp.MustCompile(`[^\w]`) +func (s *Imports) add(path string) *Import { + if path == "" { + return nil + } -func sanitizePackageName(pkg string) string { - return invalidPackageNameChar.ReplaceAllLiteralString(filepath.Base(pkg), "_") -} + if stringHasSuffixFold(s.destDir, path) { + return nil + } + + if existing := s.findByPath(path); existing != nil { + return existing + } -func (s Imports) addPkg(types NamedTypes, destDir string, pkg string) (Imports, *Import) { - if pkg == "" { - return s, nil + pkg, err := build.Default.Import(path, s.destDir, 0) + if err != nil { + panic(err) } - if existing := s.findByPkg(pkg); existing != nil { - return s, existing + imp := &Import{ + Name: pkg.Name, + Path: path, } + s.imports = append(s.imports, imp) + + return imp +} + +func stringHasSuffixFold(s, suffix string) bool { + return len(s) >= len(suffix) && strings.EqualFold(s[len(s)-len(suffix):], suffix) +} + +func (s Imports) finalize() []*Import { + // ensure stable ordering by sorting + sort.Slice(s.imports, func(i, j int) bool { + return s.imports[i].Path > s.imports[j].Path + }) + + for _, imp := range s.imports { + alias := imp.Name - localName := "" - if !strings.HasSuffix(destDir, pkg) { - localName = sanitizePackageName(filepath.Base(pkg)) i := 1 - imp := s.findByName(localName) - for imp != nil && imp.Package != pkg { - localName = sanitizePackageName(filepath.Base(pkg)) + strconv.Itoa(i) - imp = s.findByName(localName) + for s.findByAlias(alias) != nil { + alias = imp.Name + strconv.Itoa(i) i++ if i > 10 { - panic("too many collisions") + panic(fmt.Errorf("too many collisions, last attempt was %s", imp.Alias)) } } + imp.Alias = alias } - imp := &Import{ - Name: localName, - Package: pkg, - } - s = append(s, imp) - return s, imp + return s.imports } -func (s Imports) findByPkg(pkg string) *Import { - for _, imp := range s { - if imp.Package == pkg { +func (s Imports) findByPath(importPath string) *Import { + for _, imp := range s.imports { + if imp.Path == importPath { return imp } } return nil } -func (s Imports) findByName(name string) *Import { - for _, imp := range s { - if imp.Name == name { +func (s Imports) findByAlias(alias string) *Import { + for _, imp := range s.imports { + if imp.Alias == alias { return imp } } diff --git a/codegen/import_test.go b/codegen/import_test.go index 3a0704f07e..6e7c497a61 100644 --- a/codegen/import_test.go +++ b/codegen/import_test.go @@ -36,3 +36,21 @@ func TestImportCollisions(t *testing.T) { require.NoError(t, err) } + +func TestDeterministicDecollisioning(t *testing.T) { + a := Imports{ + imports: []*Import{ + {Name: "types", Path: "foobar/types"}, + {Name: "types", Path: "bazfoo/types"}, + }, + }.finalize() + + b := Imports{ + imports: []*Import{ + {Name: "types", Path: "bazfoo/types"}, + {Name: "types", Path: "foobar/types"}, + }, + }.finalize() + + require.EqualValues(t, a, b) +} diff --git a/codegen/input_build.go b/codegen/input_build.go index be1ebbd8ea..81da1ffe52 100644 --- a/codegen/input_build.go +++ b/codegen/input_build.go @@ -10,7 +10,7 @@ import ( "golang.org/x/tools/go/loader" ) -func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, imports Imports) (Objects, error) { +func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, imports *Imports) (Objects, error) { var inputs Objects for _, typ := range cfg.schema.Types { diff --git a/codegen/object_build.go b/codegen/object_build.go index 20fd867f7d..0809f7d26c 100644 --- a/codegen/object_build.go +++ b/codegen/object_build.go @@ -9,7 +9,7 @@ import ( "golang.org/x/tools/go/loader" ) -func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports Imports) (Objects, error) { +func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports *Imports) (Objects, error) { var objects Objects for _, typ := range cfg.schema.Types { diff --git a/codegen/type.go b/codegen/type.go index 0138e8e7ac..31fcefc131 100644 --- a/codegen/type.go +++ b/codegen/type.go @@ -40,10 +40,10 @@ func (t Ref) FullName() string { } func (t Ref) PkgDot() string { - if t.Import == nil || t.Import.Name == "" { + if t.Import == nil || t.Import.Alias == "" { return "" } - return t.Import.Name + "." + return t.Import.Alias + "." } func (t Type) Signature() string { diff --git a/codegen/type_build.go b/codegen/type_build.go index 6ecdbfdf45..4e55bedf65 100644 --- a/codegen/type_build.go +++ b/codegen/type_build.go @@ -31,7 +31,7 @@ func (cfg *Config) buildNamedTypes() NamedTypes { return types } -func (cfg *Config) bindTypes(imports Imports, namedTypes NamedTypes, destDir string, prog *loader.Program) Imports { +func (cfg *Config) bindTypes(imports *Imports, namedTypes NamedTypes, destDir string, prog *loader.Program) { for _, t := range namedTypes { if t.Package == "" { continue @@ -45,10 +45,9 @@ func (cfg *Config) bindTypes(imports Imports, namedTypes NamedTypes, destDir str t.Marshaler = &cpy t.Package, t.GoType = pkgAndType(sig.Params().At(0).Type().String()) - imports, t.Import = imports.addPkg(namedTypes, destDir, t.Package) + t.Import = imports.add(t.Package) } } - return imports } // namedTypeFromSchema objects for every graphql type, including primitives. diff --git a/codegen/util.go b/codegen/util.go index b647069171..dfd44d50a4 100644 --- a/codegen/util.go +++ b/codegen/util.go @@ -131,7 +131,7 @@ func findField(typ *types.Struct, name string) *types.Var { return nil } -func bindObject(t types.Type, object *Object, imports Imports) error { +func bindObject(t types.Type, object *Object, imports *Imports) error { namedType, ok := t.(*types.Named) if !ok { return errors.Errorf("expected %s to be a named struct, instead found %s", object.FullName(), t.String()) @@ -184,10 +184,10 @@ func bindObject(t types.Type, object *Object, imports Imports) error { case normalizeVendor(structField.Type().Underlying().String()): pkg, typ := pkgAndType(structField.Type().String()) - imp := imports.findByPkg(pkg) + imp := imports.findByPath(pkg) field.CastType = typ - if imp.Name != "" { - field.CastType = imp.Name + "." + typ + if imp != nil { + field.CastType = imp.Alias + "." + typ } default: