Skip to content

Commit

Permalink
sql: Refactors CTAS logic to allow user defined PKs.
Browse files Browse the repository at this point in the history
This change removes the need for a separate `AsColumnNames` field
in the CreateTable node, and uses `ColumnTableDefs` instead. This
is similar to a normal CREATE TABLE query, thereby allowing us to
piggyback on most of the existing logic. It ensures `row_id` key
generation only occurs if no PK is explicitly specified.
Some special handling is required in pretty printing, as column
types are not populated at the time of parsing.

Release note: None
  • Loading branch information
adityamaru27 committed Jul 9, 2019
1 parent 04e5a37 commit 83e361d
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 62 deletions.
91 changes: 57 additions & 34 deletions pkg/sql/create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (p *planner) CreateTable(ctx context.Context, n *tree.CreateTable) (planNod
return nil, err
}

numColNames := len(n.AsColumnNames)
numColNames := len(n.Defs)
numColumns := len(planColumns(sourcePlan))
if numColNames != 0 && numColNames != numColumns {
sourcePlan.Close(ctx)
Expand All @@ -80,12 +80,22 @@ func (p *planner) CreateTable(ctx context.Context, n *tree.CreateTable) (planNod
}

// Synthesize an input column that provides the default value for the
// hidden rowid column.
// hidden rowid column, if none of the provided columns are specified
// as the PRIMARY KEY.
synthRowID = true
for _, def := range n.Defs {
if d, ok := def.(*tree.ColumnTableDef); !ok {
return nil, errors.Errorf("failed to cast type to ColumnTableDef\n")
} else if d.PrimaryKey {
synthRowID = false
break
}
}
}

ct := &createTableNode{n: n, dbDesc: dbDesc, sourcePlan: sourcePlan}
ct.run.synthRowID = synthRowID
ct.run.fromHeuristicPlanner = true
return ct, nil
}

Expand All @@ -95,10 +105,15 @@ type createTableRun struct {
autoCommit autoCommitOpt

// synthRowID indicates whether an input column needs to be synthesized to
// provide the default value for the hidden rowid column. The optimizer's
// plan already includes this column (so synthRowID is false), whereas the
// heuristic planner's plan does not (so synthRowID is true).
// provide the default value for the hidden rowid column. The optimizer's plan
// already includes this column (so synthRowID is false), whereas the
// heuristic planner's plan does not, and it is decided based on the existence
// of a user specified PRIMARY KEY constraint.
synthRowID bool

// fromHeuristicPlanner indicates whether the planning was performed by the
// heuristic planner instead of the optimizer.
fromHeuristicPlanner bool
}

