diff --git a/pkg/sql/logictest/testdata/logic_test/upsert b/pkg/sql/logictest/testdata/logic_test/upsert index f846bd83abd6..04276c46ee04 100644 --- a/pkg/sql/logictest/testdata/logic_test/upsert +++ b/pkg/sql/logictest/testdata/logic_test/upsert @@ -1231,3 +1231,19 @@ statement ok RESET CLUSTER SETTING kv.raft.command.max_size; DROP TABLE src; DROP TABLE dest + +# Regression test for finishing UPSERT too early (#54456). +statement ok +CREATE TABLE t54456 (c INT PRIMARY KEY); +UPSERT INTO t54456 SELECT i FROM generate_series(1, 25000) AS i + +query I +SELECT count(*) FROM t54456 +---- +25000 + +# Regression test for clearing up upserted rows too early (#54465). +query I +WITH cte(c) AS (UPSERT INTO t54456 SELECT i FROM generate_series(25001, 40000) AS i RETURNING c) SELECT count(*) FROM cte +---- +15000 diff --git a/pkg/sql/opt_exec_factory.go b/pkg/sql/opt_exec_factory.go index e972b3cf724c..f6a9cad5c81f 100644 --- a/pkg/sql/opt_exec_factory.go +++ b/pkg/sql/opt_exec_factory.go @@ -1474,7 +1474,7 @@ func (ef *execFactory) ConstructUpsert( // in the table. ups.run.tw.tabColIdxToRetIdx = row.ColMapping(tabDesc.Columns, returnColDescs) ups.run.tw.returnCols = returnColDescs - ups.run.tw.collectRows = true + ups.run.tw.rowsNeeded = true } if autoCommit { diff --git a/pkg/sql/tablewriter_upsert_opt.go b/pkg/sql/tablewriter_upsert_opt.go index b4fa3569a860..cfe1b61c40d9 100644 --- a/pkg/sql/tablewriter_upsert_opt.go +++ b/pkg/sql/tablewriter_upsert_opt.go @@ -50,18 +50,15 @@ type optTableUpserter struct { ri row.Inserter // Should we collect the rows for a RETURNING clause? - collectRows bool - - // Rows returned if collectRows is true. - rowsUpserted *rowcontainer.RowContainer + rowsNeeded bool // A mapping of column IDs to the return index used to shape the resulting // rows to those required by the returning clause. Only required if - // collectRows is true. + // rowsNeeded is true. colIDToReturnIndex map[descpb.ColumnID]int // Do the result rows have a different order than insert rows. Only set if - // collectRows is true. + // rowsNeeded is true. insertReorderingRequired bool // fetchCols indicate which columns need to be fetched from the target table, @@ -104,11 +101,11 @@ func (tu *optTableUpserter) init( ) error { tu.tableWriterBase.init(txn, tu.ri.Helper.TableDesc) - // collectRows, set upon initialization, indicates whether or not we want + // rowsNeeded, set upon initialization, indicates whether or not we want // rows returned from the operation. - if tu.collectRows { + if tu.rowsNeeded { tu.resultRow = make(tree.Datums, len(tu.returnCols)) - tu.rowsUpserted = rowcontainer.NewRowContainer( + tu.rows = rowcontainer.NewRowContainer( evalCtx.Mon.MakeBoundAccount(), colinfo.ColTypeInfoFromColDescs(tu.returnCols), ) @@ -135,30 +132,6 @@ func (tu *optTableUpserter) init( return nil } -// flushAndStartNewBatch is part of the tableWriter interface. -func (tu *optTableUpserter) flushAndStartNewBatch(ctx context.Context) error { - if tu.collectRows { - tu.rowsUpserted.Clear(ctx) - } - return tu.tableWriterBase.flushAndStartNewBatch(ctx) -} - -// batchedValues is a helper in implementing batchedPlanNode interface. -func (tu *optTableUpserter) batchedValues(rowIdx int) tree.Datums { - if !tu.collectRows { - panic("return row requested but collect rows was not set") - } - return tu.rowsUpserted.At(rowIdx) -} - -// close is part of the tableWriter interface. -func (tu *optTableUpserter) close(ctx context.Context) { - tu.tableWriterBase.close(ctx) - if tu.rowsUpserted != nil { - tu.rowsUpserted.Close(ctx) - } -} - // makeResultFromRow reshapes a row that was inserted or updated to a row // suitable for storing for a RETURNING clause, shaped by the target table's // descriptor. @@ -209,10 +182,10 @@ func (tu *optTableUpserter) row( // If no columns need to be updated, then possibly collect the unchanged row. fetchEnd := insertEnd + len(tu.fetchCols) if len(tu.updateCols) == 0 { - if !tu.collectRows { + if !tu.rowsNeeded { return nil } - _, err := tu.rowsUpserted.AddRow(ctx, row[insertEnd:fetchEnd]) + _, err := tu.rows.AddRow(ctx, row[insertEnd:fetchEnd]) return err } @@ -243,7 +216,7 @@ func (tu *optTableUpserter) insertNonConflictingRow( return err } - if !tu.collectRows { + if !tu.rowsNeeded { return nil } @@ -259,7 +232,7 @@ func (tu *optTableUpserter) insertNonConflictingRow( tu.resultRow[retIdx] = tableRow[tabIdx] } } - _, err := tu.rowsUpserted.AddRow(ctx, tu.resultRow) + _, err := tu.rows.AddRow(ctx, tu.resultRow) return err } @@ -269,7 +242,7 @@ func (tu *optTableUpserter) insertNonConflictingRow( tu.resultRow[retIdx] = insertRow[tabIdx] } } - _, err := tu.rowsUpserted.AddRow(ctx, tu.resultRow) + _, err := tu.rows.AddRow(ctx, tu.resultRow) return err } @@ -307,7 +280,7 @@ func (tu *optTableUpserter) updateConflictingRow( } // We only need a result row if we're collecting rows. - if !tu.collectRows { + if !tu.rowsNeeded { return nil } @@ -335,7 +308,7 @@ func (tu *optTableUpserter) updateConflictingRow( // The resulting row may have nil values for columns that aren't // being upserted, updated or fetched. - _, err = tu.rowsUpserted.AddRow(ctx, tu.resultRow) + _, err = tu.rows.AddRow(ctx, tu.resultRow) return err } diff --git a/pkg/sql/upsert.go b/pkg/sql/upsert.go index b94e1624bcaa..7c3429c67bba 100644 --- a/pkg/sql/upsert.go +++ b/pkg/sql/upsert.go @@ -195,7 +195,7 @@ func (n *upsertNode) processSourceRow(params runParams, rowVals tree.Datums) err func (n *upsertNode) BatchedCount() int { return n.run.tw.lastBatchSize } // BatchedValues implements the batchedPlanNode interface. -func (n *upsertNode) BatchedValues(rowIdx int) tree.Datums { return n.run.tw.batchedValues(rowIdx) } +func (n *upsertNode) BatchedValues(rowIdx int) tree.Datums { return n.run.tw.rows.At(rowIdx) } func (n *upsertNode) Close(ctx context.Context) { n.source.Close(ctx)