Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Remove ParsedGoType from codegen.proto, pass in opts as JSON #2918

Closed
wants to merge 12 commits into from
Closed
36 changes: 20 additions & 16 deletions internal/cmd/shim.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package cmd

import (
"encoding/json"
"strings"

gopluginopts "github.com/sqlc-dev/sqlc/internal/codegen/golang/opts"
"github.com/sqlc-dev/sqlc/internal/compiler"
"github.com/sqlc-dev/sqlc/internal/config"
"github.com/sqlc-dev/sqlc/internal/config/convert"
Expand Down Expand Up @@ -33,19 +35,35 @@ func pluginOverride(r *compiler.Result, o config.Override) *plugin.Override {
column = colParts[3]
}
}

var options []byte
var err error
if len(o.Options) == 0 {
// Send go-specific override information to the go codegen plugin
options, err = json.Marshal(gopluginopts.OverrideOptions{
GoType: o.GoType,
GoStructTag: o.GoStructTag,
})
if err != nil {
panic(err) // TODO don't panic, return err
}
} else {
options = o.Options
}

return &plugin.Override{
CodeType: "", // FIXME
DbType: o.DBType,
Nullable: o.Nullable,
Unsigned: o.Unsigned,
Column: o.Column,
ColumnName: column,
Table: &table,
GoType: pluginGoType(o),
Options: options,
}
}

