From 33ff2fa079e2a2d898b4c5dd5ff429f06672e373 Mon Sep 17 00:00:00 2001 From: dany74q Date: Thu, 28 Sep 2023 17:13:05 +0300 Subject: [PATCH] Consider go type name when autobinding Currently, generated schema type names are normalized, for instance - SomeTYPE in the schema will be generated as SomeType in the model. When autobinding, however, we only consider the schema type name when searching for it in the relevant package(s), thus type names that differ post normalizations aren't auto-bound properly and are instead re-generated. This commit suggests a fix where we'd try to autobind for both the schema type name (first, to maintain back compat), or the go type name if the former isn't found. --- codegen/config/binder.go | 44 +-------- codegen/config/config.go | 23 ++++- codegen/config/config_test.go | 25 +++++ codegen/config/initialisms.go | 66 ++----------- codegen/config/initialisms_test.go | 6 +- .../autobinding/chat/{message.go => model.go} | 10 ++ codegen/templates/templates.go | 99 ++++++++++++++++++- 7 files changed, 165 insertions(+), 108 deletions(-) rename codegen/config/testdata/autobinding/chat/{message.go => model.go} (62%) diff --git a/codegen/config/binder.go b/codegen/config/binder.go index 0483afdbdf..22cc855db0 100644 --- a/codegen/config/binder.go +++ b/codegen/config/binder.go @@ -5,12 +5,12 @@ import ( "fmt" "go/token" "go/types" - "strings" + "github.com/vektah/gqlparser/v2/ast" "golang.org/x/tools/go/packages" + "github.com/99designs/gqlgen/codegen/templates" "github.com/99designs/gqlgen/internal/code" - "github.com/vektah/gqlparser/v2/ast" ) var ErrTypeNotFound = errors.New("unable to find type") @@ -285,7 +285,7 @@ func (ref *TypeReference) UniquenessKey() string { // Fix for #896 elemNullability = "ᚄ" } - return nullability + ref.Definition.Name + "2" + TypeIdentifier(ref.GO) + elemNullability + return nullability + ref.Definition.Name + "2" + templates.TypeIdentifier(ref.GO) + elemNullability } func (ref *TypeReference) MarshalFunc() string { @@ -540,41 +540,3 @@ func basicUnderlying(it types.Type) *types.Basic { return nil } - -var pkgReplacer = strings.NewReplacer( - "/", "ᚋ", - ".", "ᚗ", - "-", "ᚑ", - "~", "א", -) - -func TypeIdentifier(t types.Type) string { - res := "" - for { - switch it := t.(type) { - case *types.Pointer: - t.Underlying() - res += "ᚖ" - t = it.Elem() - case *types.Slice: - res += "ᚕ" - t = it.Elem() - case *types.Named: - res += pkgReplacer.Replace(it.Obj().Pkg().Path()) - res += "ᚐ" - res += it.Obj().Name() - return res - case *types.Basic: - res += it.Name() - return res - case *types.Map: - res += "map" - return res - case *types.Interface: - res += "interface" - return res - default: - panic(fmt.Errorf("unexpected type %T", it)) - } - } -} diff --git a/codegen/config/config.go b/codegen/config/config.go index 1e844f4696..50ccaeffda 100644 --- a/codegen/config/config.go +++ b/codegen/config/config.go @@ -3,6 +3,7 @@ package config import ( "bytes" "fmt" + "go/types" "io" "os" "path/filepath" @@ -10,10 +11,13 @@ import ( "sort" "strings" - "github.com/99designs/gqlgen/internal/code" "github.com/vektah/gqlparser/v2" "github.com/vektah/gqlparser/v2/ast" + "golang.org/x/tools/go/packages" "gopkg.in/yaml.v3" + + "github.com/99designs/gqlgen/codegen/templates" + "github.com/99designs/gqlgen/internal/code" ) type Config struct { @@ -608,8 +612,10 @@ func (c *Config) autobind() error { if p == nil || p.Module == nil { return fmt.Errorf("unable to load %s - make sure you're using an import path to a package that exists", c.AutoBind[i]) } - if t := p.Types.Scope().Lookup(t.Name); t != nil { - c.Models.Add(t.Name(), t.Pkg().Path()+"."+t.Name()) + + autobindType := c.lookupAutobindType(p, t) + if autobindType != nil { + c.Models.Add(t.Name, autobindType.Pkg().Path()+"."+autobindType.Name()) break } } @@ -643,6 +649,17 @@ func (c *Config) autobind() error { return nil } +func (c *Config) lookupAutobindType(p *packages.Package, schemaType *ast.Definition) types.Object { + // Try binding to either the original schema type name, or the normalized go type name + for _, lookupName := range []string{schemaType.Name, templates.ToGo(schemaType.Name)} { + if t := p.Types.Scope().Lookup(lookupName); t != nil { + return t + } + } + + return nil +} + func (c *Config) injectBuiltins() { builtins := TypeMap{ "__Directive": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Directive"}}, diff --git a/codegen/config/config_test.go b/codegen/config/config_test.go index 68e30d4b95..c8cb42ea3a 100644 --- a/codegen/config/config_test.go +++ b/codegen/config/config_test.go @@ -206,6 +206,31 @@ func TestAutobinding(t *testing.T) { require.Equal(t, "github.com/99designs/gqlgen/codegen/config/testdata/autobinding/chat.Message", cfg.Models["Message"].Model[0]) }) + t.Run("normalized type names", func(t *testing.T) { + cfg := Config{ + Models: TypeMap{}, + AutoBind: []string{ + "github.com/99designs/gqlgen/codegen/config/testdata/autobinding/chat", + "github.com/99designs/gqlgen/codegen/config/testdata/autobinding/scalars/model", + }, + Packages: code.NewPackages(), + } + + cfg.Schema = gqlparser.MustLoadSchema(&ast.Source{Name: "TestAutobinding.schema", Input: ` + scalar Banned + type Message { id: ID } + enum ProductSKU { ProductSkuTrial } + type ChatAPI { id: ID } + `}) + + require.NoError(t, cfg.autobind()) + + require.Equal(t, "github.com/99designs/gqlgen/codegen/config/testdata/autobinding/scalars/model.Banned", cfg.Models["Banned"].Model[0]) + require.Equal(t, "github.com/99designs/gqlgen/codegen/config/testdata/autobinding/chat.Message", cfg.Models["Message"].Model[0]) + require.Equal(t, "github.com/99designs/gqlgen/codegen/config/testdata/autobinding/chat.ProductSku", cfg.Models["ProductSKU"].Model[0]) + require.Equal(t, "github.com/99designs/gqlgen/codegen/config/testdata/autobinding/chat.ChatAPI", cfg.Models["ChatAPI"].Model[0]) + }) + t.Run("with file path", func(t *testing.T) { cfg := Config{ Models: TypeMap{}, diff --git a/codegen/config/initialisms.go b/codegen/config/initialisms.go index 5c169c8900..25e7331f8a 100644 --- a/codegen/config/initialisms.go +++ b/codegen/config/initialisms.go @@ -1,62 +1,10 @@ package config -import "strings" +import ( + "strings" -// commonInitialisms is a set of common initialisms. -// Only add entries that are highly unlikely to be non-initialisms. -// For instance, "ID" is fine (Freudian code is rare), but "AND" is not. -var commonInitialisms = map[string]bool{ - "ACL": true, - "API": true, - "ASCII": true, - "CPU": true, - "CSS": true, - "CSV": true, - "DNS": true, - "EOF": true, - "GUID": true, - "HTML": true, - "HTTP": true, - "HTTPS": true, - "ICMP": true, - "ID": true, - "IP": true, - "JSON": true, - "KVK": true, - "LHS": true, - "PDF": true, - "PGP": true, - "QPS": true, - "QR": true, - "RAM": true, - "RHS": true, - "RPC": true, - "SLA": true, - "SMTP": true, - "SQL": true, - "SSH": true, - "SVG": true, - "TCP": true, - "TLS": true, - "TTL": true, - "UDP": true, - "UI": true, - "UID": true, - "URI": true, - "URL": true, - "UTF8": true, - "UUID": true, - "VM": true, - "XML": true, - "XMPP": true, - "XSRF": true, - "XSS": true, -} - -// GetInitialisms returns the initialisms to capitalize in Go names. If unchanged, default initialisms will be returned -var GetInitialisms = func() map[string]bool { - return commonInitialisms -} + "github.com/99designs/gqlgen/codegen/templates" +) // GoInitialismsConfig allows to modify the default behavior of naming Go methods, types and properties type GoInitialismsConfig struct { @@ -69,7 +17,7 @@ type GoInitialismsConfig struct { // setInitialisms adjustes GetInitialisms based on its settings. func (i GoInitialismsConfig) setInitialisms() { toUse := i.determineGoInitialisms() - GetInitialisms = func() map[string]bool { + templates.GetInitialisms = func() map[string]bool { return toUse } } @@ -82,8 +30,8 @@ func (i GoInitialismsConfig) determineGoInitialisms() (initialismsToUse map[stri initialismsToUse[strings.ToUpper(initialism)] = true } } else { - initialismsToUse = make(map[string]bool, len(commonInitialisms)+len(i.Initialisms)) - for initialism, value := range commonInitialisms { + initialismsToUse = make(map[string]bool, len(templates.CommonInitialisms)+len(i.Initialisms)) + for initialism, value := range templates.CommonInitialisms { initialismsToUse[strings.ToUpper(initialism)] = value } for _, initialism := range i.Initialisms { diff --git a/codegen/config/initialisms_test.go b/codegen/config/initialisms_test.go index 5bea561a3a..13c0da2465 100644 --- a/codegen/config/initialisms_test.go +++ b/codegen/config/initialisms_test.go @@ -5,6 +5,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/99designs/gqlgen/codegen/templates" ) func TestGoInitialismsConfig(t *testing.T) { @@ -17,12 +19,12 @@ func TestGoInitialismsConfig(t *testing.T) { t.Run("empty initialism config doesn't change anything", func(t *testing.T) { tt := GoInitialismsConfig{} result := tt.determineGoInitialisms() - assert.Equal(t, len(commonInitialisms), len(result)) + assert.Equal(t, len(templates.CommonInitialisms), len(result)) }) t.Run("initialism config appends if desired", func(t *testing.T) { tt := GoInitialismsConfig{ReplaceDefaults: false, Initialisms: []string{"ASDF"}} result := tt.determineGoInitialisms() - assert.Equal(t, len(commonInitialisms)+1, len(result)) + assert.Equal(t, len(templates.CommonInitialisms)+1, len(result)) assert.True(t, result["ASDF"]) }) t.Run("initialism config replaces if desired", func(t *testing.T) { diff --git a/codegen/config/testdata/autobinding/chat/message.go b/codegen/config/testdata/autobinding/chat/model.go similarity index 62% rename from codegen/config/testdata/autobinding/chat/message.go rename to codegen/config/testdata/autobinding/chat/model.go index b35be48c93..fa621ee90a 100644 --- a/codegen/config/testdata/autobinding/chat/message.go +++ b/codegen/config/testdata/autobinding/chat/model.go @@ -10,3 +10,13 @@ type Message struct { CreatedBy string `json:"createdBy"` CreatedAt time.Time `json:"createdAt"` } + +type ProductSku string + +const ( + ProductSkuTrial ProductSku = "Trial" +) + +type ChatAPI struct { + ID string `json:"id"` +} diff --git a/codegen/templates/templates.go b/codegen/templates/templates.go index 976527578d..4a4e91594e 100644 --- a/codegen/templates/templates.go +++ b/codegen/templates/templates.go @@ -17,7 +17,6 @@ import ( "text/template" "unicode" - "github.com/99designs/gqlgen/codegen/config" "github.com/99designs/gqlgen/internal/code" "github.com/99designs/gqlgen/internal/imports" ) @@ -202,7 +201,7 @@ func Funcs() template.FuncMap { "rawQuote": rawQuote, "dump": Dump, "ref": ref, - "ts": config.TypeIdentifier, + "ts": TypeIdentifier, "call": Call, "prefixLines": prefixLines, "notNil": notNil, @@ -465,7 +464,7 @@ func wordWalker(str string, f func(*wordInfo)) { } i++ - initialisms := config.GetInitialisms() + initialisms := GetInitialisms() // [w,i) is a word. word := string(runes[w:i]) if !eow && initialisms[word] && !unicode.IsLower(runes[i]) { @@ -668,3 +667,97 @@ func write(filename string, b []byte, packages *code.Packages) error { return nil } + +var pkgReplacer = strings.NewReplacer( + "/", "ᚋ", + ".", "ᚗ", + "-", "ᚑ", + "~", "א", +) + +func TypeIdentifier(t types.Type) string { + res := "" + for { + switch it := t.(type) { + case *types.Pointer: + t.Underlying() + res += "ᚖ" + t = it.Elem() + case *types.Slice: + res += "ᚕ" + t = it.Elem() + case *types.Named: + res += pkgReplacer.Replace(it.Obj().Pkg().Path()) + res += "ᚐ" + res += it.Obj().Name() + return res + case *types.Basic: + res += it.Name() + return res + case *types.Map: + res += "map" + return res + case *types.Interface: + res += "interface" + return res + default: + panic(fmt.Errorf("unexpected type %T", it)) + } + } +} + +// CommonInitialisms is a set of common initialisms. +// Only add entries that are highly unlikely to be non-initialisms. +// For instance, "ID" is fine (Freudian code is rare), but "AND" is not. +var CommonInitialisms = map[string]bool{ + "ACL": true, + "API": true, + "ASCII": true, + "CPU": true, + "CSS": true, + "CSV": true, + "DNS": true, + "EOF": true, + "GUID": true, + "HTML": true, + "HTTP": true, + "HTTPS": true, + "ICMP": true, + "ID": true, + "IP": true, + "JSON": true, + "KVK": true, + "LHS": true, + "PDF": true, + "PGP": true, + "QPS": true, + "QR": true, + "RAM": true, + "RHS": true, + "RPC": true, + "SLA": true, + "SMTP": true, + "SQL": true, + "SSH": true, + "SVG": true, + "TCP": true, + "TLS": true, + "TTL": true, + "UDP": true, + "UI": true, + "UID": true, + "URI": true, + "URL": true, + "UTF8": true, + "UUID": true, + "VM": true, + "XML": true, + "XMPP": true, + "XSRF": true, + "XSS": true, +} + +// GetInitialisms returns the initialisms to capitalize in Go names. If unchanged, default initialisms will be returned +var GetInitialisms = func() map[string]bool { + return CommonInitialisms +}