Skip to content

Commit

Permalink
Use the go:embed API to lookup templates (#2262)
Browse files Browse the repository at this point in the history
* Switch the templates package internally to read from TemplateFS

Users are expected to pass in the FS by using the embed API.

* Update all usages of templates.Render to use the TemplateFS option

* Fix unit tests

* Fix linter error

* Commit generated changes

Doesn't look like anything has changed though. Maybe just a different
whitespace character.

* Fix test
  • Loading branch information
clayne11 authored Jun 30, 2022
1 parent 53ca207 commit 34bbc45
Show file tree
Hide file tree
Showing 16 changed files with 124 additions and 49 deletions.
2 changes: 1 addition & 1 deletion _examples/embedding/subdir/gendir/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion _examples/embedding/subdir/root_.generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion _examples/federation/accounts/graph/generated/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion _examples/federation/products/graph/generated/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion _examples/federation/reviews/graph/generated/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions codegen/generate.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package codegen

import (
"embed"
"errors"
"fmt"
"os"
Expand All @@ -13,6 +14,9 @@ import (
"github.com/vektah/gqlparser/v2/ast"
)

//go:embed *.gotpl
var codegenTemplates embed.FS

func GenerateCode(data *Data) error {
if !data.Config.Exec.IsDefined() {
return fmt.Errorf("missing exec config")
Expand All @@ -36,6 +40,7 @@ func generateSingleFile(data *Data) error {
RegionTags: true,
GeneratedHeader: true,
Packages: data.Config.Packages,
TemplateFS: codegenTemplates,
})
}

Expand Down Expand Up @@ -82,6 +87,7 @@ func generatePerSchema(data *Data) error {
RegionTags: true,
GeneratedHeader: true,
Packages: data.Config.Packages,
TemplateFS: codegenTemplates,
})
if err != nil {
return err
Expand Down Expand Up @@ -145,6 +151,7 @@ func generateRootFile(data *Data) error {
RegionTags: false,
GeneratedHeader: true,
Packages: data.Config.Packages,
TemplateFS: codegenTemplates,
})
}

Expand Down
91 changes: 49 additions & 42 deletions codegen/templates/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"go/types"
"io/fs"
"os"
"path/filepath"
"reflect"
Expand Down Expand Up @@ -35,6 +36,11 @@ type Options struct {
// the plugin processor will look for .gotpl files
// in the same directory of where you wrote the plugin.
Template string

// Use the go:embed API to collect all the template files you want to pass into Render
// this is an alternative to passing the Template option
TemplateFS fs.FS

// Filename is the name of the file that will be
// written to the system disk once the template is rendered.
Filename string
Expand Down Expand Up @@ -62,55 +68,27 @@ func Render(cfg Options) error {
}
CurrentImports = &Imports{packages: cfg.Packages, destDir: filepath.Dir(cfg.Filename)}

// load path relative to calling source file
_, callerFile, _, _ := runtime.Caller(1)
rootDir := filepath.Dir(callerFile)

funcs := Funcs()
for n, f := range cfg.Funcs {
funcs[n] = f
}

t := template.New("").Funcs(funcs)
t, err := parseTemplates(cfg, t)
if err != nil {
return err
}

var roots []string
if cfg.Template != "" {
var err error
t, err = t.New("template.gotpl").Parse(cfg.Template)
if err != nil {
return fmt.Errorf("error with provided template: %w", err)
roots := make([]string, 0, len(t.Templates()))
for _, template := range t.Templates() {
// templates that end with _.gotpl are special files we don't want to include
if strings.HasSuffix(template.Name(), "_.gotpl") ||
// filter out templates added with {{ template xxx }} syntax inside the template file
!strings.HasSuffix(template.Name(), ".gotpl") {
continue
}
roots = append(roots, "template.gotpl")
} else {
// load all the templates in the directory
err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
name := filepath.ToSlash(strings.TrimPrefix(path, rootDir+string(os.PathSeparator)))
if !strings.HasSuffix(info.Name(), ".gotpl") {
return nil
}
// omit any templates with "_" at the end of their name, which are meant for specific contexts only
if strings.HasSuffix(info.Name(), "_.gotpl") {
return nil
}
b, err := os.ReadFile(path)
if err != nil {
return err
}

t, err = t.New(name).Parse(string(b))
if err != nil {
return fmt.Errorf("%s: %w", cfg.Filename, err)
}

roots = append(roots, name)

return nil
})
if err != nil {
return fmt.Errorf("locating templates: %w", err)
}
roots = append(roots, template.Name())
}

// then execute all the important looking ones in order, adding them to the same file
Expand All @@ -124,6 +102,7 @@ func Render(cfg Options) error {
}
return roots[i] < roots[j]
})

