Skip to content

Commit

Permalink
get package name from package source
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Jun 23, 2018
1 parent a39c63a commit abf85a1
Show file tree
Hide file tree
Showing 11 changed files with 99 additions and 66 deletions.
26 changes: 15 additions & 11 deletions codegen/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type Build struct {
Objects Objects
Inputs Objects
Interfaces []*Interface
Imports Imports
Imports []*Import
QueryRoot *Object
MutationRoot *Object
SubscriptionRoot *Object
Expand All @@ -24,7 +24,7 @@ type Build struct {

type ModelBuild struct {
PackageName string
Imports Imports
Imports []*Import
Models []Model
Enums []Enum
}
Expand All @@ -33,11 +33,11 @@ type ModelBuild struct {
func (cfg *Config) models() (*ModelBuild, error) {
namedTypes := cfg.buildNamedTypes()

imports := buildImports(namedTypes, cfg.modelDir)
prog, err := cfg.loadProgram(imports, true)
prog, err := cfg.loadProgram(namedTypes, true)
if err != nil {
return nil, errors.Wrap(err, "loading failed")
}
imports := buildImports(namedTypes, cfg.modelDir)

cfg.bindTypes(imports, namedTypes, cfg.modelDir, prog)

Expand All @@ -49,21 +49,21 @@ func (cfg *Config) models() (*ModelBuild, error) {
PackageName: cfg.ModelPackageName,
Models: models,
Enums: cfg.buildEnums(namedTypes),
Imports: buildImports(namedTypes, cfg.modelDir),
Imports: imports.imports,
}, nil
}

// bind a schema together with some code to generate a Build
func (cfg *Config) bind() (*Build, error) {
namedTypes := cfg.buildNamedTypes()

imports := buildImports(namedTypes, cfg.execDir)
prog, err := cfg.loadProgram(imports, false)
prog, err := cfg.loadProgram(namedTypes, false)
if err != nil {
return nil, errors.Wrap(err, "loading failed")
}

imports = cfg.bindTypes(imports, namedTypes, cfg.execDir, prog)
imports := buildImports(namedTypes, cfg.execDir)
cfg.bindTypes(imports, namedTypes, cfg.execDir, prog)

objects, err := cfg.buildObjects(namedTypes, prog, imports)
if err != nil {
Expand All @@ -80,7 +80,7 @@ func (cfg *Config) bind() (*Build, error) {
Objects: objects,
Interfaces: cfg.buildInterfaces(namedTypes, prog),
Inputs: inputs,
Imports: imports,
Imports: imports.imports,
}

if qr, ok := cfg.schema.EntryPoints["query"]; ok {
Expand Down Expand Up @@ -122,7 +122,7 @@ func (cfg *Config) bind() (*Build, error) {
return b, nil
}

func (cfg *Config) loadProgram(imports Imports, allowErrors bool) (*loader.Program, error) {
func (cfg *Config) loadProgram(namedTypes NamedTypes, allowErrors bool) (*loader.Program, error) {
conf := loader.Config{}
if allowErrors {
conf = loader.Config{
Expand All @@ -132,7 +132,11 @@ func (cfg *Config) loadProgram(imports Imports, allowErrors bool) (*loader.Progr
},
}
}
for _, imp := range imports {
for _, imp := range ambientImports {
conf.Import(imp)
}

for _, imp := range namedTypes {
if imp.Package != "" {
conf.Import(imp.Package)
}
Expand Down
11 changes: 7 additions & 4 deletions codegen/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ import (
)

type Import struct {
Name string
Package string
Alias string
Path string
}

type Imports []*Import
type Imports struct {
imports []*Import
destDir string
}

func (i *Import) Write() string {
return i.Name + " " + strconv.Quote(i.Package)
return i.Alias + " " + strconv.Quote(i.Path)
}
101 changes: 64 additions & 37 deletions codegen/import_build.go
Original file line number Diff line number Diff line change
@@ -1,33 +1,47 @@
package codegen

import (
"go/build"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
)

func buildImports(types NamedTypes, destDir string) Imports {
// These imports are referenced by the generated code, and are assumed to have the
// default alias. So lets make sure they get added first, and any later collisions get
// renamed.
var ambientImports = []string{
"context",
"fmt",
"io",
"strconv",
"time",
"sync",
"github.com/vektah/gqlgen/neelance/introspection",
"github.com/vektah/gqlgen/neelance/errors",
"github.com/vektah/gqlgen/neelance/query",
"github.com/vektah/gqlgen/neelance/schema",
"github.com/vektah/gqlgen/neelance/validation",
"github.com/vektah/gqlgen/graphql",
}

func buildImports(types NamedTypes, destDir string) *Imports {
imports := Imports{
{"context", "context"},
{"fmt", "fmt"},
{"io", "io"},
{"strconv", "strconv"},
{"time", "time"},
{"sync", "sync"},
{"introspection", "github.com/vektah/gqlgen/neelance/introspection"},
{"errors", "github.com/vektah/gqlgen/neelance/errors"},
{"query", "github.com/vektah/gqlgen/neelance/query"},
{"schema", "github.com/vektah/gqlgen/neelance/schema"},
{"validation", "github.com/vektah/gqlgen/neelance/validation"},
{"graphql", "github.com/vektah/gqlgen/graphql"},
destDir: destDir,
}

for _, ambient := range ambientImports {
imports.add(ambient)
}

// Imports from top level user types
for _, t := range types {
imports, t.Import = imports.addPkg(types, destDir, t.Package)
t.Import = imports.add(t.Package)
}

return imports
return &imports
}

var invalidPackageNameChar = regexp.MustCompile(`[^\w]`)
Expand All @@ -36,23 +50,32 @@ func sanitizePackageName(pkg string) string {
return invalidPackageNameChar.ReplaceAllLiteralString(filepath.Base(pkg), "_")
}

func (s Imports) addPkg(types NamedTypes, destDir string, pkg string) (Imports, *Import) {
if pkg == "" {
return s, nil
func (s *Imports) add(path string) *Import {
if path == "" {
return nil
}

if existing := s.findByPkg(pkg); existing != nil {
return s, existing
if existing := s.findByPath(path); existing != nil {
return existing
}

localName := ""
if !strings.HasSuffix(destDir, pkg) {
localName = sanitizePackageName(filepath.Base(pkg))
pkg, err := build.Default.Import(path, s.destDir, 0)
if err != nil {
panic(err)
}

alias := ""
if !strings.HasSuffix(s.destDir, path) {
if pkg == nil {
panic(path + " was not loaded")
}

alias = pkg.Name
i := 1
imp := s.findByName(localName)
for imp != nil && imp.Package != pkg {
localName = sanitizePackageName(filepath.Base(pkg)) + strconv.Itoa(i)
imp = s.findByName(localName)
imp := s.findByAlias(alias)
for imp != nil && imp.Path != path {
alias = pkg.Name + strconv.Itoa(i)
imp = s.findByAlias(alias)
i++
if i > 10 {
panic("too many collisions")
Expand All @@ -61,25 +84,29 @@ func (s Imports) addPkg(types NamedTypes, destDir string, pkg string) (Imports,
}

imp := &Import{
Name: localName,
Package: pkg,
Alias: alias,
Path: path,
}
s = append(s, imp)
return s, imp
s.imports = append(s.imports, imp)
sort.Slice(s.imports, func(i, j int) bool {
return s.imports[i].Alias > s.imports[j].Alias
})

return imp
}

func (s Imports) findByPkg(pkg string) *Import {
for _, imp := range s {
if imp.Package == pkg {
func (s Imports) findByPath(importPath string) *Import {
for _, imp := range s.imports {
if imp.Path == importPath {
return imp
}
}
return nil
}

func (s Imports) findByName(name string) *Import {
for _, imp := range s {
if imp.Name == name {
func (s Imports) findByAlias(alias string) *Import {
for _, imp := range s.imports {
if imp.Alias == alias {
return imp
}
}
Expand Down
2 changes: 1 addition & 1 deletion codegen/input_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"golang.org/x/tools/go/loader"
)

func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, imports Imports) (Objects, error) {
func (cfg *Config) buildInputs(namedTypes NamedTypes, prog *loader.Program, imports *Imports) (Objects, error) {
var inputs Objects

for _, typ := range cfg.schema.Types {
Expand Down
2 changes: 1 addition & 1 deletion codegen/object_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"golang.org/x/tools/go/loader"
)

func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports Imports) (Objects, error) {
func (cfg *Config) buildObjects(types NamedTypes, prog *loader.Program, imports *Imports) (Objects, error) {
var objects Objects

for _, typ := range cfg.schema.Types {
Expand Down
4 changes: 2 additions & 2 deletions codegen/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ func (t Ref) FullName() string {
}

func (t Ref) PkgDot() string {
if t.Import == nil || t.Import.Name == "" {
if t.Import == nil || t.Import.Alias == "" {
return ""
}
return t.Import.Name + "."
return t.Import.Alias + "."
}

func (t Type) Signature() string {
Expand Down
5 changes: 2 additions & 3 deletions codegen/type_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (cfg *Config) buildNamedTypes() NamedTypes {
return types
}

func (cfg *Config) bindTypes(imports Imports, namedTypes NamedTypes, destDir string, prog *loader.Program) Imports {
func (cfg *Config) bindTypes(imports *Imports, namedTypes NamedTypes, destDir string, prog *loader.Program) {
for _, t := range namedTypes {
if t.Package == "" {
continue
Expand All @@ -45,10 +45,9 @@ func (cfg *Config) bindTypes(imports Imports, namedTypes NamedTypes, destDir str
t.Marshaler = &cpy

t.Package, t.GoType = pkgAndType(sig.Params().At(0).Type().String())
imports, t.Import = imports.addPkg(namedTypes, destDir, t.Package)
t.Import = imports.add(t.Package)
}
}
return imports
}

// namedTypeFromSchema objects for every graphql type, including primitives.
Expand Down
8 changes: 4 additions & 4 deletions codegen/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func findField(typ *types.Struct, name string) *types.Var {
return nil
}

func bindObject(t types.Type, object *Object, imports Imports) error {
func bindObject(t types.Type, object *Object, imports *Imports) error {
namedType, ok := t.(*types.Named)
if !ok {
return errors.Errorf("expected %s to be a named struct, instead found %s", object.FullName(), t.String())
Expand Down Expand Up @@ -184,10 +184,10 @@ func bindObject(t types.Type, object *Object, imports Imports) error {

case normalizeVendor(structField.Type().Underlying().String()):
pkg, typ := pkgAndType(structField.Type().String())
imp := imports.findByPkg(pkg)
imp := imports.findByPath(pkg)
field.CastType = typ
if imp.Name != "" {
field.CastType = imp.Name + "." + typ
if imp.Alias != "" {
field.CastType = imp.Alias + "." + typ
}

default:
Expand Down
2 changes: 1 addition & 1 deletion example/scalars/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion example/scalars/models_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion test/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit abf85a1

Please sign in to comment.