diff --git a/pkg/ccl/streamingccl/streamingest/BUILD.bazel b/pkg/ccl/streamingccl/streamingest/BUILD.bazel index 07b90170ac07..aa42500557ee 100644 --- a/pkg/ccl/streamingccl/streamingest/BUILD.bazel +++ b/pkg/ccl/streamingccl/streamingest/BUILD.bazel @@ -131,6 +131,7 @@ go_test( "//pkg/testutils/storageutils", "//pkg/testutils/testcluster", "//pkg/upgrade/upgradebase", + "//pkg/util/ctxgroup", "//pkg/util/hlc", "//pkg/util/leaktest", "//pkg/util/limit", diff --git a/pkg/ccl/streamingccl/streamingest/stream_ingestion_job.go b/pkg/ccl/streamingccl/streamingest/stream_ingestion_job.go index cd6fb108006e..ce75b0d1e815 100644 --- a/pkg/ccl/streamingccl/streamingest/stream_ingestion_job.go +++ b/pkg/ccl/streamingccl/streamingest/stream_ingestion_job.go @@ -466,8 +466,36 @@ func maybeRevertToCutoverTimestamp( } updateRunningStatus(ctx, j, fmt.Sprintf("starting to cut over to the given timestamp %s", cutoverTime)) + + origNRanges := -1 spans := []roachpb.Span{sd.Span} + updateJobProgress := func() error { + if spans == nil { + return nil + } + nRanges, err := sql.NumRangesInSpans(ctx, p.ExecCfg().DB, p.DistSQLPlanner(), spans) + if err != nil { + return err + } + if origNRanges == -1 { + origNRanges = nRanges + } + return p.ExecCfg().DB.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { + if nRanges < origNRanges { + fractionRangesFinished := float32(origNRanges-nRanges) / float32(origNRanges) + if err := j.FractionProgressed(ctx, txn, + jobs.FractionUpdater(fractionRangesFinished)); err != nil { + return jobs.SimplifyInvalidStatusError(err) + } + } + return nil + }) + } + for len(spans) != 0 { + if err := updateJobProgress(); err != nil { + log.Warningf(ctx, "failed to update replication job progress: %+v", err) + } var b kv.Batch for _, span := range spans { b.AddRawRequest(&roachpb.RevertRangeRequest{ @@ -479,6 +507,9 @@ func maybeRevertToCutoverTimestamp( }) } b.Header.MaxSpanRequestKeys = sql.RevertTableDefaultBatchSize + if p.ExecCfg().StreamingTestingKnobs != nil && p.ExecCfg().StreamingTestingKnobs.OverrideRevertRangeBatchSize != 0 { + b.Header.MaxSpanRequestKeys = p.ExecCfg().StreamingTestingKnobs.OverrideRevertRangeBatchSize + } if err := db.Run(ctx, &b); err != nil { return false, err } @@ -494,7 +525,7 @@ func maybeRevertToCutoverTimestamp( } } } - return true, j.SetProgress(ctx, nil /* txn */, *sp.StreamIngest) + return true, updateJobProgress() } func activateTenant(ctx context.Context, execCtx interface{}, newTenantID roachpb.TenantID) error { diff --git a/pkg/ccl/streamingccl/streamingest/stream_ingestion_job_test.go b/pkg/ccl/streamingccl/streamingest/stream_ingestion_job_test.go index 18acce6f7532..278ab8e3a02d 100644 --- a/pkg/ccl/streamingccl/streamingest/stream_ingestion_job_test.go +++ b/pkg/ccl/streamingccl/streamingest/stream_ingestion_job_test.go @@ -10,6 +10,7 @@ package streamingest import ( "context" + "fmt" "net/url" "strings" "testing" @@ -25,6 +26,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/jobs/jobspb" "github.com/cockroachdb/cockroach/pkg/keys" "github.com/cockroachdb/cockroach/pkg/kv" + "github.com/cockroachdb/cockroach/pkg/kv/kvserver" "github.com/cockroachdb/cockroach/pkg/repstream/streampb" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/security/username" @@ -35,6 +37,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/testutils/skip" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" + "github.com/cockroachdb/cockroach/pkg/util/ctxgroup" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" @@ -382,3 +385,138 @@ func TestReplicationJobResumptionStartTime(t *testing.T) { c.Cutover(producerJobID, replicationJobID, srcTime.GoTime()) jobutils.WaitForJobToSucceed(t, c.DestSysSQL, jobspb.JobID(replicationJobID)) } + +func makeTableSpan(codec keys.SQLCodec, tableID uint32) roachpb.Span { + k := codec.TablePrefix(tableID) + return roachpb.Span{Key: k, EndKey: k.PrefixEnd()} +} + +func TestCutoverFractionProgressed(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + + respRecvd := make(chan struct{}) + continueRevert := make(chan struct{}) + defer close(continueRevert) + s, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{ + Knobs: base.TestingKnobs{ + Store: &kvserver.StoreTestingKnobs{ + TestingResponseFilter: func(ctx context.Context, ba *roachpb.BatchRequest, br *roachpb.BatchResponse) *roachpb.Error { + for _, ru := range br.Responses { + switch ru.GetInner().(type) { + case *roachpb.RevertRangeResponse: + respRecvd <- struct{}{} + <-continueRevert + } + } + return nil + }, + }, + Streaming: &sql.StreamingTestingKnobs{ + OverrideRevertRangeBatchSize: 1, + }, + }, + DisableDefaultTestTenant: true, + }) + defer s.Stopper().Stop(ctx) + + _, err := sqlDB.Exec(`CREATE TABLE foo(id) AS SELECT generate_series(1, 10)`) + require.NoError(t, err) + + cutover := hlc.Timestamp{WallTime: timeutil.Now().UnixNano()} + + // Insert some revisions which we can revert to a timestamp before the update. + _, err = sqlDB.Exec(`UPDATE foo SET id = id + 1`) + require.NoError(t, err) + + // Split every other row into its own range. Progress updates are on a + // per-range basis so we need >1 range to see the fraction progress. + _, err = sqlDB.Exec(`ALTER TABLE foo SPLIT AT (SELECT rowid FROM foo WHERE rowid % 2 = 0)`) + require.NoError(t, err) + + var nRanges int + require.NoError(t, sqlDB.QueryRow( + `SELECT count(*) FROM [SHOW RANGES FROM TABLE foo]`).Scan(&nRanges)) + + require.Equal(t, nRanges, 6) + var id int + err = sqlDB.QueryRow(`SELECT id FROM system.namespace WHERE name = 'foo'`).Scan(&id) + require.NoError(t, err) + + // Create a mock replication job with the `foo` table span so that on cut over + // we can revert the table's ranges. + execCfg := s.ExecutorConfig().(sql.ExecutorConfig) + jobExecCtx := &sql.FakeJobExecContext{ExecutorConfig: &execCfg} + mockReplicationJobDetails := jobspb.StreamIngestionDetails{ + Span: makeTableSpan(execCfg.Codec, uint32(id)), + } + mockReplicationJobRecord := jobs.Record{ + Details: mockReplicationJobDetails, + Progress: jobspb.StreamIngestionProgress{ + CutoverTime: cutover, + }, + Username: username.TestUserName(), + } + registry := execCfg.JobRegistry + jobID := registry.MakeJobID() + replicationJob, err := registry.CreateJobWithTxn(ctx, mockReplicationJobRecord, jobID, nil) + require.NoError(t, err) + require.NoError(t, replicationJob.Update(ctx, nil, func(txn *kv.Txn, md jobs.JobMetadata, ju *jobs.JobUpdater) error { + return jobs.UpdateHighwaterProgressed(cutover, md, ju) + })) + + g := ctxgroup.WithContext(ctx) + g.GoCtx(func(ctx context.Context) error { + defer close(respRecvd) + revert, err := maybeRevertToCutoverTimestamp(ctx, jobExecCtx, jobID) + require.NoError(t, err) + require.True(t, revert) + return nil + }) + + loadProgress := func() jobspb.Progress { + j, err := execCfg.JobRegistry.LoadJob(ctx, jobID) + require.NoError(t, err) + return j.Progress() + } + progressMap := map[string]bool{ + "0.00": false, + "0.17": false, + "0.33": false, + "0.50": false, + "0.67": false, + "0.83": false, + } + g.GoCtx(func(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case _, ok := <-respRecvd: + if !ok { + return nil + } + sip := loadProgress() + curProgress := sip.GetFractionCompleted() + s := fmt.Sprintf("%.2f", curProgress) + if _, ok := progressMap[s]; !ok { + t.Fatalf("unexpected progress fraction %s", s) + } + progressMap[s] = true + continueRevert <- struct{}{} + } + } + }) + require.NoError(t, g.Wait()) + sip := loadProgress() + require.Equal(t, sip.GetFractionCompleted(), float32(1)) + + // Ensure we have hit all our expected progress fractions. + for k, v := range progressMap { + if !v { + t.Fatalf("failed to see progress fraction %s", k) + } + } +} diff --git a/pkg/sql/backfill.go b/pkg/sql/backfill.go index 7ded154c3f63..15a640a71eee 100644 --- a/pkg/sql/backfill.go +++ b/pkg/sql/backfill.go @@ -843,11 +843,11 @@ func getJobIDForMutationWithDescriptor( "job not found for table id %d, mutation %d", tableDesc.GetID(), mutationID) } -// numRangesInSpans returns the number of ranges that cover a set of spans. +// NumRangesInSpans returns the number of ranges that cover a set of spans. // -// It operates entirely on the current goroutine and is thus able to -// reuse an existing kv.Txn safely. -func numRangesInSpans( +// It operates entirely on the current goroutine and is thus able to reuse an +// existing kv.Txn safely. +func NumRangesInSpans( ctx context.Context, db *kv.DB, distSQLPlanner *DistSQLPlanner, spans []roachpb.Span, ) (int, error) { txn := db.NewTxn(ctx, "num-ranges-in-spans") @@ -1099,7 +1099,7 @@ func (sc *SchemaChanger) distIndexBackfill( if updatedTodoSpans == nil { return nil } - nRanges, err := numRangesInSpans(ctx, sc.db, sc.distSQLPlanner, updatedTodoSpans) + nRanges, err := NumRangesInSpans(ctx, sc.db, sc.distSQLPlanner, updatedTodoSpans) if err != nil { return err } @@ -1252,7 +1252,7 @@ func (sc *SchemaChanger) distColumnBackfill( // schema change state machine or from a previous backfill attempt, // we scale that fraction of ranges completed by the remaining fraction // of the job's progress bar. - nRanges, err := numRangesInSpans(ctx, sc.db, sc.distSQLPlanner, todoSpans) + nRanges, err := NumRangesInSpans(ctx, sc.db, sc.distSQLPlanner, todoSpans) if err != nil { return err } @@ -2889,7 +2889,7 @@ func (sc *SchemaChanger) distIndexMerge( // TODO(rui): these can be initialized along with other new schema changer dependencies. planner := NewIndexBackfillerMergePlanner(sc.execCfg) rc := func(ctx context.Context, spans []roachpb.Span) (int, error) { - return numRangesInSpans(ctx, sc.db, sc.distSQLPlanner, spans) + return NumRangesInSpans(ctx, sc.db, sc.distSQLPlanner, spans) } tracker := NewIndexMergeTracker(progress, sc.job, rc, fractionScaler) periodicFlusher := newPeriodicProgressFlusher(sc.settings) diff --git a/pkg/sql/exec_util.go b/pkg/sql/exec_util.go index 3c1a9900c881..c6e362bac1d7 100644 --- a/pkg/sql/exec_util.go +++ b/pkg/sql/exec_util.go @@ -1662,6 +1662,10 @@ type StreamingTestingKnobs struct { // frontier specs generated for the replication job. AfterReplicationFlowPlan func([]*execinfrapb.StreamIngestionDataSpec, *execinfrapb.StreamIngestionFrontierSpec) + + // OverrideRevertRangeBatchSize allows overriding the `MaxSpanRequestKeys` + // used when sending a RevertRange request. + OverrideRevertRangeBatchSize int64 } var _ base.ModuleTestingKnobs = &StreamingTestingKnobs{} diff --git a/pkg/sql/job_exec_context_test_util.go b/pkg/sql/job_exec_context_test_util.go index 044d5c73d5b3..0d82c6fe6c48 100644 --- a/pkg/sql/job_exec_context_test_util.go +++ b/pkg/sql/job_exec_context_test_util.go @@ -52,7 +52,10 @@ func (p *FakeJobExecContext) SessionDataMutatorIterator() *sessionDataMutatorIte // DistSQLPlanner implements the JobExecContext interface. func (p *FakeJobExecContext) DistSQLPlanner() *DistSQLPlanner { - panic("unimplemented") + if p.ExecutorConfig == nil { + panic("unimplemented") + } + return p.ExecutorConfig.DistSQLPlanner } // LeaseMgr implements the JobExecContext interface.