diff --git a/docs/generated/sql/bnf/stmt_block.bnf b/docs/generated/sql/bnf/stmt_block.bnf index d2a359b134ed..940d99450de0 100644 --- a/docs/generated/sql/bnf/stmt_block.bnf +++ b/docs/generated/sql/bnf/stmt_block.bnf @@ -193,7 +193,7 @@ truncate_stmt ::= 'TRUNCATE' opt_table relation_expr_list opt_drop_behavior update_stmt ::= - opt_with_clause 'UPDATE' table_name_expr_opt_alias_idx 'SET' set_clause_list opt_where_clause opt_sort_clause opt_limit_clause returning_clause + opt_with_clause 'UPDATE' table_name_expr_opt_alias_idx 'SET' set_clause_list opt_from_list opt_where_clause opt_sort_clause opt_limit_clause returning_clause upsert_stmt ::= opt_with_clause 'UPSERT' 'INTO' insert_target insert_rest returning_clause @@ -555,6 +555,10 @@ opt_drop_behavior ::= set_clause_list ::= ( set_clause ) ( ( ',' set_clause ) )* +opt_from_list ::= + 'FROM' from_list + | + db_object_name ::= simple_db_object_name | complex_db_object_name @@ -1185,6 +1189,9 @@ set_clause ::= single_set_clause | multiple_set_clause +from_list ::= + ( table_ref ) ( ( ',' table_ref ) )* + simple_db_object_name ::= db_object_name_component @@ -1545,6 +1552,16 @@ single_set_clause ::= multiple_set_clause ::= '(' insert_column_list ')' '=' in_expr +table_ref ::= + relation_expr opt_index_flags opt_ordinality opt_alias_clause + | select_with_parens opt_ordinality opt_alias_clause + | 'LATERAL' select_with_parens opt_ordinality opt_alias_clause + | joined_table + | '(' joined_table ')' opt_ordinality alias_clause + | func_table opt_ordinality opt_alias_clause + | 'LATERAL' func_table opt_ordinality opt_alias_clause + | '[' row_source_extension_stmt ']' opt_ordinality opt_alias_clause + cockroachdb_extra_type_func_name_keyword ::= 'FAMILY' @@ -1847,16 +1864,6 @@ distinct_clause ::= distinct_on_clause ::= 'DISTINCT' 'ON' '(' expr_list ')' -table_ref ::= - relation_expr opt_index_flags opt_ordinality opt_alias_clause - | select_with_parens opt_ordinality opt_alias_clause - | 'LATERAL' select_with_parens opt_ordinality opt_alias_clause - | joined_table - | '(' joined_table ')' opt_ordinality alias_clause - | func_table opt_ordinality opt_alias_clause - | 'LATERAL' func_table opt_ordinality opt_alias_clause - | '[' row_source_extension_stmt ']' opt_ordinality opt_alias_clause - all_or_distinct ::= 'ALL' | 'DISTINCT' @@ -1865,6 +1872,39 @@ all_or_distinct ::= var_list ::= ( var_value ) ( ( ',' var_value ) )* +opt_ordinality ::= + 'WITH' 'ORDINALITY' + | + +opt_alias_clause ::= + alias_clause + | + +joined_table ::= + '(' joined_table ')' + | table_ref 'CROSS' opt_join_hint 'JOIN' table_ref + | table_ref join_type opt_join_hint 'JOIN' table_ref join_qual + | table_ref 'JOIN' table_ref join_qual + | table_ref 'NATURAL' join_type opt_join_hint 'JOIN' table_ref + | table_ref 'NATURAL' 'JOIN' table_ref + +alias_clause ::= + 'AS' table_alias_name opt_column_list + | table_alias_name opt_column_list + +func_table ::= + func_expr_windowless + | 'ROWS' 'FROM' '(' rowsfrom_list ')' + +row_source_extension_stmt ::= + delete_stmt + | explain_stmt + | insert_stmt + | select_stmt + | show_stmt + | update_stmt + | upsert_stmt + user_priority ::= 'LOW' | 'NORMAL' @@ -2061,44 +2101,31 @@ character_base ::= | 'VARCHAR' | 'STRING' -from_list ::= - ( table_ref ) ( ( ',' table_ref ) )* - window_definition_list ::= ( window_definition ) ( ( ',' window_definition ) )* -opt_ordinality ::= - 'WITH' 'ORDINALITY' - | - -opt_alias_clause ::= - alias_clause +opt_join_hint ::= + 'HASH' + | 'MERGE' + | 'LOOKUP' | -joined_table ::= - '(' joined_table ')' - | table_ref 'CROSS' opt_join_hint 'JOIN' table_ref - | table_ref join_type opt_join_hint 'JOIN' table_ref join_qual - | table_ref 'JOIN' table_ref join_qual - | table_ref 'NATURAL' join_type opt_join_hint 'JOIN' table_ref - | table_ref 'NATURAL' 'JOIN' table_ref +join_type ::= + 'FULL' join_outer + | 'LEFT' join_outer + | 'RIGHT' join_outer + | 'INNER' -alias_clause ::= - 'AS' table_alias_name opt_column_list - | table_alias_name opt_column_list +join_qual ::= + 'USING' '(' name_list ')' + | 'ON' a_expr -func_table ::= - func_expr_windowless - | 'ROWS' 'FROM' '(' rowsfrom_list ')' +func_expr_windowless ::= + func_application + | func_expr_common_subexpr -row_source_extension_stmt ::= - delete_stmt - | explain_stmt - | insert_stmt - | select_stmt - | show_stmt - | update_stmt - | upsert_stmt +rowsfrom_list ::= + ( rowsfrom_item ) ( ( ',' rowsfrom_item ) )* opt_column ::= 'COLUMN' @@ -2230,28 +2257,12 @@ char_aliases ::= window_definition ::= window_name 'AS' window_specification -opt_join_hint ::= - 'HASH' - | 'MERGE' - | 'LOOKUP' +join_outer ::= + 'OUTER' | -join_type ::= - 'FULL' join_outer - | 'LEFT' join_outer - | 'RIGHT' join_outer - | 'INNER' - -join_qual ::= - 'USING' '(' name_list ')' - | 'ON' a_expr - -func_expr_windowless ::= - func_application - | func_expr_common_subexpr - -rowsfrom_list ::= - ( rowsfrom_item ) ( ( ',' rowsfrom_item ) )* +rowsfrom_item ::= + func_expr_windowless create_as_col_qualification_elem ::= 'PRIMARY' 'KEY' @@ -2317,13 +2328,6 @@ trim_list ::= | 'FROM' expr_list | expr_list -join_outer ::= - 'OUTER' - | - -rowsfrom_item ::= - func_expr_windowless - create_as_param ::= column_name diff --git a/docs/generated/sql/bnf/update_stmt.bnf b/docs/generated/sql/bnf/update_stmt.bnf index fd5bf8e88dbe..9a2b313b1a0e 100644 --- a/docs/generated/sql/bnf/update_stmt.bnf +++ b/docs/generated/sql/bnf/update_stmt.bnf @@ -1,2 +1,2 @@ update_stmt ::= - ( ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) | ) 'UPDATE' ( ( table_name opt_index_flags ) | ( table_name opt_index_flags ) table_alias_name | ( table_name opt_index_flags ) 'AS' table_alias_name ) 'SET' ( ( ( ( column_name '=' a_expr ) | ( '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' '=' ( '(' select_stmt ')' | ( '(' ')' | '(' ( a_expr | a_expr ',' | a_expr ',' ( ( a_expr ) ( ( ',' a_expr ) )* ) ) ')' ) ) ) ) ) ( ( ',' ( ( column_name '=' a_expr ) | ( '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' '=' ( '(' select_stmt ')' | ( '(' ')' | '(' ( a_expr | a_expr ',' | a_expr ',' ( ( a_expr ) ( ( ',' a_expr ) )* ) ) ')' ) ) ) ) ) )* ) ( ( 'WHERE' a_expr ) | ) ( sort_clause | ) ( limit_clause | ) ( 'RETURNING' target_list | 'RETURNING' 'NOTHING' | ) + ( ( 'WITH' ( ( common_table_expr ) ( ( ',' common_table_expr ) )* ) ) | ) 'UPDATE' ( ( table_name opt_index_flags ) | ( table_name opt_index_flags ) table_alias_name | ( table_name opt_index_flags ) 'AS' table_alias_name ) 'SET' ( ( ( ( column_name '=' a_expr ) | ( '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' '=' ( '(' select_stmt ')' | ( '(' ')' | '(' ( a_expr | a_expr ',' | a_expr ',' ( ( a_expr ) ( ( ',' a_expr ) )* ) ) ')' ) ) ) ) ) ( ( ',' ( ( column_name '=' a_expr ) | ( '(' ( ( ( column_name ) ) ( ( ',' ( column_name ) ) )* ) ')' '=' ( '(' select_stmt ')' | ( '(' ')' | '(' ( a_expr | a_expr ',' | a_expr ',' ( ( a_expr ) ( ( ',' a_expr ) )* ) ) ')' ) ) ) ) ) )* ) opt_from_list ( ( 'WHERE' a_expr ) | ) ( sort_clause | ) ( limit_clause | ) ( 'RETURNING' target_list | 'RETURNING' 'NOTHING' | ) diff --git a/pkg/sql/logictest/testdata/logic_test/update_from b/pkg/sql/logictest/testdata/logic_test/update_from new file mode 100644 index 000000000000..c13cc9bc84ba --- /dev/null +++ b/pkg/sql/logictest/testdata/logic_test/update_from @@ -0,0 +1,174 @@ +# LogicTest: local-opt fakedist-opt + +statement ok +CREATE TABLE abc (a int primary key, b int, c int) + +statement ok +INSERT INTO abc VALUES (1, 20, 300), (2, 30, 400) + +# Updating using self join. +statement ok +UPDATE abc SET b = other.b + 1, c = other.c + 1 FROM abc AS other WHERE abc.a = other.a + +query III rowsort +SELECT * FROM abc +---- +1 21 301 +2 31 401 + +# Update only some columns. +statement ok +UPDATE abc SET b = other.b + 1 FROM abc AS other WHERE abc.a = other.a + +query III rowsort +SELECT * FROM abc +---- +1 22 301 +2 32 401 + +# Update only some rows. +statement ok +UPDATE abc SET b = other.b + 1 FROM abc AS other WHERE abc.a = other.a AND abc.a = 1 + +query III rowsort +SELECT * FROM abc +---- +1 23 301 +2 32 401 + +# Update from another table. +statement ok +CREATE TABLE new_abc (a int, b int, c int) + +statement ok +INSERT INTO new_abc VALUES (1, 2, 3), (2, 3, 4) + +statement ok +UPDATE abc SET b = new_abc.b, c = new_abc.c FROM new_abc WHERE abc.a = new_abc.a + +query III rowsort +SELECT * FROM abc +---- +1 2 3 +2 3 4 + +# Multiple matching values for a given row. When this happens, we pick +# the first matching value for the row (this is arbitrary). This behavior +# is consistent with Postgres. +statement ok +INSERT INTO new_abc VALUES (1, 1, 1) + +statement ok +UPDATE abc SET b = new_abc.b, c = new_abc.c FROM new_abc WHERE abc.a = new_abc.a + +query III rowsort +SELECT * FROM abc +---- +1 2 3 +2 3 4 + +# Returning old values. +query IIIII colnames,rowsort +UPDATE abc +SET + b = old.b + 1, c = old.c + 2 +FROM + abc AS old +WHERE + abc.a = old.a +RETURNING + abc.a, abc.b AS new_b, old.b as old_b, abc.c as new_c, old.c as old_c +---- +a new_b old_b new_c old_c +1 3 2 5 3 +2 4 3 6 4 + +# Check if RETURNING * returns everything +query IIIIII colnames,rowsort +UPDATE abc SET b = old.b + 1, c = old.c + 2 FROM abc AS old WHERE abc.a = old.a RETURNING * +---- +a b c a b c +1 4 7 1 3 5 +2 5 8 2 4 6 + +# Make sure UPDATE FROM works properly in the presence of check columns. +statement ok +CREATE TABLE abc_check (a int primary key, b int, c int, check (a > 0), check (b > 0 AND b < 10)) + +statement ok +INSERT INTO abc_check VALUES (1, 2, 3), (2, 3, 4) + +query III colnames,rowsort +UPDATE abc_check +SET + b = other.b, c = other.c +FROM + abc AS other +WHERE + abc_check.a = other.a +RETURNING + abc_check.a, abc_check.b, abc_check.c +---- +a b c +1 4 7 +2 5 8 + +query III rowsort +SELECT * FROM abc +---- +1 4 7 +2 5 8 + +# Update values of table from values expression +statement ok +UPDATE abc SET b = other.b, c = other.c FROM (values (1, 2, 3), (2, 3, 4)) as other ("a", "b", "c") WHERE abc.a = other.a + +query III rowsort +SELECT * FROM abc +---- +1 2 3 +2 3 4 + +# Check if UPDATE ... FROM works with multiple tables. +statement ok +CREATE TABLE ab (a INT, b INT) + +statement ok +CREATE TABLE ac (a INT, c INT) + +statement ok +INSERT INTO ab VALUES (1, 200), (2, 300) + +statement ok +INSERT INTO ac VALUES (1, 300), (2, 400) + +statement ok +UPDATE abc SET b = ab.b, c = ac.c FROM ab, ac WHERE abc.a = ab.a AND abc.a = ac.a + +query III rowsort +SELECT * FROM abc +---- +1 200 300 +2 300 400 + +# Make sure UPDATE ... FROM works with LATERAL. +query IIIIIII colnames,rowsort +UPDATE abc +SET + b=ab.b, c = other.c +FROM + ab, LATERAL + (SELECT * FROM ac WHERE ab.a=ac.a) AS other +WHERE + abc.a=ab.a +RETURNING + * +---- +a b c a b a c +1 200 300 1 200 1 300 +2 300 400 2 300 2 400 + + +# Make sure the FROM clause cannot reference the target table. +statement error no data source matches prefix: abc +UPDATE abc SET a = other.a FROM (SELECT abc.a FROM abc AS x) AS other WHERE abc.a=other.a diff --git a/pkg/sql/opt/bench/stub_factory.go b/pkg/sql/opt/bench/stub_factory.go index f0f10927f838..40ae5fd8284a 100644 --- a/pkg/sql/opt/bench/stub_factory.go +++ b/pkg/sql/opt/bench/stub_factory.go @@ -236,6 +236,7 @@ func (f *stubFactory) ConstructUpdate( updateCols exec.ColumnOrdinalSet, returnCols exec.ColumnOrdinalSet, checks exec.CheckOrdinalSet, + passthrough sqlbase.ResultColumns, ) (exec.Node, error) { return struct{}{}, nil } diff --git a/pkg/sql/opt/exec/execbuilder/mutation.go b/pkg/sql/opt/exec/execbuilder/mutation.go index 79c8b14a5ee1..96e1c2533b10 100644 --- a/pkg/sql/opt/exec/execbuilder/mutation.go +++ b/pkg/sql/opt/exec/execbuilder/mutation.go @@ -22,6 +22,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" "github.com/cockroachdb/cockroach/pkg/util" "github.com/cockroachdb/errors" ) @@ -105,11 +106,19 @@ func (b *Builder) buildUpdate(upd *memo.UpdateExpr) (execPlan, error) { // // TODO(andyk): Using ensureColumns here can result in an extra Render. // Upgrade execution engine to not require this. - colList := make(opt.ColList, 0, len(upd.FetchCols)+len(upd.UpdateCols)+len(upd.CheckCols)) + colList := make(opt.ColList, 0, len(upd.FetchCols)+len(upd.UpdateCols)+len(upd.CheckCols)+len(upd.PassthroughCols)) colList = appendColsWhenPresent(colList, upd.FetchCols) colList = appendColsWhenPresent(colList, upd.UpdateCols) - colList = appendColsWhenPresent(colList, upd.CheckCols) + // The RETURNING clause of the Update can refer to the columns + // in any of the FROM tables. As a result, the Update may need + // to passthrough those columns so the projection above can use + // them. + if upd.NeedResults() { + colList = appendColsWhenPresent(colList, upd.PassthroughCols) + } + + colList = appendColsWhenPresent(colList, upd.CheckCols) input, err := b.buildMutationInput(upd.Input, colList, &upd.MutationPrivate) if err != nil { return execPlan{}, err @@ -122,6 +131,16 @@ func (b *Builder) buildUpdate(upd *memo.UpdateExpr) (execPlan, error) { updateColOrds := ordinalSetFromColList(upd.UpdateCols) returnColOrds := ordinalSetFromColList(upd.ReturnCols) checkOrds := ordinalSetFromColList(upd.CheckCols) + + // Construct the result columns for the passthrough set. + var passthroughCols sqlbase.ResultColumns + if upd.NeedResults() { + for _, passthroughCol := range upd.PassthroughCols { + colMeta := b.mem.Metadata().ColumnMeta(passthroughCol) + passthroughCols = append(passthroughCols, sqlbase.ResultColumn{Name: colMeta.Alias, Typ: colMeta.Type}) + } + } + node, err := b.factory.ConstructUpdate( input.root, tab, @@ -129,6 +148,7 @@ func (b *Builder) buildUpdate(upd *memo.UpdateExpr) (execPlan, error) { updateColOrds, returnColOrds, checkOrds, + passthroughCols, ) if err != nil { return execPlan{}, err @@ -354,6 +374,16 @@ func mutationOutputColMap(mutation memo.RelExpr) opt.ColMap { ord++ } } + + // The output columns of the mutation will also include all + // columns it allowed to pass through. + for _, colID := range private.PassthroughCols { + if colID != 0 { + colMap.Set(int(colID), ord) + ord++ + } + } + return colMap } diff --git a/pkg/sql/opt/exec/execbuilder/testdata/update_from b/pkg/sql/opt/exec/execbuilder/testdata/update_from new file mode 100644 index 000000000000..42555bcb0e47 --- /dev/null +++ b/pkg/sql/opt/exec/execbuilder/testdata/update_from @@ -0,0 +1,210 @@ +# LogicTest: local-opt + +statement ok +CREATE TABLE abc (a int primary key, b int, c int) + +# Updating using self join. +query TTT +EXPLAIN UPDATE abc SET b = other.b + 1, c = other.c + 1 FROM abc AS other WHERE abc.a = other.a +---- +count · · + └── update · · + │ table abc + │ set b, c + │ strategy updater + └── render · · + └── merge-join · · + │ type inner + │ equality (a) = (a) + │ mergeJoinOrder +"(a=a)" + ├── scan · · + │ table abc@primary + │ spans ALL + └── scan · · +· table abc@primary +· spans ALL + +# Update from another table. +statement ok +CREATE TABLE new_abc (a int, b int, c int) + +query TTT +EXPLAIN UPDATE abc SET b = other.b, c = other.c FROM new_abc AS other WHERE abc.a = other.a +---- +count · · + └── update · · + │ table abc + │ set b, c + │ strategy updater + └── render · · + └── distinct · · + │ distinct on a + └── hash-join · · + │ type inner + │ equality (a) = (a) + │ left cols are key · + ├── scan · · + │ table abc@primary + │ spans ALL + └── scan · · +· table new_abc@primary +· spans ALL + +# Returning old values. +query TTT +EXPLAIN UPDATE abc +SET + b = old.b + 1, c = old.c + 2 +FROM + abc AS old +WHERE + abc.a = old.a +RETURNING + abc.a, abc.b AS new_b, old.b as old_b, abc.c as new_c, old.c as old_c +---- +render · · + └── render · · + └── run · · + └── update · · + │ table abc + │ set b, c + │ strategy updater + └── render · · + └── merge-join · · + │ type inner + │ equality (a) = (a) + │ mergeJoinOrder +"(a=a)" + ├── scan · · + │ table abc@primary + │ spans ALL + └── scan · · +· table abc@primary +· spans ALL + +# Check if RETURNING * returns everything +query TTTTT +EXPLAIN (VERBOSE) UPDATE abc SET b = old.b + 1, c = old.c + 2 FROM abc AS old WHERE abc.a = old.a RETURNING * +---- +run · · (a, b, c, a, b, c) · + └── update · · (a, b, c, a, b, c) · + │ table abc · · + │ set b, c · · + │ strategy updater · · + └── render · · (a, b, c, column10, column11, a, b, c) · + │ render 0 a · · + │ render 1 b · · + │ render 2 c · · + │ render 3 b + 1 · · + │ render 4 c + 2 · · + │ render 5 a · · + │ render 6 b · · + │ render 7 c · · + └── merge-join · · (a, b, c, a, b, c) · + │ type inner · · + │ equality (a) = (a) · · + │ mergeJoinOrder +"(a=a)" · · + ├── scan · · (a, b, c) +a + │ table abc@primary · · + │ spans ALL · · + └── scan · · (a, b, c) +a +· table abc@primary · · +· spans ALL · · + +# Update values of table from values expression +query TTT +EXPLAIN UPDATE abc SET b = other.b, c = other.c FROM (values (1, 2, 3), (2, 3, 4)) as other ("a", "b", "c") WHERE abc.a = other.a +---- +count · · + └── update · · + │ table abc + │ set b, c + │ strategy updater + └── render · · + └── distinct · · + │ distinct on a + └── lookup-join · · + │ table abc@primary + │ type inner + │ equality (column1) = (a) + │ equality cols are key · + └── values · · +· size 3 columns, 2 rows + +# Check if UPDATE ... FROM works with multiple tables. +statement ok +CREATE TABLE ab (a INT, b INT) + +statement ok +CREATE TABLE ac (a INT, c INT) + +query TTT +EXPLAIN UPDATE abc SET b = ab.b, c = ac.c FROM ab, ac WHERE abc.a = ab.a AND abc.a = ac.a +---- +count · · + └── update · · + │ table abc + │ set b, c + │ strategy updater + └── render · · + └── distinct · · + │ distinct on a + │ order key a + └── merge-join · · + │ type inner + │ equality (a) = (a) + │ mergeJoinOrder +"(a=a)" + ├── merge-join · · + │ │ type inner + │ │ equality (a) = (a) + │ │ mergeJoinOrder +"(a=a)" + │ ├── scan · · + │ │ table abc@primary + │ │ spans ALL + │ └── sort · · + │ │ order +a + │ └── scan · · + │ table ac@primary + │ spans ALL + └── sort · · + │ order +a + └── scan · · +· table ab@primary +· spans ALL + +# Make sure UPDATE ... FROM works with LATERAL. +query TTT +EXPLAIN UPDATE abc +SET + b=ab.b, c = other.c +FROM + ab, LATERAL + (SELECT * FROM ac WHERE ab.a=ac.a) AS other +WHERE + abc.a=ab.a +RETURNING + * +---- +run · · + └── update · · + │ table abc + │ set b, c + │ strategy updater + └── render · · + └── distinct · · + │ distinct on a + └── hash-join · · + │ type inner + │ equality (a) = (a) + ├── scan · · + │ table ac@primary + │ spans ALL + └── hash-join · · + │ type inner + │ equality (a) = (a) + │ right cols are key · + ├── scan · · + │ table ab@primary + │ spans ALL + └── scan · · +· table abc@primary +· spans ALL diff --git a/pkg/sql/opt/exec/factory.go b/pkg/sql/opt/exec/factory.go index 7822cf384e2d..f6afef651cdb 100644 --- a/pkg/sql/opt/exec/factory.go +++ b/pkg/sql/opt/exec/factory.go @@ -321,6 +321,10 @@ type Factory interface { // columns in the same order as they appear in the table schema, with the // fetch columns first and the update columns second. The rowsNeeded parameter // is true if a RETURNING clause needs the updated row(s) as output. + // The passthrough parameter contains all the result columns that are part + // of the input node that the update node needs to return (passing through + // from the input). The pass through columns are used to return any column + // from the FROM tables that are referenced in the RETURNING clause. ConstructUpdate( input Node, table cat.Table, @@ -328,6 +332,7 @@ type Factory interface { updateCols ColumnOrdinalSet, returnCols ColumnOrdinalSet, checks CheckOrdinalSet, + passthrough sqlbase.ResultColumns, ) (Node, error) // ConstructUpsert creates a node that implements an INSERT..ON CONFLICT or diff --git a/pkg/sql/opt/memo/logical_props_builder.go b/pkg/sql/opt/memo/logical_props_builder.go index 369c09ef7621..7ca26312f797 100644 --- a/pkg/sql/opt/memo/logical_props_builder.go +++ b/pkg/sql/opt/memo/logical_props_builder.go @@ -1192,6 +1192,14 @@ func (b *logicalPropsBuilder) buildMutationProps(mutation RelExpr, rel *props.Re } } + // The output columns of the mutation will also include all + // columns it allowed to pass through. + for _, col := range private.PassthroughCols { + if col != 0 { + rel.OutputCols.Add(col) + } + } + // Not Null Columns // ---------------- // A column should be marked as not-null if the target table column is not diff --git a/pkg/sql/opt/norm/prune_cols.go b/pkg/sql/opt/norm/prune_cols.go index 0b6c61c9e173..0c41e9f96533 100644 --- a/pkg/sql/opt/norm/prune_cols.go +++ b/pkg/sql/opt/norm/prune_cols.go @@ -59,6 +59,7 @@ func (c *CustomFuncs) NeededMutationCols(private *memo.MutationPrivate) opt.ColS addCols(private.UpdateCols) addCols(private.CheckCols) addCols(private.ReturnCols) + addCols(private.PassthroughCols) if private.CanaryCol != 0 { cols.Add(private.CanaryCol) } diff --git a/pkg/sql/opt/ops/mutation.opt b/pkg/sql/opt/ops/mutation.opt index 47082e072457..cb60823229c8 100644 --- a/pkg/sql/opt/ops/mutation.opt +++ b/pkg/sql/opt/ops/mutation.opt @@ -121,6 +121,13 @@ define MutationPrivate { # then ReturnCols is nil. ReturnCols ColList + # PassthroughCols are columns that the mutation needs to passthrough from + # its input. It's similar to the passthrough columns in projections. This + # is useful for `UPDATE .. FROM` mutations where the `RETURNING` clause + # references columns from tables in the `FROM` clause. When this happens + # the update will need to pass through those refenced columns from its input. + PassthroughCols ColList + # Mutation operators can act similarly to a With operator: they buffer their # input, making it accessible to FK queries. If this is not required, WithID # is zero. diff --git a/pkg/sql/opt/optbuilder/delete.go b/pkg/sql/opt/optbuilder/delete.go index eb6de57f9546..16b6bcfc35cf 100644 --- a/pkg/sql/opt/optbuilder/delete.go +++ b/pkg/sql/opt/optbuilder/delete.go @@ -68,7 +68,7 @@ func (b *Builder) buildDelete(del *tree.Delete, inScope *scope) (outScope *scope // ORDER BY LIMIT // // All columns from the delete table will be projected. - mb.buildInputForUpdateOrDelete(inScope, del.Where, del.Limit, del.OrderBy) + mb.buildInputForDelete(inScope, del.Where, del.Limit, del.OrderBy) // Build the final delete statement, including any returned expressions. if resultsNeeded(del.Returning) { diff --git a/pkg/sql/opt/optbuilder/mutation_builder.go b/pkg/sql/opt/optbuilder/mutation_builder.go index 80586e9ab3d5..c0b5352661bc 100644 --- a/pkg/sql/opt/optbuilder/mutation_builder.go +++ b/pkg/sql/opt/optbuilder/mutation_builder.go @@ -128,6 +128,12 @@ type mutationBuilder struct { // withID is nonzero if we need to buffer the input for FK checks. withID opt.WithID + + // extraAccessibleCols stores all the columns that are available to the + // mutation that are not part of the target table. This is useful for + // UPDATE ... FROM queries, as the columns from the FROM tables must be + // made accessible to the RETURNING clause. + extraAccessibleCols []scopeColumn } func (mb *mutationBuilder) init(b *Builder, op opt.Operator, tab cat.Table, alias tree.TableName) { @@ -176,8 +182,8 @@ func (mb *mutationBuilder) fetchColID(tabOrd int) opt.ColumnID { return mb.scopeOrdToColID(mb.fetchOrds[tabOrd]) } -// buildInputForUpdateOrDelete constructs a Select expression from the fields in -// the Update or Delete operator, similar to this: +// buildInputForUpdate constructs a Select expression from the fields in +// the Update operator, similar to this: // // SELECT // FROM @@ -186,9 +192,22 @@ func (mb *mutationBuilder) fetchColID(tabOrd int) opt.ColumnID { // LIMIT // // All columns from the table to update are added to fetchColList. +// If a FROM clause is defined, we build out each of the table +// expressions required and JOIN them together (LATERAL joins between +// the tables are allowed). We then JOIN the result with the target +// table (the FROM tables can't reference this table) and apply the +// appropriate WHERE conditions. +// +// It is the responsibility of the user to guarantee that the JOIN +// produces a maximum of one row per row of the target table. If multiple +// are found, an arbitrary one is chosen (this row is not readily +// predictable, consistent with the POSTGRES implementation). +// buildInputForUpdate stores the columns of the FROM tables in the +// mutation builder so they can be made accessible to other parts of +// the query (RETURNING clause). // TODO(andyk): Do needed column analysis to project fewer columns if possible. -func (mb *mutationBuilder) buildInputForUpdateOrDelete( - inScope *scope, where *tree.Where, limit *tree.Limit, orderBy tree.OrderBy, +func (mb *mutationBuilder) buildInputForUpdate( + inScope *scope, from tree.TableExprs, where *tree.Where, limit *tree.Limit, orderBy tree.OrderBy, ) { // Fetch columns from different instance of the table metadata, so that it's // possible to remap columns, as in this example: @@ -205,6 +224,102 @@ func (mb *mutationBuilder) buildInputForUpdateOrDelete( inScope, ) + fromClausePresent := len(from) > 0 + numCols := len(mb.outScope.cols) + + // If there is a FROM clause present, we must join all the tables + // together with the table being updated. + if fromClausePresent { + fromScope := mb.b.buildFromTables(from, inScope) + + // Check that the same table name is not used multiple times. + mb.b.validateJoinTableNames(mb.outScope, fromScope) + + // The FROM table columns can be accessed by the RETURNING clause of the + // query and so we have to make them accessible. + mb.extraAccessibleCols = fromScope.cols + + // Add the columns in the FROM scope. + mb.outScope.appendColumnsFromScope(fromScope) + + left := mb.outScope.expr.(memo.RelExpr) + right := fromScope.expr.(memo.RelExpr) + mb.outScope.expr = mb.b.factory.ConstructInnerJoin(left, right, memo.TrueFilter, memo.EmptyJoinPrivate) + } + + // WHERE + mb.b.buildWhere(where, mb.outScope) + + // SELECT + ORDER BY (which may add projected expressions) + projectionsScope := mb.outScope.replace() + projectionsScope.appendColumnsFromScope(mb.outScope) + orderByScope := mb.b.analyzeOrderBy(orderBy, mb.outScope, projectionsScope) + mb.b.buildOrderBy(mb.outScope, projectionsScope, orderByScope) + mb.b.constructProjectForScope(mb.outScope, projectionsScope) + + // LIMIT + if limit != nil { + mb.b.buildLimit(limit, inScope, projectionsScope) + } + + mb.outScope = projectionsScope + + // Build a distinct on to ensure there is at most one row in the joined output + // for every row in the table. + if fromClausePresent { + var pkCols opt.ColSet + + // We need to ensure that the join has a maximum of one row for every row in the + // table and we ensure this by constructing a distinct on the primary key columns. + primaryIndex := mb.tab.Index(cat.PrimaryIndex) + for i := 0; i < primaryIndex.KeyColumnCount(); i++ { + pkCol := mb.outScope.cols[primaryIndex.Column(i).Ordinal] + + // If the primary key column is hidden, then we don't need to use it + // for the distinct on. + if !pkCol.hidden { + pkCols.Add(pkCol.id) + } + } + + if !pkCols.Empty() { + mb.outScope = mb.b.buildDistinctOn(pkCols, mb.outScope) + } + } + + // Set list of columns that will be fetched by the input expression. + for i := 0; i < numCols; i++ { + mb.fetchOrds[i] = scopeOrdinal(i) + } +} + +// buildInputForDelete constructs a Select expression from the fields in +// the Delete operator, similar to this: +// +// SELECT +// FROM
+// WHERE +// ORDER BY +// LIMIT +// +// All columns from the table to update are added to fetchColList. +// TODO(andyk): Do needed column analysis to project fewer columns if possible. +func (mb *mutationBuilder) buildInputForDelete( + inScope *scope, where *tree.Where, limit *tree.Limit, orderBy tree.OrderBy, +) { + // Fetch columns from different instance of the table metadata, so that it's + // possible to remap columns, as in this example: + // + // DELETE FROM abc WHERE a=b + // + mb.outScope = mb.b.buildScan( + mb.b.addTable(mb.tab, &mb.alias), + nil, /* ordinals */ + nil, /* indexFlags */ + includeMutations, + inScope, + ) + // WHERE mb.b.buildWhere(where, mb.outScope) @@ -682,6 +797,12 @@ func (mb *mutationBuilder) buildReturning(returning tree.ReturningExprs) { }) } + // extraAccessibleCols contains all the columns that the RETURNING + // clause can refer to in addition to the table columns. This is useful for + // UPDATE ... FROM statements, where all columns from tables in the FROM clause + // are in scope for the RETURNING clause. + inScope.appendColumns(mb.extraAccessibleCols) + // Construct the Project operator that projects the RETURNING expressions. outScope := inScope.replace() mb.b.analyzeReturningList(returning, nil /* desiredTypes */, inScope, outScope) diff --git a/pkg/sql/opt/optbuilder/testdata/update_from b/pkg/sql/opt/optbuilder/testdata/update_from new file mode 100644 index 000000000000..84882363e58d --- /dev/null +++ b/pkg/sql/opt/optbuilder/testdata/update_from @@ -0,0 +1,330 @@ +exec-ddl +CREATE TABLE abc (a int primary key, b int, c int) +---- + +exec-ddl +CREATE TABLE new_abc (a int, b int, c int) +---- + +# Test a self join. +opt +UPDATE abc SET b = other.b + 1, c = other.c + 1 FROM abc AS other WHERE abc.a = other.a +---- +update abc + ├── columns: + ├── fetch columns: abc.a:4(int) abc.b:5(int) abc.c:6(int) + ├── update-mapping: + │ ├── column10:10 => abc.b:2 + │ └── column11:11 => abc.c:3 + └── project + ├── columns: column10:10(int) column11:11(int) abc.a:4(int!null) abc.b:5(int) abc.c:6(int) other.a:7(int!null) other.b:8(int) other.c:9(int) + ├── inner-join (merge) + │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) other.a:7(int!null) other.b:8(int) other.c:9(int) + │ ├── left ordering: +4 + │ ├── right ordering: +7 + │ ├── scan abc + │ │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) + │ │ └── ordering: +4 + │ ├── scan other + │ │ ├── columns: other.a:7(int!null) other.b:8(int) other.c:9(int) + │ │ └── ordering: +7 + │ └── filters (true) + └── projections + ├── plus [type=int] + │ ├── variable: other.b [type=int] + │ └── const: 1 [type=int] + └── plus [type=int] + ├── variable: other.c [type=int] + └── const: 1 [type=int] + +# Test when Update uses multiple tables. +opt +UPDATE abc SET b = other.b, c = other.c FROM new_abc AS other WHERE abc.a = other.a +---- +update abc + ├── columns: + ├── fetch columns: abc.a:4(int) abc.b:5(int) abc.c:6(int) + ├── update-mapping: + │ ├── other.b:8 => abc.b:2 + │ └── other.c:9 => abc.c:3 + └── distinct-on + ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) other.a:7(int) other.b:8(int) other.c:9(int) + ├── grouping columns: abc.a:4(int!null) + ├── inner-join (hash) + │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) other.a:7(int!null) other.b:8(int) other.c:9(int) + │ ├── scan abc + │ │ └── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) + │ ├── scan other + │ │ └── columns: other.a:7(int) other.b:8(int) other.c:9(int) + │ └── filters + │ └── eq [type=bool] + │ ├── variable: abc.a [type=int] + │ └── variable: other.a [type=int] + └── aggregations + ├── first-agg [type=int] + │ └── variable: abc.b [type=int] + ├── first-agg [type=int] + │ └── variable: abc.c [type=int] + ├── first-agg [type=int] + │ └── variable: other.a [type=int] + ├── first-agg [type=int] + │ └── variable: other.b [type=int] + └── first-agg [type=int] + └── variable: other.c [type=int] + +# Check if UPDATE FROM works well with RETURNING expressions that reference the FROM tables. +opt +UPDATE abc +SET + b = old.b + 1, c = old.c + 2 +FROM + abc AS old +WHERE + abc.a = old.a +RETURNING + abc.a, abc.b AS new_b, old.b as old_b, abc.c as new_c, old.c as old_c +---- +project + ├── columns: a:1(int!null) new_b:2(int) old_b:8(int) new_c:3(int) old_c:9(int) + └── update abc + ├── columns: abc.a:1(int!null) abc.b:2(int) abc.c:3(int) old.a:7(int) old.b:8(int) old.c:9(int) + ├── fetch columns: abc.a:4(int) abc.b:5(int) abc.c:6(int) + ├── update-mapping: + │ ├── column10:10 => abc.b:2 + │ └── column11:11 => abc.c:3 + └── project + ├── columns: column10:10(int) column11:11(int) abc.a:4(int!null) abc.b:5(int) abc.c:6(int) old.a:7(int!null) old.b:8(int) old.c:9(int) + ├── inner-join (merge) + │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) old.a:7(int!null) old.b:8(int) old.c:9(int) + │ ├── left ordering: +4 + │ ├── right ordering: +7 + │ ├── scan abc + │ │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) + │ │ └── ordering: +4 + │ ├── scan old + │ │ ├── columns: old.a:7(int!null) old.b:8(int) old.c:9(int) + │ │ └── ordering: +7 + │ └── filters (true) + └── projections + ├── plus [type=int] + │ ├── variable: old.b [type=int] + │ └── const: 1 [type=int] + └── plus [type=int] + ├── variable: old.c [type=int] + └── const: 2 [type=int] + +# Check if RETURNING * returns everything +opt +UPDATE abc SET b = old.b + 1, c = old.c + 2 FROM abc AS old WHERE abc.a = old.a RETURNING * +---- +update abc + ├── columns: a:1(int!null) b:2(int) c:3(int) a:7(int) b:8(int) c:9(int) + ├── fetch columns: abc.a:4(int) abc.b:5(int) abc.c:6(int) + ├── update-mapping: + │ ├── column10:10 => abc.b:2 + │ └── column11:11 => abc.c:3 + └── project + ├── columns: column10:10(int) column11:11(int) abc.a:4(int!null) abc.b:5(int) abc.c:6(int) old.a:7(int!null) old.b:8(int) old.c:9(int) + ├── inner-join (merge) + │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) old.a:7(int!null) old.b:8(int) old.c:9(int) + │ ├── left ordering: +4 + │ ├── right ordering: +7 + │ ├── scan abc + │ │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) + │ │ └── ordering: +4 + │ ├── scan old + │ │ ├── columns: old.a:7(int!null) old.b:8(int) old.c:9(int) + │ │ └── ordering: +7 + │ └── filters (true) + └── projections + ├── plus [type=int] + │ ├── variable: old.b [type=int] + │ └── const: 1 [type=int] + └── plus [type=int] + ├── variable: old.c [type=int] + └── const: 2 [type=int] + +# Check if the joins are optimized (check if the filters are pushed down). +opt +UPDATE abc SET b = old.b + 1, c = old.c + 2 FROM abc AS old WHERE abc.a = old.a AND abc.a = 2 +---- +update abc + ├── columns: + ├── fetch columns: abc.a:4(int) abc.b:5(int) abc.c:6(int) + ├── update-mapping: + │ ├── column10:10 => abc.b:2 + │ └── column11:11 => abc.c:3 + └── project + ├── columns: column10:10(int) column11:11(int) abc.a:4(int!null) abc.b:5(int) abc.c:6(int) old.a:7(int!null) old.b:8(int) old.c:9(int) + ├── inner-join (merge) + │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) old.a:7(int!null) old.b:8(int) old.c:9(int) + │ ├── left ordering: +4 + │ ├── right ordering: +7 + │ ├── scan abc + │ │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) + │ │ └── constraint: /4: [/2 - /2] + │ ├── scan old + │ │ ├── columns: old.a:7(int!null) old.b:8(int) old.c:9(int) + │ │ └── constraint: /7: [/2 - /2] + │ └── filters (true) + └── projections + ├── plus [type=int] + │ ├── variable: old.b [type=int] + │ └── const: 1 [type=int] + └── plus [type=int] + ├── variable: old.c [type=int] + └── const: 2 [type=int] + +# Update values of table from values expression +opt +UPDATE abc SET b = other.b, c = other.c FROM (values (1, 2, 3), (2, 3, 4)) as other ("a", "b", "c") WHERE abc.a = other.a +---- +update abc + ├── columns: + ├── fetch columns: a:4(int) b:5(int) c:6(int) + ├── update-mapping: + │ ├── column2:8 => b:2 + │ └── column3:9 => c:3 + └── distinct-on + ├── columns: a:4(int!null) b:5(int) c:6(int) column1:7(int) column2:8(int) column3:9(int) + ├── grouping columns: a:4(int!null) + ├── inner-join (lookup abc) + │ ├── columns: a:4(int!null) b:5(int) c:6(int) column1:7(int!null) column2:8(int!null) column3:9(int!null) + │ ├── key columns: [7] = [4] + │ ├── values + │ │ ├── columns: column1:7(int!null) column2:8(int!null) column3:9(int!null) + │ │ ├── tuple [type=tuple{int, int, int}] + │ │ │ ├── const: 1 [type=int] + │ │ │ ├── const: 2 [type=int] + │ │ │ └── const: 3 [type=int] + │ │ └── tuple [type=tuple{int, int, int}] + │ │ ├── const: 2 [type=int] + │ │ ├── const: 3 [type=int] + │ │ └── const: 4 [type=int] + │ └── filters (true) + └── aggregations + ├── first-agg [type=int] + │ └── variable: b [type=int] + ├── first-agg [type=int] + │ └── variable: c [type=int] + ├── first-agg [type=int] + │ └── variable: column1 [type=int] + ├── first-agg [type=int] + │ └── variable: column2 [type=int] + └── first-agg [type=int] + └── variable: column3 [type=int] + +# Check if UPDATE ... FROM works with multiple tables. +exec-ddl +CREATE TABLE ab (a INT, b INT) +---- + +exec-ddl +CREATE TABLE ac (a INT, c INT) +---- + +opt +UPDATE abc SET b = ab.b, c = ac.c FROM ab, ac WHERE abc.a = ab.a AND abc.a = ac.a +---- +update abc + ├── columns: + ├── fetch columns: abc.a:4(int) abc.b:5(int) abc.c:6(int) + ├── update-mapping: + │ ├── ab.b:8 => abc.b:2 + │ └── ac.c:11 => abc.c:3 + └── distinct-on + ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) ab.a:7(int) ab.b:8(int) ac.a:10(int) ac.c:11(int) + ├── grouping columns: abc.a:4(int!null) + ├── internal-ordering: +(4|7|10) + ├── inner-join (merge) + │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) ab.a:7(int!null) ab.b:8(int) ac.a:10(int!null) ac.c:11(int) + │ ├── left ordering: +4 + │ ├── right ordering: +7 + │ ├── ordering: +(4|7|10) + │ ├── inner-join (merge) + │ │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) ac.a:10(int!null) ac.c:11(int) + │ │ ├── left ordering: +4 + │ │ ├── right ordering: +10 + │ │ ├── ordering: +(4|10) + │ │ ├── scan abc + │ │ │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) + │ │ │ └── ordering: +4 + │ │ ├── sort + │ │ │ ├── columns: ac.a:10(int) ac.c:11(int) + │ │ │ ├── ordering: +10 + │ │ │ └── scan ac + │ │ │ └── columns: ac.a:10(int) ac.c:11(int) + │ │ └── filters (true) + │ ├── sort + │ │ ├── columns: ab.a:7(int) ab.b:8(int) + │ │ ├── ordering: +7 + │ │ └── scan ab + │ │ └── columns: ab.a:7(int) ab.b:8(int) + │ └── filters (true) + └── aggregations + ├── first-agg [type=int] + │ └── variable: abc.b [type=int] + ├── first-agg [type=int] + │ └── variable: abc.c [type=int] + ├── first-agg [type=int] + │ └── variable: ab.a [type=int] + ├── first-agg [type=int] + │ └── variable: ab.b [type=int] + ├── first-agg [type=int] + │ └── variable: ac.a [type=int] + └── first-agg [type=int] + └── variable: ac.c [type=int] + +# Make sure UPDATE ... FROM works with LATERAL. +opt +UPDATE abc +SET + b=ab.b, c = other.c +FROM + ab, LATERAL + (SELECT * FROM ac WHERE ab.a=ac.a) AS other +WHERE + abc.a=ab.a +RETURNING + * +---- +update abc + ├── columns: a:1(int!null) b:2(int) c:3(int) a:7(int) b:8(int) a:10(int) c:11(int) + ├── fetch columns: abc.a:4(int) abc.b:5(int) abc.c:6(int) + ├── update-mapping: + │ ├── ab.b:8 => abc.b:2 + │ └── ac.c:11 => abc.c:3 + └── distinct-on + ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) ab.a:7(int) ab.b:8(int) ac.a:10(int) ac.c:11(int) + ├── grouping columns: abc.a:4(int!null) + ├── inner-join (hash) + │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) ab.a:7(int!null) ab.b:8(int) ac.a:10(int!null) ac.c:11(int) + │ ├── scan ac + │ │ └── columns: ac.a:10(int) ac.c:11(int) + │ ├── inner-join (hash) + │ │ ├── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) ab.a:7(int!null) ab.b:8(int) + │ │ ├── scan ab + │ │ │ └── columns: ab.a:7(int) ab.b:8(int) + │ │ ├── scan abc + │ │ │ └── columns: abc.a:4(int!null) abc.b:5(int) abc.c:6(int) + │ │ └── filters + │ │ └── eq [type=bool] + │ │ ├── variable: abc.a [type=int] + │ │ └── variable: ab.a [type=int] + │ └── filters + │ └── eq [type=bool] + │ ├── variable: ab.a [type=int] + │ └── variable: ac.a [type=int] + └── aggregations + ├── first-agg [type=int] + │ └── variable: abc.b [type=int] + ├── first-agg [type=int] + │ └── variable: abc.c [type=int] + ├── first-agg [type=int] + │ └── variable: ab.a [type=int] + ├── first-agg [type=int] + │ └── variable: ab.b [type=int] + ├── first-agg [type=int] + │ └── variable: ac.a [type=int] + └── first-agg [type=int] + └── variable: ac.c [type=int] diff --git a/pkg/sql/opt/optbuilder/update.go b/pkg/sql/opt/optbuilder/update.go index 94c12a362a68..ba1c2eed5351 100644 --- a/pkg/sql/opt/optbuilder/update.go +++ b/pkg/sql/opt/optbuilder/update.go @@ -103,7 +103,7 @@ func (b *Builder) buildUpdate(upd *tree.Update, inScope *scope) (outScope *scope // ORDER BY LIMIT // // All columns from the update table will be projected. - mb.buildInputForUpdateOrDelete(inScope, upd.Where, upd.Limit, upd.OrderBy) + mb.buildInputForUpdate(inScope, upd.From, upd.Where, upd.Limit, upd.OrderBy) // Derive the columns that will be updated from the SET expressions. mb.addTargetColsForUpdate(upd.Exprs) @@ -325,7 +325,11 @@ func (mb *mutationBuilder) buildUpdate(returning tree.ReturningExprs) { mb.addCheckConstraintCols() private := mb.makeMutationPrivate(returning != nil) + for _, col := range mb.extraAccessibleCols { + if col.id != 0 && !col.hidden { + private.PassthroughCols = append(private.PassthroughCols, col.id) + } + } mb.outScope.expr = mb.b.factory.ConstructUpdate(mb.outScope.expr, mb.checks, private) - mb.buildReturning(returning) } diff --git a/pkg/sql/opt_exec_factory.go b/pkg/sql/opt_exec_factory.go index a0bb5fd28e68..0214bfb3e053 100644 --- a/pkg/sql/opt_exec_factory.go +++ b/pkg/sql/opt_exec_factory.go @@ -1328,6 +1328,7 @@ func (ef *execFactory) ConstructUpdate( updateColOrdSet exec.ColumnOrdinalSet, returnColOrdSet exec.ColumnOrdinalSet, checks exec.CheckOrdinalSet, + passthrough sqlbase.ResultColumns, ) (exec.Node, error) { // Derive table and column descriptors. rowsNeeded := !returnColOrdSet.Empty() @@ -1400,6 +1401,9 @@ func (ef *execFactory) ConstructUpdate( returnCols = sqlbase.ResultColumnsFromColDescs(returnColDescs) + // Add the passthrough columns to the returning columns. + returnCols = append(returnCols, passthrough...) + // Update the rowIdxToRetIdx for the mutation. Update returns // the non-mutation columns specified, in the same order they are // defined in the table. @@ -1436,6 +1440,7 @@ func (ef *execFactory) ConstructUpdate( updateValues: make(tree.Datums, len(ru.UpdateCols)), updateColsIdx: updateColsIdx, rowIdxToRetIdx: rowIdxToRetIdx, + numPassthrough: len(passthrough), }, } diff --git a/pkg/sql/parser/parse_test.go b/pkg/sql/parser/parse_test.go index 750fc75c2e6a..a55ef705b919 100644 --- a/pkg/sql/parser/parse_test.go +++ b/pkg/sql/parser/parse_test.go @@ -1033,6 +1033,9 @@ func TestParse(t *testing.T) { {`UPDATE a.b SET b = 3`}, {`UPDATE a.b@c SET b = 3`}, {`UPDATE a SET b = 3, c = DEFAULT`}, + {`UPDATE a SET b = 3, c = DEFAULT FROM b`}, + {`UPDATE a SET b = 3, c = DEFAULT FROM a AS other`}, + {`UPDATE a SET b = 3, c = DEFAULT FROM a AS other, b`}, {`UPDATE a SET b = 3 + 4`}, {`UPDATE a SET (b, c) = (3, DEFAULT)`}, {`UPDATE a SET (b, c) = (SELECT 3, 4)`}, @@ -1045,6 +1048,7 @@ func TestParse(t *testing.T) { {`UPDATE a SET b = 3 WHERE a = b RETURNING a, a + b`}, {`UPDATE a SET b = 3 WHERE a = b RETURNING NOTHING`}, {`UPDATE a SET b = 3 WHERE a = b ORDER BY c LIMIT d RETURNING e`}, + {`UPDATE a SET b = 3 FROM other WHERE a = b ORDER BY c LIMIT d RETURNING e`}, {`UPDATE t AS "0" SET k = ''`}, // "0" lost its quotes {`SELECT * FROM "0" JOIN "0" USING (id, "0")`}, // last "0" lost its quotes. @@ -3053,7 +3057,6 @@ func TestUnimplementedSyntax(t *testing.T) { {`UPDATE foo SET (a, a.b) = (1, 2)`, 27792, ``}, {`UPDATE foo SET a.b = 1`, 27792, ``}, - {`UPDATE foo SET x = y FROM a, b`, 7841, ``}, {`UPDATE Foo SET x.y = z`, 27792, ``}, {`UPSERT INTO foo(a, a.b) VALUES (1,2)`, 27792, ``}, diff --git a/pkg/sql/parser/sql.y b/pkg/sql/parser/sql.y index 49fa2a4983c1..2b087fe59381 100644 --- a/pkg/sql/parser/sql.y +++ b/pkg/sql/parser/sql.y @@ -820,8 +820,8 @@ func newNameFromStr(s string) *tree.Name { %type index_params create_as_params %type name_list privilege_list %type <[]int32> opt_array_bounds -%type from_clause update_from_clause -%type from_list rowsfrom_list +%type from_clause +%type from_list rowsfrom_list opt_from_list %type table_pattern_list single_table_pattern_list %type table_name_list %type expr_list opt_expr_list tuple1_ambiguous_values tuple1_unambiguous_values @@ -5523,12 +5523,13 @@ returning_clause: // %SeeAlso: INSERT, UPSERT, DELETE, WEBDOCS/update.html update_stmt: opt_with_clause UPDATE table_name_expr_opt_alias_idx - SET set_clause_list update_from_clause opt_where_clause opt_sort_clause opt_limit_clause returning_clause + SET set_clause_list opt_from_list opt_where_clause opt_sort_clause opt_limit_clause returning_clause { $$.val = &tree.Update{ With: $1.with(), Table: $3.tblExpr(), Exprs: $5.updateExprs(), + From: $6.tblExprs(), Where: tree.NewWhere(tree.AstWhere, $7.expr()), OrderBy: $8.orderBy(), Limit: $9.limit(), @@ -5537,10 +5538,13 @@ update_stmt: } | opt_with_clause UPDATE error // SHOW HELP: UPDATE -// Mark this as unimplemented until the normal from_clause is supported here. -update_from_clause: - FROM from_list { return unimplementedWithIssue(sqllex, 7841) } -| /* EMPTY */ {} +opt_from_list: + FROM from_list { + $$.val = $2.tblExprs() + } +| /* EMPTY */ { + $$.val = tree.TableExprs{} +} set_clause_list: set_clause diff --git a/pkg/sql/sem/tree/pretty.go b/pkg/sql/sem/tree/pretty.go index 0a967c573cc9..487d10a70221 100644 --- a/pkg/sql/sem/tree/pretty.go +++ b/pkg/sql/sem/tree/pretty.go @@ -1023,7 +1023,12 @@ func (node *Update) doc(p *PrettyCfg) pretty.Doc { items = append(items, node.With.docRow(p), p.row("UPDATE", p.Doc(node.Table)), - p.row("SET", p.Doc(&node.Exprs)), + p.row("SET", p.Doc(&node.Exprs))) + if len(node.From) > 0 { + items = append(items, + p.row("FROM", p.Doc(&node.From))) + } + items = append(items, node.Where.docRow(p), node.OrderBy.docRow(p)) items = append(items, node.Limit.docTable(p)...) diff --git a/pkg/sql/sem/tree/testdata/pretty/22.align-deindent.golden.short b/pkg/sql/sem/tree/testdata/pretty/22.align-deindent.golden.short new file mode 100644 index 000000000000..e2f1c72ea0e2 --- /dev/null +++ b/pkg/sql/sem/tree/testdata/pretty/22.align-deindent.golden.short @@ -0,0 +1,13 @@ +// Code generated by TestPretty. DO NOT EDIT. +// GENERATED FILE DO NOT EDIT +1: +- + UPDATE abc + SET b = old.b + 1, c = old.c + 1 + FROM abc AS old + WHERE abc.a = 2 AND abc.b = abc.c + ORDER BY abc.v DESC + LIMIT 1 +RETURNING abc.b, c, 4 AS d + + diff --git a/pkg/sql/sem/tree/testdata/pretty/22.align-only.golden.short b/pkg/sql/sem/tree/testdata/pretty/22.align-only.golden.short new file mode 100644 index 000000000000..e2f1c72ea0e2 --- /dev/null +++ b/pkg/sql/sem/tree/testdata/pretty/22.align-only.golden.short @@ -0,0 +1,13 @@ +// Code generated by TestPretty. DO NOT EDIT. +// GENERATED FILE DO NOT EDIT +1: +- + UPDATE abc + SET b = old.b + 1, c = old.c + 1 + FROM abc AS old + WHERE abc.a = 2 AND abc.b = abc.c + ORDER BY abc.v DESC + LIMIT 1 +RETURNING abc.b, c, 4 AS d + + diff --git a/pkg/sql/sem/tree/testdata/pretty/22.ref.golden.short b/pkg/sql/sem/tree/testdata/pretty/22.ref.golden.short new file mode 100644 index 000000000000..3a97001d051a --- /dev/null +++ b/pkg/sql/sem/tree/testdata/pretty/22.ref.golden.short @@ -0,0 +1,20 @@ +// Code generated by TestPretty. DO NOT EDIT. +// GENERATED FILE DO NOT EDIT +1: +- +UPDATE + abc +SET + b = old.b + 1, c = old.c + 1 +FROM + abc AS old +WHERE + abc.a = 2 AND abc.b = abc.c +ORDER BY + abc.v DESC +LIMIT + 1 +RETURNING + abc.b, c, 4 AS d + + diff --git a/pkg/sql/sem/tree/testdata/pretty/22.sql b/pkg/sql/sem/tree/testdata/pretty/22.sql new file mode 100644 index 000000000000..90eaee338906 --- /dev/null +++ b/pkg/sql/sem/tree/testdata/pretty/22.sql @@ -0,0 +1 @@ +UPDATE abc SET b = old.b + 1, c = old.c + 1 FROM abc AS old WHERE abc.a=2 and abc.b=abc.c ORDER BY abc.v DESC LIMIT 1 RETURNING abc.b, c, 4 AS d diff --git a/pkg/sql/sem/tree/update.go b/pkg/sql/sem/tree/update.go index 801b571ce9cb..ce812e227ada 100644 --- a/pkg/sql/sem/tree/update.go +++ b/pkg/sql/sem/tree/update.go @@ -24,6 +24,7 @@ type Update struct { With *With Table TableExpr Exprs UpdateExprs + From TableExprs Where *Where OrderBy OrderBy Limit *Limit @@ -37,6 +38,10 @@ func (node *Update) Format(ctx *FmtCtx) { ctx.FormatNode(node.Table) ctx.WriteString(" SET ") ctx.FormatNode(&node.Exprs) + if len(node.From) > 0 { + ctx.WriteString(" FROM ") + ctx.FormatNode(&node.From) + } if node.Where != nil { ctx.WriteByte(' ') ctx.FormatNode(node.Where) diff --git a/pkg/sql/update.go b/pkg/sql/update.go index a035b3fba27f..06458998b8b7 100644 --- a/pkg/sql/update.go +++ b/pkg/sql/update.go @@ -494,6 +494,11 @@ type updateRun struct { // of the mutation. Otherwise, the value at the i-th index refers to the // index of the resultRowBuffer where the i-th column is to be returned. rowIdxToRetIdx []int + + // numPassthrough is the number of columns in addition to the set of + // columns of the target table being returned, that we must pass through + // from the input node. + numPassthrough int } // maxUpdateBatchSize is the max number of entries in the KV batch for @@ -693,7 +698,7 @@ func (u *updateNode) processSourceRow(params runParams, sourceVals tree.Datums) return err } } else { - checkVals := sourceVals[len(u.run.tu.ru.FetchCols)+len(u.run.tu.ru.UpdateCols):] + checkVals := sourceVals[len(u.run.tu.ru.FetchCols)+len(u.run.tu.ru.UpdateCols)+u.run.numPassthrough:] if err := u.run.checkHelper.CheckInput(checkVals); err != nil { return err } @@ -716,13 +721,31 @@ func (u *updateNode) processSourceRow(params runParams, sourceVals tree.Datums) // MakeUpdater guarantees that the first columns of the new values // are those specified u.columns. resultValues := make([]tree.Datum, len(u.columns)) + largestRetIdx := -1 for i := range u.run.rowIdxToRetIdx { retIdx := u.run.rowIdxToRetIdx[i] if retIdx >= 0 { + if retIdx >= largestRetIdx { + largestRetIdx = retIdx + } resultValues[retIdx] = newValues[i] } } + // At this point we've extracted all the RETURNING values that are part + // of the target table. We must now extract the columns in the RETURNING + // clause that refer to other tables (from the FROM clause of the update). + if u.run.numPassthrough > 0 { + passthroughBegin := len(u.run.tu.ru.FetchCols) + len(u.run.tu.ru.UpdateCols) + passthroughEnd := passthroughBegin + u.run.numPassthrough + passthroughValues := sourceVals[passthroughBegin:passthroughEnd] + + for i := 0; i < u.run.numPassthrough; i++ { + largestRetIdx++ + resultValues[largestRetIdx] = passthroughValues[i] + } + } + if _, err := u.run.rows.AddRow(params.ctx, resultValues); err != nil { return err }