Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to Package Catalog to Improve Readability #1595

Merged
merged 8 commits into from
May 23, 2022
260 changes: 10 additions & 250 deletions internal/sql/catalog/catalog.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,10 @@
package catalog

import (
"strings"

"github.com/kyleconroy/sqlc/internal/sql/ast"
"github.com/kyleconroy/sqlc/internal/sql/sqlerr"
)

func stringSlice(list *ast.List) []string {
items := []string{}
for _, item := range list.Items {
if n, ok := item.(*ast.String); ok {
items = append(items, n.Str)
}
}
return items
}

// Catalog describes a database instance consisting of metadata in which database objects are defined
type Catalog struct {
Comment string
DefaultSchema string
Expand All @@ -29,241 +17,20 @@ type Catalog struct {
Extensions map[string]struct{}
}

func (c *Catalog) getSchema(name string) (*Schema, error) {
for i := range c.Schemas {
if c.Schemas[i].Name == name {
return c.Schemas[i], nil
}
}
return nil, sqlerr.SchemaNotFound(name)
}

func (c *Catalog) getFunc(rel *ast.FuncName, tns []*ast.TypeName) (*Function, int, error) {
ns := rel.Schema
if ns == "" {
ns = c.DefaultSchema
}
s, err := c.getSchema(ns)
if err != nil {
return nil, -1, err
}
return s.getFunc(rel, tns)
}

func (c *Catalog) getTable(name *ast.TableName) (*Schema, *Table, error) {
ns := name.Schema
if ns == "" {
ns = c.DefaultSchema
}
var s *Schema
for i := range c.Schemas {
if c.Schemas[i].Name == ns {
s = c.Schemas[i]
break
}
}
if s == nil {
return nil, nil, sqlerr.SchemaNotFound(ns)
}
t, _, err := s.getTable(name)
if err != nil {
return nil, nil, err
}
return s, t, nil
}

func (c *Catalog) getType(rel *ast.TypeName) (Type, int, error) {
ns := rel.Schema
if ns == "" {
ns = c.DefaultSchema
}
s, err := c.getSchema(ns)
if err != nil {
return nil, -1, err
}
return s.getType(rel)
}

type Schema struct {
Name string
Tables []*Table
Types []Type
Funcs []*Function
// New creates a new catalog
func New(defaultSchema string) *Catalog {

Comment string
}

func sameType(a, b *ast.TypeName) bool {
if a.Catalog != b.Catalog {
return false
}
// The pg_catalog schema is searched by default, so take that into
// account when comparing schemas
aSchema := a.Schema
bSchema := b.Schema
if aSchema == "pg_catalog" {
aSchema = ""
newCatalog := &Catalog{
DefaultSchema: defaultSchema,
Schemas: make([]*Schema, 0),
Extensions: make(map[string]struct{}),
}
if bSchema == "pg_catalog" {
bSchema = ""
}
if aSchema != bSchema {
return false
}
if a.Name != b.Name {
return false
}
return true
}

func (s *Schema) getFunc(rel *ast.FuncName, tns []*ast.TypeName) (*Function, int, error) {
for i := range s.Funcs {
if strings.ToLower(s.Funcs[i].Name) != strings.ToLower(rel.Name) {
continue
}

args := s.Funcs[i].InArgs()
if len(args) != len(tns) {
continue
}
found := true
for j := range args {
if !sameType(s.Funcs[i].Args[j].Type, tns[j]) {
found = false
break
}
}
if !found {
continue
}
return s.Funcs[i], i, nil
if newCatalog.DefaultSchema != "" {
newCatalog.Schemas = append(newCatalog.Schemas, &Schema{Name: defaultSchema})
}
return nil, -1, sqlerr.RelationNotFound(rel.Name)
}

func (s *Schema) getFuncByName(rel *ast.FuncName) (*Function, int, error) {
idx := -1
name := strings.ToLower(rel.Name)
for i := range s.Funcs {
lowered := strings.ToLower(s.Funcs[i].Name)
if lowered == name && idx >= 0 {
return nil, -1, sqlerr.FunctionNotUnique(rel.Name)
}
if lowered == name {
idx = i
}
}
if idx < 0 {
return nil, -1, sqlerr.RelationNotFound(rel.Name)
}
return s.Funcs[idx], idx, nil
}

func (s *Schema) getTable(rel *ast.TableName) (*Table, int, error) {
for i := range s.Tables {
if s.Tables[i].Rel.Name == rel.Name {
return s.Tables[i], i, nil
}
}
return nil, -1, sqlerr.RelationNotFound(rel.Name)
}

func (s *Schema) getType(rel *ast.TypeName) (Type, int, error) {
for i := range s.Types {
switch typ := s.Types[i].(type) {
case *Enum:
if typ.Name == rel.Name {
return s.Types[i], i, nil
}
}
}
return nil, -1, sqlerr.TypeNotFound(rel.Name)
}

type Table struct {
Rel *ast.TableName
Columns []*Column
Comment string
}

// TODO: Should this just be ast Nodes?
type Column struct {
Name string
Type ast.TypeName
IsNotNull bool
IsArray bool
Comment string
Length *int
}

type Type interface {
isType()

SetComment(string)
}

type Enum struct {
Name string
Vals []string
Comment string
}

func (e *Enum) SetComment(c string) {
e.Comment = c
}

func (e *Enum) isType() {
}

type CompositeType struct {
Name string
Comment string
}

func (ct *CompositeType) isType() {
}

func (ct *CompositeType) SetComment(c string) {
ct.Comment = c
}

type Function struct {
Name string
Args []*Argument
ReturnType *ast.TypeName
Comment string
Desc string
ReturnTypeNullable bool
}

func (f *Function) InArgs() []*Argument {
var args []*Argument
for _, a := range f.Args {
switch a.Mode {
case ast.FuncParamTable, ast.FuncParamOut:
continue
default:
args = append(args, a)
}
}
return args
}

type Argument struct {
Name string
Type *ast.TypeName
HasDefault bool
Mode ast.FuncParamMode
}

func New(def string) *Catalog {
return &Catalog{
DefaultSchema: def,
Schemas: []*Schema{
{Name: def},
},
Extensions: map[string]struct{}{},
}
return newCatalog
}

func (c *Catalog) Build(stmts []ast.Statement) error {
Expand All @@ -275,13 +42,6 @@ func (c *Catalog) Build(stmts []ast.Statement) error {
return nil
}

// An interface is used to resolve a circular import between the catalog and compiler packages.
// The createView function requires access to functions in the compiler package to parse the SELECT
// statement that defines the view.
type columnGenerator interface {
OutputColumns(node ast.Node) ([]*Column, error)
}

func (c *Catalog) Update(stmt ast.Statement, colGen columnGenerator) error {
if stmt.Raw == nil {
return nil
Expand Down
44 changes: 44 additions & 0 deletions internal/sql/catalog/func.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,50 @@ import (
"github.com/kyleconroy/sqlc/internal/sql/sqlerr"
)

// Function describes a database function
//
// A database function is a method written to performs specific operation on data within the database.
type Function struct {
Name string
Args []*Argument
ReturnType *ast.TypeName
Comment string
Desc string
ReturnTypeNullable bool
}

type Argument struct {
Name string
Type *ast.TypeName
HasDefault bool
Mode ast.FuncParamMode
}

func (f *Function) InArgs() []*Argument {
var args []*Argument
for _, a := range f.Args {
switch a.Mode {
case ast.FuncParamTable, ast.FuncParamOut:
continue
default:
args = append(args, a)
}
}
return args
}

func (c *Catalog) getFunc(rel *ast.FuncName, tns []*ast.TypeName) (*Function, int, error) {
ns := rel.Schema
if ns == "" {
ns = c.DefaultSchema
}
s, err := c.getSchema(ns)
if err != nil {
return nil, -1, err
}
return s.getFunc(rel, tns)
}

func (c *Catalog) createFunction(stmt *ast.CreateFunctionStmt) error {
ns := stmt.Func.Schema
if ns == "" {
Expand Down
Loading