diff --git a/docs/generated/sql/bnf/stmt_block.bnf b/docs/generated/sql/bnf/stmt_block.bnf index 0a2a202fc34d..a1f23f1e5beb 100644 --- a/docs/generated/sql/bnf/stmt_block.bnf +++ b/docs/generated/sql/bnf/stmt_block.bnf @@ -193,7 +193,7 @@ truncate_stmt ::= 'TRUNCATE' opt_table relation_expr_list opt_drop_behavior update_stmt ::= - opt_with_clause 'UPDATE' table_name_expr_opt_alias_idx 'SET' set_clause_list opt_where_clause opt_sort_clause opt_limit_clause returning_clause + opt_with_clause 'UPDATE' table_name_expr_opt_alias_idx 'SET' set_clause_list opt_from_list opt_where_clause opt_sort_clause opt_limit_clause returning_clause upsert_stmt ::= opt_with_clause 'UPSERT' 'INTO' insert_target insert_rest returning_clause @@ -555,6 +555,10 @@ opt_drop_behavior ::= set_clause_list ::= ( set_clause ) ( ( ',' set_clause ) )* +opt_from_list ::= + 'FROM' from_list + | + db_object_name ::= simple_db_object_name | complex_db_object_name @@ -1185,6 +1189,9 @@ set_clause ::= single_set_clause | multiple_set_clause +from_list ::= + ( table_ref ) ( ( ',' table_ref ) )* + simple_db_object_name ::= db_object_name_component @@ -1545,6 +1552,16 @@ single_set_clause ::= multiple_set_clause ::= '(' insert_column_list ')' '=' in_expr +table_ref ::= + relation_expr opt_index_flags opt_ordinality opt_alias_clause + | select_with_parens opt_ordinality opt_alias_clause + | 'LATERAL' select_with_parens opt_ordinality opt_alias_clause + | joined_table + | '(' joined_table ')' opt_ordinality alias_clause + | func_table opt_ordinality opt_alias_clause + | 'LATERAL' func_table opt_ordinality opt_alias_clause + | '[' row_source_extension_stmt ']' opt_ordinality opt_alias_clause + cockroachdb_extra_type_func_name_keyword ::= 'FAMILY' @@ -1847,16 +1864,6 @@ distinct_clause ::= distinct_on_clause ::= 'DISTINCT' 'ON' '(' expr_list ')' -table_ref ::= - relation_expr opt_index_flags opt_ordinality opt_alias_clause - | select_with_parens opt_ordinality opt_alias_clause - | 'LATERAL' select_with_parens opt_ordinality opt_alias_clause - | joined_table - | '(' joined_table ')' opt_ordinality alias_clause - | func_table opt_ordinality opt_alias_clause - | 'LATERAL' func_table opt_ordinality opt_alias_clause - | '[' row_source_extension_stmt ']' opt_ordinality opt_alias_clause - all_or_distinct ::= 'ALL' | 'DISTINCT' @@ -1865,6 +1872,39 @@ all_or_distinct ::= var_list ::= ( var_value ) ( ( ',' var_value ) )* +opt_ordinality ::= + 'WITH' 'ORDINALITY' + | + +opt_alias_clause ::= + alias_clause + | + +joined_table ::= + '(' joined_table ')' + | table_ref 'CROSS' opt_join_hint 'JOIN' table_ref + | table_ref join_type opt_join_hint 'JOIN' table_ref join_qual + | table_ref 'JOIN' table_ref join_qual + | table_ref 'NATURAL' join_type opt_join_hint 'JOIN' table_ref + | table_ref 'NATURAL' 'JOIN' table_ref + +alias_clause ::= + 'AS' table_alias_name opt_column_list + | table_alias_name opt_column_list + +func_table ::= + func_expr_windowless + | 'ROWS' 'FROM' '(' rowsfrom_list ')' + +row_source_extension_stmt ::= + delete_stmt + | explain_stmt + | insert_stmt + | select_stmt + | show_stmt + | update_stmt + | upsert_stmt + user_priority ::= 'LOW' | 'NORMAL' @@ -2060,44 +2100,31 @@ character_base ::= | 'VARCHAR' | 'STRING' -from_list ::= - ( table_ref ) ( ( ',' table_ref ) )* - window_definition_list ::= ( window_definition ) ( ( ',' window_definition ) )* -opt_ordinality ::= - 'WITH' 'ORDINALITY' - | - -opt_alias_clause ::= - alias_clause +opt_join_hint ::= + 'HASH' + | 'MERGE' + | 'LOOKUP' | -joined_table ::= - '(' joined_table ')' - | table_ref 'CROSS' opt_join_hint 'JOIN' table_ref - | table_ref join_type opt_join_hint 'JOIN' table_ref join_qual - | table_ref 'JOIN' table_ref join_qual - | table_ref 'NATURAL' join_type opt_join_hint 'JOIN' table_ref - | table_ref 'NATURAL' 'JOIN' table_ref +join_type ::= + 'FULL' join_outer + | 'LEFT' join_outer + | 'RIGHT' join_outer + | 'INNER' -alias_clause ::= - 'AS' table_alias_name opt_column_list - | table_alias_name opt_column_list +join_qual ::= + 'USING' '(' name_list ')' + | 'ON' a_expr -func_table ::= - func_expr_windowless - | 'ROWS' 'FROM' '(' rowsfrom_list ')' +func_expr_windowless ::= + func_application + | func_expr_common_subexpr -row_source_extension_stmt ::= - delete_stmt - | explain_stmt - | insert_stmt - | select_stmt - | show_stmt - | update_stmt - | upsert_stmt +rowsfrom_list ::= + ( rowsfrom_item ) ( ( ',' rowsfrom_item ) )* opt_column ::= 'COLUMN' @@ -2229,28 +2256,12 @@ char_aliases ::= window_definition ::= window_name 'AS' window_specification -opt_join_hint ::= - 'HASH' - | 'MERGE' - | 'LOOKUP' +join_outer ::= + 'OUTER' | -join_type ::= - 'FULL' join_outer - | 'LEFT' join_outer - | 'RIGHT' join_outer - | 'INNER' - -join_qual ::= - 'USING' '(' name_list ')' - | 'ON' a_expr - -func_expr_windowless ::= - func_application - | func_expr_common_subexpr - -rowsfrom_list ::= - ( rowsfrom_item ) ( ( ',' rowsfrom_item ) )* +rowsfrom_item ::= + func_expr_windowless create_as_col_qualification_elem ::= 'PRIMARY' 'KEY' @@ -2316,13 +2327,6 @@ trim_list ::= | 'FROM' expr_list | expr_list -join_outer ::= - 'OUTER' - | - -rowsfrom_item ::= - func_expr_windowless - create_as_param ::= column_name diff --git a/docs/generated/sql/bnf/update_stmt.bnf b/docs/generated/sql/bnf/update_stmt.bnf index fd5bf8e88dbe..9a2b313b1a0e 100644 --- a/docs/generated/sql/bnf/update_stmt.bnf +++ b/docs/generated/sql/bnf/update_stmt.bnf @@ -1,2 +1,2 @@ update_stmt ::= - ( ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) | ) 'UPDATE' ( ( table_name opt_index_flags ) | ( table_name opt_index_flags ) table_alias_name | ( table_name opt_index_flags ) 'AS' table_alias_name ) 'SET' ( ( ( ( column_name '=' a_expr ) | ( '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' '=' ( '(' select_stmt ')' | ( '(' ')' | '(' ( a_expr | a_expr ',' | a_expr ',' ( ( a_expr ) ( ( ',' a_expr ) )* ) ) ')' ) ) ) ) ) ( ( ',' ( ( column_name '=' a_expr ) | ( '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' '=' ( '(' select_stmt ')' | ( '(' ')' | '(' ( a_expr | a_expr ',' | a_expr ',' ( ( a_expr ) ( ( ',' a_expr ) )* ) ) ')' ) ) ) ) ) )* ) ( ( 'WHERE' a_expr ) | ) ( sort_clause | ) ( limit_clause | ) ( 'RETURNING' target_list | 'RETURNING' 'NOTHING' | ) + ( ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) | ) 'UPDATE' ( ( table_name opt_index_flags ) | ( table_name opt_index_flags ) table_alias_name | ( table_name opt_index_flags ) 'AS' table_alias_name ) 'SET' ( ( ( ( column_name '=' a_expr ) | ( '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' '=' ( '(' select_stmt ')' | ( '(' ')' | '(' ( a_expr | a_expr ',' | a_expr ',' ( ( a_expr ) ( ( ',' a_expr ) )* ) ) ')' ) ) ) ) ) ( ( ',' ( ( column_name '=' a_expr ) | ( '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' '=' ( '(' select_stmt ')' | ( '(' ')' | '(' ( a_expr | a_expr ',' | a_expr ',' ( ( a_expr ) ( ( ',' a_expr ) )* ) ) ')' ) ) ) ) ) )* ) opt_from_list ( ( 'WHERE' a_expr ) | ) ( sort_clause | ) ( limit_clause | ) ( 'RETURNING' target_list | 'RETURNING' 'NOTHING' | ) diff --git a/pkg/sql/logictest/testdata/logic_test/update_from b/pkg/sql/logictest/testdata/logic_test/update_from new file mode 100644 index 000000000000..81edefd6ef09 --- /dev/null +++ b/pkg/sql/logictest/testdata/logic_test/update_from @@ -0,0 +1,61 @@ +# LogicTest: local-opt + +statement ok +CREATE TABLE abc (a int primary key, b int, c int) + +statement ok +INSERT INTO abc VALUES (1, 2, 3), (2, 3, 4) + +# Updating using self join. +statement ok +UPDATE abc SET b = other.b + 1, c = other.c + 1 FROM abc AS other WHERE abc.a = other.a + +query III +SELECT * FROM abc +---- +1 3 4 +2 4 5 + +# Update from another table. +statement ok +CREATE TABLE new_abc (a int, b int, c int) + +statement ok +INSERT INTO new_abc VALUES (1, 2, 3), (2, 3, 4) + +statement ok +UPDATE abc SET b = other.b, c = other.c FROM new_abc AS other WHERE abc.a = other.a + +query III +SELECT * FROM abc +---- +1 2 3 +2 3 4 + +# Multiple matching values for a given row. +statement ok +INSERT INTO new_abc VALUES (1, 1, 1) + +statement ok +UPDATE abc SET b = other.b, c = other.c FROM new_abc AS other WHERE abc.a = other.a + +query III +SELECT * FROM abc +---- +1 2 3 +2 3 4 + +# Returning old values. +query IIIII colnames +UPDATE abc SET b = old.b + 1, c = old.c + 2 FROM abc AS old WHERE abc.a = old.a RETURNING abc.a, abc.b AS new_b, old.b as old_b, abc.c as new_c, old.c as old_c +---- +a new_b old_b new_c old_c +1 3 2 5 3 +2 4 3 6 4 + +# Check if RETURNING * returns everything +query IIIIII +UPDATE abc SET b = old.b + 1, c = old.c + 2 FROM abc AS old WHERE abc.a = old.a RETURNING * +---- +1 4 7 1 3 5 +2 5 8 2 4 6 diff --git a/pkg/sql/opt/bench/stub_factory.go b/pkg/sql/opt/bench/stub_factory.go index d62d7a405896..f11a78a30dfe 100644 --- a/pkg/sql/opt/bench/stub_factory.go +++ b/pkg/sql/opt/bench/stub_factory.go @@ -236,6 +236,7 @@ func (f *stubFactory) ConstructUpdate( updateCols exec.ColumnOrdinalSet, returnCols exec.ColumnOrdinalSet, checks exec.CheckOrdinalSet, + passthrough sqlbase.ResultColumns, ) (exec.Node, error) { return struct{}{}, nil } diff --git a/pkg/sql/opt/exec/execbuilder/mutation.go b/pkg/sql/opt/exec/execbuilder/mutation.go index 79c8b14a5ee1..6658df7d279b 100644 --- a/pkg/sql/opt/exec/execbuilder/mutation.go +++ b/pkg/sql/opt/exec/execbuilder/mutation.go @@ -22,6 +22,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" "github.com/cockroachdb/cockroach/pkg/util" "github.com/cockroachdb/errors" ) @@ -108,8 +109,16 @@ func (b *Builder) buildUpdate(upd *memo.UpdateExpr) (execPlan, error) { colList := make(opt.ColList, 0, len(upd.FetchCols)+len(upd.UpdateCols)+len(upd.CheckCols)) colList = appendColsWhenPresent(colList, upd.FetchCols) colList = appendColsWhenPresent(colList, upd.UpdateCols) - colList = appendColsWhenPresent(colList, upd.CheckCols) + // The RETURNING clause of the Update can refer to the columns + // in any of the FROM tables. As a result, the Update may need + // to passthrough those columns so the projection above can use + // them. + if upd.NeedResults() { + colList = appendColsWhenPresent(colList, upd.PassthroughCols) + } + + colList = appendColsWhenPresent(colList, upd.CheckCols) input, err := b.buildMutationInput(upd.Input, colList, &upd.MutationPrivate) if err != nil { return execPlan{}, err @@ -122,6 +131,16 @@ func (b *Builder) buildUpdate(upd *memo.UpdateExpr) (execPlan, error) { updateColOrds := ordinalSetFromColList(upd.UpdateCols) returnColOrds := ordinalSetFromColList(upd.ReturnCols) checkOrds := ordinalSetFromColList(upd.CheckCols) + + // Construct the result columns for the passthrough set. + var passthroughCols sqlbase.ResultColumns + if upd.NeedResults() { + for _, passthroughCol := range upd.PassthroughCols { + colMeta := b.mem.Metadata().ColumnMeta(passthroughCol) + passthroughCols = append(passthroughCols, sqlbase.ResultColumn{Name: colMeta.Alias, Typ: colMeta.Type}) + } + } + node, err := b.factory.ConstructUpdate( input.root, tab, @@ -129,6 +148,7 @@ func (b *Builder) buildUpdate(upd *memo.UpdateExpr) (execPlan, error) { updateColOrds, returnColOrds, checkOrds, + passthroughCols, ) if err != nil { return execPlan{}, err @@ -354,6 +374,16 @@ func mutationOutputColMap(mutation memo.RelExpr) opt.ColMap { ord++ } } + + // The output columns of the mutation will also include all + // columns it allowed to pass through. + for _, colID := range private.PassthroughCols { + if colID != 0 { + colMap.Set(int(colID), ord) + ord++ + } + } + return colMap } diff --git a/pkg/sql/opt/exec/factory.go b/pkg/sql/opt/exec/factory.go index ac14e6786bb9..02a4a4ac4df6 100644 --- a/pkg/sql/opt/exec/factory.go +++ b/pkg/sql/opt/exec/factory.go @@ -328,6 +328,7 @@ type Factory interface { updateCols ColumnOrdinalSet, returnCols ColumnOrdinalSet, checks CheckOrdinalSet, + passthrough sqlbase.ResultColumns, ) (Node, error) // ConstructUpsert creates a node that implements an INSERT..ON CONFLICT or diff --git a/pkg/sql/opt/memo/logical_props_builder.go b/pkg/sql/opt/memo/logical_props_builder.go index dbcaf0002dba..725c2c481cec 100644 --- a/pkg/sql/opt/memo/logical_props_builder.go +++ b/pkg/sql/opt/memo/logical_props_builder.go @@ -1195,6 +1195,14 @@ func (b *logicalPropsBuilder) buildMutationProps(mutation RelExpr, rel *props.Re } } + // The output columns of the mutation will also include all + // columns it allowed to pass through. + for _, col := range private.PassthroughCols { + if col != 0 { + rel.OutputCols.Add(col) + } + } + // Not Null Columns // ---------------- // A column should be marked as not-null if the target table column is not diff --git a/pkg/sql/opt/norm/prune_cols.go b/pkg/sql/opt/norm/prune_cols.go index 0b6c61c9e173..0c41e9f96533 100644 --- a/pkg/sql/opt/norm/prune_cols.go +++ b/pkg/sql/opt/norm/prune_cols.go @@ -59,6 +59,7 @@ func (c *CustomFuncs) NeededMutationCols(private *memo.MutationPrivate) opt.ColS addCols(private.UpdateCols) addCols(private.CheckCols) addCols(private.ReturnCols) + addCols(private.PassthroughCols) if private.CanaryCol != 0 { cols.Add(private.CanaryCol) } diff --git a/pkg/sql/opt/ops/mutation.opt b/pkg/sql/opt/ops/mutation.opt index 47082e072457..3ed7a839a09e 100644 --- a/pkg/sql/opt/ops/mutation.opt +++ b/pkg/sql/opt/ops/mutation.opt @@ -121,6 +121,13 @@ define MutationPrivate { # then ReturnCols is nil. ReturnCols ColList + # PassthroughCols are columns that the mutation needs to passthrough from + # its input. Its similar to the passthrough columns in projections. This + # is useful for `UPDATE .. FROM` mutations where the `RETURNING` clause + # references columns from tables in the `FROM` clause. When this happens + # the update will need to pass through those refenced columns from its input. + PassthroughCols ColList + # Mutation operators can act similarly to a With operator: they buffer their # input, making it accessible to FK queries. If this is not required, WithID # is zero. diff --git a/pkg/sql/opt/optbuilder/delete.go b/pkg/sql/opt/optbuilder/delete.go index eb6de57f9546..63cf40af0600 100644 --- a/pkg/sql/opt/optbuilder/delete.go +++ b/pkg/sql/opt/optbuilder/delete.go @@ -68,7 +68,7 @@ func (b *Builder) buildDelete(del *tree.Delete, inScope *scope) (outScope *scope // ORDER BY LIMIT // // All columns from the delete table will be projected. - mb.buildInputForUpdateOrDelete(inScope, del.Where, del.Limit, del.OrderBy) + mb.buildInputForDelete(inScope, del.Where, del.Limit, del.OrderBy) // Build the final delete statement, including any returned expressions. if resultsNeeded(del.Returning) { @@ -90,5 +90,5 @@ func (mb *mutationBuilder) buildDelete(returning tree.ReturningExprs) { private := mb.makeMutationPrivate(returning != nil) mb.outScope.expr = mb.b.factory.ConstructDelete(mb.outScope.expr, mb.checks, private) - mb.buildReturning(returning) + mb.buildReturning(returning, nil) } diff --git a/pkg/sql/opt/optbuilder/insert.go b/pkg/sql/opt/optbuilder/insert.go index 0f6070ca89a6..b1efdce91331 100644 --- a/pkg/sql/opt/optbuilder/insert.go +++ b/pkg/sql/opt/optbuilder/insert.go @@ -620,7 +620,7 @@ func (mb *mutationBuilder) buildInsert(returning tree.ReturningExprs) { private := mb.makeMutationPrivate(returning != nil) mb.outScope.expr = mb.b.factory.ConstructInsert(mb.outScope.expr, mb.checks, private) - mb.buildReturning(returning) + mb.buildReturning(returning, nil) } // buildInputForDoNothing wraps the input expression in LEFT OUTER JOIN @@ -871,7 +871,7 @@ func (mb *mutationBuilder) buildUpsert(returning tree.ReturningExprs) { private := mb.makeMutationPrivate(returning != nil) mb.outScope.expr = mb.b.factory.ConstructUpsert(mb.outScope.expr, mb.checks, private) - mb.buildReturning(returning) + mb.buildReturning(returning, nil) } // projectUpsertColumns projects a set of merged columns that will be either diff --git a/pkg/sql/opt/optbuilder/mutation_builder.go b/pkg/sql/opt/optbuilder/mutation_builder.go index ce545a8b439a..3ee9bd46cd28 100644 --- a/pkg/sql/opt/optbuilder/mutation_builder.go +++ b/pkg/sql/opt/optbuilder/mutation_builder.go @@ -176,8 +176,8 @@ func (mb *mutationBuilder) fetchColID(tabOrd int) opt.ColumnID { return mb.scopeOrdToColID(mb.fetchOrds[tabOrd]) } -// buildInputForUpdateOrDelete constructs a Select expression from the fields in -// the Update or Delete operator, similar to this: +// buildInputForUpdate constructs a Select expression from the fields in +// the Update operator, similar to this: // // SELECT // FROM @@ -186,8 +186,106 @@ func (mb *mutationBuilder) fetchColID(tabOrd int) opt.ColumnID { // LIMIT // // All columns from the table to update are added to fetchColList. +// If a FROM clause is defined, the columns in the FROM +// clause are returned. // TODO(andyk): Do needed column analysis to project fewer columns if possible. -func (mb *mutationBuilder) buildInputForUpdateOrDelete( +func (mb *mutationBuilder) buildInputForUpdate( + inScope *scope, from tree.TableExprs, where *tree.Where, limit *tree.Limit, orderBy tree.OrderBy, +) (cols []scopeColumn) { + // Fetch columns from different instance of the table metadata, so that it's + // possible to remap columns, as in this example: + // + // UPDATE abc SET a=b + // + + // FROM + mb.outScope = mb.b.buildScan( + mb.b.addTable(mb.tab, &mb.alias), + nil, /* ordinals */ + nil, /* indexFlags */ + includeMutations, + inScope, + ) + + fromClausePresent := len(from) > 0 + numCols := len(mb.outScope.cols) + + // If there is a from clause present, we must join all the tables together with the + // table being updated. + if fromClausePresent { + fromScope := mb.b.buildFromTables(from, inScope) + + // Check that the same table name is not used multiple times. + mb.b.validateJoinTableNames(mb.outScope, fromScope) + + cols = fromScope.cols + mb.outScope.appendColumnsFromScope(fromScope) + + left := mb.outScope.expr.(memo.RelExpr) + right := fromScope.expr.(memo.RelExpr) + mb.outScope.expr = mb.b.factory.ConstructInnerJoin(left, right, memo.TrueFilter, memo.EmptyJoinPrivate) + } + + // WHERE + mb.b.buildWhere(where, mb.outScope) + + // SELECT + ORDER BY (which may add projected expressions) + projectionsScope := mb.outScope.replace() + projectionsScope.appendColumnsFromScope(mb.outScope) + orderByScope := mb.b.analyzeOrderBy(orderBy, mb.outScope, projectionsScope) + mb.b.buildOrderBy(mb.outScope, projectionsScope, orderByScope) + mb.b.constructProjectForScope(mb.outScope, projectionsScope) + + // LIMIT + if limit != nil { + mb.b.buildLimit(limit, inScope, projectionsScope) + } + + mb.outScope = projectionsScope + + // Build a distinct on to ensure there is at most one row in the joined output + // for every row in the table. + if fromClausePresent { + var pkCols opt.ColSet + + // We need to ensure that the join has a maximum of one row for every row in the + // table and we ensure this by constructing a distinct on the primary key columns. + primaryIndex := mb.tab.Index(cat.PrimaryIndex) + for i := 0; i < primaryIndex.KeyColumnCount(); i++ { + pkCol := mb.outScope.cols[primaryIndex.Column(i).Ordinal] + + // If the primary key column is hidden, then we don't need to use it + // for the distinct on. + if !pkCol.hidden { + pkCols.Add(pkCol.id) + } + } + + if !pkCols.Empty() { + mb.outScope = mb.b.buildDistinctOn(pkCols, mb.outScope) + } + } + + // Set list of columns that will be fetched by the input expression. + for i := 0; i < numCols; i++ { + mb.fetchOrds[i] = scopeOrdinal(i) + } + + return cols +} + +// buildInputForDelete constructs a Select expression from the fields in +// the Delete operator, similar to this: +// +// SELECT +// FROM
+// WHERE +// ORDER BY +// LIMIT +// +// All columns from the table to update are added to fetchColList. +// TODO(andyk): Do needed column analysis to project fewer columns if possible. +func (mb *mutationBuilder) buildInputForDelete( inScope *scope, where *tree.Where, limit *tree.Limit, orderBy tree.OrderBy, ) { // Fetch columns from different instance of the table metadata, so that it's @@ -652,7 +750,7 @@ func (mb *mutationBuilder) mapToReturnScopeOrd(tabOrd int) scopeOrdinal { // buildReturning wraps the input expression with a Project operator that // projects the given RETURNING expressions. -func (mb *mutationBuilder) buildReturning(returning tree.ReturningExprs) { +func (mb *mutationBuilder) buildReturning(returning tree.ReturningExprs, fromCols []scopeColumn) { // Handle case of no RETURNING clause. if returning == nil { mb.outScope = &scope{builder: mb.b, expr: mb.outScope.expr} @@ -682,6 +780,10 @@ func (mb *mutationBuilder) buildReturning(returning tree.ReturningExprs) { }) } + // The returning columns can reference the columns defined in the FROM clause of + // an Update. + inScope.appendColumns(fromCols) + // Construct the Project operator that projects the RETURNING expressions. outScope := inScope.replace() mb.b.analyzeReturningList(returning, nil /* desiredTypes */, inScope, outScope) diff --git a/pkg/sql/opt/optbuilder/testdata/update_from b/pkg/sql/opt/optbuilder/testdata/update_from new file mode 100644 index 000000000000..2d25d7e04a8a --- /dev/null +++ b/pkg/sql/opt/optbuilder/testdata/update_from @@ -0,0 +1,168 @@ +exec-ddl +CREATE TABLE abc (a int primary key, b int, c int) +---- + +exec-ddl +CREATE TABLE new_abc (a int, b int, c int) +---- + +# Test a self join. +opt +UPDATE abc SET b = other.b + 1, c = other.c + 1 FROM abc AS other WHERE abc.a = other.a +---- +update abc + ├── columns: + ├── fetch columns: abc.a:4(int) abc.b:5(int) abc.c:6(int) + ├── update-mapping: + │ ├── column10:10 => abc.b:2 + │ └── column11:11 => abc.c:3 + └── project + ├── columns: column10:10(int) column11:11(int) abc.a:4(int!null) abc.b:5(int) abc.c:6(int) other.a:7(int!null) other.b:8(int) other.c:9(int) + ├── inner-join (merge) + │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) other.a:7(int!null) other.b:8(int) other.c:9(int) + │ ├── left ordering: +4 + │ ├── right ordering: +7 + │ ├── scan abc + │ │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) + │ │ └── ordering: +4 + │ ├── scan other + │ │ ├── columns: other.a:7(int!null) other.b:8(int) other.c:9(int) + │ │ └── ordering: +7 + │ └── filters (true) + └── projections + ├── plus [type=int] + │ ├── variable: other.b [type=int] + │ └── const: 1 [type=int] + └── plus [type=int] + ├── variable: other.c [type=int] + └── const: 1 [type=int] + +# Test when Update uses multiple tables. +opt +UPDATE abc SET b = other.b, c = other.c FROM new_abc AS other WHERE abc.a = other.a +---- +update abc + ├── columns: + ├── fetch columns: abc.a:4(int) abc.b:5(int) abc.c:6(int) + ├── update-mapping: + │ ├── other.b:8 => abc.b:2 + │ └── other.c:9 => abc.c:3 + └── distinct-on + ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) other.a:7(int) other.b:8(int) other.c:9(int) + ├── grouping columns: abc.a:4(int!null) + ├── inner-join (hash) + │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) other.a:7(int!null) other.b:8(int) other.c:9(int) + │ ├── scan abc + │ │ └── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) + │ ├── scan other + │ │ └── columns: other.a:7(int) other.b:8(int) other.c:9(int) + │ └── filters + │ └── eq [type=bool] + │ ├── variable: abc.a [type=int] + │ └── variable: other.a [type=int] + └── aggregations + ├── first-agg [type=int] + │ └── variable: abc.b [type=int] + ├── first-agg [type=int] + │ └── variable: abc.c [type=int] + ├── first-agg [type=int] + │ └── variable: other.a [type=int] + ├── first-agg [type=int] + │ └── variable: other.b [type=int] + └── first-agg [type=int] + └── variable: other.c [type=int] + +# Check if UPDATE FROM works well with RETURNING expressions that reference the FROM tables. +opt +UPDATE abc SET b = old.b + 1, c = old.c + 2 FROM abc AS old WHERE abc.a = old.a RETURNING abc.a, abc.b AS new_b, old.b as old_b, abc.c as new_c, old.c as old_c +---- +project + ├── columns: a:1(int!null) new_b:2(int) old_b:8(int) new_c:3(int) old_c:9(int) + └── update abc + ├── columns: abc.a:1(int!null) abc.b:2(int) abc.c:3(int) old.a:7(int) old.b:8(int) old.c:9(int) + ├── fetch columns: abc.a:4(int) abc.b:5(int) abc.c:6(int) + ├── update-mapping: + │ ├── column10:10 => abc.b:2 + │ └── column11:11 => abc.c:3 + └── project + ├── columns: column10:10(int) column11:11(int) abc.a:4(int!null) abc.b:5(int) abc.c:6(int) old.a:7(int!null) old.b:8(int) old.c:9(int) + ├── inner-join (merge) + │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) old.a:7(int!null) old.b:8(int) old.c:9(int) + │ ├── left ordering: +4 + │ ├── right ordering: +7 + │ ├── scan abc + │ │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) + │ │ └── ordering: +4 + │ ├── scan old + │ │ ├── columns: old.a:7(int!null) old.b:8(int) old.c:9(int) + │ │ └── ordering: +7 + │ └── filters (true) + └── projections + ├── plus [type=int] + │ ├── variable: old.b [type=int] + │ └── const: 1 [type=int] + └── plus [type=int] + ├── variable: old.c [type=int] + └── const: 2 [type=int] + +# Check if RETURNING * returns everything +opt +UPDATE abc SET b = old.b + 1, c = old.c + 2 FROM abc AS old WHERE abc.a = old.a RETURNING * +---- +update abc + ├── columns: a:1(int!null) b:2(int) c:3(int) a:7(int) b:8(int) c:9(int) + ├── fetch columns: abc.a:4(int) abc.b:5(int) abc.c:6(int) + ├── update-mapping: + │ ├── column10:10 => abc.b:2 + │ └── column11:11 => abc.c:3 + └── project + ├── columns: column10:10(int) column11:11(int) abc.a:4(int!null) abc.b:5(int) abc.c:6(int) old.a:7(int!null) old.b:8(int) old.c:9(int) + ├── inner-join (merge) + │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) old.a:7(int!null) old.b:8(int) old.c:9(int) + │ ├── left ordering: +4 + │ ├── right ordering: +7 + │ ├── scan abc + │ │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) + │ │ └── ordering: +4 + │ ├── scan old + │ │ ├── columns: old.a:7(int!null) old.b:8(int) old.c:9(int) + │ │ └── ordering: +7 + │ └── filters (true) + └── projections + ├── plus [type=int] + │ ├── variable: old.b [type=int] + │ └── const: 1 [type=int] + └── plus [type=int] + ├── variable: old.c [type=int] + └── const: 2 [type=int] + +# Check if the joins are optimized (check if the filters are pushed down). +opt +UPDATE abc SET b = old.b + 1, c = old.c + 2 FROM abc AS old WHERE abc.a = old.a AND abc.a = 2 +---- +update abc + ├── columns: + ├── fetch columns: abc.a:4(int) abc.b:5(int) abc.c:6(int) + ├── update-mapping: + │ ├── column10:10 => abc.b:2 + │ └── column11:11 => abc.c:3 + └── project + ├── columns: column10:10(int) column11:11(int) abc.a:4(int!null) abc.b:5(int) abc.c:6(int) old.a:7(int!null) old.b:8(int) old.c:9(int) + ├── inner-join (merge) + │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) old.a:7(int!null) old.b:8(int) old.c:9(int) + │ ├── left ordering: +4 + │ ├── right ordering: +7 + │ ├── scan abc + │ │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) + │ │ └── constraint: /4: [/2 - /2] + │ ├── scan old + │ │ ├── columns: old.a:7(int!null) old.b:8(int) old.c:9(int) + │ │ └── constraint: /7: [/2 - /2] + │ └── filters (true) + └── projections + ├── plus [type=int] + │ ├── variable: old.b [type=int] + │ └── const: 1 [type=int] + └── plus [type=int] + ├── variable: old.c [type=int] + └── const: 2 [type=int] diff --git a/pkg/sql/opt/optbuilder/update.go b/pkg/sql/opt/optbuilder/update.go index 94c12a362a68..56c5e5db7ac9 100644 --- a/pkg/sql/opt/optbuilder/update.go +++ b/pkg/sql/opt/optbuilder/update.go @@ -103,7 +103,7 @@ func (b *Builder) buildUpdate(upd *tree.Update, inScope *scope) (outScope *scope // ORDER BY LIMIT // // All columns from the update table will be projected. - mb.buildInputForUpdateOrDelete(inScope, upd.Where, upd.Limit, upd.OrderBy) + fromCols := mb.buildInputForUpdate(inScope, upd.From, upd.Where, upd.Limit, upd.OrderBy) // Derive the columns that will be updated from the SET expressions. mb.addTargetColsForUpdate(upd.Exprs) @@ -117,9 +117,9 @@ func (b *Builder) buildUpdate(upd *tree.Update, inScope *scope) (outScope *scope // Build the final update statement, including any returned expressions. if resultsNeeded(upd.Returning) { - mb.buildUpdate(*upd.Returning.(*tree.ReturningExprs)) + mb.buildUpdate(*upd.Returning.(*tree.ReturningExprs), fromCols) } else { - mb.buildUpdate(nil /* returning */) + mb.buildUpdate(nil /* returning */, fromCols) } mb.outScope.expr = b.wrapWithCTEs(mb.outScope.expr, ctes) @@ -321,11 +321,15 @@ func (mb *mutationBuilder) addComputedColsForUpdate() { // buildUpdate constructs an Update operator, possibly wrapped by a Project // operator that corresponds to the given RETURNING clause. -func (mb *mutationBuilder) buildUpdate(returning tree.ReturningExprs) { +func (mb *mutationBuilder) buildUpdate(returning tree.ReturningExprs, fromCols []scopeColumn) { mb.addCheckConstraintCols() private := mb.makeMutationPrivate(returning != nil) + for _, col := range fromCols { + if col.id != 0 { + private.PassthroughCols = append(private.PassthroughCols, col.id) + } + } mb.outScope.expr = mb.b.factory.ConstructUpdate(mb.outScope.expr, mb.checks, private) - - mb.buildReturning(returning) + mb.buildReturning(returning, fromCols) } diff --git a/pkg/sql/opt_exec_factory.go b/pkg/sql/opt_exec_factory.go index f0ea9af649c9..3c3f2d7ab3cc 100644 --- a/pkg/sql/opt_exec_factory.go +++ b/pkg/sql/opt_exec_factory.go @@ -1328,6 +1328,7 @@ func (ef *execFactory) ConstructUpdate( updateColOrdSet exec.ColumnOrdinalSet, returnColOrdSet exec.ColumnOrdinalSet, checks exec.CheckOrdinalSet, + passthrough sqlbase.ResultColumns, ) (exec.Node, error) { // Derive table and column descriptors. rowsNeeded := !returnColOrdSet.Empty() @@ -1400,6 +1401,9 @@ func (ef *execFactory) ConstructUpdate( returnCols = sqlbase.ResultColumnsFromColDescs(returnColDescs) + // Add the passthrough columns to the returning columns. + returnCols = append(returnCols, passthrough...) + // Update the rowIdxToRetIdx for the mutation. Update returns // the non-mutation columns specified, in the same order they are // defined in the table. @@ -1436,6 +1440,7 @@ func (ef *execFactory) ConstructUpdate( updateValues: make(tree.Datums, len(ru.UpdateCols)), updateColsIdx: updateColsIdx, rowIdxToRetIdx: rowIdxToRetIdx, + numPassthrough: len(passthrough), }, } diff --git a/pkg/sql/parser/parse_test.go b/pkg/sql/parser/parse_test.go index f2cad82b9a82..7a7699c84814 100644 --- a/pkg/sql/parser/parse_test.go +++ b/pkg/sql/parser/parse_test.go @@ -1033,6 +1033,9 @@ func TestParse(t *testing.T) { {`UPDATE a.b SET b = 3`}, {`UPDATE a.b@c SET b = 3`}, {`UPDATE a SET b = 3, c = DEFAULT`}, + {`UPDATE a SET b = 3, c = DEFAULT FROM b`}, + {`UPDATE a SET b = 3, c = DEFAULT FROM a AS other`}, + {`UPDATE a SET b = 3, c = DEFAULT FROM a AS other, b`}, {`UPDATE a SET b = 3 + 4`}, {`UPDATE a SET (b, c) = (3, DEFAULT)`}, {`UPDATE a SET (b, c) = (SELECT 3, 4)`}, @@ -1045,6 +1048,7 @@ func TestParse(t *testing.T) { {`UPDATE a SET b = 3 WHERE a = b RETURNING a, a + b`}, {`UPDATE a SET b = 3 WHERE a = b RETURNING NOTHING`}, {`UPDATE a SET b = 3 WHERE a = b ORDER BY c LIMIT d RETURNING e`}, + {`UPDATE a SET b = 3 FROM other WHERE a = b ORDER BY c LIMIT d RETURNING e`}, {`UPDATE t AS "0" SET k = ''`}, // "0" lost its quotes {`SELECT * FROM "0" JOIN "0" USING (id, "0")`}, // last "0" lost its quotes. @@ -3056,7 +3060,6 @@ func TestUnimplementedSyntax(t *testing.T) { {`UPDATE foo SET (a, a.b) = (1, 2)`, 27792, ``}, {`UPDATE foo SET a.b = 1`, 27792, ``}, - {`UPDATE foo SET x = y FROM a, b`, 7841, ``}, {`UPDATE Foo SET x.y = z`, 27792, ``}, {`UPSERT INTO foo(a, a.b) VALUES (1,2)`, 27792, ``}, diff --git a/pkg/sql/parser/sql.y b/pkg/sql/parser/sql.y index ce00e0d4d782..2a561d63e9f8 100644 --- a/pkg/sql/parser/sql.y +++ b/pkg/sql/parser/sql.y @@ -820,8 +820,8 @@ func newNameFromStr(s string) *tree.Name { %type index_params create_as_params %type name_list privilege_list %type <[]int32> opt_array_bounds -%type from_clause update_from_clause -%type from_list rowsfrom_list +%type from_clause +%type from_list rowsfrom_list opt_from_list %type table_pattern_list single_table_pattern_list %type table_name_list %type expr_list opt_expr_list tuple1_ambiguous_values tuple1_unambiguous_values @@ -5523,12 +5523,13 @@ returning_clause: // %SeeAlso: INSERT, UPSERT, DELETE, WEBDOCS/update.html update_stmt: opt_with_clause UPDATE table_name_expr_opt_alias_idx - SET set_clause_list update_from_clause opt_where_clause opt_sort_clause opt_limit_clause returning_clause + SET set_clause_list opt_from_list opt_where_clause opt_sort_clause opt_limit_clause returning_clause { $$.val = &tree.Update{ With: $1.with(), Table: $3.tblExpr(), Exprs: $5.updateExprs(), + From: $6.tblExprs(), Where: tree.NewWhere(tree.AstWhere, $7.expr()), OrderBy: $8.orderBy(), Limit: $9.limit(), @@ -5537,10 +5538,13 @@ update_stmt: } | opt_with_clause UPDATE error // SHOW HELP: UPDATE -// Mark this as unimplemented until the normal from_clause is supported here. -update_from_clause: - FROM from_list { return unimplementedWithIssue(sqllex, 7841) } -| /* EMPTY */ {} +opt_from_list: + FROM from_list { + $$.val = $2.tblExprs() + } +| /* EMPTY */ { + $$.val = tree.TableExprs{} +} set_clause_list: set_clause diff --git a/pkg/sql/sem/tree/pretty.go b/pkg/sql/sem/tree/pretty.go index 0a967c573cc9..487d10a70221 100644 --- a/pkg/sql/sem/tree/pretty.go +++ b/pkg/sql/sem/tree/pretty.go @@ -1023,7 +1023,12 @@ func (node *Update) doc(p *PrettyCfg) pretty.Doc { items = append(items, node.With.docRow(p), p.row("UPDATE", p.Doc(node.Table)), - p.row("SET", p.Doc(&node.Exprs)), + p.row("SET", p.Doc(&node.Exprs))) + if len(node.From) > 0 { + items = append(items, + p.row("FROM", p.Doc(&node.From))) + } + items = append(items, node.Where.docRow(p), node.OrderBy.docRow(p)) items = append(items, node.Limit.docTable(p)...) diff --git a/pkg/sql/sem/tree/update.go b/pkg/sql/sem/tree/update.go index 801b571ce9cb..ce812e227ada 100644 --- a/pkg/sql/sem/tree/update.go +++ b/pkg/sql/sem/tree/update.go @@ -24,6 +24,7 @@ type Update struct { With *With Table TableExpr Exprs UpdateExprs + From TableExprs Where *Where OrderBy OrderBy Limit *Limit @@ -37,6 +38,10 @@ func (node *Update) Format(ctx *FmtCtx) { ctx.FormatNode(node.Table) ctx.WriteString(" SET ") ctx.FormatNode(&node.Exprs) + if len(node.From) > 0 { + ctx.WriteString(" FROM ") + ctx.FormatNode(&node.From) + } if node.Where != nil { ctx.WriteByte(' ') ctx.FormatNode(node.Where) diff --git a/pkg/sql/update.go b/pkg/sql/update.go index a035b3fba27f..06458998b8b7 100644 --- a/pkg/sql/update.go +++ b/pkg/sql/update.go @@ -494,6 +494,11 @@ type updateRun struct { // of the mutation. Otherwise, the value at the i-th index refers to the // index of the resultRowBuffer where the i-th column is to be returned. rowIdxToRetIdx []int + + // numPassthrough is the number of columns in addition to the set of + // columns of the target table being returned, that we must pass through + // from the input node. + numPassthrough int } // maxUpdateBatchSize is the max number of entries in the KV batch for @@ -693,7 +698,7 @@ func (u *updateNode) processSourceRow(params runParams, sourceVals tree.Datums) return err } } else { - checkVals := sourceVals[len(u.run.tu.ru.FetchCols)+len(u.run.tu.ru.UpdateCols):] + checkVals := sourceVals[len(u.run.tu.ru.FetchCols)+len(u.run.tu.ru.UpdateCols)+u.run.numPassthrough:] if err := u.run.checkHelper.CheckInput(checkVals); err != nil { return err } @@ -716,13 +721,31 @@ func (u *updateNode) processSourceRow(params runParams, sourceVals tree.Datums) // MakeUpdater guarantees that the first columns of the new values // are those specified u.columns. resultValues := make([]tree.Datum, len(u.columns)) + largestRetIdx := -1 for i := range u.run.rowIdxToRetIdx { retIdx := u.run.rowIdxToRetIdx[i] if retIdx >= 0 { + if retIdx >= largestRetIdx { + largestRetIdx = retIdx + } resultValues[retIdx] = newValues[i] } } + // At this point we've extracted all the RETURNING values that are part + // of the target table. We must now extract the columns in the RETURNING + // clause that refer to other tables (from the FROM clause of the update). + if u.run.numPassthrough > 0 { + passthroughBegin := len(u.run.tu.ru.FetchCols) + len(u.run.tu.ru.UpdateCols) + passthroughEnd := passthroughBegin + u.run.numPassthrough + passthroughValues := sourceVals[passthroughBegin:passthroughEnd] + + for i := 0; i < u.run.numPassthrough; i++ { + largestRetIdx++ + resultValues[largestRetIdx] = passthroughValues[i] + } + } + if _, err := u.run.rows.AddRow(params.ctx, resultValues); err != nil { return err }