From 34bbc450c502919cd46c5eefcc66341ef697c0e8 Mon Sep 17 00:00:00 2001 From: Curtis Layne Date: Thu, 30 Jun 2022 10:51:07 -0400 Subject: [PATCH] Use the go:embed API to lookup templates (#2262) * Switch the templates package internally to read from TemplateFS Users are expected to pass in the FS by using the embed API. * Update all usages of templates.Render to use the TemplateFS option * Fix unit tests * Fix linter error * Commit generated changes Doesn't look like anything has changed though. Maybe just a different whitespace character. * Fix test --- .../embedding/subdir/gendir/generated.go | 2 +- _examples/embedding/subdir/root_.generated.go | 2 +- .../accounts/graph/generated/generated.go | 2 +- .../products/graph/generated/generated.go | 2 +- .../reviews/graph/generated/generated.go | 2 +- codegen/generate.go | 7 ++ codegen/templates/templates.go | 91 ++++++++++--------- codegen/templates/templates_test.go | 33 +++++++ codegen/templates/test.gotpl | 1 + codegen/templates/test_.gotpl | 1 + plugin/federation/federation.go | 7 +- .../testdata/entityresolver/generated/exec.go | 2 +- plugin/modelgen/models.go | 5 + plugin/resolvergen/resolver.go | 6 ++ plugin/servergen/server.go | 5 + plugin/stubgen/stubs.go | 5 + 16 files changed, 124 insertions(+), 49 deletions(-) create mode 100644 codegen/templates/test.gotpl create mode 100644 codegen/templates/test_.gotpl diff --git a/_examples/embedding/subdir/gendir/generated.go b/_examples/embedding/subdir/gendir/generated.go index 3e6a17a726b..eae119fdf2e 100644 --- a/_examples/embedding/subdir/gendir/generated.go +++ b/_examples/embedding/subdir/gendir/generated.go @@ -178,7 +178,7 @@ var sources = []*ast.Source{ {Name: "../federation/directives.graphql", Input: ` scalar _Any scalar _FieldSet - + directive @external on FIELD_DEFINITION directive @requires(fields: _FieldSet!) on FIELD_DEFINITION directive @provides(fields: _FieldSet!) on FIELD_DEFINITION diff --git a/_examples/embedding/subdir/root_.generated.go b/_examples/embedding/subdir/root_.generated.go index faa187408df..1aa34cc59e5 100644 --- a/_examples/embedding/subdir/root_.generated.go +++ b/_examples/embedding/subdir/root_.generated.go @@ -172,7 +172,7 @@ var sources = []*ast.Source{ {Name: "federation/directives.graphql", Input: ` scalar _Any scalar _FieldSet - + directive @external on FIELD_DEFINITION directive @requires(fields: _FieldSet!) on FIELD_DEFINITION directive @provides(fields: _FieldSet!) on FIELD_DEFINITION diff --git a/_examples/federation/accounts/graph/generated/generated.go b/_examples/federation/accounts/graph/generated/generated.go index ef06e1c44af..666f381ad08 100644 --- a/_examples/federation/accounts/graph/generated/generated.go +++ b/_examples/federation/accounts/graph/generated/generated.go @@ -266,7 +266,7 @@ type User @key(fields: "id") { {Name: "../../federation/directives.graphql", Input: ` scalar _Any scalar _FieldSet - + directive @external on FIELD_DEFINITION directive @requires(fields: _FieldSet!) on FIELD_DEFINITION directive @provides(fields: _FieldSet!) on FIELD_DEFINITION diff --git a/_examples/federation/products/graph/generated/generated.go b/_examples/federation/products/graph/generated/generated.go index 5f2577f2f66..d1ed696ca44 100644 --- a/_examples/federation/products/graph/generated/generated.go +++ b/_examples/federation/products/graph/generated/generated.go @@ -294,7 +294,7 @@ type Product @key(fields: "manufacturer { id } id") @key(fields: "upc") { {Name: "../../federation/directives.graphql", Input: ` scalar _Any scalar _FieldSet - + directive @external on FIELD_DEFINITION directive @requires(fields: _FieldSet!) on FIELD_DEFINITION directive @provides(fields: _FieldSet!) on FIELD_DEFINITION diff --git a/_examples/federation/reviews/graph/generated/generated.go b/_examples/federation/reviews/graph/generated/generated.go index 271e4cd74b9..c1ca3d0e63f 100644 --- a/_examples/federation/reviews/graph/generated/generated.go +++ b/_examples/federation/reviews/graph/generated/generated.go @@ -326,7 +326,7 @@ extend type Product @key(fields: " manufacturer{ id} id") { {Name: "../../federation/directives.graphql", Input: ` scalar _Any scalar _FieldSet - + directive @external on FIELD_DEFINITION directive @requires(fields: _FieldSet!) on FIELD_DEFINITION directive @provides(fields: _FieldSet!) on FIELD_DEFINITION diff --git a/codegen/generate.go b/codegen/generate.go index 9e3654552db..1ce8c329dc3 100644 --- a/codegen/generate.go +++ b/codegen/generate.go @@ -1,6 +1,7 @@ package codegen import ( + "embed" "errors" "fmt" "os" @@ -13,6 +14,9 @@ import ( "github.com/vektah/gqlparser/v2/ast" ) +//go:embed *.gotpl +var codegenTemplates embed.FS + func GenerateCode(data *Data) error { if !data.Config.Exec.IsDefined() { return fmt.Errorf("missing exec config") @@ -36,6 +40,7 @@ func generateSingleFile(data *Data) error { RegionTags: true, GeneratedHeader: true, Packages: data.Config.Packages, + TemplateFS: codegenTemplates, }) } @@ -82,6 +87,7 @@ func generatePerSchema(data *Data) error { RegionTags: true, GeneratedHeader: true, Packages: data.Config.Packages, + TemplateFS: codegenTemplates, }) if err != nil { return err @@ -145,6 +151,7 @@ func generateRootFile(data *Data) error { RegionTags: false, GeneratedHeader: true, Packages: data.Config.Packages, + TemplateFS: codegenTemplates, }) } diff --git a/codegen/templates/templates.go b/codegen/templates/templates.go index edd6ec18dc0..ab39c08eecf 100644 --- a/codegen/templates/templates.go +++ b/codegen/templates/templates.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "go/types" + "io/fs" "os" "path/filepath" "reflect" @@ -35,6 +36,11 @@ type Options struct { // the plugin processor will look for .gotpl files // in the same directory of where you wrote the plugin. Template string + + // Use the go:embed API to collect all the template files you want to pass into Render + // this is an alternative to passing the Template option + TemplateFS fs.FS + // Filename is the name of the file that will be // written to the system disk once the template is rendered. Filename string @@ -62,55 +68,27 @@ func Render(cfg Options) error { } CurrentImports = &Imports{packages: cfg.Packages, destDir: filepath.Dir(cfg.Filename)} - // load path relative to calling source file - _, callerFile, _, _ := runtime.Caller(1) - rootDir := filepath.Dir(callerFile) - funcs := Funcs() for n, f := range cfg.Funcs { funcs[n] = f } + t := template.New("").Funcs(funcs) + t, err := parseTemplates(cfg, t) + if err != nil { + return err + } - var roots []string - if cfg.Template != "" { - var err error - t, err = t.New("template.gotpl").Parse(cfg.Template) - if err != nil { - return fmt.Errorf("error with provided template: %w", err) + roots := make([]string, 0, len(t.Templates())) + for _, template := range t.Templates() { + // templates that end with _.gotpl are special files we don't want to include + if strings.HasSuffix(template.Name(), "_.gotpl") || + // filter out templates added with {{ template xxx }} syntax inside the template file + !strings.HasSuffix(template.Name(), ".gotpl") { + continue } - roots = append(roots, "template.gotpl") - } else { - // load all the templates in the directory - err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - name := filepath.ToSlash(strings.TrimPrefix(path, rootDir+string(os.PathSeparator))) - if !strings.HasSuffix(info.Name(), ".gotpl") { - return nil - } - // omit any templates with "_" at the end of their name, which are meant for specific contexts only - if strings.HasSuffix(info.Name(), "_.gotpl") { - return nil - } - b, err := os.ReadFile(path) - if err != nil { - return err - } - - t, err = t.New(name).Parse(string(b)) - if err != nil { - return fmt.Errorf("%s: %w", cfg.Filename, err) - } - - roots = append(roots, name) - return nil - }) - if err != nil { - return fmt.Errorf("locating templates: %w", err) - } + roots = append(roots, template.Name()) } // then execute all the important looking ones in order, adding them to the same file @@ -124,6 +102,7 @@ func Render(cfg Options) error { } return roots[i] < roots[j] }) + var buf bytes.Buffer for _, root := range roots { if cfg.RegionTags { @@ -155,7 +134,7 @@ func Render(cfg Options) error { result.WriteString("import (\n") result.WriteString(CurrentImports.String()) result.WriteString(")\n") - _, err := buf.WriteTo(&result) + _, err = buf.WriteTo(&result) if err != nil { return err } @@ -170,6 +149,34 @@ func Render(cfg Options) error { return nil } +func parseTemplates(cfg Options, t *template.Template) (*template.Template, error) { + if cfg.Template != "" { + var err error + t, err = t.New("template.gotpl").Parse(cfg.Template) + if err != nil { + return nil, fmt.Errorf("error with provided template: %w", err) + } + return t, nil + } + + var fileSystem fs.FS + if cfg.TemplateFS != nil { + fileSystem = cfg.TemplateFS + } else { + // load path relative to calling source file + _, callerFile, _, _ := runtime.Caller(1) + rootDir := filepath.Dir(callerFile) + fileSystem = os.DirFS(rootDir) + } + + t, err := t.ParseFS(fileSystem, "*.gotpl") + if err != nil { + return nil, fmt.Errorf("locating templates: %w", err) + } + + return t, nil +} + func center(width int, pad string, s string) string { if len(s)+2 > width { return s diff --git a/codegen/templates/templates_test.go b/codegen/templates/templates_test.go index c6bf078ffce..ed69b1c7c7b 100644 --- a/codegen/templates/templates_test.go +++ b/codegen/templates/templates_test.go @@ -1,14 +1,20 @@ package templates import ( + "embed" "os" + "path/filepath" "testing" "github.com/99designs/gqlgen/internal/code" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +//go:embed *.gotpl +var templateFS embed.FS + func TestToGo(t *testing.T) { require.Equal(t, "ToCamel", ToGo("TO_CAMEL")) require.Equal(t, "ToCamel", ToGo("to_camel")) @@ -119,3 +125,30 @@ func TestTemplateOverride(t *testing.T) { t.Fatal(err) } } + +func TestRenderFS(t *testing.T) { + + tempDir := t.TempDir() + + outDir := filepath.Join(tempDir, "output") + + _ = os.Mkdir(outDir, 0o755) + + f, err := os.CreateTemp(outDir, "gqlgen.go") + if err != nil { + t.Fatal(err) + } + defer f.Close() + defer os.RemoveAll(f.Name()) + err = Render(Options{TemplateFS: templateFS, Filename: f.Name(), Packages: &code.Packages{}}) + if err != nil { + t.Fatal(err) + } + + expectedString := "package \n\nimport (\n)\nthis is my test package" + actualContents, _ := os.ReadFile(f.Name()) + actualContentsStr := string(actualContents) + + // don't look at last character since it's \n on Linux and \r\n on Windows + assert.Equal(t, expectedString, actualContentsStr[:len(expectedString)]) +} diff --git a/codegen/templates/test.gotpl b/codegen/templates/test.gotpl new file mode 100644 index 00000000000..07b8462a6a9 --- /dev/null +++ b/codegen/templates/test.gotpl @@ -0,0 +1 @@ +this is my test package diff --git a/codegen/templates/test_.gotpl b/codegen/templates/test_.gotpl new file mode 100644 index 00000000000..c74258f3e1d --- /dev/null +++ b/codegen/templates/test_.gotpl @@ -0,0 +1 @@ +this will not be included diff --git a/plugin/federation/federation.go b/plugin/federation/federation.go index 7f911a13825..61f905b892a 100644 --- a/plugin/federation/federation.go +++ b/plugin/federation/federation.go @@ -1,6 +1,7 @@ package federation import ( + _ "embed" "fmt" "sort" "strings" @@ -14,6 +15,9 @@ import ( "github.com/99designs/gqlgen/plugin/federation/fieldset" ) +//go:embed federation.gotpl +var federationTemplate string + type federation struct { Entities []*Entity Version int @@ -85,7 +89,7 @@ func (f *federation) InjectSourceEarly() *ast.Source { input := ` scalar _Any scalar _FieldSet - + directive @external on FIELD_DEFINITION directive @requires(fields: _FieldSet!) on FIELD_DEFINITION directive @provides(fields: _FieldSet!) on FIELD_DEFINITION @@ -274,6 +278,7 @@ func (f *federation) GenerateCode(data *codegen.Data) error { Data: f, GeneratedHeader: true, Packages: data.Config.Packages, + Template: federationTemplate, }) } diff --git a/plugin/federation/testdata/entityresolver/generated/exec.go b/plugin/federation/testdata/entityresolver/generated/exec.go index 5741f6a2f70..8a29fa3f15b 100644 --- a/plugin/federation/testdata/entityresolver/generated/exec.go +++ b/plugin/federation/testdata/entityresolver/generated/exec.go @@ -764,7 +764,7 @@ type MultiHelloMultipleRequires @key(fields: "name") @entityResolver(multi: true {Name: "../../../federation/directives.graphql", Input: ` scalar _Any scalar _FieldSet - + directive @external on FIELD_DEFINITION directive @requires(fields: _FieldSet!) on FIELD_DEFINITION directive @provides(fields: _FieldSet!) on FIELD_DEFINITION diff --git a/plugin/modelgen/models.go b/plugin/modelgen/models.go index 110b0bc3bd8..2f636701721 100644 --- a/plugin/modelgen/models.go +++ b/plugin/modelgen/models.go @@ -1,6 +1,7 @@ package modelgen import ( + _ "embed" "fmt" "go/types" "sort" @@ -12,6 +13,9 @@ import ( "github.com/vektah/gqlparser/v2/ast" ) +//go:embed models.gotpl +var modelTemplate string + type BuildMutateHook = func(b *ModelBuild) *ModelBuild type FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) @@ -269,6 +273,7 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error { Data: b, GeneratedHeader: true, Packages: cfg.Packages, + Template: modelTemplate, }) if err != nil { return err diff --git a/plugin/resolvergen/resolver.go b/plugin/resolvergen/resolver.go index 538aa714b1e..69c45f7fdd8 100644 --- a/plugin/resolvergen/resolver.go +++ b/plugin/resolvergen/resolver.go @@ -1,6 +1,7 @@ package resolvergen import ( + _ "embed" "errors" "io/fs" "os" @@ -14,6 +15,9 @@ import ( "github.com/99designs/gqlgen/plugin" ) +//go:embed resolver.gotpl +var resolverTemplate string + func New() plugin.Plugin { return &Plugin{} } @@ -76,6 +80,7 @@ func (m *Plugin) generateSingleFile(data *codegen.Data) error { Filename: data.Config.Resolver.Filename, Data: resolverBuild, Packages: data.Config.Packages, + Template: resolverTemplate, }) } @@ -143,6 +148,7 @@ func (m *Plugin) generatePerSchema(data *codegen.Data) error { Filename: filename, Data: resolverBuild, Packages: data.Config.Packages, + Template: resolverTemplate, }) if err != nil { return err diff --git a/plugin/servergen/server.go b/plugin/servergen/server.go index d9b35fe4178..4084f8ead6e 100644 --- a/plugin/servergen/server.go +++ b/plugin/servergen/server.go @@ -1,6 +1,7 @@ package servergen import ( + _ "embed" "errors" "io/fs" "log" @@ -11,6 +12,9 @@ import ( "github.com/99designs/gqlgen/plugin" ) +//go:embed server.gotpl +var serverTemplate string + func New(filename string) plugin.Plugin { return &Plugin{filename} } @@ -37,6 +41,7 @@ func (m *Plugin) GenerateCode(data *codegen.Data) error { Filename: m.filename, Data: serverBuild, Packages: data.Config.Packages, + Template: serverTemplate, }) } diff --git a/plugin/stubgen/stubs.go b/plugin/stubgen/stubs.go index a8d018fa2ab..01db65b1f9b 100644 --- a/plugin/stubgen/stubs.go +++ b/plugin/stubgen/stubs.go @@ -1,6 +1,7 @@ package stubgen import ( + _ "embed" "path/filepath" "syscall" @@ -12,6 +13,9 @@ import ( "github.com/99designs/gqlgen/plugin" ) +//go:embed stubs.gotpl +var stubsTemplate string + func New(filename string, typename string) plugin.Plugin { return &Plugin{filename: filename, typeName: typename} } @@ -51,6 +55,7 @@ func (m *Plugin) GenerateCode(data *codegen.Data) error { }, GeneratedHeader: true, Packages: data.Config.Packages, + Template: stubsTemplate, }) }