func (n *createTableNode) startExec(params runParams) error {
Expand Down Expand Up @@ -139,15 +154,18 @@ func (n *createTableNode) startExec(params runParams) error {
}

asCols = planColumns(n.sourcePlan)
if !n.run.synthRowID {
// rowID column is already present in the input as the last column, so
// ignore it for the purpose of creating column metadata (because
if !n.run.fromHeuristicPlanner && !n.n.AsHasUserSpecifiedPrimaryKey() {
// rowID column is already present in the input as the last column if it
// was planned by the optimizer and the user did not specify a PRIMARY
// KEY. So ignore it for the purpose of creating column metadata (because
// makeTableDescIfAs does it automatically).
asCols = asCols[:len(asCols)-1]
}
desc, err = makeTableDescIfAs(

desc, err = makeTableDescIfAs(params,
n.n, n.dbDesc.ID, id, creationTime, asCols,
privs, &params.p.semaCtx, params.p.EvalContext())
privs, params.p.EvalContext(), nil /* affected */)

if err != nil {
return err
}
Expand Down Expand Up @@ -259,9 +277,9 @@ func (n *createTableNode) startExec(params runParams) error {
return err
}

// Prepare the buffer for row values. At this point, one more
// column has been added by ensurePrimaryKey() to the list of
// columns in sourcePlan.
// Prepare the buffer for row values. At this point, one more column has
// been added by ensurePrimaryKey() to the list of columns in sourcePlan, if
// a PRIMARY KEY is not specified by the user.
rowBuffer := make(tree.Datums, len(desc.Columns))
pkColIdx := len(desc.Columns) - 1

Expand Down Expand Up @@ -975,38 +993,43 @@ func getFinalSourceQuery(source *tree.Select, evalCtx *tree.EvalContext) string
// makeTableDescIfAs is the MakeTableDesc method for when we have a table
// that is created with the CREATE AS format.
func makeTableDescIfAs(
params runParams,
p *tree.CreateTable,
parentID, id sqlbase.ID,
creationTime hlc.Timestamp,
resultColumns []sqlbase.ResultColumn,
privileges *sqlbase.PrivilegeDescriptor,
semaCtx *tree.SemaContext,
evalContext *tree.EvalContext,
affected map[sqlbase.ID]*sqlbase.MutableTableDescriptor,
) (desc sqlbase.MutableTableDescriptor, err error) {
desc = InitTableDescriptor(id, parentID, p.Table.Table(), creationTime, privileges)
desc.CreateQuery = getFinalSourceQuery(p.AsSource, evalContext)

hasColumnDefs := len(p.Defs) == len(resultColumns)
for i, colRes := range resultColumns {
columnTableDef := tree.ColumnTableDef{Name: tree.Name(colRes.Name), Type: colRes.Typ}
columnTableDef.Nullable.Nullability = tree.SilentNull
if len(p.AsColumnNames) > i {
columnTableDef.Name = p.AsColumnNames[i]
}

// The new types in the CREATE TABLE AS column specs never use
// SERIAL so we need not process SERIAL types here.
col, _, _, err := sqlbase.MakeColumnDefDescs(&columnTableDef, semaCtx)
if err != nil {
return desc, err
var d *tree.ColumnTableDef
var ok bool
if hasColumnDefs {
if d, ok = p.Defs[i].(*tree.ColumnTableDef); !ok {
return desc, errors.Errorf("failed to cast type to ColumnTableDef\n")
}
d.Type = colRes.Typ
} else {
var tableDef tree.TableDef = &tree.ColumnTableDef{Name: tree.Name(colRes.Name), Type: colRes.Typ}
if d, ok = tableDef.(*tree.ColumnTableDef); !ok {
return desc, errors.Errorf("failed to cast type to ColumnTableDef\n")
}
d.Nullable.Nullability = tree.SilentNull
p.Defs = append(p.Defs, tableDef)
}
desc.AddColumn(col)
}

// AllocateIDs mutates its receiver. `return desc, desc.AllocateIDs()`
// happens to work in gc, but does not work in gccgo.
//
// See https://github.com/golang/go/issues/23188.
err = desc.AllocateIDs()
desc, err = makeTableDesc(
params,
p,
parentID, id,
creationTime,
privileges,
affected,
)
desc.CreateQuery = getFinalSourceQuery(p.AsSource, evalContext)
return desc, err
}

Expand Down
46 changes: 46 additions & 0 deletions pkg/sql/logictest/testdata/logic_test/create_as
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,49 @@ SELECT * FROM baz
----
a b c
1 2 4

# Check that CREATE TABLE AS allows users to specify primary key (#20940)
statement ok
CREATE TABLE foo5 (a, b PRIMARY KEY, c) AS SELECT * FROM baz

query TT
SHOW CREATE TABLE foo5
----
foo5 CREATE TABLE foo5 (
a INT8 NULL,
b INT8 NOT NULL,
c INT8 NULL,
CONSTRAINT "primary" PRIMARY KEY (b ASC),
FAMILY "primary" (a, b, c)
)

statement ok
SET OPTIMIZER=ON; CREATE TABLE foo6 (a PRIMARY KEY, b, c) AS SELECT * FROM baz; SET OPTIMIZER=OFF

query TT
SHOW CREATE TABLE foo6
----
foo6 CREATE TABLE foo6 (
a INT8 NOT NULL,
b INT8 NULL,
c INT8 NULL,
CONSTRAINT "primary" PRIMARY KEY (a ASC),
FAMILY "primary" (a, b, c)
)

statement error generate insert row: null value in column "x" violates not-null constraint
CREATE TABLE foo7 (x PRIMARY KEY) AS VALUES (1), (NULL);

statement ok
BEGIN; CREATE TABLE foo8 (item PRIMARY KEY, qty) AS SELECT * FROM stock UNION VALUES ('spoons', 25), ('knives', 50); END

query TT
SHOW CREATE TABLE foo8
----
foo8 CREATE TABLE foo8 (
item STRING NOT NULL,
qty INT8 NULL,
CONSTRAINT "primary" PRIMARY KEY (item ASC),
FAMILY "primary" (item, qty)
)

25 changes: 14 additions & 11 deletions pkg/sql/opt/optbuilder/create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (b *Builder) buildCreateTable(ct *tree.CreateTable, inScope *scope) (outSco
// Build the input query.
outScope := b.buildSelect(ct.AsSource, nil /* desiredTypes */, inScope)

numColNames := len(ct.AsColumnNames)
numColNames := len(ct.Defs)
numColumns := len(outScope.cols)
if numColNames != 0 && numColNames != numColumns {
panic(builderError{sqlbase.NewSyntaxError(fmt.Sprintf(
Expand All @@ -50,17 +50,20 @@ func (b *Builder) buildCreateTable(ct *tree.CreateTable, inScope *scope) (outSco
numColumns, util.Pluralize(int64(numColumns))))})
}

// Synthesize rowid column, and append to end of column list.
props, overloads := builtins.GetBuiltinProperties("unique_rowid")
private := &memo.FunctionPrivate{
Name: "unique_rowid",
Typ: types.Int,
Properties: props,
Overload: &overloads[0],
input = outScope.expr
if !ct.AsHasUserSpecifiedPrimaryKey() {
// Synthesize rowid column, and append to end of column list.
props, overloads := builtins.GetBuiltinProperties("unique_rowid")
private := &memo.FunctionPrivate{
Name: "unique_rowid",
Typ: types.Int,
Properties: props,
Overload: &overloads[0],
}
fn := b.factory.ConstructFunction(memo.EmptyScalarListExpr, private)
scopeCol := b.synthesizeColumn(outScope, "rowid", types.Int, nil /* expr */, fn)
input = b.factory.CustomFuncs().ProjectExtraCol(outScope.expr, fn, scopeCol.id)
}
fn := b.factory.ConstructFunction(memo.EmptyScalarListExpr, private)
scopeCol := b.synthesizeColumn(outScope, "rowid", types.Int, nil /* expr */, fn)
input = b.factory.CustomFuncs().ProjectExtraCol(outScope.expr, fn, scopeCol.id)
inputCols = outScope.makePhysicalProps().Presentation
} else {
// Create dummy empty input.
Expand Down
42 changes: 31 additions & 11 deletions pkg/sql/sem/tree/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,14 @@ func (node *ColumnTableDef) HasColumnFamily() bool {
// Format implements the NodeFormatter interface.
func (node *ColumnTableDef) Format(ctx *FmtCtx) {
ctx.FormatNode(&node.Name)
ctx.WriteByte(' ')
ctx.WriteString(node.columnTypeString())

// ColumnTableDef node type will not be specified if it represents a CREATE
// TABLE ... AS query.
if node.Type != nil {
ctx.WriteByte(' ')
ctx.WriteString(node.columnTypeString())
}

if node.Nullable.Nullability != SilentNull && node.Nullable.ConstraintName != "" {
ctx.WriteString(" CONSTRAINT ")
ctx.FormatNode(&node.Nullable.ConstraintName)
Expand Down Expand Up @@ -885,13 +891,12 @@ func (node *RangePartition) Format(ctx *FmtCtx) {

// CreateTable represents a CREATE TABLE statement.
type CreateTable struct {
IfNotExists bool
Table TableName
Interleave *InterleaveDef
PartitionBy *PartitionBy
Defs TableDefs
AsSource *Select
AsColumnNames NameList // Only to be used in conjunction with AsSource
IfNotExists bool
Table TableName
Interleave *InterleaveDef
PartitionBy *PartitionBy
Defs TableDefs
AsSource *Select
}

// As returns true if this table represents a CREATE TABLE ... AS statement,
Expand All @@ -900,6 +905,21 @@ func (node *CreateTable) As() bool {
return node.AsSource != nil
}

// AsHasUserSpecifiedPrimaryKey returns true if a CREATE TABLE ... AS statement
// has a PRIMARY KEY constraint specified.
func (node *CreateTable) AsHasUserSpecifiedPrimaryKey() bool {
if node.As() {
for _, def := range node.Defs {
if d, ok := def.(*ColumnTableDef); !ok {
return false
} else if d.PrimaryKey {
return true
}
}
}
return false
}

// Format implements the NodeFormatter interface.
func (node *CreateTable) Format(ctx *FmtCtx) {
ctx.WriteString("CREATE TABLE ")
Expand All @@ -914,9 +934,9 @@ func (node *CreateTable) Format(ctx *FmtCtx) {
// but the CREATE TABLE tableName part.
func (node *CreateTable) FormatBody(ctx *FmtCtx) {
if node.As() {
if len(node.AsColumnNames) > 0 {
if len(node.Defs) > 0 {
ctx.WriteString(" (")
ctx.FormatNode(&node.AsColumnNames)
ctx.FormatNode(&node.Defs)
ctx.WriteByte(')')
}
ctx.WriteString(" AS ")
Expand Down
32 changes: 26 additions & 6 deletions pkg/sql/sem/tree/pretty.go
Original file line number Diff line number Diff line change
Expand Up @@ -1110,9 +1110,9 @@ func (node *CreateTable) doc(p *PrettyCfg) pretty.Doc {
title = pretty.ConcatSpace(title, p.Doc(&node.Table))

if node.As() {
if len(node.AsColumnNames) > 0 {
if len(node.Defs) > 0 {
title = pretty.ConcatSpace(title,
p.bracket("(", p.Doc(&node.AsColumnNames), ")"))
p.bracket("(", p.Doc(&node.Defs), ")"))
}
title = pretty.ConcatSpace(title, pretty.Keyword("AS"))
} else {
Expand Down Expand Up @@ -1609,7 +1609,11 @@ func (node *ColumnTableDef) docRow(p *PrettyCfg) pretty.TableRow {
clauses := make([]pretty.Doc, 0, 7)

// Column type.
clauses = append(clauses, pretty.Text(node.columnTypeString()))
// ColumnTableDef node type will not be specified if it represents a CREATE
// TABLE ... AS query.
if node.Type != nil {
clauses = append(clauses, pretty.Text(node.columnTypeString()))
}

// Compute expression (for computed columns).
if node.IsComputed() {
Expand Down Expand Up @@ -1690,10 +1694,26 @@ func (node *ColumnTableDef) docRow(p *PrettyCfg) pretty.TableRow {
clauses = append(clauses, p.maybePrependConstraintName(&node.References.ConstraintName, fk))
}

return pretty.TableRow{
Label: node.Name.String(),
Doc: pretty.Group(pretty.Stack(clauses...)),
// Prevents an additional space from being appended at the end of every column
// name in the case of CREATE TABLE ... AS query. The additional space is
// being caused due to the absence of column type qualifiers in CTAS queries.
//
// TODO(adityamaru): Consult someone with more knowledge about the pretty
// printer architecture to find a cleaner solution.
var tblRow pretty.TableRow
if node.Type == nil {
tblRow = pretty.TableRow{
Label: node.Name.String(),
Doc: pretty.Stack(clauses...),
}
} else {
tblRow = pretty.TableRow{
Label: node.Name.String(),
Doc: pretty.Group(pretty.Stack(clauses...)),
}
}

return tblRow
}

func (node *CheckConstraintTableDef) doc(p *PrettyCfg) pretty.Doc {
Expand Down

0 comments on commit 83e361d

Please sign in to comment.