Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

single packages.Load for NameForPackage #944

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions codegen/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"sort"

"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/internal/code"
"github.com/pkg/errors"
"github.com/vektah/gqlparser/ast"
"golang.org/x/tools/go/packages"
)

// Data is a unified model of the code to be generated. Plugins may modify this structure to do things like implement
Expand All @@ -25,6 +27,9 @@ type Data struct {
QueryRoot *Object
MutationRoot *Object
SubscriptionRoot *Object

// This is important for looking up packages during code generation
NameForPackage code.NameForPackage
}

type builder struct {
Expand Down Expand Up @@ -75,12 +80,18 @@ func BuildData(cfg *config.Config) (*Data, error) {
}
}

pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName}, cfg.Models.ReferencedPackages()...)
if err != nil {
return nil, errors.Wrap(err, "loading failed")
}

s := Data{
Config: cfg,
Directives: dataDirectives,
Schema: b.Schema,
SchemaStr: b.SchemaStr,
Interfaces: map[string]*Interface{},
Config: cfg,
Directives: dataDirectives,
Schema: b.Schema,
SchemaStr: b.SchemaStr,
Interfaces: map[string]*Interface{},
NameForPackage: code.NewNameForPackage(pkgs),
}

for _, schemaType := range b.Schema.Types {
Expand Down
1 change: 1 addition & 0 deletions codegen/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ func GenerateCode(data *Data) error {
Data: data,
RegionTags: true,
GeneratedHeader: true,
NameForPackage: data.NameForPackage,
})
}
26 changes: 15 additions & 11 deletions codegen/templates/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ import (
)

type Import struct {
Name string
Path string
Alias string
NameForPackage code.NameForPackage
Name string
Path string
Alias string
}

type Imports struct {
imports []*Import
destDir string
nameForPackage code.NameForPackage
imports []*Import
destDir string
}

