diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 0e553b64ff..0fda1f3ac9 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -200,6 +200,8 @@ type Query struct { SourceName string Ret QueryValue Arg QueryValue + ManyKey string // Taken from CmdParams.ManyKey, the map-key for :many. If empty a slice should be returned. + ManyKeyType string // The Go type of ManyKey // Used for :copyfrom Table *plugin.Identifier } @@ -219,3 +221,23 @@ func (q Query) TableIdentifier() string { } return "[]string{" + strings.Join(escapedNames, ", ") + "}" } + +func (v Query) DefineRetTypeMultiple() string { + if v.ManyKey != "" { + return "map[" + v.ManyKeyType + "]" + v.Ret.DefineType() + } + return "[]" + v.Ret.DefineType() +} + +func (v Query) HasManyKey() bool { + return v.ManyKey != "" +} + +func (v Query) ManyKeyField() string { + for _, f := range v.Ret.Struct.Fields { + if f.DBName == v.ManyKey { + return v.Ret.Name + "." + f.Name + } + } + panic("couldn't find :many-key in struct fields") +} diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index f5ecd124a1..004d6fcbbd 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -247,6 +247,9 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) } if len(query.Columns) == 1 && query.Columns[0].EmbedTable == nil { + if query.CmdParams.GetManyKey() != "" { + return nil, fmt.Errorf(":many key=%s query %s has only one column", query.CmdParams.GetManyKey(), query.Name) + } c := query.Columns[0] name := columnName(c, 0) if c.IsFuncCall { @@ -305,6 +308,18 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) SQLDriver: sqlpkg, EmitPointer: req.Settings.Go.EmitResultStructPointers, } + if query.CmdParams.GetManyKey() != "" { + gq.ManyKey = query.CmdParams.GetManyKey() + for _, c := range gs.Fields { + if c.DBName == gq.ManyKey { + gq.ManyKeyType = c.Type + break + } + } + if gq.ManyKeyType == "" { + return nil, fmt.Errorf("can not find key column %q in query %s", gq.ManyKey, gq.MethodName) + } + } } qs = append(qs, gq) diff --git a/internal/codegen/golang/templates/stdlib/queryCode.tmpl b/internal/codegen/golang/templates/stdlib/queryCode.tmpl index cde37d81ed..0e4f278ff0 100644 --- a/internal/codegen/golang/templates/stdlib/queryCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/queryCode.tmpl @@ -35,23 +35,27 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{if eq .Cmd ":many"}} {{range .Comments}}//{{.}} {{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { +func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) ({{.DefineRetTypeMultiple}}, error) { {{- template "queryCodeStdExec" . }} if err != nil { return nil, err } defer rows.Close() - {{- if $.EmitEmptySlices}} - items := []{{.Ret.DefineType}}{} + {{- if (or $.EmitEmptySlices .HasManyKey)}} + items := {{.DefineRetTypeMultiple}}{} {{else}} - var items []{{.Ret.DefineType}} + var items {{.DefineRetTypeMultiple}} {{end -}} for rows.Next() { var {{.Ret.Name}} {{.Ret.Type}} if err := rows.Scan({{.Ret.Scan}}); err != nil { return nil, err } + {{- if .HasManyKey}} + items[{{.ManyKeyField}}] = {{.Ret.ReturnName}} + {{else}} items = append(items, {{.Ret.ReturnName}}) + {{end -}} } if err := rows.Close(); err != nil { return nil, err diff --git a/internal/endtoend/testdata/manykey/stdlib/go/db.go b/internal/endtoend/testdata/manykey/stdlib/go/db.go new file mode 100644 index 0000000000..8c5b31f933 --- /dev/null +++ b/internal/endtoend/testdata/manykey/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.18.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/manykey/stdlib/go/models.go b/internal/endtoend/testdata/manykey/stdlib/go/models.go new file mode 100644 index 0000000000..c5ce3259cd --- /dev/null +++ b/internal/endtoend/testdata/manykey/stdlib/go/models.go @@ -0,0 +1,12 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.18.0 + +package querytest + +import () + +type Foo struct { + GroupID int32 + Score int32 +} diff --git a/internal/endtoend/testdata/manykey/stdlib/go/query.sql.go b/internal/endtoend/testdata/manykey/stdlib/go/query.sql.go new file mode 100644 index 0000000000..d6f6494a4a --- /dev/null +++ b/internal/endtoend/testdata/manykey/stdlib/go/query.sql.go @@ -0,0 +1,42 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.18.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const selectScoreSums = `-- name: SelectScoreSums :many +SELECT group_id, SUM(score) FROM foo GROUP BY group_id +` + +type SelectScoreSumsRow struct { + GroupID int32 + Sum int64 +} + +func (q *Queries) SelectScoreSums(ctx context.Context) (map[int32]SelectScoreSumsRow, error) { + rows, err := q.db.QueryContext(ctx, selectScoreSums) + if err != nil { + return nil, err + } + defer rows.Close() + items := map[int32]SelectScoreSumsRow{} + for rows.Next() { + var i SelectScoreSumsRow + if err := rows.Scan(&i.GroupID, &i.Sum); err != nil { + return nil, err + } + items[i.GroupID] = i + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/manykey/stdlib/query.sql b/internal/endtoend/testdata/manykey/stdlib/query.sql new file mode 100644 index 0000000000..31adcde830 --- /dev/null +++ b/internal/endtoend/testdata/manykey/stdlib/query.sql @@ -0,0 +1,7 @@ +CREATE TABLE foo ( + group_id INT NOT NULL, + score INT NOT NULL +); + +-- name: SelectScoreSums :many key=group_id +SELECT group_id, SUM(score) FROM foo GROUP BY group_id; diff --git a/internal/endtoend/testdata/manykey/stdlib/sqlc.json b/internal/endtoend/testdata/manykey/stdlib/sqlc.json new file mode 100644 index 0000000000..ac7c2ed829 --- /dev/null +++ b/internal/endtoend/testdata/manykey/stdlib/sqlc.json @@ -0,0 +1,11 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +}