diff --git a/codegen/config/binder_test.go b/codegen/config/binder_test.go index fed3913adba..ae6b3027fbc 100644 --- a/codegen/config/binder_test.go +++ b/codegen/config/binder_test.go @@ -187,7 +187,7 @@ func createBinder(cfg Config) (*Binder, *ast.Schema) { Model: []string{"github.com/99designs/gqlgen/graphql.String"}, }, } - cfg.Packages = &code.Packages{} + cfg.Packages = code.NewPackages() cfg.Schema = gqlparser.MustLoadSchema(&ast.Source{Name: "TestAutobinding.schema", Input: ` type Message { id: ID } diff --git a/codegen/config/config.go b/codegen/config/config.go index 8a7205f06bf..9f4be075ff9 100644 --- a/codegen/config/config.go +++ b/codegen/config/config.go @@ -26,6 +26,7 @@ type Config struct { Models TypeMap `yaml:"models,omitempty"` StructTag string `yaml:"struct_tag,omitempty"` Directives map[string]DirectiveConfig `yaml:"directives,omitempty"` + GoBuildTags StringList `yaml:"go_build_tags,omitempty"` GoInitialisms GoInitialismsConfig `yaml:"go_initialisms,omitempty"` OmitSliceElementPointers bool `yaml:"omit_slice_element_pointers,omitempty"` OmitGetters bool `yaml:"omit_getters,omitempty"` @@ -211,7 +212,9 @@ func CompleteConfig(config *Config) error { func (c *Config) Init() error { if c.Packages == nil { - c.Packages = &code.Packages{} + c.Packages = code.NewPackages( + code.WithBuildTags(c.GoBuildTags...), + ) } if c.Schema == nil { @@ -671,7 +674,9 @@ func (c *Config) injectBuiltins() { func (c *Config) LoadSchema() error { if c.Packages != nil { - c.Packages = &code.Packages{} + c.Packages = code.NewPackages( + code.WithBuildTags(c.GoBuildTags...), + ) } if err := c.check(); err != nil { diff --git a/codegen/config/config_test.go b/codegen/config/config_test.go index 2ed5d5ac86a..68e30d4b953 100644 --- a/codegen/config/config_test.go +++ b/codegen/config/config_test.go @@ -192,7 +192,7 @@ func TestAutobinding(t *testing.T) { "github.com/99designs/gqlgen/codegen/config/testdata/autobinding/chat", "github.com/99designs/gqlgen/codegen/config/testdata/autobinding/scalars/model", }, - Packages: &code.Packages{}, + Packages: code.NewPackages(), } cfg.Schema = gqlparser.MustLoadSchema(&ast.Source{Name: "TestAutobinding.schema", Input: ` @@ -212,7 +212,7 @@ func TestAutobinding(t *testing.T) { AutoBind: []string{ "../chat", }, - Packages: &code.Packages{}, + Packages: code.NewPackages(), } cfg.Schema = gqlparser.MustLoadSchema(&ast.Source{Name: "TestAutobinding.schema", Input: ` diff --git a/codegen/templates/import_test.go b/codegen/templates/import_test.go index baf88b50ab2..483fd5ece82 100644 --- a/codegen/templates/import_test.go +++ b/codegen/templates/import_test.go @@ -20,14 +20,14 @@ func TestImports(t *testing.T) { mismatch := "github.com/99designs/gqlgen/codegen/templates/testdata/pkg_mismatch" t.Run("multiple lookups is ok", func(t *testing.T) { - a := Imports{destDir: wd, packages: &code.Packages{}} + a := Imports{destDir: wd, packages: code.NewPackages()} require.Equal(t, "bar", a.Lookup(aBar)) require.Equal(t, "bar", a.Lookup(aBar)) }) t.Run("lookup by type", func(t *testing.T) { - a := Imports{destDir: wd, packages: &code.Packages{}} + a := Imports{destDir: wd, packages: code.NewPackages()} pkg := types.NewPackage("github.com/99designs/gqlgen/codegen/templates/testdata/b/bar", "bar") typ := types.NewNamed(types.NewTypeName(0, pkg, "Boolean", types.Typ[types.Bool]), types.Typ[types.Bool], nil) @@ -36,7 +36,7 @@ func TestImports(t *testing.T) { }) t.Run("duplicates are decollisioned", func(t *testing.T) { - a := Imports{destDir: wd, packages: &code.Packages{}} + a := Imports{destDir: wd, packages: code.NewPackages()} require.Equal(t, "bar", a.Lookup(aBar)) require.Equal(t, "bar1", a.Lookup(bBar)) @@ -47,7 +47,7 @@ func TestImports(t *testing.T) { }) t.Run("duplicates above 10 are decollisioned", func(t *testing.T) { - a := Imports{destDir: wd, packages: &code.Packages{}} + a := Imports{destDir: wd, packages: code.NewPackages()} for i := 0; i < 100; i++ { cBar := fmt.Sprintf("github.com/99designs/gqlgen/codegen/templates/testdata/%d/bar", i) if i > 0 { @@ -59,13 +59,13 @@ func TestImports(t *testing.T) { }) t.Run("package name defined in code will be used", func(t *testing.T) { - a := Imports{destDir: wd, packages: &code.Packages{}} + a := Imports{destDir: wd, packages: code.NewPackages()} require.Equal(t, "turtles", a.Lookup(mismatch)) }) t.Run("string printing for import block", func(t *testing.T) { - a := Imports{destDir: wd, packages: &code.Packages{}} + a := Imports{destDir: wd, packages: code.NewPackages()} a.Lookup(aBar) a.Lookup(bBar) a.Lookup(mismatch) @@ -80,7 +80,7 @@ turtles "github.com/99designs/gqlgen/codegen/templates/testdata/pkg_mismatch"`, }) t.Run("aliased imports will not collide", func(t *testing.T) { - a := Imports{destDir: wd, packages: &code.Packages{}} + a := Imports{destDir: wd, packages: code.NewPackages()} _, _ = a.Reserve(aBar, "abar") _, _ = a.Reserve(bBar, "bbar") diff --git a/codegen/templates/templates_test.go b/codegen/templates/templates_test.go index 2f57fc9ca15..12ccbb23c10 100644 --- a/codegen/templates/templates_test.go +++ b/codegen/templates/templates_test.go @@ -326,7 +326,7 @@ func TestTemplateOverride(t *testing.T) { } defer f.Close() defer os.RemoveAll(f.Name()) - err = Render(Options{Template: "hello", Filename: f.Name(), Packages: &code.Packages{}}) + err = Render(Options{Template: "hello", Filename: f.Name(), Packages: code.NewPackages()}) if err != nil { t.Fatal(err) } @@ -346,7 +346,7 @@ func TestRenderFS(t *testing.T) { } defer f.Close() defer os.RemoveAll(f.Name()) - err = Render(Options{TemplateFS: templateFS, Filename: f.Name(), Packages: &code.Packages{}}) + err = Render(Options{TemplateFS: templateFS, Filename: f.Name(), Packages: code.NewPackages()}) if err != nil { t.Fatal(err) } diff --git a/docs/content/config.md b/docs/content/config.md index 1d167e2ef0e..547c62b8d81 100644 --- a/docs/content/config.md +++ b/docs/content/config.md @@ -84,6 +84,11 @@ resolver: # Optional: set to skip running `go mod tidy` when generating server code # skip_mod_tidy: true +# Optional: set build tags that will be used to load packages +# go_build_tags: +# - private +# - enterprise + # Optional: set to modify the initialisms regarded for Go names # go_initialisms: # replace_defaults: false # if true, the default initialisms will get dropped in favor of the new ones instead of being added diff --git a/internal/code/packages.go b/internal/code/packages.go index 6cb6ef2323f..e7f8655aa01 100644 --- a/internal/code/packages.go +++ b/internal/code/packages.go @@ -28,15 +28,37 @@ var mode = packages.NeedName | packages.NeedModule | packages.NeedDeps -// Packages is a wrapper around x/tools/go/packages that maintains a (hopefully prewarmed) cache of packages -// that can be invalidated as writes are made and packages are known to change. -type Packages struct { - packages map[string]*packages.Package - importToName map[string]string - loadErrors []error - - numLoadCalls int // stupid test steam. ignore. - numNameCalls int // stupid test steam. ignore. +type ( + // Packages is a wrapper around x/tools/go/packages that maintains a (hopefully prewarmed) cache of packages + // that can be invalidated as writes are made and packages are known to change. + Packages struct { + packages map[string]*packages.Package + importToName map[string]string + loadErrors []error + buildFlags []string + + numLoadCalls int // stupid test steam. ignore. + numNameCalls int // stupid test steam. ignore. + } + // Option is a function that can be passed to NewPackages to configure the package loader + Option func(p *Packages) +) + +// WithBuildTags adds build tags to the packages.Load call +func WithBuildTags(tags ...string) func(p *Packages) { + return func(p *Packages) { + p.buildFlags = append(p.buildFlags, "-tags", strings.Join(tags, ",")) + } +} + +// NewPackages creates a new packages cache +// It will load all packages in the current module, and any packages that are passed to Load or LoadAll +func NewPackages(opts ...Option) *Packages { + p := &Packages{} + for _, opt := range opts { + opt(p) + } + return p } func (p *Packages) CleanupUserPackages() { @@ -47,8 +69,8 @@ func (p *Packages) CleanupUserPackages() { modInfo = nil } }) - - // Don't cleanup github.com/99designs/gqlgen prefixed packages, they haven't changed and do not need to be reloaded + // Don't cleanup github.com/99designs/gqlgen prefixed packages, + // they haven't changed and do not need to be reloaded if modInfo != nil { var toRemove []string for k := range p.packages { @@ -56,7 +78,6 @@ func (p *Packages) CleanupUserPackages() { toRemove = append(toRemove, k) } } - for _, k := range toRemove { delete(p.packages, k) } @@ -91,7 +112,10 @@ func (p *Packages) LoadAll(importPaths ...string) []*packages.Package { if len(missing) > 0 { p.numLoadCalls++ - pkgs, err := packages.Load(&packages.Config{Mode: mode}, missing...) + pkgs, err := packages.Load(&packages.Config{ + Mode: mode, + BuildFlags: p.buildFlags, + }, missing...) if err != nil { p.loadErrors = append(p.loadErrors, err) } @@ -140,7 +164,10 @@ func (p *Packages) LoadWithTypes(importPath string) *packages.Package { pkg := p.Load(importPath) if pkg == nil || pkg.TypesInfo == nil { p.numLoadCalls++ - pkgs, err := packages.Load(&packages.Config{Mode: mode}, importPath) + pkgs, err := packages.Load(&packages.Config{ + Mode: mode, + BuildFlags: p.buildFlags, + }, importPath) if err != nil { p.loadErrors = append(p.loadErrors, err) return nil @@ -173,7 +200,10 @@ func (p *Packages) NameForPackage(importPath string) string { if pkg == nil { // otherwise do a name only lookup for it but don't put it in the package cache. p.numNameCalls++ - pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName}, importPath) + pkgs, err := packages.Load(&packages.Config{ + Mode: packages.NeedName, + BuildFlags: p.buildFlags, + }, importPath) if err != nil { p.loadErrors = append(p.loadErrors, err) } else { diff --git a/internal/code/packages_test.go b/internal/code/packages_test.go index 2fbf780c55e..1d67419eaea 100644 --- a/internal/code/packages_test.go +++ b/internal/code/packages_test.go @@ -38,6 +38,14 @@ func TestPackages(t *testing.T) { require.Equal(t, "b", p.Load("github.com/99designs/gqlgen/internal/code/testdata/b").Name) require.Equal(t, 3, p.numLoadCalls) }) + t.Run("able to load private package with build tags", func(t *testing.T) { + p := initialState(t, WithBuildTags("private")) + p.Evict("github.com/99designs/gqlgen/internal/code/testdata/a") + require.Equal(t, "a", p.Load("github.com/99designs/gqlgen/internal/code/testdata/a").Name) + require.Equal(t, 2, p.numLoadCalls) + require.Equal(t, "p", p.Load("github.com/99designs/gqlgen/internal/code/testdata/p").Name) + require.Equal(t, 3, p.numLoadCalls) + }) } func TestNameForPackage(t *testing.T) { @@ -50,8 +58,8 @@ func TestNameForPackage(t *testing.T) { assert.Equal(t, "github_com", p.NameForPackage("github.com")) } -func initialState(t *testing.T) *Packages { - p := &Packages{} +func initialState(t *testing.T, opts ...Option) *Packages { + p := NewPackages(opts...) pkgs := p.LoadAll( "github.com/99designs/gqlgen/internal/code/testdata/a", "github.com/99designs/gqlgen/internal/code/testdata/b", diff --git a/internal/code/testdata/p/p.go b/internal/code/testdata/p/p.go new file mode 100644 index 00000000000..bf516e9813f --- /dev/null +++ b/internal/code/testdata/p/p.go @@ -0,0 +1,13 @@ +//go:build private +// +build private + +// This file is excluded from the build unless the "private" build tag is set. +// This is used to test loading private packages. +// See internal/code/packages_test.go for more details. +package p + +import ( + "github.com/99designs/gqlgen/internal/code/testdata/b" +) + +var P = b.C + " P" diff --git a/internal/imports/prune_test.go b/internal/imports/prune_test.go index 15af13dcd2b..0006aec9ac0 100644 --- a/internal/imports/prune_test.go +++ b/internal/imports/prune_test.go @@ -12,7 +12,7 @@ import ( func TestPrune(t *testing.T) { // prime the packages cache so that it's not considered uninitialized - b, err := Prune("testdata/unused.go", mustReadFile("testdata/unused.go"), &code.Packages{}) + b, err := Prune("testdata/unused.go", mustReadFile("testdata/unused.go"), code.NewPackages()) require.NoError(t, err) require.Equal(t, strings.ReplaceAll(string(mustReadFile("testdata/unused.expected.go")), "\r\n", "\n"), string(b)) }