Skip to content

Commit

Permalink
sql/ast: Implement ALTER TABLE SET SCHEMA (#400)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyleconroy authored Mar 11, 2020
1 parent 3fb980f commit 9f84662
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 21 deletions.
32 changes: 12 additions & 20 deletions internal/postgresql/catalog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,29 +193,21 @@ func TestUpdate(t *testing.T) {
},
},
},
/*
{
`
CREATE SCHEMA foo;
CREATE SCHEMA bar;
CREATE TABLE foo.baz ();
ALTER TABLE foo.baz SET SCHEMA bar;
`,
pg.Catalog{
Schemas: map[string]pg.Schema{
"public": {},
"foo": {},
"bar": {
Tables: map[string]pg.Table{
"baz": pg.Table{
Name: "baz",
},
},
},
{
`
CREATE SCHEMA foo;
CREATE TABLE bar ();
ALTER TABLE bar SET SCHEMA foo;
`,
&catalog.Schema{
Name: "foo",
Tables: []*catalog.Table{
{
Rel: &ast.TableName{Name: "bar"},
},
},
},
*/
},
{
`
CREATE TABLE venues ();
Expand Down
16 changes: 15 additions & 1 deletion internal/postgresql/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,21 @@ func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) {
func translate(node nodes.Node) (ast.Node, error) {
switch n := node.(type) {

case nodes.AlterObjectSchemaStmt:
switch n.ObjectType {

case nodes.OBJECT_TABLE:
tbl, err := parseTableName(*n.Relation)
if err != nil {
return nil, err
}
return &ast.AlterTableSetSchemaStmt{
Table: tbl,
NewSchema: n.Newschema,
}, nil
}
return nil, errSkip

case nodes.AlterTableStmt:
name, err := parseTableName(*n.Relation)
if err != nil {
Expand Down Expand Up @@ -271,7 +286,6 @@ func translate(node nodes.Node) (ast.Node, error) {
}, nil

}

return nil, errSkip

case nodes.CreateStmt:
Expand Down
9 changes: 9 additions & 0 deletions internal/sql/ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ func (n *AlterTableCmd) Pos() int {
return 0
}

type AlterTableSetSchemaStmt struct {
Table *TableName
NewSchema *string
}

func (n *AlterTableSetSchemaStmt) Pos() int {
return 0
}

type CreateEnumStmt struct {
TypeName *TypeName
Vals *List
Expand Down
2 changes: 2 additions & 0 deletions internal/sql/catalog/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ func (c *Catalog) Build(stmts []ast.Statement) error {
switch n := stmts[i].Raw.Stmt.(type) {
case *ast.AlterTableStmt:
err = c.alterTable(n)
case *ast.AlterTableSetSchemaStmt:
err = c.alterTableSetSchema(n)
case *ast.CommentOnColumnStmt:
err = c.commentOnColumn(n)
case *ast.CommentOnSchemaStmt:
Expand Down
26 changes: 26 additions & 0 deletions internal/sql/catalog/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,32 @@ func (c *Catalog) alterTable(stmt *ast.AlterTableStmt) error {

return nil
}

func (c *Catalog) alterTableSetSchema(stmt *ast.AlterTableSetSchemaStmt) error {
ns := stmt.Table.Schema
if ns == "" {
ns = c.DefaultSchema
}
oldSchema, err := c.getSchema(ns)
if err != nil {
return err
}
tbl, idx, err := oldSchema.getTable(stmt.Table)
if err != nil {
return err
}
newSchema, err := c.getSchema(*stmt.NewSchema)
if err != nil {
return err
}
if _, _, err := newSchema.getTable(stmt.Table); err == nil {
return sqlerr.RelationExists(stmt.Table.Name)
}
oldSchema.Tables = append(oldSchema.Tables[:idx], oldSchema.Tables[idx+1:]...)
newSchema.Tables = append(newSchema.Tables, tbl)
return nil
}

func (c *Catalog) createTable(stmt *ast.CreateTableStmt) error {
ns := stmt.Name.Schema
if ns == "" {
Expand Down

0 comments on commit 9f84662

Please sign in to comment.