diff --git a/internal/postgresql/catalog_test.go b/internal/postgresql/catalog_test.go index 90ec3373a3..835b14e6f2 100644 --- a/internal/postgresql/catalog_test.go +++ b/internal/postgresql/catalog_test.go @@ -231,26 +231,20 @@ func TestUpdate(t *testing.T) { `, nil, }, - /* - { - ` - CREATE TABLE venues (); - ALTER TABLE venues RENAME TO arenas; - `, - pg.Catalog{ - Schemas: map[string]pg.Schema{ - "public": { - Types: map[string]pg.Type{}, - Tables: map[string]pg.Table{ - "arenas": pg.Table{ - Name: "arenas", - }, - }, - }, + { + ` + CREATE TABLE venues (); + ALTER TABLE venues RENAME TO arenas; + `, + &catalog.Schema{ + Name: "public", + Tables: []*catalog.Table{ + { + Rel: &ast.TableName{Name: "arenas"}, }, }, }, - */ + }, { ` CREATE TYPE status AS ENUM ('open', 'closed'); @@ -273,15 +267,13 @@ func TestUpdate(t *testing.T) { `, nil, }, - /* - { - ` - CREATE TYPE status AS ENUM ('open', 'closed'); - DROP TYPE public.status; - `, - catalog.New("public"), - }, - */ + { + ` + CREATE TYPE status AS ENUM ('open', 'closed'); + DROP TYPE public.status; + `, + nil, + }, { ` CREATE SCHEMA foo; diff --git a/internal/postgresql/parse.go b/internal/postgresql/parse.go index eb99b98990..e5bfdf9bb4 100644 --- a/internal/postgresql/parse.go +++ b/internal/postgresql/parse.go @@ -365,6 +365,22 @@ func translate(node nodes.Node) (ast.Node, error) { } return nil, errSkip + case nodes.RenameStmt: + switch n.RenameType { + + case nodes.OBJECT_TABLE: + tbl, err := parseTableName(*n.Relation) + if err != nil { + return nil, err + } + return &ast.RenameTableStmt{ + Table: tbl, + NewName: n.Newname, + }, nil + + } + return nil, errSkip + default: return nil, errSkip } diff --git a/internal/sql/ast/ast.go b/internal/sql/ast/ast.go index f00cd382cd..3b3a2768f8 100644 --- a/internal/sql/ast/ast.go +++ b/internal/sql/ast/ast.go @@ -209,3 +209,12 @@ type CommentOnColumnStmt struct { func (n *CommentOnColumnStmt) Pos() int { return 0 } + +type RenameTableStmt struct { + Table *TableName + NewName *string +} + +func (n *RenameTableStmt) Pos() int { + return 0 +} diff --git a/internal/sql/catalog/catalog.go b/internal/sql/catalog/catalog.go index 3db7f600e0..0372751d0a 100644 --- a/internal/sql/catalog/catalog.go +++ b/internal/sql/catalog/catalog.go @@ -165,6 +165,8 @@ func (c *Catalog) Build(stmts []ast.Statement) error { err = c.dropTable(n) case *ast.DropTypeStmt: err = c.dropType(n) + case *ast.RenameTableStmt: + err = c.renameTable(n) } if err != nil { return err diff --git a/internal/sql/catalog/table.go b/internal/sql/catalog/table.go index 605cc6a3a1..deb2d6309f 100644 --- a/internal/sql/catalog/table.go +++ b/internal/sql/catalog/table.go @@ -145,3 +145,14 @@ func (c *Catalog) dropTable(stmt *ast.DropTableStmt) error { } return nil } + +func (c *Catalog) renameTable(stmt *ast.RenameTableStmt) error { + _, tbl, err := c.getTable(stmt.Table) + if err != nil { + return err + } + if stmt.NewName != nil { + tbl.Rel.Name = *stmt.NewName + } + return nil +}