diff --git a/codegen/data.go b/codegen/data.go index ef92ec8b48f..d3b191f8f94 100644 --- a/codegen/data.go +++ b/codegen/data.go @@ -6,9 +6,11 @@ import ( "sort" "github.com/99designs/gqlgen/codegen/config" + "github.com/99designs/gqlgen/internal/code" "github.com/pkg/errors" "github.com/vektah/gqlparser/ast" "github.com/vektah/gqlparser/formatter" + "golang.org/x/tools/go/packages" ) // Data is a unified model of the code to be generated. Plugins may modify this structure to do things like implement @@ -88,6 +90,12 @@ func BuildData(cfg *config.Config, plugins []SchemaMutator) (*Data, error) { } } + pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName}, cfg.Models.ReferencedPackages()...) + if err != nil { + return nil, errors.Wrap(err, "loading failed") + } + code.RecordPackagesList(pkgs) + s := Data{ Config: cfg, Directives: dataDirectives, diff --git a/codegen/templates/import_test.go b/codegen/templates/import_test.go index d225457576f..440b59147c5 100644 --- a/codegen/templates/import_test.go +++ b/codegen/templates/import_test.go @@ -5,7 +5,9 @@ import ( "os" "testing" + "github.com/99designs/gqlgen/internal/code" "github.com/stretchr/testify/require" + "golang.org/x/tools/go/packages" ) func TestImports(t *testing.T) { @@ -16,6 +18,11 @@ func TestImports(t *testing.T) { bBar := "github.com/99designs/gqlgen/codegen/templates/testdata/b/bar" mismatch := "github.com/99designs/gqlgen/codegen/templates/testdata/pkg_mismatch" + ps, err := packages.Load(nil, aBar, bBar, mismatch) + require.NoError(t, err) + + code.RecordPackagesList(ps) + t.Run("multiple lookups is ok", func(t *testing.T) { a := Imports{destDir: wd} diff --git a/internal/code/imports.go b/internal/code/imports.go index 3660079633a..10b325ba6f1 100644 --- a/internal/code/imports.go +++ b/internal/code/imports.go @@ -2,6 +2,7 @@ package code import ( "errors" + "fmt" "go/build" "go/parser" "go/token" @@ -14,7 +15,8 @@ import ( "golang.org/x/tools/go/packages" ) -var nameForPackageCache = sync.Map{} +var nameForPackageCacheLock sync.Mutex +var nameForPackageCache []*packages.Package var gopaths []string @@ -107,24 +109,32 @@ func ImportPathForDir(dir string) (res string) { var modregex = regexp.MustCompile("module (.*)\n") +// RecordPackagesList records the list of packages to be used later by NameForPackage. +// It must be called exactly once during initialization, before NameForPackage is called. +func RecordPackagesList(newNameForPackageCache []*packages.Package) { + nameForPackageCache = newNameForPackageCache +} + // NameForPackage returns the package name for a given import path. This can be really slow. func NameForPackage(importPath string) string { if importPath == "" { panic(errors.New("import path can not be empty")) } - if v, ok := nameForPackageCache.Load(importPath); ok { - return v.(string) + if nameForPackageCache == nil { + panic(fmt.Errorf("NameForPackage called for %s before RecordPackagesList", importPath)) + } + nameForPackageCacheLock.Lock() + defer nameForPackageCacheLock.Unlock() + var p *packages.Package + for _, pkg := range nameForPackageCache { + if pkg.PkgPath == importPath { + p = pkg + break + } } - importPath = QualifyPackagePath(importPath) - p, _ := packages.Load(&packages.Config{ - Mode: packages.NeedName, - }, importPath) - if len(p) != 1 || p[0].Name == "" { + if p == nil || p.Name == "" { return SanitizePackageName(filepath.Base(importPath)) } - - nameForPackageCache.Store(importPath, p[0].Name) - - return p[0].Name + return p.Name } diff --git a/internal/code/imports_test.go b/internal/code/imports_test.go index e3bc9474f99..f7e6f24a5a5 100644 --- a/internal/code/imports_test.go +++ b/internal/code/imports_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/tools/go/packages" ) func TestImportPathForDir(t *testing.T) { @@ -31,11 +32,17 @@ func TestImportPathForDir(t *testing.T) { } func TestNameForPackage(t *testing.T) { - assert.Equal(t, "api", NameForPackage("github.com/99designs/gqlgen/api")) + testPkg1 := "github.com/99designs/gqlgen/api" + testPkg2 := "github.com/99designs/gqlgen/docs" + testPkg3 := "github.com" + ps, err := packages.Load(nil, testPkg1, testPkg2, testPkg3) + require.NoError(t, err) + RecordPackagesList(ps) + assert.Equal(t, "api", NameForPackage(testPkg1)) // does not contain go code, should still give a valid name - assert.Equal(t, "docs", NameForPackage("github.com/99designs/gqlgen/docs")) - assert.Equal(t, "github_com", NameForPackage("github.com")) + assert.Equal(t, "docs", NameForPackage(testPkg2)) + assert.Equal(t, "github_com", NameForPackage(testPkg3)) } func TestNameForDir(t *testing.T) { diff --git a/internal/imports/prune_test.go b/internal/imports/prune_test.go index d0691bf242e..5f1563e510b 100644 --- a/internal/imports/prune_test.go +++ b/internal/imports/prune_test.go @@ -4,10 +4,15 @@ import ( "io/ioutil" "testing" + "github.com/99designs/gqlgen/internal/code" "github.com/stretchr/testify/require" + "golang.org/x/tools/go/packages" ) func TestPrune(t *testing.T) { + // prime the packages cache so that it's not considered uninitialized + code.RecordPackagesList([]*packages.Package{}) + b, err := Prune("testdata/unused.go", mustReadFile("testdata/unused.go")) require.NoError(t, err) require.Equal(t, string(mustReadFile("testdata/unused.expected.go")), string(b)) diff --git a/plugin/modelgen/models.go b/plugin/modelgen/models.go index 3458b42bdf5..8e976c65bf1 100644 --- a/plugin/modelgen/models.go +++ b/plugin/modelgen/models.go @@ -7,8 +7,11 @@ import ( "github.com/99designs/gqlgen/codegen/config" "github.com/99designs/gqlgen/codegen/templates" + "github.com/99designs/gqlgen/internal/code" "github.com/99designs/gqlgen/plugin" + "github.com/pkg/errors" "github.com/vektah/gqlparser/ast" + "golang.org/x/tools/go/packages" ) type BuildMutateHook = func(b *ModelBuild) *ModelBuild @@ -246,6 +249,12 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error { b = m.MutateHook(b) } + pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName}, cfg.Models.ReferencedPackages()...) + if err != nil { + return errors.Wrap(err, "loading failed") + } + code.RecordPackagesList(pkgs) + return templates.Render(templates.Options{ PackageName: cfg.Model.Package, Filename: cfg.Model.Filename,