diff --git a/pkg/sql/logictest/testdata/logic_test/upsert b/pkg/sql/logictest/testdata/logic_test/upsert index ba0084b11550..c41a304be655 100644 --- a/pkg/sql/logictest/testdata/logic_test/upsert +++ b/pkg/sql/logictest/testdata/logic_test/upsert @@ -1234,3 +1234,18 @@ statement ok RESET CLUSTER SETTING kv.raft.command.max_size; DROP TABLE src; DROP TABLE dest + +# Regression test for finishing UPSERT too early (#54456). +statement ok +CREATE TABLE t54456 (c INT PRIMARY KEY); +UPSERT INTO t54456 SELECT i FROM generate_series(1, 25000) AS i + +query I +SELECT count(*) FROM t54456 +---- +25000 + +query I +WITH cte(c) AS (UPSERT INTO t54456 SELECT i FROM generate_series(25001, 40000) AS i RETURNING c) SELECT count(*) FROM cte +---- +15000 diff --git a/pkg/sql/tablewriter_upsert_opt.go b/pkg/sql/tablewriter_upsert_opt.go index 9df944dc0e49..7da02e4656c1 100644 --- a/pkg/sql/tablewriter_upsert_opt.go +++ b/pkg/sql/tablewriter_upsert_opt.go @@ -62,9 +62,10 @@ type optTableUpserter struct { // collectRows is true. insertReorderingRequired bool - // resultCount is the number of upserts. Mirrors rowsUpserted.Len() if - // collectRows is set, counted separately otherwise. - resultCount int + // rowsInLastProcessedBatch tracks the number of upserts that were + // performed in the last processed batch. If collectRows is true, it will + // be equal to rowsUpserted.Len() after the batch has been created. + rowsInLastProcessedBatch int // fetchCols indicate which columns need to be fetched from the target table, // in order to detect whether a conflict has occurred, as well as to provide @@ -148,24 +149,9 @@ func (tu *optTableUpserter) init( // flushAndStartNewBatch is part of the tableWriter interface. func (tu *optTableUpserter) flushAndStartNewBatch(ctx context.Context) error { - tu.resultCount = 0 - if tu.collectRows { - tu.rowsUpserted.Clear(ctx) - } return tu.tableWriterBase.flushAndStartNewBatch(ctx, tu.tableDesc()) } -// batchedCount is part of the batchedTableWriter interface. -func (tu *optTableUpserter) batchedCount() int { return tu.resultCount } - -// batchedValues is part of the batchedTableWriter interface. -func (tu *optTableUpserter) batchedValues(rowIdx int) tree.Datums { - if !tu.collectRows { - panic("return row requested but collect rows was not set") - } - return tu.rowsUpserted.At(rowIdx) -} - // close is part of the tableWriter interface. func (tu *optTableUpserter) close(ctx context.Context) { if tu.rowsUpserted != nil { @@ -210,7 +196,6 @@ func (*optTableUpserter) desc() string { return "opt upserter" } // row is part of the tableWriter interface. func (tu *optTableUpserter) row(ctx context.Context, row tree.Datums, traceKV bool) error { tu.batchSize++ - tu.resultCount++ // Consult the canary column to determine whether to insert or update. For // more details on how canary columns work, see the block comment on diff --git a/pkg/sql/upsert.go b/pkg/sql/upsert.go index 09a11ea6d8cd..486509ca993a 100644 --- a/pkg/sql/upsert.go +++ b/pkg/sql/upsert.go @@ -17,6 +17,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" "github.com/cockroachdb/cockroach/pkg/util/tracing" + "github.com/cockroachdb/errors" ) var upsertNodePool = sync.Pool{ @@ -83,6 +84,11 @@ func (n *upsertNode) BatchedNext(params runParams) (bool, error) { tracing.AnnotateTrace() + // Advance one batch. First, clear the current batch. + if n.run.tw.collectRows { + n.run.tw.rowsUpserted.Clear(params.ctx) + } + // Now consume/accumulate the rows for this batch. lastBatch := false for { @@ -136,13 +142,17 @@ func (n *upsertNode) BatchedNext(params runParams) (bool, error) { n.run.done = true } + // We've just finished processing this batch, and we need to remember how + // many rows were in it. + n.run.tw.rowsInLastProcessedBatch = batchSize + // Possibly initiate a run of CREATE STATISTICS. params.ExecCfg().StatsRefresher.NotifyMutation( n.run.tw.tableDesc().ID, - n.run.tw.batchedCount(), + n.run.tw.rowsInLastProcessedBatch, ) - return n.run.tw.batchedCount() > 0, nil + return n.run.tw.rowsInLastProcessedBatch > 0, nil } // processSourceRow processes one row from the source for upsertion. @@ -172,10 +182,15 @@ func (n *upsertNode) processSourceRow(params runParams, rowVals tree.Datums) err } // BatchedCount implements the batchedPlanNode interface. -func (n *upsertNode) BatchedCount() int { return n.run.tw.batchedCount() } +func (n *upsertNode) BatchedCount() int { return n.run.tw.rowsInLastProcessedBatch } // BatchedValues implements the batchedPlanNode interface. -func (n *upsertNode) BatchedValues(rowIdx int) tree.Datums { return n.run.tw.batchedValues(rowIdx) } +func (n *upsertNode) BatchedValues(rowIdx int) tree.Datums { + if !n.run.tw.collectRows { + panic(errors.AssertionFailedf("return row requested but collect rows was not set")) + } + return n.run.tw.rowsUpserted.At(rowIdx) +} func (n *upsertNode) Close(ctx context.Context) { n.source.Close(ctx)