func (i *Import) String() string {
Expand Down Expand Up @@ -49,7 +51,7 @@ func (s *Imports) Reserve(path string, aliases ...string) (string, error) {
return "", nil
}

name := code.NameForPackage(path)
name := s.nameForPackage.Get(path)
var alias string
if len(aliases) != 1 {
alias = name
Expand All @@ -69,9 +71,10 @@ func (s *Imports) Reserve(path string, aliases ...string) (string, error) {
}

s.imports = append(s.imports, &Import{
Name: name,
Path: path,
Alias: alias,
NameForPackage: s.nameForPackage,
Name: name,
Path: path,
Alias: alias,
})

return "", nil
Expand All @@ -94,8 +97,9 @@ func (s *Imports) Lookup(path string) string {
}

imp := &Import{
Name: code.NameForPackage(path),
Path: path,
NameForPackage: s.nameForPackage,
Name: s.nameForPackage.Get(path),
Path: path,
}
s.imports = append(s.imports, imp)

Expand Down
19 changes: 13 additions & 6 deletions codegen/templates/import_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (
"os"
"testing"

"github.com/99designs/gqlgen/internal/code"
"github.com/stretchr/testify/require"
"golang.org/x/tools/go/packages"
)

func TestImports(t *testing.T) {
Expand All @@ -16,15 +18,20 @@ func TestImports(t *testing.T) {
bBar := "github.com/99designs/gqlgen/codegen/templates/testdata/b/bar"
mismatch := "github.com/99designs/gqlgen/codegen/templates/testdata/pkg_mismatch"

ps, err := packages.Load(nil, aBar, bBar, mismatch)
require.NoError(t, err)

nameForPackage := code.NewNameForPackage(ps)

t.Run("multiple lookups is ok", func(t *testing.T) {
a := Imports{destDir: wd}
a := Imports{nameForPackage: nameForPackage, destDir: wd}

require.Equal(t, "bar", a.Lookup(aBar))
require.Equal(t, "bar", a.Lookup(aBar))
})

t.Run("lookup by type", func(t *testing.T) {
a := Imports{destDir: wd}
a := Imports{nameForPackage: nameForPackage, destDir: wd}

pkg := types.NewPackage("github.com/99designs/gqlgen/codegen/templates/testdata/b/bar", "bar")
typ := types.NewNamed(types.NewTypeName(0, pkg, "Boolean", types.Typ[types.Bool]), types.Typ[types.Bool], nil)
Expand All @@ -33,7 +40,7 @@ func TestImports(t *testing.T) {
})

t.Run("duplicates are decollisioned", func(t *testing.T) {
a := Imports{destDir: wd}
a := Imports{nameForPackage: nameForPackage, destDir: wd}

require.Equal(t, "bar", a.Lookup(aBar))
require.Equal(t, "bar1", a.Lookup(bBar))
Expand All @@ -44,13 +51,13 @@ func TestImports(t *testing.T) {
})

t.Run("package name defined in code will be used", func(t *testing.T) {
a := Imports{destDir: wd}
a := Imports{nameForPackage: nameForPackage, destDir: wd}

require.Equal(t, "turtles", a.Lookup(mismatch))
})

t.Run("string printing for import block", func(t *testing.T) {
a := Imports{destDir: wd}
a := Imports{nameForPackage: nameForPackage, destDir: wd}
a.Lookup(aBar)
a.Lookup(bBar)
a.Lookup(mismatch)
Expand All @@ -65,7 +72,7 @@ turtles "github.com/99designs/gqlgen/codegen/templates/testdata/pkg_mismatch"`,
})

t.Run("aliased imports will not collide", func(t *testing.T) {
a := Imports{destDir: wd}
a := Imports{nameForPackage: nameForPackage, destDir: wd}

_, _ = a.Reserve(aBar, "abar")
_, _ = a.Reserve(bBar, "bbar")
Expand Down
12 changes: 8 additions & 4 deletions codegen/templates/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"text/template"
"unicode"

"github.com/99designs/gqlgen/internal/code"
"github.com/99designs/gqlgen/internal/imports"
"github.com/pkg/errors"
)
Expand Down Expand Up @@ -43,6 +44,9 @@ type Options struct {
// Data will be passed to the template execution.
Data interface{}
Funcs template.FuncMap

// Lookups for pre-cached package names
NameForPackage code.NameForPackage
}

// Render renders a gql plugin template from the given Options. Render is an
Expand All @@ -53,7 +57,7 @@ func Render(cfg Options) error {
if CurrentImports != nil {
panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected"))
}
CurrentImports = &Imports{destDir: filepath.Dir(cfg.Filename)}
CurrentImports = &Imports{nameForPackage: cfg.NameForPackage, destDir: filepath.Dir(cfg.Filename)}

// load path relative to calling source file
_, callerFile, _, _ := runtime.Caller(1)
Expand Down Expand Up @@ -143,7 +147,7 @@ func Render(cfg Options) error {
}
CurrentImports = nil

return write(cfg.Filename, result.Bytes())
return write(cfg.Filename, result.Bytes(), cfg.NameForPackage)
}

func center(width int, pad string, s string) string {
Expand Down Expand Up @@ -551,13 +555,13 @@ func render(filename string, tpldata interface{}) (*bytes.Buffer, error) {
return buf, t.Execute(buf, tpldata)
}

func write(filename string, b []byte) error {
func write(filename string, b []byte, nameForPackage code.NameForPackage) error {
err := os.MkdirAll(filepath.Dir(filename), 0755)
if err != nil {
return errors.Wrap(err, "failed to create directory")
}

formatted, err := imports.Prune(filename, b)
formatted, err := imports.Prune(filename, b, nameForPackage)
if err != nil {
fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error())
formatted = b
Expand Down
36 changes: 26 additions & 10 deletions internal/code/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ import (
"golang.org/x/tools/go/packages"
)

var nameForPackageCache = sync.Map{}

var gopaths []string

func init() {
Expand Down Expand Up @@ -93,23 +91,41 @@ func ImportPathForDir(dir string) (res string) {
var modregex = regexp.MustCompile("module (.*)\n")

// NameForPackage returns the package name for a given import path. This can be really slow.
func NameForPackage(importPath string) string {
type NameForPackage struct {
cache *sync.Map
packages []*packages.Package
}

// NewNameForPackage creates a NameForPackage
func NewNameForPackage(packages []*packages.Package) NameForPackage {
return NameForPackage{
cache: &sync.Map{},
packages: packages,
}
}

// Get returns the package name for a given import path. This can be really slow.
func (n NameForPackage) Get(importPath string) string {
if importPath == "" {
panic(errors.New("import path can not be empty"))
}
if v, ok := nameForPackageCache.Load(importPath); ok {

if v, ok := n.cache.Load(importPath); ok {
return v.(string)
}
importPath = QualifyPackagePath(importPath)
p, _ := packages.Load(&packages.Config{
Mode: packages.NeedName,
}, importPath)
var p *packages.Package
for _, pkg := range n.packages {
if pkg.PkgPath == importPath {
p = pkg
}
}

if len(p) != 1 || p[0].Name == "" {
if p == nil || p.Name == "" {
return SanitizePackageName(filepath.Base(importPath))
}

nameForPackageCache.Store(importPath, p[0].Name)
n.cache.Store(importPath, p.Name)

return p[0].Name
return p.Name
}
11 changes: 8 additions & 3 deletions internal/code/imports_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/tools/go/packages"
)

func TestImportPathForDir(t *testing.T) {
Expand All @@ -31,11 +32,15 @@ func TestImportPathForDir(t *testing.T) {
}

func TestNameForPackage(t *testing.T) {
assert.Equal(t, "api", NameForPackage("github.com/99designs/gqlgen/api"))
ps, _ := packages.Load(&packages.Config{Mode: packages.NeedName},
"github.com/99designs/gqlgen/api", "github.com/99designs/gqlgen/docs", "github.com")
nfp := NewNameForPackage(ps)

assert.Equal(t, "api", nfp.Get("github.com/99designs/gqlgen/api"))

// does not contain go code, should still give a valid name
assert.Equal(t, "docs", NameForPackage("github.com/99designs/gqlgen/docs"))
assert.Equal(t, "github_com", NameForPackage("github.com"))
assert.Equal(t, "docs", nfp.Get("github.com/99designs/gqlgen/docs"))
assert.Equal(t, "github_com", nfp.Get("github.com"))
}

func TestNameForDir(t *testing.T) {
Expand Down
8 changes: 4 additions & 4 deletions internal/imports/prune.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ func (fn visitFn) Visit(node ast.Node) ast.Visitor {
}

// Prune removes any unused imports
func Prune(filename string, src []byte) ([]byte, error) {
func Prune(filename string, src []byte, nameForPackage code.NameForPackage) ([]byte, error) {
fset := token.NewFileSet()

file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.AllErrors)
if err != nil {
return nil, err
}

unused := getUnusedImports(file)
unused := getUnusedImports(file, nameForPackage)
for ipath, name := range unused {
astutil.DeleteNamedImport(fset, file, name, ipath)
}
Expand All @@ -46,7 +46,7 @@ func Prune(filename string, src []byte) ([]byte, error) {
return imports.Process(filename, buf.Bytes(), &imports.Options{FormatOnly: true, Comments: true, TabIndent: true, TabWidth: 8})
}

func getUnusedImports(file ast.Node) map[string]string {
func getUnusedImports(file ast.Node, nameForPackage code.NameForPackage) map[string]string {
imported := map[string]*ast.ImportSpec{}
used := map[string]bool{}

Expand All @@ -65,7 +65,7 @@ func getUnusedImports(file ast.Node) map[string]string {
break
}

local := code.NameForPackage(ipath)
local := nameForPackage.Get(ipath)

imported[local] = v
case *ast.SelectorExpr:
Expand Down
3 changes: 2 additions & 1 deletion internal/imports/prune_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"io/ioutil"
"testing"

"github.com/99designs/gqlgen/internal/code"
"github.com/stretchr/testify/require"
)

func TestPrune(t *testing.T) {
b, err := Prune("testdata/unused.go", mustReadFile("testdata/unused.go"))
b, err := Prune("testdata/unused.go", mustReadFile("testdata/unused.go"), code.NewNameForPackage(nil))
require.NoError(t, err)
require.Equal(t, string(mustReadFile("testdata/unused.expected.go")), string(b))
}
Expand Down
9 changes: 9 additions & 0 deletions plugin/modelgen/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ import (

"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/99designs/gqlgen/internal/code"
"github.com/99designs/gqlgen/plugin"
"github.com/pkg/errors"
"github.com/vektah/gqlparser/ast"
"golang.org/x/tools/go/packages"
)

type BuildMutateHook = func(b *ModelBuild) *ModelBuild
Expand Down Expand Up @@ -235,11 +238,17 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
b = m.MutateHook(b)
}

pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName}, cfg.Models.ReferencedPackages()...)
if err != nil {
return errors.Wrap(err, "loading failed")
}

return templates.Render(templates.Options{
PackageName: cfg.Model.Package,
Filename: cfg.Model.Filename,
Data: b,
GeneratedHeader: true,
NameForPackage: code.NewNameForPackage(pkgs),
})
}

Expand Down
Loading