Skip to content

Commit

Permalink
codegen: add support for go_build_tags option in gqlgen.yaml (#2784)
Browse files Browse the repository at this point in the history
* codegen: support go_build_tags option in gqlgen.yaml

* chore: added test

* docs/content: update config example

* chore: more comment
  • Loading branch information
giautm authored Sep 8, 2023
1 parent bee47dc commit 11bb9b1
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 32 deletions.
2 changes: 1 addition & 1 deletion codegen/config/binder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func createBinder(cfg Config) (*Binder, *ast.Schema) {
Model: []string{"github.com/99designs/gqlgen/graphql.String"},
},
}
cfg.Packages = &code.Packages{}
cfg.Packages = code.NewPackages()

cfg.Schema = gqlparser.MustLoadSchema(&ast.Source{Name: "TestAutobinding.schema", Input: `
type Message { id: ID }
Expand Down
9 changes: 7 additions & 2 deletions codegen/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type Config struct {
Models TypeMap `yaml:"models,omitempty"`
StructTag string `yaml:"struct_tag,omitempty"`
Directives map[string]DirectiveConfig `yaml:"directives,omitempty"`
GoBuildTags StringList `yaml:"go_build_tags,omitempty"`
GoInitialisms GoInitialismsConfig `yaml:"go_initialisms,omitempty"`
OmitSliceElementPointers bool `yaml:"omit_slice_element_pointers,omitempty"`
OmitGetters bool `yaml:"omit_getters,omitempty"`
Expand Down Expand Up @@ -211,7 +212,9 @@ func CompleteConfig(config *Config) error {

func (c *Config) Init() error {
if c.Packages == nil {
c.Packages = &code.Packages{}
c.Packages = code.NewPackages(
code.WithBuildTags(c.GoBuildTags...),
)
}

if c.Schema == nil {
Expand Down Expand Up @@ -671,7 +674,9 @@ func (c *Config) injectBuiltins() {

func (c *Config) LoadSchema() error {
if c.Packages != nil {
c.Packages = &code.Packages{}
c.Packages = code.NewPackages(
code.WithBuildTags(c.GoBuildTags...),
)
}

if err := c.check(); err != nil {
Expand Down
4 changes: 2 additions & 2 deletions codegen/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func TestAutobinding(t *testing.T) {
"github.com/99designs/gqlgen/codegen/config/testdata/autobinding/chat",
"github.com/99designs/gqlgen/codegen/config/testdata/autobinding/scalars/model",
},
Packages: &code.Packages{},
Packages: code.NewPackages(),
}

cfg.Schema = gqlparser.MustLoadSchema(&ast.Source{Name: "TestAutobinding.schema", Input: `
Expand All @@ -212,7 +212,7 @@ func TestAutobinding(t *testing.T) {
AutoBind: []string{
"../chat",
},
Packages: &code.Packages{},
Packages: code.NewPackages(),
}

cfg.Schema = gqlparser.MustLoadSchema(&ast.Source{Name: "TestAutobinding.schema", Input: `
Expand Down
14 changes: 7 additions & 7 deletions codegen/templates/import_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ func TestImports(t *testing.T) {
mismatch := "github.com/99designs/gqlgen/codegen/templates/testdata/pkg_mismatch"

t.Run("multiple lookups is ok", func(t *testing.T) {
a := Imports{destDir: wd, packages: &code.Packages{}}
a := Imports{destDir: wd, packages: code.NewPackages()}

require.Equal(t, "bar", a.Lookup(aBar))
require.Equal(t, "bar", a.Lookup(aBar))
})

t.Run("lookup by type", func(t *testing.T) {
a := Imports{destDir: wd, packages: &code.Packages{}}
a := Imports{destDir: wd, packages: code.NewPackages()}

pkg := types.NewPackage("github.com/99designs/gqlgen/codegen/templates/testdata/b/bar", "bar")
typ := types.NewNamed(types.NewTypeName(0, pkg, "Boolean", types.Typ[types.Bool]), types.Typ[types.Bool], nil)
Expand All @@ -36,7 +36,7 @@ func TestImports(t *testing.T) {
})

t.Run("duplicates are decollisioned", func(t *testing.T) {
a := Imports{destDir: wd, packages: &code.Packages{}}
a := Imports{destDir: wd, packages: code.NewPackages()}

require.Equal(t, "bar", a.Lookup(aBar))
require.Equal(t, "bar1", a.Lookup(bBar))
Expand All @@ -47,7 +47,7 @@ func TestImports(t *testing.T) {
})

t.Run("duplicates above 10 are decollisioned", func(t *testing.T) {
a := Imports{destDir: wd, packages: &code.Packages{}}
a := Imports{destDir: wd, packages: code.NewPackages()}
for i := 0; i < 100; i++ {
cBar := fmt.Sprintf("github.com/99designs/gqlgen/codegen/templates/testdata/%d/bar", i)
if i > 0 {
Expand All @@ -59,13 +59,13 @@ func TestImports(t *testing.T) {
})

t.Run("package name defined in code will be used", func(t *testing.T) {
a := Imports{destDir: wd, packages: &code.Packages{}}
a := Imports{destDir: wd, packages: code.NewPackages()}

require.Equal(t, "turtles", a.Lookup(mismatch))
})

t.Run("string printing for import block", func(t *testing.T) {
a := Imports{destDir: wd, packages: &code.Packages{}}
a := Imports{destDir: wd, packages: code.NewPackages()}
a.Lookup(aBar)
a.Lookup(bBar)
a.Lookup(mismatch)
Expand All @@ -80,7 +80,7 @@ turtles "github.com/99designs/gqlgen/codegen/templates/testdata/pkg_mismatch"`,
})

t.Run("aliased imports will not collide", func(t *testing.T) {
a := Imports{destDir: wd, packages: &code.Packages{}}
a := Imports{destDir: wd, packages: code.NewPackages()}

_, _ = a.Reserve(aBar, "abar")
_, _ = a.Reserve(bBar, "bbar")
Expand Down
4 changes: 2 additions & 2 deletions codegen/templates/templates_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ func TestTemplateOverride(t *testing.T) {
}
defer f.Close()
defer os.RemoveAll(f.Name())
err = Render(Options{Template: "hello", Filename: f.Name(), Packages: &code.Packages{}})
err = Render(Options{Template: "hello", Filename: f.Name(), Packages: code.NewPackages()})
if err != nil {
t.Fatal(err)
}
Expand All @@ -346,7 +346,7 @@ func TestRenderFS(t *testing.T) {
}
defer f.Close()
defer os.RemoveAll(f.Name())
err = Render(Options{TemplateFS: templateFS, Filename: f.Name(), Packages: &code.Packages{}})
err = Render(Options{TemplateFS: templateFS, Filename: f.Name(), Packages: code.NewPackages()})
if err != nil {
t.Fatal(err)
}
Expand Down
5 changes: 5 additions & 0 deletions docs/content/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ resolver:
# Optional: set to skip running `go mod tidy` when generating server code
# skip_mod_tidy: true

# Optional: set build tags that will be used to load packages
# go_build_tags:
# - private
# - enterprise

# Optional: set to modify the initialisms regarded for Go names
# go_initialisms:
# replace_defaults: false # if true, the default initialisms will get dropped in favor of the new ones instead of being added
Expand Down
60 changes: 45 additions & 15 deletions internal/code/packages.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,37 @@ var mode = packages.NeedName |
packages.NeedModule |
packages.NeedDeps

// Packages is a wrapper around x/tools/go/packages that maintains a (hopefully prewarmed) cache of packages
// that can be invalidated as writes are made and packages are known to change.
type Packages struct {
packages map[string]*packages.Package
importToName map[string]string
loadErrors []error

numLoadCalls int // stupid test steam. ignore.
numNameCalls int // stupid test steam. ignore.
type (
// Packages is a wrapper around x/tools/go/packages that maintains a (hopefully prewarmed) cache of packages
// that can be invalidated as writes are made and packages are known to change.
Packages struct {
packages map[string]*packages.Package
importToName map[string]string
loadErrors []error
buildFlags []string

numLoadCalls int // stupid test steam. ignore.
numNameCalls int // stupid test steam. ignore.
}
// Option is a function that can be passed to NewPackages to configure the package loader
Option func(p *Packages)
)

// WithBuildTags adds build tags to the packages.Load call
func WithBuildTags(tags ...string) func(p *Packages) {
return func(p *Packages) {
p.buildFlags = append(p.buildFlags, "-tags", strings.Join(tags, ","))
}
}

// NewPackages creates a new packages cache
// It will load all packages in the current module, and any packages that are passed to Load or LoadAll
func NewPackages(opts ...Option) *Packages {
p := &Packages{}
for _, opt := range opts {
opt(p)
}
return p
}

func (p *Packages) CleanupUserPackages() {
Expand All @@ -47,16 +69,15 @@ func (p *Packages) CleanupUserPackages() {
modInfo = nil
}
})

// Don't cleanup github.com/99designs/gqlgen prefixed packages, they haven't changed and do not need to be reloaded
// Don't cleanup github.com/99designs/gqlgen prefixed packages,
// they haven't changed and do not need to be reloaded
if modInfo != nil {
var toRemove []string
for k := range p.packages {
if !strings.HasPrefix(k, modInfo.Main.Path) {
toRemove = append(toRemove, k)
}
}

for _, k := range toRemove {
delete(p.packages, k)
}
Expand Down Expand Up @@ -91,7 +112,10 @@ func (p *Packages) LoadAll(importPaths ...string) []*packages.Package {

if len(missing) > 0 {
p.numLoadCalls++
pkgs, err := packages.Load(&packages.Config{Mode: mode}, missing...)
pkgs, err := packages.Load(&packages.Config{
Mode: mode,
BuildFlags: p.buildFlags,
}, missing...)
if err != nil {
p.loadErrors = append(p.loadErrors, err)
}
Expand Down Expand Up @@ -140,7 +164,10 @@ func (p *Packages) LoadWithTypes(importPath string) *packages.Package {
pkg := p.Load(importPath)
if pkg == nil || pkg.TypesInfo == nil {
p.numLoadCalls++
pkgs, err := packages.Load(&packages.Config{Mode: mode}, importPath)
pkgs, err := packages.Load(&packages.Config{
Mode: mode,
BuildFlags: p.buildFlags,
}, importPath)
if err != nil {
p.loadErrors = append(p.loadErrors, err)
return nil
Expand Down Expand Up @@ -173,7 +200,10 @@ func (p *Packages) NameForPackage(importPath string) string {
if pkg == nil {
// otherwise do a name only lookup for it but don't put it in the package cache.
p.numNameCalls++
pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName}, importPath)
pkgs, err := packages.Load(&packages.Config{
Mode: packages.NeedName,
BuildFlags: p.buildFlags,
}, importPath)
if err != nil {
p.loadErrors = append(p.loadErrors, err)
} else {
Expand Down
12 changes: 10 additions & 2 deletions internal/code/packages_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ func TestPackages(t *testing.T) {
require.Equal(t, "b", p.Load("github.com/99designs/gqlgen/internal/code/testdata/b").Name)
require.Equal(t, 3, p.numLoadCalls)
})
t.Run("able to load private package with build tags", func(t *testing.T) {
p := initialState(t, WithBuildTags("private"))
p.Evict("github.com/99designs/gqlgen/internal/code/testdata/a")
require.Equal(t, "a", p.Load("github.com/99designs/gqlgen/internal/code/testdata/a").Name)
require.Equal(t, 2, p.numLoadCalls)
require.Equal(t, "p", p.Load("github.com/99designs/gqlgen/internal/code/testdata/p").Name)
require.Equal(t, 3, p.numLoadCalls)
})
}

func TestNameForPackage(t *testing.T) {
Expand All @@ -50,8 +58,8 @@ func TestNameForPackage(t *testing.T) {
assert.Equal(t, "github_com", p.NameForPackage("github.com"))
}

func initialState(t *testing.T) *Packages {
p := &Packages{}
func initialState(t *testing.T, opts ...Option) *Packages {
p := NewPackages(opts...)
pkgs := p.LoadAll(
"github.com/99designs/gqlgen/internal/code/testdata/a",
"github.com/99designs/gqlgen/internal/code/testdata/b",
Expand Down
13 changes: 13 additions & 0 deletions internal/code/testdata/p/p.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//go:build private
// +build private

// This file is excluded from the build unless the "private" build tag is set.
// This is used to test loading private packages.
// See internal/code/packages_test.go for more details.
package p

import (
"github.com/99designs/gqlgen/internal/code/testdata/b"
)

var P = b.C + " P"
2 changes: 1 addition & 1 deletion internal/imports/prune_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
func TestPrune(t *testing.T) {
// prime the packages cache so that it's not considered uninitialized

b, err := Prune("testdata/unused.go", mustReadFile("testdata/unused.go"), &code.Packages{})
b, err := Prune("testdata/unused.go", mustReadFile("testdata/unused.go"), code.NewPackages())
require.NoError(t, err)
require.Equal(t, strings.ReplaceAll(string(mustReadFile("testdata/unused.expected.go")), "\r\n", "\n"), string(b))
}
Expand Down

0 comments on commit 11bb9b1

Please sign in to comment.