var buf bytes.Buffer
for _, root := range roots {
if cfg.RegionTags {
Expand Down Expand Up @@ -155,7 +134,7 @@ func Render(cfg Options) error {
result.WriteString("import (\n")
result.WriteString(CurrentImports.String())
result.WriteString(")\n")
_, err := buf.WriteTo(&result)
_, err = buf.WriteTo(&result)
if err != nil {
return err
}
Expand All @@ -170,6 +149,34 @@ func Render(cfg Options) error {
return nil
}

func parseTemplates(cfg Options, t *template.Template) (*template.Template, error) {
if cfg.Template != "" {
var err error
t, err = t.New("template.gotpl").Parse(cfg.Template)
if err != nil {
return nil, fmt.Errorf("error with provided template: %w", err)
}
return t, nil
}

var fileSystem fs.FS
if cfg.TemplateFS != nil {
fileSystem = cfg.TemplateFS
} else {
// load path relative to calling source file
_, callerFile, _, _ := runtime.Caller(1)
rootDir := filepath.Dir(callerFile)
fileSystem = os.DirFS(rootDir)
}

t, err := t.ParseFS(fileSystem, "*.gotpl")
if err != nil {
return nil, fmt.Errorf("locating templates: %w", err)
}

return t, nil
}

func center(width int, pad string, s string) string {
if len(s)+2 > width {
return s
Expand Down
33 changes: 33 additions & 0 deletions codegen/templates/templates_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
package templates

import (
"embed"
"os"
"path/filepath"
"testing"

"github.com/99designs/gqlgen/internal/code"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

//go:embed *.gotpl
var templateFS embed.FS

func TestToGo(t *testing.T) {
require.Equal(t, "ToCamel", ToGo("TO_CAMEL"))
require.Equal(t, "ToCamel", ToGo("to_camel"))
Expand Down Expand Up @@ -119,3 +125,30 @@ func TestTemplateOverride(t *testing.T) {
t.Fatal(err)
}
}

func TestRenderFS(t *testing.T) {

tempDir := t.TempDir()

outDir := filepath.Join(tempDir, "output")

_ = os.Mkdir(outDir, 0o755)

f, err := os.CreateTemp(outDir, "gqlgen.go")
if err != nil {
t.Fatal(err)
}
defer f.Close()
defer os.RemoveAll(f.Name())
err = Render(Options{TemplateFS: templateFS, Filename: f.Name(), Packages: &code.Packages{}})
if err != nil {
t.Fatal(err)
}

expectedString := "package \n\nimport (\n)\nthis is my test package"
actualContents, _ := os.ReadFile(f.Name())
actualContentsStr := string(actualContents)

// don't look at last character since it's \n on Linux and \r\n on Windows
assert.Equal(t, expectedString, actualContentsStr[:len(expectedString)])
}
1 change: 1 addition & 0 deletions codegen/templates/test.gotpl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
this is my test package
1 change: 1 addition & 0 deletions codegen/templates/test_.gotpl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
this will not be included
7 changes: 6 additions & 1 deletion plugin/federation/federation.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package federation

import (
_ "embed"
"fmt"
"sort"
"strings"
Expand All @@ -14,6 +15,9 @@ import (
"github.com/99designs/gqlgen/plugin/federation/fieldset"
)

//go:embed federation.gotpl
var federationTemplate string

type federation struct {
Entities []*Entity
Version int
Expand Down Expand Up @@ -85,7 +89,7 @@ func (f *federation) InjectSourceEarly() *ast.Source {
input := `
scalar _Any
scalar _FieldSet
directive @external on FIELD_DEFINITION
directive @requires(fields: _FieldSet!) on FIELD_DEFINITION
directive @provides(fields: _FieldSet!) on FIELD_DEFINITION
Expand Down Expand Up @@ -274,6 +278,7 @@ func (f *federation) GenerateCode(data *codegen.Data) error {
Data: f,
GeneratedHeader: true,
Packages: data.Config.Packages,
Template: federationTemplate,
})
}

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions plugin/modelgen/models.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package modelgen

import (
_ "embed"
"fmt"
"go/types"
"sort"
Expand All @@ -12,6 +13,9 @@ import (
"github.com/vektah/gqlparser/v2/ast"
)

//go:embed models.gotpl
var modelTemplate string

type BuildMutateHook = func(b *ModelBuild) *ModelBuild

type FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error)
Expand Down Expand Up @@ -269,6 +273,7 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
Data: b,
GeneratedHeader: true,
Packages: cfg.Packages,
Template: modelTemplate,
})
if err != nil {
return err
Expand Down
6 changes: 6 additions & 0 deletions plugin/resolvergen/resolver.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resolvergen

import (
_ "embed"
"errors"
"io/fs"
"os"
Expand All @@ -14,6 +15,9 @@ import (
"github.com/99designs/gqlgen/plugin"
)

//go:embed resolver.gotpl
var resolverTemplate string

func New() plugin.Plugin {
return &Plugin{}
}
Expand Down Expand Up @@ -76,6 +80,7 @@ func (m *Plugin) generateSingleFile(data *codegen.Data) error {
Filename: data.Config.Resolver.Filename,
Data: resolverBuild,
Packages: data.Config.Packages,
Template: resolverTemplate,
})
}

Expand Down Expand Up @@ -143,6 +148,7 @@ func (m *Plugin) generatePerSchema(data *codegen.Data) error {
Filename: filename,
Data: resolverBuild,
Packages: data.Config.Packages,
Template: resolverTemplate,
})
if err != nil {
return err
Expand Down
Loading

0 comments on commit 34bbc45

Please sign in to comment.