Skip to content

Commit

Permalink
Merge pull request 99designs#147 from vektah/import-overhaul
Browse files Browse the repository at this point in the history
Improve import handling
  • Loading branch information
vektah authored Jun 26, 2018
2 parents 9f26b08 + 9a99155 commit e7657b9
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 72 deletions.
26 changes: 15 additions & 11 deletions codegen/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type Build struct {
Objects Objects
Inputs Objects
Interfaces []*Interface
Imports Imports
Imports []*Import
QueryRoot *Object
MutationRoot *Object
SubscriptionRoot *Object
Expand All @@ -24,7 +24,7 @@ type Build struct {

type ModelBuild struct {
PackageName string
Imports Imports
Imports []*Import
Models []Model
Enums []Enum
}
Expand All @@ -33,11 +33,11 @@ type ModelBuild struct {
func (cfg *Config) models() (*ModelBuild, error) {
namedTypes := cfg.buildNamedTypes()

imports := buildImports(namedTypes, cfg.modelDir)
prog, err := cfg.loadProgram(imports, true)
prog, err := cfg.loadProgram(namedTypes, true)
if err != nil {
return nil, errors.Wrap(err, "loading failed")
}
imports := buildImports(namedTypes, cfg.modelDir)

cfg.bindTypes(imports, namedTypes, cfg.modelDir, prog)

Expand All @@ -49,21 +49,21 @@ func (cfg *Config) models() (*ModelBuild, error) {
PackageName: cfg.ModelPackageName,
Models: models,
Enums: cfg.buildEnums(namedTypes),
Imports: buildImports(namedTypes, cfg.modelDir),
Imports: imports.finalize(),
}, nil
}

// bind a schema together with some code to generate a Build
func (cfg *Config) bind() (*Build, error) {
namedTypes := cfg.buildNamedTypes()

imports := buildImports(namedTypes, cfg.execDir)
prog, err := cfg.loadProgram(imports, false)
prog, err := cfg.loadProgram(namedTypes, false)
if err != nil {
return nil, errors.Wrap(err, "loading failed")
}

imports = cfg.bindTypes(imports, namedTypes, cfg.execDir, prog)
imports := buildImports(namedTypes, cfg.execDir)
cfg.bindTypes(imports, namedTypes, cfg.execDir, prog)

objects, err := cfg.buildObjects(namedTypes, prog, imports)
if err != nil {
Expand All @@ -80,7 +80,7 @@ func (cfg *Config) bind() (*Build, error) {
Objects: objects,
Interfaces: cfg.buildInterfaces(namedTypes, prog),
Inputs: inputs,
Imports: imports,
Imports: imports.finalize(),
}

if qr, ok := cfg.schema.EntryPoints["query"]; ok {
Expand Down Expand Up @@ -122,7 +122,7 @@ func (cfg *Config) bind() (*Build, error) {
return b, nil
}

func (cfg *Config) loadProgram(imports Imports, allowErrors bool) (*loader.Program, error) {
func (cfg *Config) loadProgram(namedTypes NamedTypes, allowErrors bool) (*loader.Program, error) {
conf := loader.Config{}
if allowErrors {
conf = loader.Config{
Expand All @@ -132,7 +132,11 @@ func (cfg *Config) loadProgram(imports Imports, allowErrors bool) (*loader.Progr
},
}
}
for _, imp := range imports {
for _, imp := range ambientImports {
conf.Import(imp)
}

for _, imp := range namedTypes {
if imp.Package != "" {
conf.Import(imp.Package)
}
Expand Down
7 changes: 7 additions & 0 deletions codegen/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io/ioutil"
"os"
"path/filepath"
"regexp"
"strings"
"syscall"

Expand Down Expand Up @@ -136,6 +137,12 @@ func (cfg *Config) normalize() error {
return cfg.schema.Parse(cfg.SchemaStr)
}

var invalidPackageNameChar = regexp.MustCompile(`[^\w]`)

func sanitizePackageName(pkg string) string {
return invalidPackageNameChar.ReplaceAllLiteralString(filepath.Base(pkg), "_")
}

func abs(path string) string {
absPath, err := filepath.Abs(path)
if err != nil {
Expand Down
12 changes: 8 additions & 4 deletions codegen/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@ import (
)

type Import struct {
Name string
Package string
Name string
Alias string
Path string
}

type Imports []*Import
type Imports struct {
imports []*Import
destDir string
}

func (i *Import) Write() string {
return i.Name + " " + strconv.Quote(i.Package)
return i.Alias + " " + strconv.Quote(i.Path)
}
121 changes: 75 additions & 46 deletions codegen/import_build.go
Original file line number Diff line number Diff line change
@@ -1,85 +1,114 @@
package codegen

import (
"path/filepath"
"regexp"
"fmt"
"go/build"
"sort"
"strconv"
"strings"
)

func buildImports(types NamedTypes, destDir string) Imports {
// These imports are referenced by the generated code, and are assumed to have the
// default alias. So lets make sure they get added first, and any later collisions get
// renamed.
var ambientImports = []string{
"context",
"fmt",
"io",
"strconv",
"time",
"sync",
"github.com/vektah/gqlgen/neelance/introspection",
"github.com/vektah/gqlgen/neelance/errors",
"github.com/vektah/gqlgen/neelance/query",
"github.com/vektah/gqlgen/neelance/schema",
"github.com/vektah/gqlgen/neelance/validation",
"github.com/vektah/gqlgen/graphql",
}

func buildImports(types NamedTypes, destDir string) *Imports {
imports := Imports{
{"context", "context"},
{"fmt", "fmt"},
{"io", "io"},
{"strconv", "strconv"},
{"time", "time"},
{"sync", "sync"},
{"introspection", "github.com/vektah/gqlgen/neelance/introspection"},
{"errors", "github.com/vektah/gqlgen/neelance/errors"},
{"query", "github.com/vektah/gqlgen/neelance/query"},
{"schema", "github.com/vektah/gqlgen/neelance/schema"},
{"validation", "github.com/vektah/gqlgen/neelance/validation"},
{"graphql", "github.com/vektah/gqlgen/graphql"},
destDir: destDir,
}

for _, ambient := range ambientImports {
imports.add(ambient)
}

// Imports from top level user types
for _, t := range types {
imports, t.Import = imports.addPkg(types, destDir, t.Package)
t.Import = imports.add(t.Package)
}

return imports
return &imports
}

var invalidPackageNameChar = regexp.MustCompile(`[^\w]`)
func (s *Imports) add(path string) *Import {
if path == "" {
return nil
}

func sanitizePackageName(pkg string) string {
return invalidPackageNameChar.ReplaceAllLiteralString(filepath.Base(pkg), "_")
}
if stringHasSuffixFold(s.destDir, path) {
return nil
}

if existing := s.findByPath(path); existing != nil {
return existing
}

func (s Imports) addPkg(types NamedTypes, destDir string, pkg string) (Imports, *Import) {
if pkg == "" {
return s, nil
pkg, err := build.Default.Import(path, s.destDir, 0)
if err != nil {
panic(err)
}

if existing := s.findByPkg(pkg); existing != nil {
return s, existing
imp := &Import{
Name: pkg.Name,
Path: path,
}
s.imports = append(s.imports, imp)

return imp
}

func stringHasSuffixFold(s, suffix string) bool {
return len(s) >= len(suffix) && strings.EqualFold(s[len(s)-len(suffix):], suffix)
}

func (s Imports) finalize() []*Import {
// ensure stable ordering by sorting
sort.Slice(s.imports, func(i, j int) bool {
return s.imports[i].Path > s.imports[j].Path
})

for _, imp := range s.imports {
alias := imp.Name

localName := ""
if !strings.HasSuffix(destDir, pkg) {
localName = sanitizePackageName(filepath.Base(pkg))
i := 1
imp := s.findByName(localName)
for imp != nil && imp.Package != pkg {
localName = sanitizePackageName(filepath.Base(pkg)) + strconv.Itoa(i)
imp = s.findByName(localName)
for s.findByAlias(alias) != nil {
alias = imp.Name + strconv.Itoa(i)
i++
if i > 10 {
panic("too many collisions")
panic(fmt.Errorf("too many collisions, last attempt was %s", imp.Alias))
}
}
imp.Alias = alias
}

imp := &Import{
Name: localName,
Package: pkg,
}
s = append(s, imp)
return s, imp
return s.imports
}

func (s Imports) findByPkg(pkg string) *Import {
for _, imp := range s {
if imp.Package == pkg {
func (s Imports) findByPath(importPath string) *Import {
for _, imp := range s.imports {
if imp.Path == importPath {
return imp
}
}
return nil
}

func (s Imports) findByName(name string) *Import {
for _, imp := range s {
if imp.Name == name {
func (s Imports) findByAlias(alias string) *Import {
for _, imp := range s.imports {
if imp.Alias == alias {
return imp
}
}
Expand Down
18 changes: 18 additions & 0 deletions codegen/import_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,21 @@ func TestImportCollisions(t *testing.T) {

require.NoError(t, err)
}

func TestDeterministicDecollisioning(t *testing.T) {
a := Imports{
imports: []*Import{
{Name: "types", Path: "foobar/types"},
{Name: "types", Path: "bazfoo/types"},
},
}.finalize()

b := Imports{
imports: []*Import{
{Name: "types", Path: "bazfoo/types"},
{Name: "types", Path: "foobar/types"},
},
}.finalize()

require.EqualValues(t, a, b)
}
2 changes: 1 addition & 1 deletion codegen/input_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"golang.org/x/tools/go/loader"
)

func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, imports Imports) (Objects, error) {
func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, imports *Imports) (Objects, error) {
var inputs Objects

for _, typ := range cfg.schema.Types {
Expand Down
2 changes: 1 addition & 1 deletion codegen/object_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"golang.org/x/tools/go/loader"
)

func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports Imports) (Objects, error) {
func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports *Imports) (Objects, error) {
var objects Objects

for _, typ := range cfg.schema.Types {
Expand Down
4 changes: 2 additions & 2 deletions codegen/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ func (t Ref) FullName() string {
}

func (t Ref) PkgDot() string {
if t.Import == nil || t.Import.Name == "" {
if t.Import == nil || t.Import.Alias == "" {
return ""
}
return t.Import.Name + "."
return t.Import.Alias + "."
}

func (t Type) Signature() string {
Expand Down
5 changes: 2 additions & 3 deletions codegen/type_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (cfg *Config) buildNamedTypes() NamedTypes {
return types
}

func (cfg *Config) bindTypes(imports Imports, namedTypes NamedTypes, destDir string, prog *loader.Program) Imports {
func (cfg *Config) bindTypes(imports *Imports, namedTypes NamedTypes, destDir string, prog *loader.Program) {
for _, t := range namedTypes {
if t.Package == "" {
continue
Expand All @@ -45,10 +45,9 @@ func (cfg *Config) bindTypes(imports Imports, namedTypes NamedTypes, destDir str
t.Marshaler = &cpy

t.Package, t.GoType = pkgAndType(sig.Params().At(0).Type().String())
imports, t.Import = imports.addPkg(namedTypes, destDir, t.Package)
t.Import = imports.add(t.Package)
}
}
return imports
}

// namedTypeFromSchema objects for every graphql type, including primitives.
Expand Down
8 changes: 4 additions & 4 deletions codegen/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func findField(typ *types.Struct, name string) *types.Var {
return nil
}

func bindObject(t types.Type, object *Object, imports Imports) error {
func bindObject(t types.Type, object *Object, imports *Imports) error {
namedType, ok := t.(*types.Named)
if !ok {
return errors.Errorf("expected %s to be a named struct, instead found %s", object.FullName(), t.String())
Expand Down Expand Up @@ -184,10 +184,10 @@ func bindObject(t types.Type, object *Object, imports Imports) error {

case normalizeVendor(structField.Type().Underlying().String()):
pkg, typ := pkgAndType(structField.Type().String())
imp := imports.findByPkg(pkg)
imp := imports.findByPath(pkg)
field.CastType = typ
if imp.Name != "" {
field.CastType = imp.Name + "." + typ
if imp != nil {
field.CastType = imp.Alias + "." + typ
}

default:
Expand Down

0 comments on commit e7657b9

Please sign in to comment.