func pluginSettings(r *compiler.Result, cs config.CombinedSettings) *plugin.Settings {
// TODO only send overrides meant for this plugin
var over []*plugin.Override
for _, o := range cs.Overrides {
over = append(over, pluginOverride(r, o))
Expand Down Expand Up @@ -101,20 +119,6 @@ func pluginWASM(p config.Plugin) *plugin.Codegen_WASM {
return nil
}

func pluginGoType(o config.Override) *plugin.ParsedGoType {
// Note that there is a slight mismatch between this and the
// proto api. The GoType on the override is the unparsed type,
// which could be a qualified path or an object, as per
// https://docs.sqlc.dev/en/v1.18.0/reference/config.html#type-overriding
return &plugin.ParsedGoType{
ImportPath: o.GoImportPath,
Package: o.GoPackage,
TypeName: o.GoTypeName,
BasicType: o.GoBasicType,
StructTags: o.GoStructTags,
}
}

func pluginCatalog(c *catalog.Catalog) *plugin.Catalog {
var schemas []*plugin.Schema
for _, s := range c.Schemas {
Expand Down
9 changes: 4 additions & 5 deletions internal/codegen/golang/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,10 @@ func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenR

func generate(req *plugin.CodeGenRequest, options *opts.Options, enums []Enum, structs []Struct, queries []Query) (*plugin.CodeGenResponse, error) {
i := &importer{
Settings: req.Settings,
Options: options,
Queries: queries,
Enums: enums,
Structs: structs,
Options: options,
Queries: queries,
Enums: enums,
Structs: structs,
}

tctx := tmplCtx{
Expand Down
12 changes: 6 additions & 6 deletions internal/codegen/golang/go_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ import (
"github.com/sqlc-dev/sqlc/internal/plugin"
)

func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, col *plugin.Column) {
for _, oride := range req.Settings.Overrides {
func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) {
for _, oride := range options.Overrides {
if oride.GoType.StructTags == nil {
continue
}
if !sdk.Matches(oride, col.Table, req.Catalog.DefaultSchema) {
if !oride.Matches(col.Table, req.Catalog.DefaultSchema) {
// Different table.
continue
}
Expand All @@ -34,15 +34,15 @@ func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, co

func goType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) string {
// Check if the column's type has been overridden
for _, oride := range req.Settings.Overrides {
for _, oride := range options.Overrides {
if oride.GoType.TypeName == "" {
continue
}
cname := col.Name
if col.OriginalName != "" {
cname = col.OriginalName
}
sameTable := sdk.Matches(oride, col.Table, req.Catalog.DefaultSchema)
sameTable := oride.Matches(col.Table, req.Catalog.DefaultSchema)
if oride.Column != "" && sdk.MatchString(oride.ColumnName, cname) && sameTable {
if col.IsSqlcSlice {
return "[]" + oride.GoType.TypeName
Expand All @@ -65,7 +65,7 @@ func goInnerType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.
notNull := col.NotNull || col.IsArray

// package overrides have a higher precedence
for _, oride := range req.Settings.Overrides {
for _, oride := range options.Overrides {
if oride.GoType.TypeName == "" {
continue
}
Expand Down
26 changes: 12 additions & 14 deletions internal/codegen/golang/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"github.com/sqlc-dev/sqlc/internal/codegen/golang/opts"
"github.com/sqlc-dev/sqlc/internal/metadata"
"github.com/sqlc-dev/sqlc/internal/plugin"
)

type fileImports struct {
Expand Down Expand Up @@ -59,11 +58,10 @@ func mergeImports(imps ...fileImports) [][]ImportSpec {
}

type importer struct {
Settings *plugin.Settings
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having all overrides in Options lets us drop this field

Options *opts.Options
Queries []Query
Enums []Enum
Structs []Struct
Options *opts.Options
Queries []Query
Enums []Enum
Structs []Struct
}

func (i *importer) usesType(typ string) bool {
Expand Down Expand Up @@ -157,7 +155,7 @@ var pqtypeTypes = map[string]struct{}{
"pqtype.NullRawMessage": {},
}

func buildImports(settings *plugin.Settings, options *opts.Options, queries []Query, uses func(string) bool) (map[string]struct{}, map[ImportSpec]struct{}) {
func buildImports(options *opts.Options, queries []Query, uses func(string) bool) (map[string]struct{}, map[ImportSpec]struct{}) {
pkg := make(map[ImportSpec]struct{})
std := make(map[string]struct{})

Expand Down Expand Up @@ -201,7 +199,7 @@ func buildImports(settings *plugin.Settings, options *opts.Options, queries []Qu
}

overrideTypes := map[string]string{}
for _, o := range settings.Overrides {
for _, o := range options.Overrides {
if o.GoType.BasicType || o.GoType.TypeName == "" {
continue
}
Expand All @@ -226,7 +224,7 @@ func buildImports(settings *plugin.Settings, options *opts.Options, queries []Qu
}

// Custom imports
for _, o := range settings.Overrides {
for _, o := range options.Overrides {
if o.GoType.BasicType || o.GoType.TypeName == "" {
continue
}
Expand All @@ -241,7 +239,7 @@ func buildImports(settings *plugin.Settings, options *opts.Options, queries []Qu
}

func (i *importer) interfaceImports() fileImports {
std, pkg := buildImports(i.Settings, i.Options, i.Queries, func(name string) bool {
std, pkg := buildImports(i.Options, i.Queries, func(name string) bool {
for _, q := range i.Queries {
if q.hasRetType() {
if usesBatch([]Query{q}) {
Expand All @@ -266,7 +264,7 @@ func (i *importer) interfaceImports() fileImports {
}

func (i *importer) modelImports() fileImports {
std, pkg := buildImports(i.Settings, i.Options, nil, i.usesType)
std, pkg := buildImports(i.Options, nil, i.usesType)

if len(i.Enums) > 0 {
std["fmt"] = struct{}{}
Expand Down Expand Up @@ -305,7 +303,7 @@ func (i *importer) queryImports(filename string) fileImports {
}
}

std, pkg := buildImports(i.Settings, i.Options, gq, func(name string) bool {
std, pkg := buildImports(i.Options, gq, func(name string) bool {
for _, q := range gq {
if q.hasRetType() {
if q.Ret.EmitStruct() {
Expand Down Expand Up @@ -406,7 +404,7 @@ func (i *importer) copyfromImports() fileImports {
copyFromQueries = append(copyFromQueries, q)
}
}
std, pkg := buildImports(i.Settings, i.Options, copyFromQueries, func(name string) bool {
std, pkg := buildImports(i.Options, copyFromQueries, func(name string) bool {
for _, q := range copyFromQueries {
if q.hasRetType() {
if strings.HasPrefix(q.Ret.Type(), name) {
Expand Down Expand Up @@ -441,7 +439,7 @@ func (i *importer) batchImports() fileImports {
batchQueries = append(batchQueries, q)
}
}
std, pkg := buildImports(i.Settings, i.Options, batchQueries, func(name string) bool {
std, pkg := buildImports(i.Options, batchQueries, func(name string) bool {
for _, q := range batchQueries {
if q.hasRetType() {
if q.Ret.EmitStruct() {
Expand Down
40 changes: 40 additions & 0 deletions internal/codegen/golang/opts/go_override.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package opts

import (
"github.com/sqlc-dev/sqlc/internal/codegen/sdk"
"github.com/sqlc-dev/sqlc/internal/plugin"
)

type GoOverride struct {
*plugin.Override

GoType *ParsedGoType
}

func (o *GoOverride) Convert() *plugin.Override {
return &plugin.Override{
DbType: o.DbType,
Nullable: o.Nullable,
Column: o.Column,
Table: o.Table,
ColumnName: o.ColumnName,
Unsigned: o.Unsigned,
}
}

func (o *GoOverride) Matches(n *plugin.Identifier, defaultSchema string) bool {
return sdk.Matches(o.Convert(), n, defaultSchema)
}

func NewGoOverride(po *plugin.Override, o Override) GoOverride {
return GoOverride{
po,
&ParsedGoType{
ImportPath: o.GoImportPath,
Package: o.GoPackage,
TypeName: o.GoTypeName,
BasicType: o.GoBasicType,
StructTags: o.GoStructTags,
},
}
}
Loading
Loading