diff --git a/br/pkg/restore/BUILD.bazel b/br/pkg/restore/BUILD.bazel index 5351b6ee4178d..584f73869b956 100644 --- a/br/pkg/restore/BUILD.bazel +++ b/br/pkg/restore/BUILD.bazel @@ -5,16 +5,22 @@ go_library( srcs = [ "import_mode_switcher.go", "misc.go", + "restorer.go", ], importpath = "github.com/pingcap/tidb/br/pkg/restore", visibility = ["//visibility:public"], deps = [ + "//br/pkg/checkpoint", "//br/pkg/conn", "//br/pkg/conn/util", "//br/pkg/errors", "//br/pkg/logutil", "//br/pkg/pdutil", + "//br/pkg/restore/split", + "//br/pkg/restore/utils", + "//br/pkg/summary", "//br/pkg/utils", + "//br/pkg/utils/iter", "//pkg/domain", "//pkg/kv", "//pkg/meta", @@ -22,8 +28,10 @@ go_library( "//pkg/parser/model", "//pkg/util", "@com_github_go_sql_driver_mysql//:mysql", + "@com_github_opentracing_opentracing_go//:opentracing-go", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", + "@com_github_pingcap_kvproto//pkg/brpb", "@com_github_pingcap_kvproto//pkg/import_sstpb", "@com_github_pingcap_log//:log", "@com_github_tikv_client_go_v2//oracle", @@ -34,6 +42,7 @@ go_library( "@org_golang_google_grpc//credentials/insecure", "@org_golang_x_sync//errgroup", "@org_uber_go_zap//:zap", + "@org_uber_go_zap//zapcore", ], ) @@ -43,21 +52,28 @@ go_test( srcs = [ "import_mode_switcher_test.go", "misc_test.go", + "restorer_test.go", ], flaky = True, - race = "off", - shard_count = 6, + shard_count = 13, deps = [ ":restore", "//br/pkg/conn", "//br/pkg/mock", "//br/pkg/pdutil", "//br/pkg/restore/split", + "//br/pkg/restore/utils", + "//br/pkg/utils/iter", "//pkg/kv", "//pkg/parser/model", "//pkg/session", + "//pkg/tablecodec", + "//pkg/util", + "//pkg/util/codec", "@com_github_coreos_go_semver//semver", + "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", + "@com_github_pingcap_kvproto//pkg/brpb", "@com_github_pingcap_kvproto//pkg/import_sstpb", "@com_github_pingcap_kvproto//pkg/metapb", "@com_github_stretchr_testify//require", diff --git a/br/pkg/restore/import_mode_switcher.go b/br/pkg/restore/import_mode_switcher.go index 0ae69f4a6a0af..be01389c19e5f 100644 --- a/br/pkg/restore/import_mode_switcher.go +++ b/br/pkg/restore/import_mode_switcher.go @@ -5,6 +5,7 @@ package restore import ( "context" "crypto/tls" + "sync" "time" _ "github.com/go-sql-driver/mysql" // mysql driver @@ -46,9 +47,11 @@ func NewImportModeSwitcher( } } +var closeOnce sync.Once + // switchToNormalMode switch tikv cluster to normal mode. -func (switcher *ImportModeSwitcher) switchToNormalMode(ctx context.Context) error { - close(switcher.switchCh) +func (switcher *ImportModeSwitcher) SwitchToNormalMode(ctx context.Context) error { + closeOnce.Do(func() { close(switcher.switchCh) }) return switcher.switchTiKVMode(ctx, import_sstpb.SwitchMode_Normal) } @@ -113,8 +116,8 @@ func (switcher *ImportModeSwitcher) switchTiKVMode( return nil } -// switchToImportMode switch tikv cluster to import mode. -func (switcher *ImportModeSwitcher) switchToImportMode( +// SwitchToImportMode switch tikv cluster to import mode. +func (switcher *ImportModeSwitcher) SwitchToImportMode( ctx context.Context, ) { // tikv automatically switch to normal mode in every 10 minutes @@ -163,7 +166,7 @@ func RestorePreWork( if switchToImport { // Switch TiKV cluster to import mode (adjust rocksdb configuration). - switcher.switchToImportMode(ctx) + switcher.SwitchToImportMode(ctx) } return mgr.RemoveSchedulersWithConfig(ctx) @@ -186,7 +189,7 @@ func RestorePostWork( ctx = context.Background() } - if err := switcher.switchToNormalMode(ctx); err != nil { + if err := switcher.SwitchToNormalMode(ctx); err != nil { log.Warn("fail to switch to normal mode", zap.Error(err)) } if err := restoreSchedulers(ctx); err != nil { diff --git a/br/pkg/restore/log_client/BUILD.bazel b/br/pkg/restore/log_client/BUILD.bazel index 85da59628f8b7..7fb781e7ad0ef 100644 --- a/br/pkg/restore/log_client/BUILD.bazel +++ b/br/pkg/restore/log_client/BUILD.bazel @@ -4,10 +4,12 @@ go_library( name = "log_client", srcs = [ "client.go", + "compacted_file_strategy.go", "import.go", "import_retry.go", "log_file_manager.go", "log_file_map.go", + "log_split_strategy.go", "migration.go", ], importpath = "github.com/pingcap/tidb/br/pkg/restore/log_client", @@ -26,6 +28,7 @@ go_library( "//br/pkg/restore/ingestrec", "//br/pkg/restore/internal/import_client", "//br/pkg/restore/internal/rawkv", + "//br/pkg/restore/snap_client", "//br/pkg/restore/split", "//br/pkg/restore/tiflashrec", "//br/pkg/restore/utils", @@ -43,6 +46,7 @@ go_library( "//pkg/util", "//pkg/util/codec", "//pkg/util/redact", + "//pkg/util/sqlexec", "//pkg/util/table-filter", "@com_github_fatih_color//:color", "@com_github_gogo_protobuf//proto", @@ -86,12 +90,13 @@ go_test( ], embed = [":log_client"], flaky = True, - shard_count = 42, + shard_count = 45, deps = [ "//br/pkg/errors", "//br/pkg/glue", "//br/pkg/gluetidb", "//br/pkg/mock", + "//br/pkg/restore", "//br/pkg/restore/internal/import_client", "//br/pkg/restore/split", "//br/pkg/restore/utils", @@ -113,6 +118,7 @@ go_test( "//pkg/util/codec", "//pkg/util/sqlexec", "//pkg/util/table-filter", + "@com_github_docker_go_units//:go-units", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", "@com_github_pingcap_kvproto//pkg/brpb", @@ -120,10 +126,8 @@ go_test( "@com_github_pingcap_kvproto//pkg/errorpb", "@com_github_pingcap_kvproto//pkg/import_sstpb", "@com_github_pingcap_kvproto//pkg/metapb", - "@com_github_pingcap_kvproto//pkg/pdpb", "@com_github_pingcap_log//:log", "@com_github_stretchr_testify//require", - "@com_github_tikv_pd_client//:client", "@org_golang_google_grpc//codes", "@org_golang_google_grpc//keepalive", "@org_golang_google_grpc//status", diff --git a/br/pkg/restore/log_client/client.go b/br/pkg/restore/log_client/client.go index 4160aa86a6048..4faf59a316657 100644 --- a/br/pkg/restore/log_client/client.go +++ b/br/pkg/restore/log_client/client.go @@ -47,6 +47,7 @@ import ( "github.com/pingcap/tidb/br/pkg/restore/ingestrec" importclient "github.com/pingcap/tidb/br/pkg/restore/internal/import_client" "github.com/pingcap/tidb/br/pkg/restore/internal/rawkv" + snapclient "github.com/pingcap/tidb/br/pkg/restore/snap_client" "github.com/pingcap/tidb/br/pkg/restore/split" "github.com/pingcap/tidb/br/pkg/restore/tiflashrec" restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" @@ -62,6 +63,7 @@ import ( "github.com/pingcap/tidb/pkg/meta" "github.com/pingcap/tidb/pkg/meta/model" tidbutil "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/sqlexec" filter "github.com/pingcap/tidb/pkg/util/table-filter" "github.com/tikv/client-go/v2/config" kvutil "github.com/tikv/client-go/v2/util" @@ -79,19 +81,122 @@ const maxSplitKeysOnce = 10240 // rawKVBatchCount specifies the count of entries that the rawkv client puts into TiKV. const rawKVBatchCount = 64 +// LogRestoreManager is a comprehensive wrapper that encapsulates all logic related to log restoration, +// including concurrency management, checkpoint handling, and file importing for efficient log processing. +type LogRestoreManager struct { + fileImporter *LogFileImporter + workerPool *tidbutil.WorkerPool + checkpointRunner *checkpoint.CheckpointRunner[checkpoint.LogRestoreKeyType, checkpoint.LogRestoreValueType] +} + +func NewLogRestoreManager( + ctx context.Context, + fileImporter *LogFileImporter, + poolSize uint, + createCheckpointSessionFn func() (glue.Session, error), +) (*LogRestoreManager, error) { + // for compacted reason, user only set --concurrency for log file restore speed. + log.Info("log restore worker pool", zap.Uint("size", poolSize)) + l := &LogRestoreManager{ + fileImporter: fileImporter, + workerPool: tidbutil.NewWorkerPool(poolSize, "log manager worker pool"), + } + se, err := createCheckpointSessionFn() + if err != nil { + return nil, errors.Trace(err) + } + + if se != nil { + l.checkpointRunner, err = checkpoint.StartCheckpointRunnerForLogRestore(ctx, se) + if err != nil { + return nil, errors.Trace(err) + } + } + return l, nil +} + +func (l *LogRestoreManager) Close(ctx context.Context) { + if l.fileImporter != nil { + if err := l.fileImporter.Close(); err != nil { + log.Warn("failed to close file importer") + } + } + if l.checkpointRunner != nil { + l.checkpointRunner.WaitForFinish(ctx, true) + } +} + +// SstRestoreManager is a comprehensive wrapper that encapsulates all logic related to sst restoration, +// including concurrency management, checkpoint handling, and file importing(splitting) for efficient log processing. +type SstRestoreManager struct { + restorer restore.SstRestorer + workerPool *tidbutil.WorkerPool + checkpointRunner *checkpoint.CheckpointRunner[checkpoint.RestoreKeyType, checkpoint.RestoreValueType] +} + +func (s *SstRestoreManager) Close(ctx context.Context) { + if s.restorer != nil { + if err := s.restorer.Close(); err != nil { + log.Warn("failed to close file restorer") + } + } + if s.checkpointRunner != nil { + s.checkpointRunner.WaitForFinish(ctx, true) + } +} + +func NewSstRestoreManager( + ctx context.Context, + snapFileImporter *snapclient.SnapFileImporter, + concurrencyPerStore uint, + storeCount uint, + createCheckpointSessionFn func() (glue.Session, error), +) (*SstRestoreManager, error) { + // This poolSize is similar to full restore, as both workflows are comparable. + // The poolSize should be greater than concurrencyPerStore multiplied by the number of stores. + poolSize := concurrencyPerStore * 32 * storeCount + log.Info("sst restore worker pool", zap.Uint("size", poolSize)) + sstWorkerPool := tidbutil.NewWorkerPool(poolSize, "sst file") + + s := &SstRestoreManager{ + workerPool: tidbutil.NewWorkerPool(poolSize, "log manager worker pool"), + } + se, err := createCheckpointSessionFn() + if err != nil { + return nil, errors.Trace(err) + } + if se != nil { + checkpointRunner, err := checkpoint.StartCheckpointRunnerForRestore(ctx, se) + if err != nil { + return nil, errors.Trace(err) + } + s.checkpointRunner = checkpointRunner + } + // TODO implement checkpoint + s.restorer = restore.NewSimpleSstRestorer(ctx, snapFileImporter, sstWorkerPool, nil) + return s, nil +} + type LogClient struct { - cipher *backuppb.CipherInfo - pdClient pd.Client - pdHTTPClient pdhttp.Client - clusterID uint64 - dom *domain.Domain - tlsConf *tls.Config - keepaliveConf keepalive.ClientParameters + *LogFileManager + logRestoreManager *LogRestoreManager + sstRestoreManager *SstRestoreManager + + cipher *backuppb.CipherInfo + pdClient pd.Client + pdHTTPClient pdhttp.Client + clusterID uint64 + dom *domain.Domain + tlsConf *tls.Config + keepaliveConf keepalive.ClientParameters + concurrencyPerStore uint rawKVClient *rawkv.RawKVBatchClient storage storage.ExternalStorage - se glue.Session + // unsafeSession is not thread-safe. + // Currently, it is only utilized in some initialization and post-handle functions. + unsafeSession glue.Session // currentTS is used for rewrite meta kv when restore stream. // Can not use `restoreTS` directly, because schema created in `full backup` maybe is new than `restoreTS`. @@ -99,11 +204,6 @@ type LogClient struct { upstreamClusterID uint64 - *LogFileManager - - workerPool *tidbutil.WorkerPool - fileImporter *LogFileImporter - // the query to insert rows into table `gc_delete_range`, lack of ts. deleteRangeQuery []*stream.PreDelRangeQuery deleteRangeQueryCh chan *stream.PreDelRangeQuery @@ -131,21 +231,80 @@ func NewRestoreClient( } // Close a client. -func (rc *LogClient) Close() { +func (rc *LogClient) Close(ctx context.Context) { + defer func() { + if rc.logRestoreManager != nil { + rc.logRestoreManager.Close(ctx) + } + if rc.sstRestoreManager != nil { + rc.sstRestoreManager.Close(ctx) + } + }() + // close the connection, and it must be succeed when in SQL mode. - if rc.se != nil { - rc.se.Close() + if rc.unsafeSession != nil { + rc.unsafeSession.Close() } if rc.rawKVClient != nil { rc.rawKVClient.Close() } + log.Info("Restore client closed") +} - if err := rc.fileImporter.Close(); err != nil { - log.Warn("failed to close file improter") +func (rc *LogClient) RestoreCompactedSstFiles( + ctx context.Context, + compactionsIter iter.TryNextor[*backuppb.LogFileSubcompaction], + rules map[int64]*restoreutils.RewriteRules, + importModeSwitcher *restore.ImportModeSwitcher, + onProgress func(int64), +) error { + backupFileSets := make([]restore.BackupFileSet, 0, 8) + // Collect all items from the iterator in advance to avoid blocking during restoration. + // This approach ensures that we have all necessary data ready for processing, + // preventing any potential delays caused by waiting for the iterator to yield more items. + for r := compactionsIter.TryNext(ctx); !r.Finished; r = compactionsIter.TryNext(ctx) { + if r.Err != nil { + return r.Err + } + i := r.Item + rewriteRules, ok := rules[i.Meta.TableId] + if !ok { + log.Warn("[Compacted SST Restore] Skipping excluded table during restore.", zap.Int64("table_id", i.Meta.TableId)) + continue + } + set := restore.BackupFileSet{ + TableID: i.Meta.TableId, + SSTFiles: i.SstOutputs, + RewriteRules: rewriteRules, + } + backupFileSets = append(backupFileSets, set) + } + if len(backupFileSets) == 0 { + log.Info("[Compacted SST Restore] No SST files found for restoration.") + return nil } + importModeSwitcher.SwitchToImportMode(ctx) + defer func() { + switchErr := importModeSwitcher.SwitchToNormalMode(ctx) + if switchErr != nil { + log.Warn("[Compacted SST Restore] Failed to switch back to normal mode after restoration.", zap.Error(switchErr)) + } + }() - log.Info("Restore client closed") + // To optimize performance and minimize cross-region downloads, + // we are currently opting for a single restore approach instead of batch restoration. + // This decision is similar to the handling of raw and txn restores, + // where batch processing may lead to increased complexity and potential inefficiencies. + // TODO: Future enhancements may explore the feasibility of reintroducing batch restoration + // while maintaining optimal performance and resource utilization. + for _, i := range backupFileSets { + err := rc.sstRestoreManager.restorer.GoRestore(onProgress, []restore.BackupFileSet{i}) + if err != nil { + return errors.Trace(err) + } + } + return rc.sstRestoreManager.restorer.WaitUntilFinish() } func (rc *LogClient) SetRawKVBatchClient( @@ -166,11 +325,6 @@ func (rc *LogClient) SetCrypter(crypter *backuppb.CipherInfo) { rc.cipher = crypter } -func (rc *LogClient) SetConcurrency(c uint) { - log.Info("download worker pool", zap.Uint("size", c)) - rc.workerPool = tidbutil.NewWorkerPool(c, "file") -} - func (rc *LogClient) SetUpstreamClusterID(upstreamClusterID uint64) { log.Info("upstream cluster id", zap.Uint64("cluster id", upstreamClusterID)) rc.upstreamClusterID = upstreamClusterID @@ -210,28 +364,19 @@ func (rc *LogClient) CleanUpKVFiles( ) error { // Current we only have v1 prefix. // In the future, we can add more operation for this interface. - return rc.fileImporter.ClearFiles(ctx, rc.pdClient, "v1") -} - -func (rc *LogClient) StartCheckpointRunnerForLogRestore(ctx context.Context, g glue.Glue, store kv.Storage) (*checkpoint.CheckpointRunner[checkpoint.LogRestoreKeyType, checkpoint.LogRestoreValueType], error) { - se, err := g.CreateSession(store) - if err != nil { - return nil, errors.Trace(err) - } - runner, err := checkpoint.StartCheckpointRunnerForLogRestore(ctx, se) - return runner, errors.Trace(err) + return rc.logRestoreManager.fileImporter.ClearFiles(ctx, rc.pdClient, "v1") } // Init create db connection and domain for storage. -func (rc *LogClient) Init(g glue.Glue, store kv.Storage) error { +func (rc *LogClient) Init(ctx context.Context, g glue.Glue, store kv.Storage) error { var err error - rc.se, err = g.CreateSession(store) + rc.unsafeSession, err = g.CreateSession(store) if err != nil { return errors.Trace(err) } // Set SQL mode to None for avoiding SQL compatibility problem - err = rc.se.Execute(context.Background(), "set @@sql_mode=''") + err = rc.unsafeSession.Execute(ctx, "set @@sql_mode=''") if err != nil { return errors.Trace(err) } @@ -244,7 +389,13 @@ func (rc *LogClient) Init(g glue.Glue, store kv.Storage) error { return nil } -func (rc *LogClient) InitClients(ctx context.Context, backend *backuppb.StorageBackend) { +func (rc *LogClient) InitClients( + ctx context.Context, + backend *backuppb.StorageBackend, + createSessionFn func() (glue.Session, error), + concurrency uint, + concurrencyPerStore uint, +) error { stores, err := conn.GetAllTiKVStoresWithRetry(ctx, rc.pdClient, util.SkipTiFlash) if err != nil { log.Fatal("failed to get stores", zap.Error(err)) @@ -252,7 +403,48 @@ func (rc *LogClient) InitClients(ctx context.Context, backend *backuppb.StorageB metaClient := split.NewClient(rc.pdClient, rc.pdHTTPClient, rc.tlsConf, maxSplitKeysOnce, len(stores)+1) importCli := importclient.NewImportClient(metaClient, rc.tlsConf, rc.keepaliveConf) - rc.fileImporter = NewLogFileImporter(metaClient, importCli, backend) + + rc.logRestoreManager, err = NewLogRestoreManager( + ctx, + NewLogFileImporter(metaClient, importCli, backend), + concurrency, + createSessionFn, + ) + if err != nil { + return errors.Trace(err) + } + var createCallBacks []func(*snapclient.SnapFileImporter) error + var closeCallBacks []func(*snapclient.SnapFileImporter) error + createCallBacks = append(createCallBacks, func(importer *snapclient.SnapFileImporter) error { + return importer.CheckMultiIngestSupport(ctx, stores) + }) + + opt := snapclient.NewSnapFileImporterOptions( + rc.cipher, metaClient, importCli, backend, + snapclient.RewriteModeKeyspace, stores, rc.concurrencyPerStore, createCallBacks, closeCallBacks, + ) + snapFileImporter, err := snapclient.NewSnapFileImporter( + ctx, rc.dom.Store().GetCodec().GetAPIVersion(), snapclient.TiDBCompcated, opt) + if err != nil { + return errors.Trace(err) + } + rc.sstRestoreManager, err = NewSstRestoreManager( + ctx, + snapFileImporter, + concurrencyPerStore, + uint(len(stores)), + createSessionFn, + ) + return errors.Trace(err) +} + +func (rc *LogClient) InitCheckpointMetadataForCompactedSstRestore( + ctx context.Context, +) (map[string]struct{}, error) { + // get sst checkpoint to skip repeated files + sstCheckpointSets := make(map[string]struct{}) + // TODO initial checkpoint + return sstCheckpointSets, nil } func (rc *LogClient) InitCheckpointMetadataForLogRestore( @@ -267,7 +459,7 @@ func (rc *LogClient) InitCheckpointMetadataForLogRestore( // for the first time. if checkpoint.ExistsLogRestoreCheckpointMetadata(ctx, rc.dom) { // load the checkpoint since this is not the first time to restore - meta, err := checkpoint.LoadCheckpointMetadataForLogRestore(ctx, rc.se.GetSessionCtx().GetRestrictedSQLExecutor()) + meta, err := checkpoint.LoadCheckpointMetadataForLogRestore(ctx, rc.unsafeSession.GetSessionCtx().GetRestrictedSQLExecutor()) if err != nil { return "", errors.Trace(err) } @@ -284,7 +476,7 @@ func (rc *LogClient) InitCheckpointMetadataForLogRestore( log.Info("save gc ratio into checkpoint metadata", zap.Uint64("start-ts", startTS), zap.Uint64("restored-ts", restoredTS), zap.Uint64("rewrite-ts", rc.currentTS), zap.String("gc-ratio", gcRatio), zap.Int("tiflash-item-count", len(items))) - if err := checkpoint.SaveCheckpointMetadataForLogRestore(ctx, rc.se, &checkpoint.CheckpointMetadataForLogRestore{ + if err := checkpoint.SaveCheckpointMetadataForLogRestore(ctx, rc.unsafeSession, &checkpoint.CheckpointMetadataForLogRestore{ UpstreamClusterID: rc.upstreamClusterID, RestoredTS: restoredTS, StartTS: startTS, @@ -298,6 +490,15 @@ func (rc *LogClient) InitCheckpointMetadataForLogRestore( return gcRatio, nil } +func (rc *LogClient) GetMigrations(ctx context.Context) ([]*backuppb.Migration, error) { + ext := stream.MigerationExtension(rc.storage) + migs, err := ext.Load(ctx) + if err != nil { + return nil, errors.Trace(err) + } + return migs.ListAll(), nil +} + func (rc *LogClient) InstallLogFileManager(ctx context.Context, startTS, restoreTS uint64, metadataDownloadBatchSize uint, encryptionManager *encryption.Manager) error { init := LogFileManagerInit{ @@ -305,6 +506,10 @@ func (rc *LogClient) InstallLogFileManager(ctx context.Context, startTS, restore RestoreTS: restoreTS, Storage: rc.storage, + MigrationsBuilder: &WithMigrationsBuilder{ + startTS: startTS, + restoredTS: restoreTS, + }, MetadataDownloadBatchSize: metadataDownloadBatchSize, EncryptionManager: encryptionManager, } @@ -477,9 +682,7 @@ func ApplyKVFilesWithSingleMethod( func (rc *LogClient) RestoreKVFiles( ctx context.Context, rules map[int64]*restoreutils.RewriteRules, - idrules map[int64]int64, logIter LogIter, - runner *checkpoint.CheckpointRunner[checkpoint.LogRestoreKeyType, checkpoint.LogRestoreValueType], pitrBatchCount uint32, pitrBatchSize uint32, updateStats func(kvCount uint64, size uint64), @@ -522,31 +725,26 @@ func (rc *LogClient) RestoreKVFiles( // For this version we do not handle new created table after full backup. // in next version we will perform rewrite and restore meta key to restore new created tables. // so we can simply skip the file that doesn't have the rule here. - onProgress(int64(len(files))) + onProgress(kvCount) summary.CollectInt("FileSkip", len(files)) log.Debug("skip file due to table id not matched", zap.Int64("table-id", files[0].TableId)) skipFile += len(files) } else { applyWg.Add(1) - downstreamId := idrules[files[0].TableId] - rc.workerPool.ApplyOnErrorGroup(eg, func() (err error) { + rc.logRestoreManager.workerPool.ApplyOnErrorGroup(eg, func() (err error) { fileStart := time.Now() defer applyWg.Done() defer func() { - onProgress(int64(len(files))) + onProgress(kvCount) updateStats(uint64(kvCount), size) summary.CollectInt("File", len(files)) if err == nil { filenames := make([]string, 0, len(files)) - if runner == nil { - for _, f := range files { - filenames = append(filenames, f.Path+", ") - } - } else { - for _, f := range files { - filenames = append(filenames, f.Path+", ") - if e := checkpoint.AppendRangeForLogRestore(ectx, runner, f.MetaDataGroupName, downstreamId, f.OffsetInMetaGroup, f.OffsetInMergedGroup); e != nil { + for _, f := range files { + filenames = append(filenames, f.Path+", ") + if rc.logRestoreManager.checkpointRunner != nil { + if e := checkpoint.AppendRangeForLogRestore(ectx, rc.logRestoreManager.checkpointRunner, f.MetaDataGroupName, rule.NewTableID, f.OffsetInMetaGroup, f.OffsetInMergedGroup); e != nil { err = errors.Annotate(e, "failed to append checkpoint data") break } @@ -557,13 +755,13 @@ func (rc *LogClient) RestoreKVFiles( } }() - return rc.fileImporter.ImportKVFiles(ectx, files, rule, rc.shiftStartTS, rc.startTS, rc.restoreTS, + return rc.logRestoreManager.fileImporter.ImportKVFiles(ectx, files, rule, rc.shiftStartTS, rc.startTS, rc.restoreTS, supportBatch, cipherInfo, masterKeys) }) } } - rc.workerPool.ApplyOnErrorGroup(eg, func() error { + rc.logRestoreManager.workerPool.ApplyOnErrorGroup(eg, func() error { if supportBatch { err = ApplyKVFilesWithBatchMethod(ectx, logIter, int(pitrBatchCount), uint64(pitrBatchSize), applyFunc, &applyWg) } else { @@ -590,7 +788,7 @@ func (rc *LogClient) initSchemasMap( restoreTS uint64, ) ([]*backuppb.PitrDBMap, error) { getPitrIDMapSQL := "SELECT segment_id, id_map FROM mysql.tidb_pitr_id_map WHERE restored_ts = %? and upstream_cluster_id = %? ORDER BY segment_id;" - execCtx := rc.se.GetSessionCtx().GetRestrictedSQLExecutor() + execCtx := rc.unsafeSession.GetSessionCtx().GetRestrictedSQLExecutor() rows, _, errSQL := execCtx.ExecRestrictedSQL( kv.WithInternalSourceType(ctx, kv.InternalTxnBR), nil, @@ -1246,36 +1444,45 @@ func (rc *LogClient) UpdateSchemaVersion(ctx context.Context) error { return nil } -func (rc *LogClient) WrapLogFilesIterWithSplitHelper(logIter LogIter, rules map[int64]*restoreutils.RewriteRules, g glue.Glue, store kv.Storage) (LogIter, error) { - se, err := g.CreateSession(store) - if err != nil { - return nil, errors.Trace(err) - } - execCtx := se.GetSessionCtx().GetRestrictedSQLExecutor() - splitSize, splitKeys := utils.GetRegionSplitInfo(execCtx) - log.Info("get split threshold from tikv config", zap.Uint64("split-size", splitSize), zap.Int64("split-keys", splitKeys)) +// WrapCompactedFilesIteratorWithSplit applies a splitting strategy to the compacted files iterator. +// It uses a region splitter to handle the splitting logic based on the provided rules and checkpoint sets. +func (rc *LogClient) WrapCompactedFilesIterWithSplitHelper( + ctx context.Context, + compactedIter iter.TryNextor[*backuppb.LogFileSubcompaction], + rules map[int64]*restoreutils.RewriteRules, + checkpointSets map[string]struct{}, + updateStatsFn func(uint64, uint64), + splitSize uint64, + splitKeys int64, +) (iter.TryNextor[*backuppb.LogFileSubcompaction], error) { client := split.NewClient(rc.pdClient, rc.pdHTTPClient, rc.tlsConf, maxSplitKeysOnce, 3) - return NewLogFilesIterWithSplitHelper(logIter, rules, client, splitSize, splitKeys), nil + wrapper := restore.PipelineRestorerWrapper[*backuppb.LogFileSubcompaction]{ + PipelineRegionsSplitter: split.NewPipelineRegionsSplitter(client, splitSize, splitKeys), + } + strategy := NewCompactedFileSplitStrategy(rules, checkpointSets, updateStatsFn) + return wrapper.WithSplit(ctx, compactedIter, strategy), nil } -func (rc *LogClient) generateKvFilesSkipMap(ctx context.Context, downstreamIdset map[int64]struct{}) (*LogFilesSkipMap, error) { - skipMap := NewLogFilesSkipMap() - t, err := checkpoint.LoadCheckpointDataForLogRestore( - ctx, rc.se.GetSessionCtx().GetRestrictedSQLExecutor(), func(groupKey checkpoint.LogRestoreKeyType, off checkpoint.LogRestoreValueMarshaled) { - for tableID, foffs := range off.Foffs { - // filter out the checkpoint data of dropped table - if _, exists := downstreamIdset[tableID]; exists { - for _, foff := range foffs { - skipMap.Insert(groupKey, off.Goff, foff) - } - } - } - }) +// WrapLogFilesIteratorWithSplit applies a splitting strategy to the log files iterator. +// It uses a region splitter to handle the splitting logic based on the provided rules. +func (rc *LogClient) WrapLogFilesIterWithSplitHelper( + ctx context.Context, + logIter LogIter, + execCtx sqlexec.RestrictedSQLExecutor, + rules map[int64]*restoreutils.RewriteRules, + updateStatsFn func(uint64, uint64), + splitSize uint64, + splitKeys int64, +) (LogIter, error) { + client := split.NewClient(rc.pdClient, rc.pdHTTPClient, rc.tlsConf, maxSplitKeysOnce, 3) + wrapper := restore.PipelineRestorerWrapper[*LogDataFileInfo]{ + PipelineRegionsSplitter: split.NewPipelineRegionsSplitter(client, splitSize, splitKeys), + } + strategy, err := NewLogSplitStrategy(ctx, rc.useCheckpoint, execCtx, rules, updateStatsFn) if err != nil { return nil, errors.Trace(err) } - summary.AdjustStartTimeToEarlierTime(t) - return skipMap, nil + return wrapper.WithSplit(ctx, logIter, strategy), nil } func WrapLogFilesIterWithCheckpointFailpoint( @@ -1304,27 +1511,6 @@ func WrapLogFilesIterWithCheckpointFailpoint( return logIter, nil } -func (rc *LogClient) WrapLogFilesIterWithCheckpoint( - ctx context.Context, - logIter LogIter, - downstreamIdset map[int64]struct{}, - updateStats func(kvCount, size uint64), - onProgress func(), -) (LogIter, error) { - skipMap, err := rc.generateKvFilesSkipMap(ctx, downstreamIdset) - if err != nil { - return nil, errors.Trace(err) - } - return iter.FilterOut(logIter, func(d *LogDataFileInfo) bool { - if skipMap.NeedSkip(d.MetaDataGroupName, d.OffsetInMetaGroup, d.OffsetInMergedGroup) { - onProgress() - updateStats(uint64(d.NumberOfEntries), d.Length) - return true - } - return false - }), nil -} - const ( alterTableDropIndexSQL = "ALTER TABLE %n.%n DROP INDEX %n" alterTableAddIndexFormat = "ALTER TABLE %%n.%%n ADD INDEX %%n(%s)" @@ -1339,7 +1525,7 @@ func (rc *LogClient) generateRepairIngestIndexSQLs( var sqls []checkpoint.CheckpointIngestIndexRepairSQL if rc.useCheckpoint { if checkpoint.ExistsCheckpointIngestIndexRepairSQLs(ctx, rc.dom) { - checkpointSQLs, err := checkpoint.LoadCheckpointIngestIndexRepairSQLs(ctx, rc.se.GetSessionCtx().GetRestrictedSQLExecutor()) + checkpointSQLs, err := checkpoint.LoadCheckpointIngestIndexRepairSQLs(ctx, rc.unsafeSession.GetSessionCtx().GetRestrictedSQLExecutor()) if err != nil { return sqls, false, errors.Trace(err) } @@ -1404,7 +1590,7 @@ func (rc *LogClient) generateRepairIngestIndexSQLs( } if rc.useCheckpoint && len(sqls) > 0 { - if err := checkpoint.SaveCheckpointIngestIndexRepairSQLs(ctx, rc.se, &checkpoint.CheckpointIngestIndexRepairSQLs{ + if err := checkpoint.SaveCheckpointIngestIndexRepairSQLs(ctx, rc.unsafeSession, &checkpoint.CheckpointIngestIndexRepairSQLs{ SQLs: sqls, }); err != nil { return sqls, false, errors.Trace(err) @@ -1466,7 +1652,7 @@ NEXTSQL: // only when first execution or old index id is not dropped if !fromCheckpoint || oldIndexIDFound { - if err := rc.se.ExecuteInternal(ctx, alterTableDropIndexSQL, sql.SchemaName.O, sql.TableName.O, sql.IndexName); err != nil { + if err := rc.unsafeSession.ExecuteInternal(ctx, alterTableDropIndexSQL, sql.SchemaName.O, sql.TableName.O, sql.IndexName); err != nil { return errors.Trace(err) } } @@ -1476,7 +1662,7 @@ NEXTSQL: } }) // create the repaired index when first execution or not found it - if err := rc.se.ExecuteInternal(ctx, sql.AddSQL, sql.AddArgs...); err != nil { + if err := rc.unsafeSession.ExecuteInternal(ctx, sql.AddSQL, sql.AddArgs...); err != nil { return errors.Trace(err) } w.Inc() @@ -1549,7 +1735,7 @@ func (rc *LogClient) InsertGCRows(ctx context.Context) error { // trim the ',' behind the query.Sql if exists // that's when the rewrite rule of the last table id is not exist sql := strings.TrimSuffix(query.Sql, ",") - if err := rc.se.ExecuteInternal(ctx, sql, paramsList...); err != nil { + if err := rc.unsafeSession.ExecuteInternal(ctx, sql, paramsList...); err != nil { return errors.Trace(err) } } @@ -1577,7 +1763,7 @@ func (rc *LogClient) saveIDMap( return errors.Trace(err) } // clean the dirty id map at first - err = rc.se.ExecuteInternal(ctx, "DELETE FROM mysql.tidb_pitr_id_map WHERE restored_ts = %? and upstream_cluster_id = %?;", rc.restoreTS, rc.upstreamClusterID) + err = rc.unsafeSession.ExecuteInternal(ctx, "DELETE FROM mysql.tidb_pitr_id_map WHERE restored_ts = %? and upstream_cluster_id = %?;", rc.restoreTS, rc.upstreamClusterID) if err != nil { return errors.Trace(err) } @@ -1587,7 +1773,7 @@ func (rc *LogClient) saveIDMap( if endIdx > len(data) { endIdx = len(data) } - err := rc.se.ExecuteInternal(ctx, replacePitrIDMapSQL, rc.restoreTS, rc.upstreamClusterID, segmentId, data[startIdx:endIdx]) + err := rc.unsafeSession.ExecuteInternal(ctx, replacePitrIDMapSQL, rc.restoreTS, rc.upstreamClusterID, segmentId, data[startIdx:endIdx]) if err != nil { return errors.Trace(err) } @@ -1596,7 +1782,7 @@ func (rc *LogClient) saveIDMap( if rc.useCheckpoint { log.Info("save checkpoint task info with InLogRestoreAndIdMapPersist status") - if err := checkpoint.SaveCheckpointProgress(ctx, rc.se, &checkpoint.CheckpointProgress{ + if err := checkpoint.SaveCheckpointProgress(ctx, rc.unsafeSession, &checkpoint.CheckpointProgress{ Progress: checkpoint.InLogRestoreAndIdMapPersist, }); err != nil { return errors.Trace(err) @@ -1613,7 +1799,6 @@ func (rc *LogClient) FailpointDoChecksumForLogRestore( ctx context.Context, kvClient kv.Client, pdClient pd.Client, - idrules map[int64]int64, rewriteRules map[int64]*restoreutils.RewriteRules, ) (finalErr error) { startTS, err := restore.GetTSWithRetry(ctx, rc.pdClient) @@ -1650,11 +1835,11 @@ func (rc *LogClient) FailpointDoChecksumForLogRestore( infoSchema := rc.GetDomain().InfoSchema() // downstream id -> upstream id reidRules := make(map[int64]int64) - for upstreamID, downstreamID := range idrules { - reidRules[downstreamID] = upstreamID + for upstreamID, r := range rewriteRules { + reidRules[r.NewTableID] = upstreamID } - for upstreamID, downstreamID := range idrules { - newTable, ok := infoSchema.TableByID(ctx, downstreamID) + for upstreamID, r := range rewriteRules { + newTable, ok := infoSchema.TableByID(ctx, r.NewTableID) if !ok { // a dropped table continue @@ -1715,52 +1900,3 @@ func (rc *LogClient) FailpointDoChecksumForLogRestore( return eg.Wait() } - -type LogFilesIterWithSplitHelper struct { - iter LogIter - helper *split.LogSplitHelper - buffer []*LogDataFileInfo - next int -} - -const SplitFilesBufferSize = 4096 - -func NewLogFilesIterWithSplitHelper(iter LogIter, rules map[int64]*restoreutils.RewriteRules, client split.SplitClient, splitSize uint64, splitKeys int64) LogIter { - return &LogFilesIterWithSplitHelper{ - iter: iter, - helper: split.NewLogSplitHelper(rules, client, splitSize, splitKeys), - buffer: nil, - next: 0, - } -} - -func (splitIter *LogFilesIterWithSplitHelper) TryNext(ctx context.Context) iter.IterResult[*LogDataFileInfo] { - if splitIter.next >= len(splitIter.buffer) { - splitIter.buffer = make([]*LogDataFileInfo, 0, SplitFilesBufferSize) - for r := splitIter.iter.TryNext(ctx); !r.Finished; r = splitIter.iter.TryNext(ctx) { - if r.Err != nil { - return r - } - f := r.Item - splitIter.helper.Merge(f.DataFileInfo) - splitIter.buffer = append(splitIter.buffer, f) - if len(splitIter.buffer) >= SplitFilesBufferSize { - break - } - } - splitIter.next = 0 - if len(splitIter.buffer) == 0 { - return iter.Done[*LogDataFileInfo]() - } - log.Info("start to split the regions") - startTime := time.Now() - if err := splitIter.helper.Split(ctx); err != nil { - return iter.Throw[*LogDataFileInfo](errors.Trace(err)) - } - log.Info("end to split the regions", zap.Duration("takes", time.Since(startTime))) - } - - res := iter.Emit(splitIter.buffer[splitIter.next]) - splitIter.next += 1 - return res -} diff --git a/br/pkg/restore/log_client/client_test.go b/br/pkg/restore/log_client/client_test.go index 6b16ec34c28ba..504d7bb798d72 100644 --- a/br/pkg/restore/log_client/client_test.go +++ b/br/pkg/restore/log_client/client_test.go @@ -22,12 +22,15 @@ import ( "testing" "time" + "github.com/docker/go-units" "github.com/pingcap/errors" backuppb "github.com/pingcap/kvproto/pkg/brpb" "github.com/pingcap/kvproto/pkg/import_sstpb" + "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/tidb/br/pkg/glue" "github.com/pingcap/tidb/br/pkg/gluetidb" "github.com/pingcap/tidb/br/pkg/mock" + "github.com/pingcap/tidb/br/pkg/restore" logclient "github.com/pingcap/tidb/br/pkg/restore/log_client" "github.com/pingcap/tidb/br/pkg/restore/split" "github.com/pingcap/tidb/br/pkg/restore/utils" @@ -35,12 +38,14 @@ import ( "github.com/pingcap/tidb/br/pkg/utils/iter" "github.com/pingcap/tidb/br/pkg/utiltest" "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/planner/core/resolve" "github.com/pingcap/tidb/pkg/session" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/tablecodec" "github.com/pingcap/tidb/pkg/testkit" "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" "github.com/pingcap/tidb/pkg/util/sqlexec" filter "github.com/pingcap/tidb/pkg/util/table-filter" "github.com/stretchr/testify/require" @@ -92,7 +97,7 @@ func TestDeleteRangeQueryExec(t *testing.T) { g := gluetidb.New() client := logclient.NewRestoreClient( split.NewFakePDClient(nil, false, nil), nil, nil, keepalive.ClientParameters{}) - err := client.Init(g, m.Storage) + err := client.Init(ctx, g, m.Storage) require.NoError(t, err) client.RunGCRowsLoader(ctx) @@ -111,7 +116,7 @@ func TestDeleteRangeQuery(t *testing.T) { g := gluetidb.New() client := logclient.NewRestoreClient( split.NewFakePDClient(nil, false, nil), nil, nil, keepalive.ClientParameters{}) - err := client.Init(g, m.Storage) + err := client.Init(ctx, g, m.Storage) require.NoError(t, err) client.RunGCRowsLoader(ctx) @@ -1339,7 +1344,12 @@ func TestLogFilesIterWithSplitHelper(t *testing.T) { } mockIter := &mockLogIter{} ctx := context.Background() - logIter := logclient.NewLogFilesIterWithSplitHelper(mockIter, rewriteRulesMap, split.NewFakeSplitClient(), 144*1024*1024, 1440000) + w := restore.PipelineRestorerWrapper[*logclient.LogDataFileInfo]{ + PipelineRegionsSplitter: split.NewPipelineRegionsSplitter(split.NewFakeSplitClient(), 144*1024*1024, 1440000), + } + s, err := logclient.NewLogSplitStrategy(ctx, false, nil, rewriteRulesMap, func(uint64, uint64) {}) + require.NoError(t, err) + logIter := w.WithSplit(context.Background(), mockIter, s) next := 0 for r := logIter.TryNext(ctx); !r.Finished; r = logIter.TryNext(ctx) { require.NoError(t, r.Err) @@ -1506,3 +1516,476 @@ func TestPITRIDMap(t *testing.T) { } } } + +type mockLogStrategy struct { + *logclient.LogSplitStrategy + expectSplitCount int +} + +func (m *mockLogStrategy) ShouldSplit() bool { + return m.AccumulateCount == m.expectSplitCount +} + +func TestLogSplitStrategy(t *testing.T) { + ctx := context.Background() + + // Define rewrite rules for table ID transformations. + rules := map[int64]*utils.RewriteRules{ + 1: { + Data: []*import_sstpb.RewriteRule{ + { + OldKeyPrefix: tablecodec.GenTableRecordPrefix(1), + NewKeyPrefix: tablecodec.GenTableRecordPrefix(100), + }, + }, + }, + 2: { + Data: []*import_sstpb.RewriteRule{ + { + OldKeyPrefix: tablecodec.GenTableRecordPrefix(2), + NewKeyPrefix: tablecodec.GenTableRecordPrefix(200), + }, + }, + }, + } + + // Define initial regions for the mock PD client. + oriRegions := [][]byte{ + {}, + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(100)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(200)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(402)), + } + + // Set up a mock PD client with predefined regions. + storesMap := make(map[uint64]*metapb.Store) + storesMap[1] = &metapb.Store{Id: 1} + mockPDCli := split.NewMockPDClientForSplit() + mockPDCli.SetStores(storesMap) + mockPDCli.SetRegions(oriRegions) + + // Create a split client with the mock PD client. + client := split.NewClient(mockPDCli, nil, nil, 100, 4) + + // Define a mock iterator with sample data files. + mockIter := iter.FromSlice([]*backuppb.DataFileInfo{ + fakeFile(1, 100, 100, 100), + fakeFile(1, 200, 2*units.MiB, 200), + fakeFile(2, 100, 3*units.MiB, 300), + fakeFile(3, 100, 10*units.MiB, 100000), + fakeFile(1, 300, 3*units.MiB, 10), + fakeFile(1, 400, 4*units.MiB, 10), + }) + logIter := toLogDataFileInfoIter(mockIter) + + // Initialize a wrapper for the file restorer with a region splitter. + wrapper := restore.PipelineRestorerWrapper[*logclient.LogDataFileInfo]{ + PipelineRegionsSplitter: split.NewPipelineRegionsSplitter(client, 4*units.MB, 400), + } + + // Create a log split strategy with the given rewrite rules. + strategy, err := logclient.NewLogSplitStrategy(ctx, false, nil, rules, func(u1, u2 uint64) {}) + require.NoError(t, err) + + // Set up a mock strategy to control split behavior. + expectSplitCount := 2 + mockStrategy := &mockLogStrategy{ + LogSplitStrategy: strategy, + // fakeFile(3, 100, 10*units.MiB, 100000) will skipped due to no rewrite rule found. + expectSplitCount: expectSplitCount, + } + + // Use the wrapper to apply the split strategy to the log iterator. + helper := wrapper.WithSplit(ctx, logIter, mockStrategy) + + // Iterate over the log items and verify the split behavior. + count := 0 + for i := helper.TryNext(ctx); !i.Finished; i = helper.TryNext(ctx) { + require.NoError(t, i.Err) + if count == expectSplitCount { + // Verify that no split occurs initially due to insufficient data. + regions, err := mockPDCli.ScanRegions(ctx, []byte{}, []byte{}, 0) + require.NoError(t, err) + require.Len(t, regions, 3) + require.Equal(t, []byte{}, regions[0].Meta.StartKey) + require.Equal(t, codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(100)), regions[1].Meta.StartKey) + require.Equal(t, codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(200)), regions[2].Meta.StartKey) + require.Equal(t, codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(402)), regions[2].Meta.EndKey) + } + // iter.Filterout execute first + count += 1 + } + + // Verify that a split occurs on the second region due to excess data. + regions, err := mockPDCli.ScanRegions(ctx, []byte{}, []byte{}, 0) + require.NoError(t, err) + require.Len(t, regions, 4) + require.Equal(t, fakeRowKey(100, 400), kv.Key(regions[1].Meta.EndKey)) +} + +type mockCompactedStrategy struct { + *logclient.CompactedFileSplitStrategy + expectSplitCount int +} + +func (m *mockCompactedStrategy) ShouldSplit() bool { + return m.AccumulateCount%m.expectSplitCount == 0 +} + +func TestCompactedSplitStrategy(t *testing.T) { + ctx := context.Background() + + rules := map[int64]*utils.RewriteRules{ + 1: { + Data: []*import_sstpb.RewriteRule{ + { + OldKeyPrefix: tablecodec.GenTableRecordPrefix(1), + NewKeyPrefix: tablecodec.GenTableRecordPrefix(100), + }, + }, + }, + 2: { + Data: []*import_sstpb.RewriteRule{ + { + OldKeyPrefix: tablecodec.GenTableRecordPrefix(2), + NewKeyPrefix: tablecodec.GenTableRecordPrefix(200), + }, + }, + }, + } + + oriRegions := [][]byte{ + {}, + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(100)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(200)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(402)), + } + + cases := []struct { + MockSubcompationIter iter.TryNextor[*backuppb.LogFileSubcompaction] + ExpectRegionEndKeys [][]byte + }{ + { + iter.FromSlice([]*backuppb.LogFileSubcompaction{ + fakeSubCompactionWithOneSst(1, 100, 16*units.MiB, 100), + fakeSubCompactionWithOneSst(1, 200, 32*units.MiB, 200), + fakeSubCompactionWithOneSst(2, 100, 48*units.MiB, 300), + fakeSubCompactionWithOneSst(3, 100, 100*units.MiB, 100000), + }), + // no split + // table 1 has not accumlate enough 400 keys or 4MB + [][]byte{ + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(100)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(200)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(402)), + }, + }, + { + iter.FromSlice([]*backuppb.LogFileSubcompaction{ + fakeSubCompactionWithOneSst(1, 100, 16*units.MiB, 100), + fakeSubCompactionWithOneSst(1, 200, 32*units.MiB, 200), + fakeSubCompactionWithOneSst(1, 100, 32*units.MiB, 10), + fakeSubCompactionWithOneSst(2, 100, 48*units.MiB, 300), + }), + // split on table 1 + // table 1 has accumlate enough keys + [][]byte{ + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(100)), + []byte(fakeRowKey(100, 200)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(200)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(402)), + }, + }, + { + iter.FromSlice([]*backuppb.LogFileSubcompaction{ + fakeSubCompactionWithOneSst(1, 100, 16*units.MiB, 100), + fakeSubCompactionWithOneSst(1, 200, 32*units.MiB, 200), + fakeSubCompactionWithOneSst(2, 100, 32*units.MiB, 300), + fakeSubCompactionWithOneSst(3, 100, 10*units.MiB, 100000), + fakeSubCompactionWithOneSst(1, 300, 48*units.MiB, 13), + fakeSubCompactionWithOneSst(1, 400, 64*units.MiB, 14), + fakeSubCompactionWithOneSst(1, 100, 1*units.MiB, 15), + }), + [][]byte{ + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(100)), + []byte(fakeRowKey(100, 400)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(200)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(402)), + }, + }, + } + for _, ca := range cases { + storesMap := make(map[uint64]*metapb.Store) + storesMap[1] = &metapb.Store{Id: 1} + mockPDCli := split.NewMockPDClientForSplit() + mockPDCli.SetStores(storesMap) + mockPDCli.SetRegions(oriRegions) + + client := split.NewClient(mockPDCli, nil, nil, 100, 4) + wrapper := restore.PipelineRestorerWrapper[*backuppb.LogFileSubcompaction]{ + PipelineRegionsSplitter: split.NewPipelineRegionsSplitter(client, 4*units.MB, 400), + } + + strategy := logclient.NewCompactedFileSplitStrategy(rules, nil, nil) + mockStrategy := &mockCompactedStrategy{ + CompactedFileSplitStrategy: strategy, + expectSplitCount: 3, + } + + helper := wrapper.WithSplit(ctx, ca.MockSubcompationIter, mockStrategy) + + for i := helper.TryNext(ctx); !i.Finished; i = helper.TryNext(ctx) { + require.NoError(t, i.Err) + } + + regions, err := mockPDCli.ScanRegions(ctx, []byte{}, []byte{}, 0) + require.NoError(t, err) + require.Len(t, regions, len(ca.ExpectRegionEndKeys)) + for i, endKey := range ca.ExpectRegionEndKeys { + require.Equal(t, endKey, regions[i].Meta.EndKey) + } + } +} + +func TestCompactedSplitStrategyWithCheckpoint(t *testing.T) { + ctx := context.Background() + + rules := map[int64]*utils.RewriteRules{ + 1: { + Data: []*import_sstpb.RewriteRule{ + { + OldKeyPrefix: tablecodec.GenTableRecordPrefix(1), + NewKeyPrefix: tablecodec.GenTableRecordPrefix(100), + }, + }, + }, + 2: { + Data: []*import_sstpb.RewriteRule{ + { + OldKeyPrefix: tablecodec.GenTableRecordPrefix(2), + NewKeyPrefix: tablecodec.GenTableRecordPrefix(200), + }, + }, + }, + } + + oriRegions := [][]byte{ + {}, + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(100)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(200)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(402)), + } + + cases := []struct { + MockSubcompationIter iter.TryNextor[*backuppb.LogFileSubcompaction] + CheckpointSet map[string]struct{} + ProcessedKVCount int + ProcessedSize int + ExpectRegionEndKeys [][]byte + }{ + { + iter.FromSlice([]*backuppb.LogFileSubcompaction{ + fakeSubCompactionWithOneSst(1, 100, 16*units.MiB, 100), + fakeSubCompactionWithOneSst(1, 200, 32*units.MiB, 200), + fakeSubCompactionWithOneSst(2, 100, 48*units.MiB, 300), + fakeSubCompactionWithOneSst(3, 100, 100*units.MiB, 100000), + }), + map[string]struct{}{ + "1:100": {}, + "1:200": {}, + }, + 300, + 48 * units.MiB, + // no split, checkpoint files came in order + [][]byte{ + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(100)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(200)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(402)), + }, + }, + { + iter.FromSlice([]*backuppb.LogFileSubcompaction{ + fakeSubCompactionWithOneSst(1, 100, 16*units.MiB, 100), + fakeSubCompactionWithOneSst(1, 200, 32*units.MiB, 200), + fakeSubCompactionWithOneSst(1, 100, 32*units.MiB, 10), + fakeSubCompactionWithOneSst(2, 100, 48*units.MiB, 300), + }), + map[string]struct{}{ + "1:100": {}, + }, + 110, + 48 * units.MiB, + // no split, checkpoint files came in different order + [][]byte{ + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(100)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(200)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(402)), + }, + }, + { + iter.FromSlice([]*backuppb.LogFileSubcompaction{ + fakeSubCompactionWithOneSst(1, 100, 16*units.MiB, 100), + fakeSubCompactionWithOneSst(1, 200, 32*units.MiB, 200), + fakeSubCompactionWithOneSst(2, 100, 32*units.MiB, 300), + fakeSubCompactionWithOneSst(3, 100, 10*units.MiB, 100000), + fakeSubCompactionWithOneSst(1, 300, 48*units.MiB, 13), + fakeSubCompactionWithOneSst(1, 400, 64*units.MiB, 14), + fakeSubCompactionWithOneSst(1, 100, 1*units.MiB, 15), + }), + map[string]struct{}{ + "1:300": {}, + "1:400": {}, + }, + 27, + 112 * units.MiB, + // no split, the main file has skipped due to checkpoint. + [][]byte{ + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(100)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(200)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(402)), + }, + }, + { + iter.FromSlice([]*backuppb.LogFileSubcompaction{ + fakeSubCompactionWithOneSst(1, 100, 16*units.MiB, 100), + fakeSubCompactionWithOneSst(1, 200, 32*units.MiB, 200), + fakeSubCompactionWithOneSst(2, 100, 32*units.MiB, 300), + fakeSubCompactionWithOneSst(3, 100, 10*units.MiB, 100000), + fakeSubCompactionWithOneSst(1, 300, 48*units.MiB, 13), + fakeSubCompactionWithOneSst(1, 400, 64*units.MiB, 14), + fakeSubCompactionWithOneSst(1, 100, 1*units.MiB, 15), + }), + map[string]struct{}{ + "1:100": {}, + "1:200": {}, + }, + 315, + 49 * units.MiB, + [][]byte{ + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(100)), + []byte(fakeRowKey(100, 400)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(200)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(402)), + }, + }, + { + iter.FromSlice([]*backuppb.LogFileSubcompaction{ + fakeSubCompactionWithOneSst(1, 100, 16*units.MiB, 100), + fakeSubCompactionWithMultiSsts(1, 200, 32*units.MiB, 200), + fakeSubCompactionWithOneSst(2, 100, 32*units.MiB, 300), + fakeSubCompactionWithOneSst(3, 100, 10*units.MiB, 100000), + fakeSubCompactionWithOneSst(1, 300, 48*units.MiB, 13), + fakeSubCompactionWithOneSst(1, 400, 64*units.MiB, 14), + fakeSubCompactionWithOneSst(1, 100, 1*units.MiB, 15), + }), + map[string]struct{}{ + "1:100": {}, + "1:200": {}, + }, + 315, + 49 * units.MiB, + [][]byte{ + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(100)), + []byte(fakeRowKey(100, 300)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(200)), + codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(402)), + }, + }, + } + for _, ca := range cases { + storesMap := make(map[uint64]*metapb.Store) + storesMap[1] = &metapb.Store{Id: 1} + mockPDCli := split.NewMockPDClientForSplit() + mockPDCli.SetStores(storesMap) + mockPDCli.SetRegions(oriRegions) + + client := split.NewClient(mockPDCli, nil, nil, 100, 4) + wrapper := restore.PipelineRestorerWrapper[*backuppb.LogFileSubcompaction]{ + PipelineRegionsSplitter: split.NewPipelineRegionsSplitter(client, 4*units.MB, 400), + } + totalSize := 0 + totalKvCount := 0 + + strategy := logclient.NewCompactedFileSplitStrategy(rules, ca.CheckpointSet, func(u1, u2 uint64) { + totalKvCount += int(u1) + totalSize += int(u2) + }) + mockStrategy := &mockCompactedStrategy{ + CompactedFileSplitStrategy: strategy, + expectSplitCount: 3, + } + + helper := wrapper.WithSplit(ctx, ca.MockSubcompationIter, mockStrategy) + + for i := helper.TryNext(ctx); !i.Finished; i = helper.TryNext(ctx) { + require.NoError(t, i.Err) + } + + regions, err := mockPDCli.ScanRegions(ctx, []byte{}, []byte{}, 0) + require.NoError(t, err) + require.Len(t, regions, len(ca.ExpectRegionEndKeys)) + for i, endKey := range ca.ExpectRegionEndKeys { + require.Equal(t, endKey, regions[i].Meta.EndKey) + } + require.Equal(t, totalKvCount, ca.ProcessedKVCount) + require.Equal(t, totalSize, ca.ProcessedSize) + } +} + +func fakeSubCompactionWithMultiSsts(tableID, rowID int64, length uint64, num uint64) *backuppb.LogFileSubcompaction { + return &backuppb.LogFileSubcompaction{ + Meta: &backuppb.LogFileSubcompactionMeta{ + TableId: tableID, + }, + SstOutputs: []*backuppb.File{ + { + Name: fmt.Sprintf("%d:%d", tableID, rowID), + StartKey: fakeRowRawKey(tableID, rowID), + EndKey: fakeRowRawKey(tableID, rowID+1), + Size_: length, + TotalKvs: num, + }, + { + Name: fmt.Sprintf("%d:%d", tableID, rowID+1), + StartKey: fakeRowRawKey(tableID, rowID+1), + EndKey: fakeRowRawKey(tableID, rowID+2), + Size_: length, + TotalKvs: num, + }, + }, + } +} +func fakeSubCompactionWithOneSst(tableID, rowID int64, length uint64, num uint64) *backuppb.LogFileSubcompaction { + return &backuppb.LogFileSubcompaction{ + Meta: &backuppb.LogFileSubcompactionMeta{ + TableId: tableID, + }, + SstOutputs: []*backuppb.File{ + { + Name: fmt.Sprintf("%d:%d", tableID, rowID), + StartKey: fakeRowRawKey(tableID, rowID), + EndKey: fakeRowRawKey(tableID, rowID+1), + Size_: length, + TotalKvs: num, + }, + }, + } +} + +func fakeFile(tableID, rowID int64, length uint64, num int64) *backuppb.DataFileInfo { + return &backuppb.DataFileInfo{ + StartKey: fakeRowKey(tableID, rowID), + EndKey: fakeRowKey(tableID, rowID+1), + TableId: tableID, + Length: length, + NumberOfEntries: num, + } +} + +func fakeRowKey(tableID, rowID int64) kv.Key { + return codec.EncodeBytes(nil, fakeRowRawKey(tableID, rowID)) +} + +func fakeRowRawKey(tableID, rowID int64) kv.Key { + return tablecodec.EncodeRecordKey(tablecodec.GenTableRecordPrefix(tableID), kv.IntHandle(rowID)) +} diff --git a/br/pkg/restore/log_client/compacted_file_strategy.go b/br/pkg/restore/log_client/compacted_file_strategy.go new file mode 100644 index 0000000000000..9637cf2e529b6 --- /dev/null +++ b/br/pkg/restore/log_client/compacted_file_strategy.go @@ -0,0 +1,112 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package logclient + +import ( + "fmt" + + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/restore/split" + restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" + "github.com/pingcap/tidb/pkg/util/codec" + "go.uber.org/zap" +) + +// The impact factor is used to reduce the size and number of MVCC entries +// in SST files, helping to optimize performance and resource usage. +const impactFactor = 16 + +type CompactedFileSplitStrategy struct { + *split.BaseSplitStrategy + checkpointSets map[string]struct{} + checkpointFileProgressFn func(uint64, uint64) +} + +var _ split.SplitStrategy[*backuppb.LogFileSubcompaction] = &CompactedFileSplitStrategy{} + +func NewCompactedFileSplitStrategy( + rules map[int64]*restoreutils.RewriteRules, + checkpointsSet map[string]struct{}, + updateStatsFn func(uint64, uint64), +) *CompactedFileSplitStrategy { + return &CompactedFileSplitStrategy{ + BaseSplitStrategy: split.NewBaseSplitStrategy(rules), + checkpointSets: checkpointsSet, + checkpointFileProgressFn: updateStatsFn, + } +} + +func (cs *CompactedFileSplitStrategy) Accumulate(subCompaction *backuppb.LogFileSubcompaction) { + splitHelper, exist := cs.TableSplitter[subCompaction.Meta.TableId] + if !exist { + splitHelper = split.NewSplitHelper() + cs.TableSplitter[subCompaction.Meta.TableId] = splitHelper + } + + for _, f := range subCompaction.SstOutputs { + startKey := codec.EncodeBytes(nil, f.StartKey) + endKey := codec.EncodeBytes(nil, f.EndKey) + cs.AccumulateCount += 1 + if f.TotalKvs == 0 || f.Size_ == 0 { + log.Error("No key-value pairs in subcompaction", zap.String("name", f.Name)) + continue + } + // The number of MVCC entries in the compacted SST files can be excessive. + // This calculation takes the MVCC impact into account to optimize performance. + calculateCount := int64(f.TotalKvs) / impactFactor + if calculateCount == 0 { + // at least consider as 1 key impact + log.Warn(fmt.Sprintf("less than %d key-value pairs in subcompaction", impactFactor), zap.String("name", f.Name)) + calculateCount = 1 + } + calculateSize := f.Size_ / impactFactor + if calculateSize == 0 { + log.Warn(fmt.Sprintf("less than %d key-value size in subcompaction", impactFactor), zap.String("name", f.Name)) + calculateSize = 1 + } + splitHelper.Merge(split.Valued{ + Key: split.Span{ + StartKey: startKey, + EndKey: endKey, + }, + Value: split.Value{ + Size: calculateSize, + Number: calculateCount, + }, + }) + } +} + +func (cs *CompactedFileSplitStrategy) ShouldSplit() bool { + return cs.AccumulateCount > (4096 / impactFactor) +} + +func (cs *CompactedFileSplitStrategy) ShouldSkip(subCompaction *backuppb.LogFileSubcompaction) bool { + _, exist := cs.Rules[subCompaction.Meta.TableId] + if !exist { + log.Info("skip for no rule files", zap.Int64("tableID", subCompaction.Meta.TableId)) + return true + } + sstOutputs := make([]*backuppb.File, 0, len(subCompaction.SstOutputs)) + for _, sst := range subCompaction.SstOutputs { + if _, ok := cs.checkpointSets[sst.Name]; !ok { + sstOutputs = append(sstOutputs, sst) + } else { + // This file is recorded in the checkpoint, indicating that it has + // already been restored to the cluster. Therefore, we will skip + // processing this file and only update the statistics. + cs.checkpointFileProgressFn(sst.TotalKvs, sst.Size_) + } + } + if len(sstOutputs) == 0 { + log.Info("all files in sub compaction skipped") + return true + } + if len(sstOutputs) != len(subCompaction.SstOutputs) { + log.Info("partial files in sub compaction skipped due to checkpoint") + subCompaction.SstOutputs = sstOutputs + return false + } + return false +} diff --git a/br/pkg/restore/log_client/export_test.go b/br/pkg/restore/log_client/export_test.go index 70a15e1ad2393..9c95409c9d754 100644 --- a/br/pkg/restore/log_client/export_test.go +++ b/br/pkg/restore/log_client/export_test.go @@ -91,7 +91,7 @@ func (rc *LogFileManager) ReadStreamMeta(ctx context.Context) ([]*MetaName, erro func TEST_NewLogClient(clusterID, startTS, restoreTS, upstreamClusterID uint64, dom *domain.Domain, se glue.Session) *LogClient { return &LogClient{ dom: dom, - se: se, + unsafeSession: se, upstreamClusterID: upstreamClusterID, LogFileManager: &LogFileManager{ startTS: startTS, diff --git a/br/pkg/restore/log_client/import_retry_test.go b/br/pkg/restore/log_client/import_retry_test.go index bcde03c69c1ed..04f2a56dc342b 100644 --- a/br/pkg/restore/log_client/import_retry_test.go +++ b/br/pkg/restore/log_client/import_retry_test.go @@ -3,13 +3,11 @@ package logclient_test import ( - "bytes" "context" "encoding/hex" "fmt" "os" "strconv" - "sync" "testing" "time" @@ -18,14 +16,12 @@ import ( "github.com/pingcap/kvproto/pkg/errorpb" "github.com/pingcap/kvproto/pkg/import_sstpb" "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/kvproto/pkg/pdpb" logclient "github.com/pingcap/tidb/br/pkg/restore/log_client" "github.com/pingcap/tidb/br/pkg/restore/split" "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/pkg/store/pdtypes" "github.com/pingcap/tidb/pkg/util/codec" "github.com/stretchr/testify/require" - pd "github.com/tikv/pd/client" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -52,157 +48,20 @@ func assertRegions(t *testing.T, regions []*split.RegionInfo, keys ...string) { } } -type TestClient struct { - split.SplitClient - pd.Client - - mu sync.RWMutex - stores map[uint64]*metapb.Store - regions map[uint64]*split.RegionInfo - regionsInfo *pdtypes.RegionTree // For now it's only used in ScanRegions - nextRegionID uint64 - - scattered map[uint64]bool - InjectErr bool - InjectTimes int32 -} - -func NewTestClient( - stores map[uint64]*metapb.Store, - regions map[uint64]*split.RegionInfo, - nextRegionID uint64, -) *TestClient { - regionsInfo := &pdtypes.RegionTree{} - for _, regionInfo := range regions { - regionsInfo.SetRegion(pdtypes.NewRegionInfo(regionInfo.Region, regionInfo.Leader)) - } - return &TestClient{ - stores: stores, - regions: regions, - regionsInfo: regionsInfo, - nextRegionID: nextRegionID, - scattered: map[uint64]bool{}, - } -} - -func (c *TestClient) GetAllRegions() map[uint64]*split.RegionInfo { - c.mu.RLock() - defer c.mu.RUnlock() - return c.regions -} - -func (c *TestClient) GetPDClient() *split.FakePDClient { - stores := make([]*metapb.Store, 0, len(c.stores)) - for _, store := range c.stores { - stores = append(stores, store) - } - return split.NewFakePDClient(stores, false, nil) -} - -func (c *TestClient) GetStore(ctx context.Context, storeID uint64) (*metapb.Store, error) { - c.mu.RLock() - defer c.mu.RUnlock() - store, ok := c.stores[storeID] - if !ok { - return nil, errors.Errorf("store not found") - } - return store, nil -} - -func (c *TestClient) GetRegion(ctx context.Context, key []byte) (*split.RegionInfo, error) { - c.mu.RLock() - defer c.mu.RUnlock() - for _, region := range c.regions { - if bytes.Compare(key, region.Region.StartKey) >= 0 && - (len(region.Region.EndKey) == 0 || bytes.Compare(key, region.Region.EndKey) < 0) { - return region, nil - } - } - return nil, errors.Errorf("region not found: key=%s", string(key)) -} - -func (c *TestClient) GetRegionByID(ctx context.Context, regionID uint64) (*split.RegionInfo, error) { - c.mu.RLock() - defer c.mu.RUnlock() - region, ok := c.regions[regionID] - if !ok { - return nil, errors.Errorf("region not found: id=%d", regionID) - } - return region, nil -} - -func (c *TestClient) SplitWaitAndScatter(_ context.Context, _ *split.RegionInfo, keys [][]byte) ([]*split.RegionInfo, error) { - c.mu.Lock() - defer c.mu.Unlock() - newRegions := make([]*split.RegionInfo, 0) - for _, key := range keys { - var target *split.RegionInfo - splitKey := codec.EncodeBytes([]byte{}, key) - for _, region := range c.regions { - if region.ContainsInterior(splitKey) { - target = region - } - } - if target == nil { - continue - } - newRegion := &split.RegionInfo{ - Region: &metapb.Region{ - Peers: target.Region.Peers, - Id: c.nextRegionID, - StartKey: target.Region.StartKey, - EndKey: splitKey, - }, - } - c.regions[c.nextRegionID] = newRegion - c.nextRegionID++ - target.Region.StartKey = splitKey - c.regions[target.Region.Id] = target - newRegions = append(newRegions, newRegion) - } - return newRegions, nil -} - -func (c *TestClient) GetOperator(context.Context, uint64) (*pdpb.GetOperatorResponse, error) { - return &pdpb.GetOperatorResponse{ - Header: new(pdpb.ResponseHeader), - }, nil -} - -func (c *TestClient) ScanRegions(ctx context.Context, key, endKey []byte, limit int) ([]*split.RegionInfo, error) { - if c.InjectErr && c.InjectTimes > 0 { - c.InjectTimes -= 1 - return nil, status.Error(codes.Unavailable, "not leader") - } - if len(key) != 0 && bytes.Equal(key, endKey) { - return nil, status.Error(codes.Internal, "key and endKey are the same") - } - - infos := c.regionsInfo.ScanRange(key, endKey, limit) - regions := make([]*split.RegionInfo, 0, len(infos)) - for _, info := range infos { - regions = append(regions, &split.RegionInfo{ - Region: info.Meta, - Leader: info.Leader, - }) - } - return regions, nil -} - -func (c *TestClient) WaitRegionsScattered(context.Context, []*split.RegionInfo) (int, error) { - return 0, nil -} - // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) -func initTestClient(isRawKv bool) *TestClient { +func initTestClient(isRawKv bool) *split.TestClient { + keys := []string{"", "aay", "bba", "bbh", "cca", ""} + stores := make(map[uint64]*metapb.Store) + stores[1] = &metapb.Store{ + Id: 1, + } peers := make([]*metapb.Peer, 1) peers[0] = &metapb.Peer{ Id: 1, StoreId: 1, } - keys := [6]string{"", "aay", "bba", "bbh", "cca", ""} regions := make(map[uint64]*split.RegionInfo) - for i := uint64(1); i < 6; i++ { + for i := 1; i < len(keys); i++ { startKey := []byte(keys[i-1]) if len(startKey) != 0 { startKey = codec.EncodeBytesExt([]byte{}, startKey, isRawKv) @@ -211,24 +70,20 @@ func initTestClient(isRawKv bool) *TestClient { if len(endKey) != 0 { endKey = codec.EncodeBytesExt([]byte{}, endKey, isRawKv) } - regions[i] = &split.RegionInfo{ + regions[uint64(i)] = &split.RegionInfo{ Leader: &metapb.Peer{ - Id: i, + Id: uint64(i), StoreId: 1, }, Region: &metapb.Region{ - Id: i, + Id: uint64(i), Peers: peers, StartKey: startKey, EndKey: endKey, }, } } - stores := make(map[uint64]*metapb.Store) - stores[1] = &metapb.Store{ - Id: 1, - } - return NewTestClient(stores, regions, 6) + return split.NewTestClient(stores, regions, 6) } func TestScanSuccess(t *testing.T) { @@ -411,7 +266,7 @@ func TestEpochNotMatch(t *testing.T) { ctl := logclient.OverRegionsInRange([]byte(""), []byte(""), cli, &rs) ctx := context.Background() - printPDRegion("cli", cli.regionsInfo.Regions) + printPDRegion("cli", cli.RegionsInfo.Regions) regions, err := split.PaginateScanRegion(ctx, cli, []byte("aaz"), []byte("bbb"), 2) require.NoError(t, err) require.Len(t, regions, 2) @@ -429,8 +284,8 @@ func TestEpochNotMatch(t *testing.T) { } newRegion := pdtypes.NewRegionInfo(info.Region, info.Leader) mergeRegion := func() { - cli.regionsInfo.SetRegion(newRegion) - cli.regions[42] = &info + cli.RegionsInfo.SetRegion(newRegion) + cli.Regions[42] = &info } epochNotMatch := &import_sstpb.Error{ Message: "Epoch not match", @@ -457,7 +312,7 @@ func TestEpochNotMatch(t *testing.T) { }) printRegion("first", firstRunRegions) printRegion("second", secondRunRegions) - printPDRegion("cli", cli.regionsInfo.Regions) + printPDRegion("cli", cli.RegionsInfo.Regions) assertRegions(t, firstRunRegions, "", "aay") assertRegions(t, secondRunRegions, "", "aay", "bbh", "cca", "") require.NoError(t, err) @@ -470,7 +325,7 @@ func TestRegionSplit(t *testing.T) { ctl := logclient.OverRegionsInRange([]byte(""), []byte(""), cli, &rs) ctx := context.Background() - printPDRegion("cli", cli.regionsInfo.Regions) + printPDRegion("cli", cli.RegionsInfo.Regions) regions, err := split.PaginateScanRegion(ctx, cli, []byte("aaz"), []byte("aazz"), 1) require.NoError(t, err) require.Len(t, regions, 1) @@ -503,8 +358,8 @@ func TestRegionSplit(t *testing.T) { splitRegion := func() { for _, r := range newRegions { newRegion := pdtypes.NewRegionInfo(r.Region, r.Leader) - cli.regionsInfo.SetRegion(newRegion) - cli.regions[r.Region.Id] = r + cli.RegionsInfo.SetRegion(newRegion) + cli.Regions[r.Region.Id] = r } } epochNotMatch := &import_sstpb.Error{ @@ -535,7 +390,7 @@ func TestRegionSplit(t *testing.T) { }) printRegion("first", firstRunRegions) printRegion("second", secondRunRegions) - printPDRegion("cli", cli.regionsInfo.Regions) + printPDRegion("cli", cli.RegionsInfo.Regions) assertRegions(t, firstRunRegions, "", "aay") assertRegions(t, secondRunRegions, "", "aay", "aayy", "bba", "bbh", "cca", "") require.NoError(t, err) @@ -548,7 +403,7 @@ func TestRetryBackoff(t *testing.T) { ctl := logclient.OverRegionsInRange([]byte(""), []byte(""), cli, &rs) ctx := context.Background() - printPDRegion("cli", cli.regionsInfo.Regions) + printPDRegion("cli", cli.RegionsInfo.Regions) regions, err := split.PaginateScanRegion(ctx, cli, []byte("aaz"), []byte("bbb"), 2) require.NoError(t, err) require.Len(t, regions, 2) @@ -569,7 +424,7 @@ func TestRetryBackoff(t *testing.T) { } return logclient.RPCResultOK() }) - printPDRegion("cli", cli.regionsInfo.Regions) + printPDRegion("cli", cli.RegionsInfo.Regions) require.Equal(t, 1, rs.Attempt()) // we retried leader not found error. so the next backoff should be 2 * initical backoff. require.Equal(t, 2*time.Millisecond, rs.ExponentialBackoff()) diff --git a/br/pkg/restore/log_client/log_file_manager.go b/br/pkg/restore/log_client/log_file_manager.go index cbaa6a594dff4..4c2992467a2ab 100644 --- a/br/pkg/restore/log_client/log_file_manager.go +++ b/br/pkg/restore/log_client/log_file_manager.go @@ -9,6 +9,7 @@ import ( "fmt" "strings" "sync" + "time" "github.com/pingcap/errors" backuppb "github.com/pingcap/kvproto/pkg/brpb" @@ -25,9 +26,13 @@ import ( "go.uber.org/zap" ) +var TotalEntryCount int64 + // MetaIter is the type of iterator of metadata files' content. type MetaIter = iter.TryNextor[*backuppb.Metadata] +type SubCompactionIter iter.TryNextor[*backuppb.LogFileSubcompaction] + type MetaName struct { meta Meta name string @@ -98,7 +103,8 @@ type LogFileManager struct { storage storage.ExternalStorage helper streamMetadataHelper - withmigrations WithMigrations + withMigraionBuilder *WithMigrationsBuilder + withMigrations *WithMigrations metadataDownloadBatchSize uint } @@ -109,7 +115,8 @@ type LogFileManagerInit struct { RestoreTS uint64 Storage storage.ExternalStorage - Migrations WithMigrations + MigrationsBuilder *WithMigrationsBuilder + Migrations *WithMigrations MetadataDownloadBatchSize uint EncryptionManager *encryption.Manager } @@ -123,11 +130,12 @@ type DDLMetaGroup struct { // Generally the config cannot be changed during its lifetime. func CreateLogFileManager(ctx context.Context, init LogFileManagerInit) (*LogFileManager, error) { fm := &LogFileManager{ - startTS: init.StartTS, - restoreTS: init.RestoreTS, - storage: init.Storage, - helper: stream.NewMetadataHelper(stream.WithEncryptionManager(init.EncryptionManager)), - withmigrations: init.Migrations, + startTS: init.StartTS, + restoreTS: init.RestoreTS, + storage: init.Storage, + helper: stream.NewMetadataHelper(stream.WithEncryptionManager(init.EncryptionManager)), + withMigraionBuilder: init.MigrationsBuilder, + withMigrations: init.Migrations, metadataDownloadBatchSize: init.MetadataDownloadBatchSize, } @@ -138,6 +146,11 @@ func CreateLogFileManager(ctx context.Context, init LogFileManagerInit) (*LogFil return fm, nil } +func (rc *LogFileManager) BuildMigrations(migs []*backuppb.Migration) { + w := rc.withMigraionBuilder.Build(migs) + rc.withMigrations = &w +} + func (rc *LogFileManager) ShiftTS() uint64 { return rc.shiftStartTS } @@ -171,9 +184,11 @@ func (rc *LogFileManager) loadShiftTS(ctx context.Context) error { } if !shiftTS.exists { rc.shiftStartTS = rc.startTS + rc.withMigraionBuilder.SetShiftStartTS(rc.shiftStartTS) return nil } rc.shiftStartTS = shiftTS.value + rc.withMigraionBuilder.SetShiftStartTS(rc.shiftStartTS) return nil } @@ -225,7 +240,7 @@ func (rc *LogFileManager) createMetaIterOver(ctx context.Context, s storage.Exte } func (rc *LogFileManager) FilterDataFiles(m MetaNameIter) LogIter { - ms := rc.withmigrations.Metas(m) + ms := rc.withMigrations.Metas(m) return iter.FlatMap(ms, func(m *MetaWithMigrations) LogIter { gs := m.Physicals(iter.Enumerate(iter.FromSlice(m.meta.FileGroups))) return iter.FlatMap(gs, func(gim *PhysicalWithMigrations) LogIter { @@ -266,10 +281,13 @@ func (rc *LogFileManager) collectDDLFilesAndPrepareCache( ctx context.Context, files MetaGroupIter, ) ([]Log, error) { + start := time.Now() + log.Info("start to collect all ddl files") fs := iter.CollectAll(ctx, files) if fs.Err != nil { return nil, errors.Annotatef(fs.Err, "failed to collect from files") } + log.Info("finish to collect all ddl files", zap.Duration("take", time.Since(start))) dataFileInfos := make([]*backuppb.DataFileInfo, 0) for _, g := range fs.Item { @@ -281,24 +299,12 @@ func (rc *LogFileManager) collectDDLFilesAndPrepareCache( } // LoadDDLFilesAndCountDMLFiles loads all DDL files needs to be restored in the restoration. -// At the same time, if the `counter` isn't nil, counting the DML file needs to be restored into `counter`. // This function returns all DDL files needing directly because we need sort all of them. -func (rc *LogFileManager) LoadDDLFilesAndCountDMLFiles(ctx context.Context, counter *int) ([]Log, error) { +func (rc *LogFileManager) LoadDDLFilesAndCountDMLFiles(ctx context.Context) ([]Log, error) { m, err := rc.streamingMeta(ctx) if err != nil { return nil, err } - if counter != nil { - m = iter.Tap(m, func(m *MetaName) { - for _, fg := range m.meta.FileGroups { - for _, f := range fg.DataFilesInfo { - if !f.IsMeta && !rc.ShouldFilterOut(f) { - *counter += 1 - } - } - } - }) - } mg := rc.FilterMetaFiles(m) return rc.collectDDLFilesAndPrepareCache(ctx, mg) @@ -324,7 +330,12 @@ func (rc *LogFileManager) FilterMetaFiles(ms MetaNameIter) MetaGroupIter { if m.meta.MetaVersion > backuppb.MetaVersion_V1 { d.Path = g.Path } - return !d.IsMeta || rc.ShouldFilterOut(d) + if rc.ShouldFilterOut(d) { + return true + } + // count the progress + TotalEntryCount += d.NumberOfEntries + return !d.IsMeta }) return DDLMetaGroup{ Path: g.Path, @@ -335,6 +346,11 @@ func (rc *LogFileManager) FilterMetaFiles(ms MetaNameIter) MetaGroupIter { }) } +// Fetch compactions that may contain file less than the TS. +func (rc *LogFileManager) GetCompactionIter(ctx context.Context) iter.TryNextor[*backuppb.LogFileSubcompaction] { + return rc.withMigrations.Compactions(ctx, rc.storage) +} + // the kv entry with ts, the ts is decoded from entry. type KvEntryWithTS struct { E kv.Entry @@ -418,3 +434,18 @@ func (rc *LogFileManager) ReadAllEntries( return kvEntries, nextKvEntries, nil } + +func Subcompactions(ctx context.Context, prefix string, s storage.ExternalStorage) SubCompactionIter { + return iter.FlatMap(storage.UnmarshalDir( + ctx, + &storage.WalkOption{SubDir: prefix}, + s, + func(t *backuppb.LogFileSubcompactions, name string, b []byte) error { return t.Unmarshal(b) }, + ), func(subcs *backuppb.LogFileSubcompactions) iter.TryNextor[*backuppb.LogFileSubcompaction] { + return iter.FromSlice(subcs.Subcompactions) + }) +} + +func LoadMigrations(ctx context.Context, s storage.ExternalStorage) iter.TryNextor[*backuppb.Migration] { + return storage.UnmarshalDir(ctx, &storage.WalkOption{SubDir: "v1/migrations/"}, s, func(t *backuppb.Migration, name string, b []byte) error { return t.Unmarshal(b) }) +} diff --git a/br/pkg/restore/log_client/log_file_manager_test.go b/br/pkg/restore/log_client/log_file_manager_test.go index 0ac289b65b8b0..79813d6ef78f2 100644 --- a/br/pkg/restore/log_client/log_file_manager_test.go +++ b/br/pkg/restore/log_client/log_file_manager_test.go @@ -235,6 +235,8 @@ func testReadMetaBetweenTSWithVersion(t *testing.T, m metaMaker) { RestoreTS: c.endTS, Storage: loc, + MigrationsBuilder: logclient.NewMigrationBuilder(0, c.startTS, c.endTS), + Migrations: emptyMigrations(), MetadataDownloadBatchSize: 32, } cli, err := logclient.CreateLogFileManager(ctx, init) @@ -469,6 +471,8 @@ func testFileManagerWithMeta(t *testing.T, m metaMaker) { RestoreTS: end, Storage: loc, + MigrationsBuilder: logclient.NewMigrationBuilder(0, start, end), + Migrations: emptyMigrations(), MetadataDownloadBatchSize: 32, }) req.NoError(err) @@ -487,15 +491,8 @@ func testFileManagerWithMeta(t *testing.T, m metaMaker) { ), ).Item } else { - var counter *int - if c.DMLFileCount != nil { - counter = new(int) - } - data, err := fm.LoadDDLFilesAndCountDMLFiles(ctx, counter) + data, err := fm.LoadDDLFilesAndCountDMLFiles(ctx) req.NoError(err) - if counter != nil { - req.Equal(*c.DMLFileCount, *counter) - } r = data } dataFileInfoMatches(t, r, c.Requires...) @@ -528,6 +525,7 @@ func TestFilterDataFiles(t *testing.T) { RestoreTS: 10, Storage: loc, + MigrationsBuilder: logclient.NewMigrationBuilder(0, 0, 10), Migrations: emptyMigrations(), MetadataDownloadBatchSize: 32, }) diff --git a/br/pkg/restore/log_client/log_split_strategy.go b/br/pkg/restore/log_client/log_split_strategy.go new file mode 100644 index 0000000000000..3327991db02fa --- /dev/null +++ b/br/pkg/restore/log_client/log_split_strategy.go @@ -0,0 +1,107 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package logclient + +import ( + "context" + + "github.com/pingcap/errors" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/checkpoint" + "github.com/pingcap/tidb/br/pkg/restore/split" + restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" + "github.com/pingcap/tidb/br/pkg/summary" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "go.uber.org/zap" +) + +type LogSplitStrategy struct { + *split.BaseSplitStrategy + checkpointSkipMap *LogFilesSkipMap + checkpointFileProgressFn func(uint64, uint64) +} + +var _ split.SplitStrategy[*LogDataFileInfo] = &LogSplitStrategy{} + +func NewLogSplitStrategy( + ctx context.Context, + useCheckpoint bool, + execCtx sqlexec.RestrictedSQLExecutor, + rules map[int64]*restoreutils.RewriteRules, + updateStatsFn func(uint64, uint64), +) (*LogSplitStrategy, error) { + downstreamIdset := make(map[int64]struct{}) + for _, rule := range rules { + downstreamIdset[rule.NewTableID] = struct{}{} + } + skipMap := NewLogFilesSkipMap() + if useCheckpoint { + t, err := checkpoint.LoadCheckpointDataForLogRestore( + ctx, execCtx, func(groupKey checkpoint.LogRestoreKeyType, off checkpoint.LogRestoreValueMarshaled) { + for tableID, foffs := range off.Foffs { + // filter out the checkpoint data of dropped table + if _, exists := downstreamIdset[tableID]; exists { + for _, foff := range foffs { + skipMap.Insert(groupKey, off.Goff, foff) + } + } + } + }) + + if err != nil { + return nil, errors.Trace(err) + } + summary.AdjustStartTimeToEarlierTime(t) + } + return &LogSplitStrategy{ + BaseSplitStrategy: split.NewBaseSplitStrategy(rules), + checkpointSkipMap: skipMap, + checkpointFileProgressFn: updateStatsFn, + }, nil +} + +const splitFileThreshold = 1024 * 1024 // 1 MB + +func (ls *LogSplitStrategy) Accumulate(file *LogDataFileInfo) { + if file.Length > splitFileThreshold { + ls.AccumulateCount += 1 + } + splitHelper, exist := ls.TableSplitter[file.TableId] + if !exist { + splitHelper = split.NewSplitHelper() + ls.TableSplitter[file.TableId] = splitHelper + } + + splitHelper.Merge(split.Valued{ + Key: split.Span{ + StartKey: file.StartKey, + EndKey: file.EndKey, + }, + Value: split.Value{ + Size: file.Length, + Number: file.NumberOfEntries, + }, + }) +} + +func (ls *LogSplitStrategy) ShouldSplit() bool { + return ls.AccumulateCount > 4096 +} + +func (ls *LogSplitStrategy) ShouldSkip(file *LogDataFileInfo) bool { + if file.IsMeta { + return true + } + _, exist := ls.Rules[file.TableId] + if !exist { + log.Info("skip for no rule files", zap.Int64("tableID", file.TableId)) + return true + } + + if ls.checkpointSkipMap.NeedSkip(file.MetaDataGroupName, file.OffsetInMetaGroup, file.OffsetInMergedGroup) { + //onPcheckpointSkipMaprogress() + ls.checkpointFileProgressFn(uint64(file.NumberOfEntries), file.Length) + return true + } + return false +} diff --git a/br/pkg/restore/log_client/migration.go b/br/pkg/restore/log_client/migration.go index 19d9d3daeb3cb..a7b4307e0f568 100644 --- a/br/pkg/restore/log_client/migration.go +++ b/br/pkg/restore/log_client/migration.go @@ -15,7 +15,10 @@ package logclient import ( + "context" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/br/pkg/utils/iter" ) @@ -99,6 +102,10 @@ type WithMigrationsBuilder struct { restoredTS uint64 } +func (builder *WithMigrationsBuilder) SetShiftStartTS(ts uint64) { + builder.shiftStartTS = ts +} + func (builder *WithMigrationsBuilder) updateSkipMap(skipmap metaSkipMap, metas []*backuppb.MetaEdit) { for _, meta := range metas { if meta.DestructSelf { @@ -136,14 +143,24 @@ func (builder *WithMigrationsBuilder) coarseGrainedFilter(mig *backuppb.Migratio // Create the wrapper by migrations. func (builder *WithMigrationsBuilder) Build(migs []*backuppb.Migration) WithMigrations { skipmap := make(metaSkipMap) + compactionDirs := make([]string, 0, 8) + for _, mig := range migs { // TODO: deal with TruncatedTo and DestructPrefix if builder.coarseGrainedFilter(mig) { continue } builder.updateSkipMap(skipmap, mig.EditMeta) + + for _, c := range mig.Compactions { + compactionDirs = append(compactionDirs, c.Artifacts) + } } - return WithMigrations(skipmap) + withMigrations := WithMigrations{ + skipmap: skipmap, + compactionDirs: compactionDirs, + } + return withMigrations } type PhysicalMigrationsIter = iter.TryNextor[*PhysicalWithMigrations] @@ -190,13 +207,16 @@ func (mwm *MetaWithMigrations) Physicals(groupIndexIter GroupIndexIter) Physical }) } -type WithMigrations metaSkipMap +type WithMigrations struct { + skipmap metaSkipMap + compactionDirs []string +} -func (wm WithMigrations) Metas(metaNameIter MetaNameIter) MetaMigrationsIter { +func (wm *WithMigrations) Metas(metaNameIter MetaNameIter) MetaMigrationsIter { return iter.MapFilter(metaNameIter, func(mname *MetaName) (*MetaWithMigrations, bool) { var phySkipmap physicalSkipMap = nil - if wm != nil { - skipmap := wm[mname.name] + if wm.skipmap != nil { + skipmap := wm.skipmap[mname.name] if skipmap != nil { if skipmap.skip { return nil, true @@ -210,3 +230,11 @@ func (wm WithMigrations) Metas(metaNameIter MetaNameIter) MetaMigrationsIter { }, false }) } + +func (wm *WithMigrations) Compactions(ctx context.Context, s storage.ExternalStorage) iter.TryNextor[*backuppb.LogFileSubcompaction] { + compactionDirIter := iter.FromSlice(wm.compactionDirs) + return iter.FlatMap(compactionDirIter, func(name string) iter.TryNextor[*backuppb.LogFileSubcompaction] { + // name is the absolute path in external storage. + return Subcompactions(ctx, name, s) + }) +} diff --git a/br/pkg/restore/log_client/migration_test.go b/br/pkg/restore/log_client/migration_test.go index 0dd7b06197c7c..5368d7416dadf 100644 --- a/br/pkg/restore/log_client/migration_test.go +++ b/br/pkg/restore/log_client/migration_test.go @@ -25,8 +25,8 @@ import ( "github.com/stretchr/testify/require" ) -func emptyMigrations() logclient.WithMigrations { - return logclient.WithMigrations{} +func emptyMigrations() *logclient.WithMigrations { + return &logclient.WithMigrations{} } func nameFromID(prefix string, id uint64) string { diff --git a/br/pkg/restore/restorer.go b/br/pkg/restore/restorer.go new file mode 100644 index 0000000000000..75a21b583eb1f --- /dev/null +++ b/br/pkg/restore/restorer.go @@ -0,0 +1,367 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package restore + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/opentracing/opentracing-go" + "github.com/pingcap/errors" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/checkpoint" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/restore/split" + "github.com/pingcap/tidb/br/pkg/restore/utils" + "github.com/pingcap/tidb/br/pkg/summary" + "github.com/pingcap/tidb/br/pkg/utils/iter" + "github.com/pingcap/tidb/pkg/util" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "golang.org/x/sync/errgroup" +) + +// BackupFileSet represents the batch files to be restored for a table. Current, we have 5 type files +// 1. Raw KV(sst files) +// 2. Txn KV(sst files) +// 3. Database KV backup(sst files) +// 4. Compacted Log backups(sst files) +type BackupFileSet struct { + // TableID only valid in 3.4.5. + // For Raw/Txn KV, table id is always 0 + TableID int64 + + // For log Backup Changes, this field is null. + SSTFiles []*backuppb.File + + // RewriteRules is the rewrite rules for the specify table. + // because these rules belongs to the *one table*. + // we can hold them here. + RewriteRules *utils.RewriteRules +} + +type BatchBackupFileSet []BackupFileSet + +type zapBatchBackupFileSetMarshaler BatchBackupFileSet + +// MarshalLogObjectForFiles is an internal util function to zap something having `Files` field. +func MarshalLogObjectForFiles(batchFileSet BatchBackupFileSet, encoder zapcore.ObjectEncoder) error { + return zapBatchBackupFileSetMarshaler(batchFileSet).MarshalLogObject(encoder) +} + +func (fgs zapBatchBackupFileSetMarshaler) MarshalLogObject(encoder zapcore.ObjectEncoder) error { + elements := make([]string, 0) + total := 0 + totalKVs := uint64(0) + totalBytes := uint64(0) + totalSize := uint64(0) + for _, fg := range fgs { + for _, f := range fg.SSTFiles { + total += 1 + elements = append(elements, f.GetName()) + totalKVs += f.GetTotalKvs() + totalBytes += f.GetTotalBytes() + totalSize += f.GetSize_() + } + } + encoder.AddInt("total", total) + _ = encoder.AddArray("files", logutil.AbbreviatedArrayMarshaler(elements)) + encoder.AddUint64("totalKVs", totalKVs) + encoder.AddUint64("totalBytes", totalBytes) + encoder.AddUint64("totalSize", totalSize) + return nil +} + +func ZapBatchBackupFileSet(batchFileSet BatchBackupFileSet) zap.Field { + return zap.Object("fileset", zapBatchBackupFileSetMarshaler(batchFileSet)) +} + +// CreateUniqueFileSets used for Raw/Txn non-tableID files +// converts a slice of files into a slice of unique BackupFileSets, +// where each BackupFileSet contains a single file. +func CreateUniqueFileSets(files []*backuppb.File) []BackupFileSet { + newSet := make([]BackupFileSet, len(files)) + for i, f := range files { + newSet[i].SSTFiles = []*backuppb.File{f} + } + return newSet +} + +func NewFileSet(files []*backuppb.File, rules *utils.RewriteRules) BackupFileSet { + return BackupFileSet{ + SSTFiles: files, + RewriteRules: rules, + } +} + +// SstRestorer defines the essential methods required for restoring SST files in various backup formats: +// 1. Raw backup SST files +// 2. Transactional (Txn) backup SST files +// 3. TiDB backup SST files +// 4. Log-compacted SST files +// +// It serves as a high-level interface for restoration, supporting implementations such as simpleRestorer +// and MultiTablesRestorer. SstRestorer includes FileImporter for handling raw, transactional, and compacted SSTs, +// and MultiTablesRestorer for TiDB-specific backups. +type SstRestorer interface { + // GoRestore imports the specified backup file sets into TiKV asynchronously. + // The onProgress function is called with progress updates as files are processed. + GoRestore(onProgress func(int64), batchFileSets ...BatchBackupFileSet) error + + // WaitUntilFinish blocks until all pending restore files have completed processing. + WaitUntilFinish() error + + // Close releases any resources associated with the restoration process. + Close() error +} + +// FileImporter is a low-level interface for handling the import of backup files into storage (e.g., TiKV). +// It is primarily used by the importer client to manage raw and transactional SST file imports. +type FileImporter interface { + // Import uploads and imports the provided backup file sets into storage. + // The ctx parameter provides context for managing request scope. + Import(ctx context.Context, fileSets ...BackupFileSet) error + + // Close releases any resources used by the importer client. + Close() error +} + +// BalancedFileImporter is a wrapper around FileImporter that adds concurrency controls. +// It ensures that file imports are balanced across storage nodes, which is particularly useful +// in MultiTablesRestorer scenarios where concurrency management is critical for efficiency. +type BalancedFileImporter interface { + FileImporter + + // PauseForBackpressure manages concurrency by controlling when imports can proceed, + // ensuring load is distributed evenly across storage nodes. + PauseForBackpressure() +} + +type SimpleRestorer struct { + eg *errgroup.Group + ectx context.Context + workerPool *util.WorkerPool + fileImporter FileImporter + checkpointRunner *checkpoint.CheckpointRunner[checkpoint.RestoreKeyType, checkpoint.RestoreValueType] +} + +func NewSimpleSstRestorer( + ctx context.Context, + fileImporter FileImporter, + workerPool *util.WorkerPool, + checkpointRunner *checkpoint.CheckpointRunner[checkpoint.RestoreKeyType, checkpoint.RestoreValueType], +) SstRestorer { + eg, ectx := errgroup.WithContext(ctx) + return &SimpleRestorer{ + eg: eg, + ectx: ectx, + workerPool: workerPool, + fileImporter: fileImporter, + checkpointRunner: checkpointRunner, + } +} + +func (s *SimpleRestorer) Close() error { + return s.fileImporter.Close() +} + +func (s *SimpleRestorer) WaitUntilFinish() error { + return s.eg.Wait() +} + +func (s *SimpleRestorer) GoRestore(onProgress func(int64), batchFileSets ...BatchBackupFileSet) error { + for _, sets := range batchFileSets { + for _, set := range sets { + s.workerPool.ApplyOnErrorGroup(s.eg, + func() (restoreErr error) { + fileStart := time.Now() + defer func() { + if restoreErr == nil { + log.Info("import sst files done", logutil.Files(set.SSTFiles), + zap.Duration("take", time.Since(fileStart))) + for _, f := range set.SSTFiles { + onProgress(int64(f.TotalKvs)) + } + } + }() + err := s.fileImporter.Import(s.ectx, set) + if err != nil { + return errors.Trace(err) + } + // TODO handle checkpoint + return nil + }) + } + } + return nil +} + +type MultiTablesRestorer struct { + eg *errgroup.Group + ectx context.Context + workerPool *util.WorkerPool + fileImporter BalancedFileImporter + checkpointRunner *checkpoint.CheckpointRunner[checkpoint.RestoreKeyType, checkpoint.RestoreValueType] +} + +func NewMultiTablesRestorer( + ctx context.Context, + fileImporter BalancedFileImporter, + workerPool *util.WorkerPool, + checkpointRunner *checkpoint.CheckpointRunner[checkpoint.RestoreKeyType, checkpoint.RestoreValueType], +) SstRestorer { + eg, ectx := errgroup.WithContext(ctx) + return &MultiTablesRestorer{ + eg: eg, + ectx: ectx, + workerPool: workerPool, + fileImporter: fileImporter, + checkpointRunner: checkpointRunner, + } +} + +func (m *MultiTablesRestorer) Close() error { + return m.fileImporter.Close() +} + +func (m *MultiTablesRestorer) WaitUntilFinish() error { + if err := m.eg.Wait(); err != nil { + summary.CollectFailureUnit("file", err) + log.Error("restore files failed", zap.Error(err)) + return errors.Trace(err) + } + return nil +} + +func (m *MultiTablesRestorer) GoRestore(onProgress func(int64), batchFileSets ...BatchBackupFileSet) (err error) { + start := time.Now() + fileCount := 0 + defer func() { + elapsed := time.Since(start) + if err == nil { + log.Info("Restore files", zap.Duration("take", elapsed)) + summary.CollectSuccessUnit("files", fileCount, elapsed) + } + }() + + log.Debug("start to restore files", zap.Int("files", fileCount)) + + if span := opentracing.SpanFromContext(m.ectx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("Client.RestoreSSTFiles", opentracing.ChildOf(span.Context())) + defer span1.Finish() + m.ectx = opentracing.ContextWithSpan(m.ectx, span1) + } + + for _, batchFileSet := range batchFileSets { + if m.ectx.Err() != nil { + log.Warn("Restoring encountered error and already stopped, give up remained files.", + logutil.ShortError(m.ectx.Err())) + // We will fetch the error from the errgroup then (If there were). + // Also note if the parent context has been canceled or something, + // breaking here directly is also a reasonable behavior. + break + } + filesReplica := batchFileSet + m.fileImporter.PauseForBackpressure() + m.workerPool.ApplyOnErrorGroup(m.eg, func() (restoreErr error) { + fileStart := time.Now() + defer func() { + if restoreErr == nil { + log.Info("import files done", zap.Duration("take", time.Since(fileStart))) + onProgress(int64(len(filesReplica))) + } + }() + if importErr := m.fileImporter.Import(m.ectx, filesReplica...); importErr != nil { + return errors.Trace(importErr) + } + + // the data of this range has been import done + if m.checkpointRunner != nil && len(filesReplica) > 0 { + for _, filesGroup := range filesReplica { + rangeKeySet := make(map[string]struct{}) + for _, file := range filesGroup.SSTFiles { + rangeKey := GetFileRangeKey(file.Name) + // Assert that the files having the same rangeKey are all in the current filesGroup.Files + rangeKeySet[rangeKey] = struct{}{} + } + for rangeKey := range rangeKeySet { + // The checkpoint range shows this ranges of kvs has been restored into + // the table corresponding to the table-id. + if err := checkpoint.AppendRangesForRestore(m.ectx, m.checkpointRunner, filesGroup.TableID, rangeKey); err != nil { + return errors.Trace(err) + } + } + } + } + return nil + }) + } + // Once the parent context canceled and there is no task running in the errgroup, + // we may break the for loop without error in the errgroup. (Will this happen?) + // At that time, return the error in the context here. + return m.ectx.Err() +} + +func GetFileRangeKey(f string) string { + // the backup date file pattern is `{store_id}_{region_id}_{epoch_version}_{key}_{ts}_{cf}.sst` + // so we need to compare with out the `_{cf}.sst` suffix + idx := strings.LastIndex(f, "_") + if idx < 0 { + panic(fmt.Sprintf("invalid backup data file name: '%s'", f)) + } + + return f[:idx] +} + +type PipelineRestorerWrapper[T any] struct { + split.PipelineRegionsSplitter +} + +// WithSplit processes items using a split strategy within a pipeline. +// It iterates over items, accumulating them until a split condition is met. +// When a split is required, it executes the split operation on the accumulated items. +func (p *PipelineRestorerWrapper[T]) WithSplit(ctx context.Context, i iter.TryNextor[T], strategy split.SplitStrategy[T]) iter.TryNextor[T] { + return iter.TryMap( + iter.FilterOut(i, func(item T) bool { + // Skip items based on the strategy's criteria. + // Non-skip iterms should be filter out. + return strategy.ShouldSkip(item) + }), func(item T) (T, error) { + // Accumulate the item for potential splitting. + strategy.Accumulate(item) + + // Check if the accumulated items meet the criteria for splitting. + if strategy.ShouldSplit() { + log.Info("Trying to start region split with accumulations") + startTime := time.Now() + + // Execute the split operation on the accumulated items. + accumulations := strategy.GetAccumulations() + err := p.ExecuteRegions(ctx, accumulations) + if err != nil { + log.Error("Failed to split regions in pipeline; exit restore", zap.Error(err), zap.Duration("duration", time.Since(startTime))) + return item, errors.Annotate(err, "Execute region split on accmulated files failed") + } + // Reset accumulations after the split operation. + strategy.ResetAccumulations() + log.Info("Completed region split in pipeline", zap.Duration("duration", time.Since(startTime))) + } + // Return the item without filtering it out. + return item, nil + }) +} diff --git a/br/pkg/restore/restorer_test.go b/br/pkg/restore/restorer_test.go new file mode 100644 index 0000000000000..8e26245dabf06 --- /dev/null +++ b/br/pkg/restore/restorer_test.go @@ -0,0 +1,280 @@ +// Copyright 2024 PingCAP, Inc. +package restore_test + +import ( + "context" + "testing" + + "github.com/pingcap/errors" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/tidb/br/pkg/restore" + "github.com/pingcap/tidb/br/pkg/restore/split" + restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" + "github.com/pingcap/tidb/br/pkg/utils/iter" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/stretchr/testify/require" +) + +// Helper function to create test files +func createTestFiles() []*backuppb.File { + return []*backuppb.File{ + {Name: "file1.sst", TotalKvs: 10}, + {Name: "file2.sst", TotalKvs: 20}, + } +} + +type fakeImporter struct { + restore.FileImporter + hasError bool +} + +func (f *fakeImporter) Import(ctx context.Context, fileSets ...restore.BackupFileSet) error { + if f.hasError { + return errors.New("import error") + } + return nil +} + +func (f *fakeImporter) Close() error { + return nil +} + +func TestSimpleRestorerImportAndProgress(t *testing.T) { + ctx := context.Background() + files := createTestFiles() + progressCount := int64(0) + + workerPool := util.NewWorkerPool(2, "simple-restorer") + restorer := restore.NewSimpleSstRestorer(ctx, &fakeImporter{}, workerPool, nil) + + fileSet := restore.BatchBackupFileSet{ + {SSTFiles: files}, + } + err := restorer.GoRestore(func(progress int64) { + progressCount += progress + }, fileSet) + require.NoError(t, err) + err = restorer.WaitUntilFinish() + require.Equal(t, int64(30), progressCount) + require.NoError(t, err) + + batchFileSet := restore.BatchBackupFileSet{ + {SSTFiles: files}, + {SSTFiles: files}, + } + progressCount = int64(0) + err = restorer.GoRestore(func(progress int64) { + progressCount += progress + }, batchFileSet) + require.NoError(t, err) + err = restorer.WaitUntilFinish() + require.NoError(t, err) + require.Equal(t, int64(60), progressCount) +} + +func TestSimpleRestorerWithErrorInImport(t *testing.T) { + ctx := context.Background() + + workerPool := util.NewWorkerPool(2, "simple-restorer") + restorer := restore.NewSimpleSstRestorer(ctx, &fakeImporter{hasError: true}, workerPool, nil) + + files := []*backuppb.File{ + {Name: "file_with_error.sst", TotalKvs: 15}, + } + fileSet := restore.BatchBackupFileSet{ + {SSTFiles: files}, + } + + // Run restore and expect an error + progressCount := int64(0) + restorer.GoRestore(func(progress int64) {}, fileSet) + err := restorer.WaitUntilFinish() + require.Error(t, err) + require.Contains(t, err.Error(), "import error") + require.Equal(t, int64(0), progressCount) +} + +func createSampleBatchFileSets() restore.BatchBackupFileSet { + return restore.BatchBackupFileSet{ + { + TableID: 1001, + SSTFiles: []*backuppb.File{ + {Name: "file1.sst", TotalKvs: 10}, + {Name: "file2.sst", TotalKvs: 20}, + }, + }, + { + TableID: 1002, + SSTFiles: []*backuppb.File{ + {Name: "file3.sst", TotalKvs: 15}, + }, + }, + } +} + +// FakeBalancedFileImporteris a minimal implementation for testing +type FakeBalancedFileImporter struct { + hasError bool + unblockCount int +} + +func (f *FakeBalancedFileImporter) Import(ctx context.Context, fileSets ...restore.BackupFileSet) error { + if f.hasError { + return errors.New("import error") + } + return nil +} + +func (f *FakeBalancedFileImporter) PauseForBackpressure() { + f.unblockCount++ +} + +func (f *FakeBalancedFileImporter) Close() error { + return nil +} + +func TestMultiTablesRestorerRestoreSuccess(t *testing.T) { + ctx := context.Background() + importer := &FakeBalancedFileImporter{} + workerPool := util.NewWorkerPool(2, "multi-tables-restorer") + + restorer := restore.NewMultiTablesRestorer(ctx, importer, workerPool, nil) + + var progress int64 + fileSets := createSampleBatchFileSets() + + restorer.GoRestore(func(p int64) { progress += p }, fileSets) + err := restorer.WaitUntilFinish() + require.NoError(t, err) + + // Ensure progress was tracked correctly + require.Equal(t, int64(2), progress) // Total files group: 2 + require.Equal(t, 1, importer.unblockCount) +} + +func TestMultiTablesRestorerRestoreWithImportError(t *testing.T) { + ctx := context.Background() + importer := &FakeBalancedFileImporter{hasError: true} + workerPool := util.NewWorkerPool(2, "multi-tables-restorer") + + restorer := restore.NewMultiTablesRestorer(ctx, importer, workerPool, nil) + fileSets := createSampleBatchFileSets() + + restorer.GoRestore(func(int64) {}, fileSets) + err := restorer.WaitUntilFinish() + require.Error(t, err) + require.Contains(t, err.Error(), "import error") +} + +func TestMultiTablesRestorerRestoreWithContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + importer := &FakeBalancedFileImporter{} + workerPool := util.NewWorkerPool(2, "multi-tables-restorer") + + restorer := restore.NewMultiTablesRestorer(ctx, importer, workerPool, nil) + + fileSets := createSampleBatchFileSets() + + // Cancel context before restore completes + cancel() + err := restorer.GoRestore(func(int64) {}, fileSets) + require.ErrorIs(t, err, context.Canceled) +} + +// FakeSplitStrategy implements split.SplitStrategy for testing purposes +type FakeSplitStrategy[T any] struct { + shouldSplit bool + accumulated []T +} + +// ShouldSkip determines if a given item should be skipped. For testing, this is hardcoded to `false`. +func (f *FakeSplitStrategy[T]) ShouldSkip(item T) bool { + return false +} + +// Accumulate adds a new item to the accumulated list. +func (f *FakeSplitStrategy[T]) Accumulate(item T) { + f.accumulated = append(f.accumulated, item) +} + +// ShouldSplit returns whether the accumulated items meet the condition for splitting. +func (f *FakeSplitStrategy[T]) ShouldSplit() bool { + return f.shouldSplit +} + +// ResetAccumulations clears the accumulated items. +func (f *FakeSplitStrategy[T]) ResetAccumulations() { + f.accumulated = []T{} +} + +// GetAccumulations returns an iterator for the accumulated items. +func (f *FakeSplitStrategy[T]) GetAccumulations() *split.SplitHelperIterator { + rewrites, ok := any(f.accumulated).([]*split.RewriteSplitter) + if !ok { + panic("GetAccumulations called with non-*split.RewriteSplitter type") + } + return split.NewSplitHelperIterator(rewrites) +} + +// FakeRegionsSplitter is a mock of the RegionsSplitter that records calls to ExecuteRegions +type FakeRegionsSplitter struct { + split.Splitter + executedSplitsCount int + expectedEndKeys [][]byte +} + +func (f *FakeRegionsSplitter) ExecuteRegions(ctx context.Context, items *split.SplitHelperIterator) error { + items.Traverse(func(v split.Valued, endKey []byte, rule *restoreutils.RewriteRules) bool { + f.expectedEndKeys = append(f.expectedEndKeys, endKey) + return true + }) + f.executedSplitsCount += 1 + return nil +} + +func TestWithSplitWithoutTriggersSplit(t *testing.T) { + ctx := context.Background() + fakeSplitter := &FakeRegionsSplitter{ + executedSplitsCount: 0, + } + strategy := &FakeSplitStrategy[string]{shouldSplit: false} + wrapper := &restore.PipelineRestorerWrapper[string]{PipelineRegionsSplitter: fakeSplitter} + + items := iter.FromSlice([]string{"item1", "item2", "item3"}) + splitIter := wrapper.WithSplit(ctx, items, strategy) + + for i := splitIter.TryNext(ctx); !i.Finished; i = splitIter.TryNext(ctx) { + } + + require.Equal(t, fakeSplitter.executedSplitsCount, 0) +} +func TestWithSplitAccumulateAndReset(t *testing.T) { + ctx := context.Background() + fakeSplitter := &FakeRegionsSplitter{} + strategy := &FakeSplitStrategy[*split.RewriteSplitter]{shouldSplit: true} + wrapper := &restore.PipelineRestorerWrapper[*split.RewriteSplitter]{PipelineRegionsSplitter: fakeSplitter} + + // Create RewriteSplitter items + items := iter.FromSlice([]*split.RewriteSplitter{ + split.NewRewriteSpliter([]byte("t_1"), 1, nil, split.NewSplitHelper()), + split.NewRewriteSpliter([]byte("t_2"), 2, nil, split.NewSplitHelper()), + }) + splitIter := wrapper.WithSplit(ctx, items, strategy) + + // Traverse through the split iterator + for i := splitIter.TryNext(ctx); !i.Finished; i = splitIter.TryNext(ctx) { + } + + endKeys := [][]byte{ + codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(2)), + codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(3)), + } + + // Verify that the split happened and the accumulation was reset + require.ElementsMatch(t, endKeys, fakeSplitter.expectedEndKeys) + require.Equal(t, 2, fakeSplitter.executedSplitsCount) + require.Empty(t, strategy.accumulated) +} diff --git a/br/pkg/restore/snap_client/BUILD.bazel b/br/pkg/restore/snap_client/BUILD.bazel index d77985e92b808..b9abcd2e99f7d 100644 --- a/br/pkg/restore/snap_client/BUILD.bazel +++ b/br/pkg/restore/snap_client/BUILD.bazel @@ -65,7 +65,6 @@ go_library( "@org_golang_x_sync//errgroup", "@org_uber_go_multierr//:multierr", "@org_uber_go_zap//:zap", - "@org_uber_go_zap//zapcore", ], ) diff --git a/br/pkg/restore/snap_client/client.go b/br/pkg/restore/snap_client/client.go index 957ec300cff94..a7e0ecab3d230 100644 --- a/br/pkg/restore/snap_client/client.go +++ b/br/pkg/restore/snap_client/client.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/log" "github.com/pingcap/tidb/br/pkg/checkpoint" "github.com/pingcap/tidb/br/pkg/checksum" @@ -66,15 +67,16 @@ const ( strictPlacementPolicyMode = "STRICT" ignorePlacementPolicyMode = "IGNORE" - defaultDDLConcurrency = 100 - maxSplitKeysOnce = 10240 + resetSpeedLimitRetryTimes = 3 + defaultDDLConcurrency = 100 + maxSplitKeysOnce = 10240 ) const minBatchDdlSize = 1 type SnapClient struct { + restorer restore.SstRestorer // Tool clients used by SnapClient - fileImporter *SnapFileImporter pdClient pd.Client pdHTTPClient pdhttp.Client @@ -91,8 +93,7 @@ type SnapClient struct { supportPolicy bool workerPool *tidbutil.WorkerPool - noSchema bool - hasSpeedLimited bool + noSchema bool databases map[string]*metautil.Database ddlJobs []*model.Job @@ -167,6 +168,10 @@ func NewRestoreClient( } } +func (rc *SnapClient) GetRestorer() restore.SstRestorer { + return rc.restorer +} + func (rc *SnapClient) closeConn() { // rc.db can be nil in raw kv mode. if rc.db != nil { @@ -182,8 +187,10 @@ func (rc *SnapClient) Close() { // close the connection, and it must be succeed when in SQL mode. rc.closeConn() - if err := rc.fileImporter.Close(); err != nil { - log.Warn("failed to close file importer") + if rc.restorer != nil { + if err := rc.restorer.Close(); err != nil { + log.Warn("failed to close file restorer") + } } log.Info("Restore client closed") @@ -339,18 +346,18 @@ func (rc *SnapClient) InitCheckpoint( } // t1 is the latest time the checkpoint ranges persisted to the external storage. - t1, err := checkpoint.LoadCheckpointDataForSnapshotRestore(ctx, execCtx, func(tableID int64, rangeKey checkpoint.RestoreValueType) { + t1, err := checkpoint.LoadCheckpointDataForSnapshotRestore(ctx, execCtx, func(tableID int64, v checkpoint.RestoreValueType) { checkpointSet, exists := checkpointSetWithTableID[tableID] if !exists { checkpointSet = make(map[string]struct{}) checkpointSetWithTableID[tableID] = checkpointSet } - checkpointSet[rangeKey.RangeKey] = struct{}{} + checkpointSet[v.RangeKey] = struct{}{} }) if err != nil { return checkpointSetWithTableID, nil, errors.Trace(err) } - // t2 is the latest time the checkpoint checksum persisted to the external storage. + checkpointChecksum, t2, err := checkpoint.LoadCheckpointChecksumForRestore(ctx, execCtx) if err != nil { return checkpointSetWithTableID, nil, errors.Trace(err) @@ -445,7 +452,30 @@ func (rc *SnapClient) Init(g glue.Glue, store kv.Storage) error { return errors.Trace(err) } -func (rc *SnapClient) initClients(ctx context.Context, backend *backuppb.StorageBackend, isRawKvMode bool, isTxnKvMode bool) error { +func SetSpeedLimitFn(ctx context.Context, stores []*metapb.Store, pool *tidbutil.WorkerPool) func(*SnapFileImporter, uint64) error { + return func(importer *SnapFileImporter, limit uint64) error { + eg, ectx := errgroup.WithContext(ctx) + for _, store := range stores { + if err := ectx.Err(); err != nil { + return errors.Trace(err) + } + + finalStore := store + pool.ApplyOnErrorGroup(eg, + func() error { + err := importer.SetDownloadSpeedLimit(ectx, finalStore.GetId(), limit) + if err != nil { + return errors.Trace(err) + } + return nil + }) + } + return eg.Wait() + } +} + +func (rc *SnapClient) initClients(ctx context.Context, backend *backuppb.StorageBackend, isRawKvMode bool, isTxnKvMode bool, + RawStartKey, RawEndKey []byte) error { stores, err := conn.GetAllTiKVStoresWithRetry(ctx, rc.pdClient, util.SkipTiFlash) if err != nil { return errors.Annotate(err, "failed to get stores") @@ -453,15 +483,73 @@ func (rc *SnapClient) initClients(ctx context.Context, backend *backuppb.Storage rc.storeCount = len(stores) rc.updateConcurrency() + var createCallBacks []func(*SnapFileImporter) error + var closeCallBacks []func(*SnapFileImporter) error var splitClientOpts []split.ClientOptionalParameter if isRawKvMode { splitClientOpts = append(splitClientOpts, split.WithRawKV()) + createCallBacks = append(createCallBacks, func(importer *SnapFileImporter) error { + return importer.SetRawRange(RawStartKey, RawEndKey) + }) + } + createCallBacks = append(createCallBacks, func(importer *SnapFileImporter) error { + return importer.CheckMultiIngestSupport(ctx, stores) + }) + if rc.rateLimit != 0 { + setFn := SetSpeedLimitFn(ctx, stores, rc.workerPool) + createCallBacks = append(createCallBacks, func(importer *SnapFileImporter) error { + return setFn(importer, rc.rateLimit) + }) + closeCallBacks = append(closeCallBacks, func(importer *SnapFileImporter) error { + // In future we may need a mechanism to set speed limit in ttl. like what we do in switchmode. TODO + var resetErr error + for retry := 0; retry < resetSpeedLimitRetryTimes; retry++ { + resetErr = setFn(importer, 0) + if resetErr != nil { + log.Warn("failed to reset speed limit, retry it", + zap.Int("retry time", retry), logutil.ShortError(resetErr)) + time.Sleep(time.Duration(retry+3) * time.Second) + continue + } + break + } + if resetErr != nil { + log.Error("failed to reset speed limit, please reset it manually", zap.Error(resetErr)) + } + return resetErr + }) } metaClient := split.NewClient(rc.pdClient, rc.pdHTTPClient, rc.tlsConf, maxSplitKeysOnce, rc.storeCount+1, splitClientOpts...) importCli := importclient.NewImportClient(metaClient, rc.tlsConf, rc.keepaliveConf) - rc.fileImporter, err = NewSnapFileImporter(ctx, metaClient, importCli, backend, isRawKvMode, isTxnKvMode, stores, rc.rewriteMode, rc.concurrencyPerStore) - return errors.Trace(err) + + var fileImporter *SnapFileImporter + opt := NewSnapFileImporterOptions( + rc.cipher, metaClient, importCli, backend, + rc.rewriteMode, stores, rc.concurrencyPerStore, createCallBacks, closeCallBacks, + ) + if isRawKvMode || isTxnKvMode { + mode := Raw + if isTxnKvMode { + mode = Txn + } + // for raw/txn mode. use backupMeta.ApiVersion to create fileImporter + fileImporter, err = NewSnapFileImporter(ctx, rc.backupMeta.ApiVersion, mode, opt) + if err != nil { + return errors.Trace(err) + } + // Raw/Txn restore are not support checkpoint for now + rc.restorer = restore.NewSimpleSstRestorer(ctx, fileImporter, rc.workerPool, nil) + } else { + // or create a fileImporter with the cluster API version + fileImporter, err = NewSnapFileImporter( + ctx, rc.dom.Store().GetCodec().GetAPIVersion(), TiDBFull, opt) + if err != nil { + return errors.Trace(err) + } + rc.restorer = restore.NewMultiTablesRestorer(ctx, fileImporter, rc.workerPool, rc.checkpointRunner) + } + return nil } func (rc *SnapClient) needLoadSchemas(backupMeta *backuppb.BackupMeta) bool { @@ -474,7 +562,10 @@ func (rc *SnapClient) LoadSchemaIfNeededAndInitClient( backupMeta *backuppb.BackupMeta, backend *backuppb.StorageBackend, reader *metautil.MetaReader, - loadStats bool) error { + loadStats bool, + RawStartKey []byte, + RawEndKey []byte, +) error { if rc.needLoadSchemas(backupMeta) { databases, err := metautil.LoadBackupTables(c, reader, loadStats) if err != nil { @@ -499,7 +590,7 @@ func (rc *SnapClient) LoadSchemaIfNeededAndInitClient( rc.backupMeta = backupMeta log.Info("load backupmeta", zap.Int("databases", len(rc.databases)), zap.Int("jobs", len(rc.ddlJobs))) - return rc.initClients(c, backend, backupMeta.IsRawKv, backupMeta.IsTxnKv) + return rc.initClients(c, backend, backupMeta.IsRawKv, backupMeta.IsTxnKv, RawStartKey, RawEndKey) } // IsRawKvMode checks whether the backup data is in raw kv format, in which case transactional recover is forbidden. @@ -949,47 +1040,6 @@ func (rc *SnapClient) ExecDDLs(ctx context.Context, ddlJobs []*model.Job) error return nil } -func (rc *SnapClient) ResetSpeedLimit(ctx context.Context) error { - rc.hasSpeedLimited = false - err := rc.setSpeedLimit(ctx, 0) - if err != nil { - return errors.Trace(err) - } - return nil -} - -func (rc *SnapClient) setSpeedLimit(ctx context.Context, rateLimit uint64) error { - if !rc.hasSpeedLimited { - stores, err := util.GetAllTiKVStores(ctx, rc.pdClient, util.SkipTiFlash) - if err != nil { - return errors.Trace(err) - } - - eg, ectx := errgroup.WithContext(ctx) - for _, store := range stores { - if err := ectx.Err(); err != nil { - return errors.Trace(err) - } - - finalStore := store - rc.workerPool.ApplyOnErrorGroup(eg, - func() error { - err := rc.fileImporter.SetDownloadSpeedLimit(ectx, finalStore.GetId(), rateLimit) - if err != nil { - return errors.Trace(err) - } - return nil - }) - } - - if err := eg.Wait(); err != nil { - return errors.Trace(err) - } - rc.hasSpeedLimited = true - } - return nil -} - func (rc *SnapClient) execAndValidateChecksum( ctx context.Context, tbl *CreatedTable, @@ -1068,54 +1118,3 @@ func (rc *SnapClient) execAndValidateChecksum( logger.Info("success in validating checksum") return nil } - -func (rc *SnapClient) WaitForFilesRestored(ctx context.Context, files []*backuppb.File, updateCh glue.Progress) error { - errCh := make(chan error, len(files)) - eg, ectx := errgroup.WithContext(ctx) - defer close(errCh) - - for _, file := range files { - fileReplica := file - rc.workerPool.ApplyOnErrorGroup(eg, - func() error { - defer func() { - log.Info("import sst files done", logutil.Files(files)) - updateCh.Inc() - }() - return rc.fileImporter.ImportSSTFiles(ectx, []TableIDWithFiles{{Files: []*backuppb.File{fileReplica}, RewriteRules: restoreutils.EmptyRewriteRule()}}, rc.cipher, rc.backupMeta.ApiVersion) - }) - } - if err := eg.Wait(); err != nil { - return errors.Trace(err) - } - return nil -} - -// RestoreRaw tries to restore raw keys in the specified range. -func (rc *SnapClient) RestoreRaw( - ctx context.Context, startKey []byte, endKey []byte, files []*backuppb.File, updateCh glue.Progress, -) error { - start := time.Now() - defer func() { - elapsed := time.Since(start) - log.Info("Restore Raw", - logutil.Key("startKey", startKey), - logutil.Key("endKey", endKey), - zap.Duration("take", elapsed)) - }() - err := rc.fileImporter.SetRawRange(startKey, endKey) - if err != nil { - return errors.Trace(err) - } - - err = rc.WaitForFilesRestored(ctx, files, updateCh) - if err != nil { - return errors.Trace(err) - } - log.Info( - "finish to restore raw range", - logutil.Key("startKey", startKey), - logutil.Key("endKey", endKey), - ) - return nil -} diff --git a/br/pkg/restore/snap_client/client_test.go b/br/pkg/restore/snap_client/client_test.go index dd919646ddf61..4b96877949e23 100644 --- a/br/pkg/restore/snap_client/client_test.go +++ b/br/pkg/restore/snap_client/client_test.go @@ -314,7 +314,7 @@ func TestSetSpeedLimit(t *testing.T) { recordStores = NewRecordStores() start := time.Now() - err := snapclient.MockCallSetSpeedLimit(ctx, FakeImporterClient{}, client, 10) + err := snapclient.MockCallSetSpeedLimit(ctx, mockStores, FakeImporterClient{}, client, 10) cost := time.Since(start) require.NoError(t, err) @@ -337,7 +337,7 @@ func TestSetSpeedLimit(t *testing.T) { split.NewFakePDClient(mockStores, false, nil), nil, nil, split.DefaultTestKeepaliveCfg) // Concurrency needs to be less than the number of stores - err = snapclient.MockCallSetSpeedLimit(ctx, FakeImporterClient{}, client, 2) + err = snapclient.MockCallSetSpeedLimit(ctx, mockStores, FakeImporterClient{}, client, 2) require.Error(t, err) t.Log(err) diff --git a/br/pkg/restore/snap_client/export_test.go b/br/pkg/restore/snap_client/export_test.go index 22b8868217933..09a47b843e279 100644 --- a/br/pkg/restore/snap_client/export_test.go +++ b/br/pkg/restore/snap_client/export_test.go @@ -20,7 +20,10 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/import_sstpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/tidb/br/pkg/metautil" + "github.com/pingcap/tidb/br/pkg/restore" importclient "github.com/pingcap/tidb/br/pkg/restore/internal/import_client" restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" "github.com/pingcap/tidb/pkg/domain" @@ -46,15 +49,26 @@ func MockClient(dbs map[string]*metautil.Database) *SnapClient { } // Mock the call of setSpeedLimit function -func MockCallSetSpeedLimit(ctx context.Context, fakeImportClient importclient.ImporterClient, rc *SnapClient, concurrency uint) (err error) { +func MockCallSetSpeedLimit(ctx context.Context, stores []*metapb.Store, fakeImportClient importclient.ImporterClient, rc *SnapClient, concurrency uint) (err error) { rc.SetRateLimit(42) rc.workerPool = tidbutil.NewWorkerPool(128, "set-speed-limit") - rc.hasSpeedLimited = false - rc.fileImporter, err = NewSnapFileImporter(ctx, nil, fakeImportClient, nil, false, false, nil, rc.rewriteMode, 128) + setFn := SetSpeedLimitFn(ctx, stores, rc.workerPool) + var createCallBacks []func(*SnapFileImporter) error + var closeCallBacks []func(*SnapFileImporter) error + + createCallBacks = append(createCallBacks, func(importer *SnapFileImporter) error { + return setFn(importer, rc.rateLimit) + }) + closeCallBacks = append(createCallBacks, func(importer *SnapFileImporter) error { + return setFn(importer, 0) + }) + opt := NewSnapFileImporterOptions(nil, nil, fakeImportClient, nil, rc.rewriteMode, nil, 128, createCallBacks, closeCallBacks) + fileImporter, err := NewSnapFileImporter(ctx, kvrpcpb.APIVersion(0), TiDBFull, opt) + rc.restorer = restore.NewSimpleSstRestorer(ctx, fileImporter, rc.workerPool, nil) if err != nil { return errors.Trace(err) } - return rc.setSpeedLimit(ctx, rc.rateLimit) + return nil } // CreateTables creates multiple tables, and returns their rewrite rules. diff --git a/br/pkg/restore/snap_client/import.go b/br/pkg/restore/snap_client/import.go index a335f651977d7..4e71c6fbe0cc4 100644 --- a/br/pkg/restore/snap_client/import.go +++ b/br/pkg/restore/snap_client/import.go @@ -34,6 +34,7 @@ import ( "github.com/pingcap/log" berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/restore" importclient "github.com/pingcap/tidb/br/pkg/restore/internal/import_client" "github.com/pingcap/tidb/br/pkg/restore/split" restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" @@ -52,9 +53,10 @@ import ( type KvMode int const ( - TiDB KvMode = iota + TiDBFull KvMode = iota Raw Txn + TiDBCompcated ) const ( @@ -129,6 +131,9 @@ func newStoreTokenChannelMap(stores []*metapb.Store, bufferSize uint) *storeToke } type SnapFileImporter struct { + cipher *backuppb.CipherInfo + apiVersion kvrpcpb.APIVersion + metaClient split.SplitClient importClient importclient.ImporterClient backend *backuppb.StorageBackend @@ -136,6 +141,8 @@ type SnapFileImporter struct { downloadTokensMap *storeTokenChannelMap ingestTokensMap *storeTokenChannelMap + closeCallbacks []func(*SnapFileImporter) error + concurrencyPerStore uint kvMode KvMode @@ -147,43 +154,92 @@ type SnapFileImporter struct { cond *sync.Cond } -func NewSnapFileImporter( - ctx context.Context, +type SnapFileImporterOptions struct { + cipher *backuppb.CipherInfo + metaClient split.SplitClient + importClient importclient.ImporterClient + backend *backuppb.StorageBackend + rewriteMode RewriteMode + tikvStores []*metapb.Store + + concurrencyPerStore uint + createCallBacks []func(*SnapFileImporter) error + closeCallbacks []func(*SnapFileImporter) error +} + +func NewSnapFileImporterOptions( + cipher *backuppb.CipherInfo, metaClient split.SplitClient, importClient importclient.ImporterClient, backend *backuppb.StorageBackend, - isRawKvMode bool, - isTxnKvMode bool, - tikvStores []*metapb.Store, rewriteMode RewriteMode, + tikvStores []*metapb.Store, concurrencyPerStore uint, -) (*SnapFileImporter, error) { - kvMode := TiDB - if isRawKvMode { - kvMode = Raw - } - if isTxnKvMode { - kvMode = Txn - } - - fileImporter := &SnapFileImporter{ + createCallbacks []func(*SnapFileImporter) error, + closeCallbacks []func(*SnapFileImporter) error, +) *SnapFileImporterOptions { + return &SnapFileImporterOptions{ + cipher: cipher, metaClient: metaClient, + importClient: importClient, backend: backend, + rewriteMode: rewriteMode, + tikvStores: tikvStores, + concurrencyPerStore: concurrencyPerStore, + createCallBacks: createCallbacks, + closeCallbacks: closeCallbacks, + } +} + +func NewSnapFileImporterOptionsForTest( + splitClient split.SplitClient, + importClient importclient.ImporterClient, + tikvStores []*metapb.Store, + rewriteMode RewriteMode, + concurrencyPerStore uint, +) *SnapFileImporterOptions { + return &SnapFileImporterOptions{ + metaClient: splitClient, importClient: importClient, - downloadTokensMap: newStoreTokenChannelMap(tikvStores, concurrencyPerStore), - ingestTokensMap: newStoreTokenChannelMap(tikvStores, concurrencyPerStore), - kvMode: kvMode, + tikvStores: tikvStores, rewriteMode: rewriteMode, - cacheKey: fmt.Sprintf("BR-%s-%d", time.Now().Format("20060102150405"), rand.Int63()), concurrencyPerStore: concurrencyPerStore, + } +} + +func NewSnapFileImporter( + ctx context.Context, + apiVersion kvrpcpb.APIVersion, + kvMode KvMode, + options *SnapFileImporterOptions, +) (*SnapFileImporter, error) { + fileImporter := &SnapFileImporter{ + apiVersion: apiVersion, + kvMode: kvMode, + + cipher: options.cipher, + metaClient: options.metaClient, + backend: options.backend, + importClient: options.importClient, + downloadTokensMap: newStoreTokenChannelMap(options.tikvStores, options.concurrencyPerStore), + ingestTokensMap: newStoreTokenChannelMap(options.tikvStores, options.concurrencyPerStore), + rewriteMode: options.rewriteMode, + cacheKey: fmt.Sprintf("BR-%s-%d", time.Now().Format("20060102150405"), rand.Int63()), + concurrencyPerStore: options.concurrencyPerStore, cond: sync.NewCond(new(sync.Mutex)), + closeCallbacks: options.closeCallbacks, } - err := fileImporter.checkMultiIngestSupport(ctx, tikvStores) - return fileImporter, errors.Trace(err) + for _, f := range options.createCallBacks { + err := f(fileImporter) + if err != nil { + return nil, errors.Trace(err) + } + } + return fileImporter, nil } -func (importer *SnapFileImporter) WaitUntilUnblock() { +func (importer *SnapFileImporter) PauseForBackpressure() { importer.cond.L.Lock() for importer.ShouldBlock() { // wait for download worker notified @@ -209,6 +265,12 @@ func (importer *SnapFileImporter) releaseToken(tokenCh chan struct{}) { func (importer *SnapFileImporter) Close() error { if importer != nil && importer.importClient != nil { + for _, f := range importer.closeCallbacks { + err := f(importer) + if err != nil { + log.Warn("failed on close snap importer", zap.Error(err)) + } + } return importer.importClient.CloseGrpcClient() } return nil @@ -222,8 +284,8 @@ func (importer *SnapFileImporter) SetDownloadSpeedLimit(ctx context.Context, sto return errors.Trace(err) } -// checkMultiIngestSupport checks whether all stores support multi-ingest -func (importer *SnapFileImporter) checkMultiIngestSupport(ctx context.Context, tikvStores []*metapb.Store) error { +// CheckMultiIngestSupport checks whether all stores support multi-ingest +func (importer *SnapFileImporter) CheckMultiIngestSupport(ctx context.Context, tikvStores []*metapb.Store) error { storeIDs := make([]uint64, 0, len(tikvStores)) for _, s := range tikvStores { if s.State != metapb.StoreState_Up { @@ -274,7 +336,7 @@ func getKeyRangeByMode(mode KvMode) func(f *backuppb.File, rules *restoreutils.R // getKeyRangeForFiles gets the maximum range on files. func (importer *SnapFileImporter) getKeyRangeForFiles( - filesGroup []TableIDWithFiles, + filesGroup []restore.BackupFileSet, ) ([]byte, []byte, error) { var ( startKey, endKey []byte @@ -283,7 +345,7 @@ func (importer *SnapFileImporter) getKeyRangeForFiles( ) getRangeFn := getKeyRangeByMode(importer.kvMode) for _, files := range filesGroup { - for _, f := range files.Files { + for _, f := range files.SSTFiles { start, end, err = getRangeFn(f, files.RewriteRules) if err != nil { return nil, nil, errors.Trace(err) @@ -300,17 +362,15 @@ func (importer *SnapFileImporter) getKeyRangeForFiles( return startKey, endKey, nil } -// ImportSSTFiles tries to import a file. +// Import tries to import a file. // Assert 1: All rewrite rules must contain raw key prefix. // Assert 2: len(filesGroup[any].Files) > 0. -func (importer *SnapFileImporter) ImportSSTFiles( +func (importer *SnapFileImporter) Import( ctx context.Context, - filesGroup []TableIDWithFiles, - cipher *backuppb.CipherInfo, - apiVersion kvrpcpb.APIVersion, + backupFileSets ...restore.BackupFileSet, ) error { // Rewrite the start key and end key of file to scan regions - startKey, endKey, err := importer.getKeyRangeForFiles(filesGroup) + startKey, endKey, err := importer.getKeyRangeForFiles(backupFileSets) if err != nil { return errors.Trace(err) } @@ -329,7 +389,7 @@ func (importer *SnapFileImporter) ImportSSTFiles( for _, regionInfo := range regionInfos { info := regionInfo // Try to download file. - downloadMetas, errDownload := importer.download(ctx, info, filesGroup, cipher, apiVersion) + downloadMetas, errDownload := importer.download(ctx, info, backupFileSets, importer.cipher, importer.apiVersion) if errDownload != nil { log.Warn("download file failed, retry later", logutil.Region(info.Region), @@ -355,11 +415,11 @@ func (importer *SnapFileImporter) ImportSSTFiles( return nil }, utils.NewImportSSTBackoffer()) if err != nil { - log.Error("import sst file failed after retry, stop the whole progress", zapFilesGroup(filesGroup), zap.Error(err)) + log.Error("import sst file failed after retry, stop the whole progress", restore.ZapBatchBackupFileSet(backupFileSets), zap.Error(err)) return errors.Trace(err) } - for _, files := range filesGroup { - for _, f := range files.Files { + for _, files := range backupFileSets { + for _, f := range files.SSTFiles { summary.CollectSuccessUnit(summary.TotalKV, 1, f.TotalKvs) summary.CollectSuccessUnit(summary.TotalBytes, 1, f.TotalBytes) } @@ -452,7 +512,7 @@ func getSSTMetaFromFile( func (importer *SnapFileImporter) download( ctx context.Context, regionInfo *split.RegionInfo, - filesGroup []TableIDWithFiles, + filesGroup []restore.BackupFileSet, cipher *backuppb.CipherInfo, apiVersion kvrpcpb.APIVersion, ) ([]*import_sstpb.SSTMeta, error) { @@ -476,7 +536,7 @@ func (importer *SnapFileImporter) download( e = status.Error(codes.Unavailable, "the connection to TiKV has been cut by a neko, meow :3") }) if isDecryptSstErr(e) { - log.Info("fail to decrypt when download sst, try again with no-crypt", zapFilesGroup(filesGroup)) + log.Info("fail to decrypt when download sst, try again with no-crypt") if importer.kvMode == Raw || importer.kvMode == Txn { downloadMetas, e = importer.downloadRawKVSST(ctx, regionInfo, filesGroup, nil, apiVersion) } else { @@ -558,7 +618,7 @@ func (importer *SnapFileImporter) buildDownloadRequest( func (importer *SnapFileImporter) downloadSST( ctx context.Context, regionInfo *split.RegionInfo, - filesGroup []TableIDWithFiles, + filesGroup []restore.BackupFileSet, cipher *backuppb.CipherInfo, apiVersion kvrpcpb.APIVersion, ) ([]*import_sstpb.SSTMeta, error) { @@ -567,7 +627,7 @@ func (importer *SnapFileImporter) downloadSST( resultMetasMap := make(map[string]*import_sstpb.SSTMeta) downloadReqsMap := make(map[string]*import_sstpb.DownloadRequest) for _, files := range filesGroup { - for _, file := range files.Files { + for _, file := range files.SSTFiles { req, sstMeta, err := importer.buildDownloadRequest(file, files.RewriteRules, regionInfo, cipher) if err != nil { return nil, errors.Trace(err) @@ -650,13 +710,13 @@ func (importer *SnapFileImporter) downloadSST( func (importer *SnapFileImporter) downloadRawKVSST( ctx context.Context, regionInfo *split.RegionInfo, - filesGroup []TableIDWithFiles, + filesGroup []restore.BackupFileSet, cipher *backuppb.CipherInfo, apiVersion kvrpcpb.APIVersion, ) ([]*import_sstpb.SSTMeta, error) { downloadMetas := make([]*import_sstpb.SSTMeta, 0, len(filesGroup)*2+1) for _, files := range filesGroup { - for _, file := range files.Files { + for _, file := range files.SSTFiles { // Empty rule var rule import_sstpb.RewriteRule sstMeta, err := getSSTMetaFromFile(file, regionInfo.Region, &rule, RewriteModeLegacy) diff --git a/br/pkg/restore/snap_client/import_test.go b/br/pkg/restore/snap_client/import_test.go index 324b2ec9a007a..9d9c79fe1a6f6 100644 --- a/br/pkg/restore/snap_client/import_test.go +++ b/br/pkg/restore/snap_client/import_test.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/kvproto/pkg/import_sstpb" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/tidb/br/pkg/restore" importclient "github.com/pingcap/tidb/br/pkg/restore/internal/import_client" snapclient "github.com/pingcap/tidb/br/pkg/restore/snap_client" "github.com/pingcap/tidb/br/pkg/restore/split" @@ -74,7 +75,7 @@ func TestGetKeyRangeByMode(t *testing.T) { require.Equal(t, []byte(""), end) // normal kv: the keys must be encoded. - testFn := snapclient.GetKeyRangeByMode(snapclient.TiDB) + testFn := snapclient.GetKeyRangeByMode(snapclient.TiDBFull) start, end, err = testFn(file, rule) require.NoError(t, err) require.Equal(t, codec.EncodeBytes(nil, []byte("t2a")), start) @@ -161,7 +162,8 @@ func TestSnapImporter(t *testing.T) { splitClient.AppendPdRegion(region) } importClient := newFakeImporterClient() - importer, err := snapclient.NewSnapFileImporter(ctx, splitClient, importClient, nil, false, false, generateStores(), snapclient.RewriteModeKeyspace, 10) + opt := snapclient.NewSnapFileImporterOptionsForTest(splitClient, importClient, generateStores(), snapclient.RewriteModeKeyspace, 10) + importer, err := snapclient.NewSnapFileImporter(ctx, kvrpcpb.APIVersion_V1, snapclient.TiDBFull, opt) require.NoError(t, err) err = importer.SetDownloadSpeedLimit(ctx, 1, 5) require.NoError(t, err) @@ -170,8 +172,8 @@ func TestSnapImporter(t *testing.T) { require.Error(t, err) files, rules := generateFiles() for _, file := range files { - importer.WaitUntilUnblock() - err = importer.ImportSSTFiles(ctx, []snapclient.TableIDWithFiles{{Files: []*backuppb.File{file}, RewriteRules: rules}}, nil, kvrpcpb.APIVersion_V1) + importer.PauseForBackpressure() + err = importer.Import(ctx, restore.BackupFileSet{SSTFiles: []*backuppb.File{file}, RewriteRules: rules}) require.NoError(t, err) } err = importer.Close() @@ -185,14 +187,15 @@ func TestSnapImporterRaw(t *testing.T) { splitClient.AppendPdRegion(region) } importClient := newFakeImporterClient() - importer, err := snapclient.NewSnapFileImporter(ctx, splitClient, importClient, nil, true, false, generateStores(), snapclient.RewriteModeKeyspace, 10) + opt := snapclient.NewSnapFileImporterOptionsForTest(splitClient, importClient, generateStores(), snapclient.RewriteModeKeyspace, 10) + importer, err := snapclient.NewSnapFileImporter(ctx, kvrpcpb.APIVersion_V1, snapclient.Raw, opt) require.NoError(t, err) err = importer.SetRawRange([]byte(""), []byte("")) require.NoError(t, err) files, rules := generateFiles() for _, file := range files { - importer.WaitUntilUnblock() - err = importer.ImportSSTFiles(ctx, []snapclient.TableIDWithFiles{{Files: []*backuppb.File{file}, RewriteRules: rules}}, nil, kvrpcpb.APIVersion_V1) + importer.PauseForBackpressure() + err = importer.Import(ctx, restore.BackupFileSet{SSTFiles: []*backuppb.File{file}, RewriteRules: rules}) require.NoError(t, err) } err = importer.Close() diff --git a/br/pkg/restore/snap_client/pipeline_items.go b/br/pkg/restore/snap_client/pipeline_items.go index 3f74434e72f02..8f1b7c129202f 100644 --- a/br/pkg/restore/snap_client/pipeline_items.go +++ b/br/pkg/restore/snap_client/pipeline_items.go @@ -19,10 +19,8 @@ import ( "time" "github.com/pingcap/errors" - backuppb "github.com/pingcap/kvproto/pkg/brpb" "github.com/pingcap/log" "github.com/pingcap/tidb/br/pkg/glue" - "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/br/pkg/metautil" restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" "github.com/pingcap/tidb/br/pkg/storage" @@ -34,7 +32,6 @@ import ( "github.com/pingcap/tidb/pkg/util/engine" pdhttp "github.com/tikv/pd/client/http" "go.uber.org/zap" - "go.uber.org/zap/zapcore" "golang.org/x/sync/errgroup" ) @@ -58,50 +55,6 @@ type PhysicalTable struct { RewriteRules *restoreutils.RewriteRules } -type TableIDWithFiles struct { - TableID int64 - - Files []*backuppb.File - // RewriteRules is the rewrite rules for the specify table. - // because these rules belongs to the *one table*. - // we can hold them here. - RewriteRules *restoreutils.RewriteRules -} - -type zapFilesGroupMarshaler []TableIDWithFiles - -// MarshalLogObjectForFiles is an internal util function to zap something having `Files` field. -func MarshalLogObjectForFiles(files []TableIDWithFiles, encoder zapcore.ObjectEncoder) error { - return zapFilesGroupMarshaler(files).MarshalLogObject(encoder) -} - -func (fgs zapFilesGroupMarshaler) MarshalLogObject(encoder zapcore.ObjectEncoder) error { - elements := make([]string, 0) - total := 0 - totalKVs := uint64(0) - totalBytes := uint64(0) - totalSize := uint64(0) - for _, fg := range fgs { - for _, f := range fg.Files { - total += 1 - elements = append(elements, f.GetName()) - totalKVs += f.GetTotalKvs() - totalBytes += f.GetTotalBytes() - totalSize += f.GetSize_() - } - } - encoder.AddInt("total", total) - _ = encoder.AddArray("files", logutil.AbbreviatedArrayMarshaler(elements)) - encoder.AddUint64("totalKVs", totalKVs) - encoder.AddUint64("totalBytes", totalBytes) - encoder.AddUint64("totalSize", totalSize) - return nil -} - -func zapFilesGroup(filesGroup []TableIDWithFiles) zap.Field { - return zap.Object("files", zapFilesGroupMarshaler(filesGroup)) -} - func defaultOutputTableChan() chan *CreatedTable { return make(chan *CreatedTable, defaultChannelSize) } diff --git a/br/pkg/restore/snap_client/tikv_sender.go b/br/pkg/restore/snap_client/tikv_sender.go index 85aeac0d76f24..57f73835beda7 100644 --- a/br/pkg/restore/snap_client/tikv_sender.go +++ b/br/pkg/restore/snap_client/tikv_sender.go @@ -25,15 +25,13 @@ import ( "github.com/pingcap/failpoint" backuppb "github.com/pingcap/kvproto/pkg/brpb" "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/checkpoint" - "github.com/pingcap/tidb/br/pkg/glue" "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/restore" "github.com/pingcap/tidb/br/pkg/restore/split" restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" "github.com/pingcap/tidb/br/pkg/summary" "github.com/pingcap/tidb/pkg/tablecodec" "go.uber.org/zap" - "golang.org/x/sync/errgroup" ) func getSortedPhysicalTables(createdTables []*CreatedTable) []*PhysicalTable { @@ -91,7 +89,7 @@ func mapTableToFiles(files []*backuppb.File) (map[int64][]*backuppb.File, int) { } // filterOutFiles filters out files that exist in the checkpoint set. -func filterOutFiles(checkpointSet map[string]struct{}, files []*backuppb.File, updateCh glue.Progress) []*backuppb.File { +func filterOutFiles(checkpointSet map[string]struct{}, files []*backuppb.File, onProgress func(int64)) []*backuppb.File { progress := int(0) totalKVs := uint64(0) totalBytes := uint64(0) @@ -110,7 +108,7 @@ func filterOutFiles(checkpointSet map[string]struct{}, files []*backuppb.File, u } if progress > 0 { // (split/scatter + download/ingest) / (default cf + write cf) - updateCh.IncBy(int64(progress) * 2 / 2) + onProgress(int64(progress) * 2 / 2) summary.CollectSuccessUnit(summary.TotalKV, progress, totalKVs) summary.CollectSuccessUnit(summary.SkippedKVCountByCheckpoint, progress, totalKVs) summary.CollectSuccessUnit(summary.TotalBytes, progress, totalBytes) @@ -130,8 +128,8 @@ func SortAndValidateFileRanges( checkpointSetWithTableID map[int64]map[string]struct{}, splitSizeBytes, splitKeyCount uint64, splitOnTable bool, - updateCh glue.Progress, -) ([][]byte, [][]TableIDWithFiles, error) { + onProgress func(int64), +) ([][]byte, []restore.BatchBackupFileSet, error) { sortedPhysicalTables := getSortedPhysicalTables(createdTables) // mapping table ID to its backup files fileOfTable, hintSplitKeyCount := mapTableToFiles(allFiles) @@ -144,8 +142,8 @@ func SortAndValidateFileRanges( lastKey []byte = nil // group the files by the generated split keys - tableIDWithFilesGroup = make([][]TableIDWithFiles, 0, hintSplitKeyCount) - lastFilesGroup []TableIDWithFiles = nil + tableIDWithFilesGroup = make([]restore.BatchBackupFileSet, 0, hintSplitKeyCount) + lastFilesGroup restore.BatchBackupFileSet = nil // statistic mergedRangeCount = 0 @@ -232,17 +230,17 @@ func SortAndValidateFileRanges( // checkpoint filter out the import done files in the previous restore executions. // Notice that skip ranges after select split keys in order to make the split keys // always the same. - newFiles := filterOutFiles(checkpointSet, rg.Files, updateCh) + newFiles := filterOutFiles(checkpointSet, rg.Files, onProgress) // append the new files into the group if len(newFiles) > 0 { if len(lastFilesGroup) == 0 || lastFilesGroup[len(lastFilesGroup)-1].TableID != table.NewPhysicalID { - lastFilesGroup = append(lastFilesGroup, TableIDWithFiles{ + lastFilesGroup = append(lastFilesGroup, restore.BackupFileSet{ TableID: table.NewPhysicalID, - Files: nil, + SSTFiles: nil, RewriteRules: table.RewriteRules, }) } - lastFilesGroup[len(lastFilesGroup)-1].Files = append(lastFilesGroup[len(lastFilesGroup)-1].Files, newFiles...) + lastFilesGroup[len(lastFilesGroup)-1].SSTFiles = append(lastFilesGroup[len(lastFilesGroup)-1].SSTFiles, newFiles...) } } @@ -286,7 +284,7 @@ func (rc *SnapClient) RestoreTables( checkpointSetWithTableID map[int64]map[string]struct{}, splitSizeBytes, splitKeyCount uint64, splitOnTable bool, - updateCh glue.Progress, + onProgress func(int64), ) error { if err := placementRuleManager.SetPlacementRule(ctx, createdTables); err != nil { return errors.Trace(err) @@ -299,20 +297,21 @@ func (rc *SnapClient) RestoreTables( }() start := time.Now() - sortedSplitKeys, tableIDWithFilesGroup, err := SortAndValidateFileRanges(createdTables, allFiles, checkpointSetWithTableID, splitSizeBytes, splitKeyCount, splitOnTable, updateCh) + sortedSplitKeys, tableIDWithFilesGroup, err := SortAndValidateFileRanges(createdTables, allFiles, checkpointSetWithTableID, splitSizeBytes, splitKeyCount, splitOnTable, onProgress) if err != nil { return errors.Trace(err) } log.Info("Restore Stage Duration", zap.String("stage", "merge ranges"), zap.Duration("take", time.Since(start))) + newProgress := func(i int64) { onProgress(i) } start = time.Now() - if err = rc.SplitPoints(ctx, sortedSplitKeys, updateCh, false); err != nil { + if err = rc.SplitPoints(ctx, sortedSplitKeys, newProgress, false); err != nil { return errors.Trace(err) } log.Info("Restore Stage Duration", zap.String("stage", "split regions"), zap.Duration("take", time.Since(start))) start = time.Now() - if err = rc.RestoreSSTFiles(ctx, tableIDWithFilesGroup, updateCh); err != nil { + if err = rc.RestoreSSTFiles(ctx, tableIDWithFilesGroup, newProgress); err != nil { return errors.Trace(err) } elapsed := time.Since(start) @@ -327,15 +326,14 @@ func (rc *SnapClient) RestoreTables( func (rc *SnapClient) SplitPoints( ctx context.Context, sortedSplitKeys [][]byte, - updateCh glue.Progress, + onProgress func(int64), isRawKv bool, ) error { splitClientOpts := make([]split.ClientOptionalParameter, 0, 2) splitClientOpts = append(splitClientOpts, split.WithOnSplit(func(keys [][]byte) { - for range keys { - updateCh.Inc() - } + onProgress(int64(len(keys))) })) + // TODO seems duplicate with metaClient. if isRawKv { splitClientOpts = append(splitClientOpts, split.WithRawKV()) } @@ -366,13 +364,9 @@ func getFileRangeKey(f string) string { // RestoreSSTFiles tries to do something prepare work, such as set speed limit, and restore the files. func (rc *SnapClient) RestoreSSTFiles( ctx context.Context, - tableIDWithFilesGroup [][]TableIDWithFiles, - updateCh glue.Progress, + tableIDWithFilesGroup []restore.BatchBackupFileSet, + onProgress func(int64), ) (retErr error) { - if err := rc.setSpeedLimit(ctx, rc.rateLimit); err != nil { - return errors.Trace(err) - } - failpoint.Inject("corrupt-files", func(v failpoint.Value) { if cmd, ok := v.(string); ok { switch cmd { @@ -382,7 +376,7 @@ func (rc *SnapClient) RestoreSSTFiles( case "only-last-table-files": // check whether all the files, except last table files, are skipped by checkpoint for _, tableIDWithFiless := range tableIDWithFilesGroup[:len(tableIDWithFilesGroup)-1] { for _, tableIDWithFiles := range tableIDWithFiless { - if len(tableIDWithFiles.Files) > 0 { + if len(tableIDWithFiles.SSTFiles) > 0 { log.Panic("has files but not the last table files") } } @@ -391,69 +385,9 @@ func (rc *SnapClient) RestoreSSTFiles( } }) - return rc.restoreSSTFilesInternal(ctx, tableIDWithFilesGroup, updateCh) -} - -func (rc *SnapClient) restoreSSTFilesInternal( - ctx context.Context, - tableIDWithFilesGroup [][]TableIDWithFiles, - updateCh glue.Progress, -) error { - eg, ectx := errgroup.WithContext(ctx) - for _, tableIDWithFiles := range tableIDWithFilesGroup { - if ectx.Err() != nil { - log.Warn("Restoring encountered error and already stopped, give up remained files.", - logutil.ShortError(ectx.Err())) - // We will fetch the error from the errgroup then (If there were). - // Also note if the parent context has been canceled or something, - // breaking here directly is also a reasonable behavior. - break - } - filesReplica := tableIDWithFiles - rc.fileImporter.WaitUntilUnblock() - rc.workerPool.ApplyOnErrorGroup(eg, func() (restoreErr error) { - fileStart := time.Now() - defer func() { - if restoreErr == nil { - log.Info("import files done", zapFilesGroup(filesReplica), - zap.Duration("take", time.Since(fileStart))) - updateCh.Inc() - } - }() - if importErr := rc.fileImporter.ImportSSTFiles(ectx, filesReplica, rc.cipher, rc.dom.Store().GetCodec().GetAPIVersion()); importErr != nil { - return errors.Trace(importErr) - } - - // the data of this range has been import done - if rc.checkpointRunner != nil && len(filesReplica) > 0 { - for _, filesGroup := range filesReplica { - rangeKeySet := make(map[string]struct{}) - for _, file := range filesGroup.Files { - rangeKey := getFileRangeKey(file.Name) - // Assert that the files having the same rangeKey are all in the current filesGroup.Files - rangeKeySet[rangeKey] = struct{}{} - } - for rangeKey := range rangeKeySet { - // The checkpoint range shows this ranges of kvs has been restored into - // the table corresponding to the table-id. - if err := checkpoint.AppendRangesForRestore(ectx, rc.checkpointRunner, filesGroup.TableID, rangeKey); err != nil { - return errors.Trace(err) - } - } - } - } - - return nil - }) - } - - if err := eg.Wait(); err != nil { - summary.CollectFailureUnit("file", err) - log.Error("restore files failed", zap.Error(err)) - return errors.Trace(err) + retErr = rc.restorer.GoRestore(onProgress, tableIDWithFilesGroup...) + if retErr != nil { + return retErr } - // Once the parent context canceled and there is no task running in the errgroup, - // we may break the for loop without error in the errgroup. (Will this happen?) - // At that time, return the error in the context here. - return ctx.Err() + return rc.restorer.WaitUntilFinish() } diff --git a/br/pkg/restore/snap_client/tikv_sender_test.go b/br/pkg/restore/snap_client/tikv_sender_test.go index b23ec6298d40f..b5a38ffc839a3 100644 --- a/br/pkg/restore/snap_client/tikv_sender_test.go +++ b/br/pkg/restore/snap_client/tikv_sender_test.go @@ -22,6 +22,7 @@ import ( backuppb "github.com/pingcap/kvproto/pkg/brpb" "github.com/pingcap/tidb/br/pkg/glue" "github.com/pingcap/tidb/br/pkg/metautil" + "github.com/pingcap/tidb/br/pkg/restore" snapclient "github.com/pingcap/tidb/br/pkg/restore/snap_client" restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" "github.com/pingcap/tidb/pkg/kv" @@ -207,14 +208,14 @@ func key(tableID int64, row int) []byte { return tablecodec.EncodeRowKeyWithHandle(downstreamID(tableID), kv.IntHandle(row)) } -func files(physicalTableID int64, startRows []int, cfs []string) snapclient.TableIDWithFiles { +func files(physicalTableID int64, startRows []int, cfs []string) restore.BackupFileSet { files := make([]*backuppb.File, 0, len(startRows)) for i, startRow := range startRows { files = append(files, &backuppb.File{Name: fmt.Sprintf("file_%d_%d_%s.sst", physicalTableID, startRow, cfs[i])}) } - return snapclient.TableIDWithFiles{ - TableID: downstreamID(physicalTableID), - Files: files, + return restore.BackupFileSet{ + TableID: downstreamID(physicalTableID), + SSTFiles: files, } } @@ -247,7 +248,7 @@ func TestSortAndValidateFileRanges(t *testing.T) { // expected result splitKeys [][]byte - tableIDWithFilesGroups [][]snapclient.TableIDWithFiles + tableIDWithFilesGroups [][]restore.BackupFileSet }{ { // large sst, split-on-table, no checkpoint upstreamTableIDs: []int64{100, 200, 300}, @@ -270,7 +271,7 @@ func TestSortAndValidateFileRanges(t *testing.T) { splitKeys: [][]byte{ /*split table key*/ key(202, 2), /*split table key*/ }, - tableIDWithFilesGroups: [][]snapclient.TableIDWithFiles{ + tableIDWithFilesGroups: [][]restore.BackupFileSet{ {files(102, []int{1}, []string{w})}, {files(202, []int{1, 1}, []string{w, d})}, {files(202, []int{2, 2}, []string{w, d})}, @@ -298,7 +299,7 @@ func TestSortAndValidateFileRanges(t *testing.T) { splitKeys: [][]byte{ /*split table key*/ key(202, 2), /*split table key*/ }, - tableIDWithFilesGroups: [][]snapclient.TableIDWithFiles{ + tableIDWithFilesGroups: [][]restore.BackupFileSet{ {files(102, []int{1}, []string{w})}, //{files(202, []int{1, 1}, []string{w, d})}, {files(202, []int{2, 2}, []string{w, d})}, @@ -323,7 +324,7 @@ func TestSortAndValidateFileRanges(t *testing.T) { splitKeys: [][]byte{ key(102, 2), key(202, 2), key(202, 3), key(302, 2), key(100, 2), }, - tableIDWithFilesGroups: [][]snapclient.TableIDWithFiles{ + tableIDWithFilesGroups: [][]restore.BackupFileSet{ {files(102, []int{1}, []string{w})}, {files(202, []int{1, 1}, []string{w, d})}, {files(202, []int{2, 2}, []string{w, d})}, @@ -351,7 +352,7 @@ func TestSortAndValidateFileRanges(t *testing.T) { splitKeys: [][]byte{ key(102, 2), key(202, 2), key(202, 3), key(302, 2), key(100, 2), }, - tableIDWithFilesGroups: [][]snapclient.TableIDWithFiles{ + tableIDWithFilesGroups: [][]restore.BackupFileSet{ {files(102, []int{1}, []string{w})}, //{files(202, []int{1, 1}, []string{w, d})}, {files(202, []int{2, 2}, []string{w, d})}, @@ -376,7 +377,7 @@ func TestSortAndValidateFileRanges(t *testing.T) { splitKeys: [][]byte{ key(202, 2), /*split table key*/ }, - tableIDWithFilesGroups: [][]snapclient.TableIDWithFiles{ + tableIDWithFilesGroups: [][]restore.BackupFileSet{ {files(102, []int{1}, []string{w})}, {files(202, []int{1, 1}, []string{w, d})}, {files(202, []int{2, 2}, []string{w, d})}, @@ -404,7 +405,7 @@ func TestSortAndValidateFileRanges(t *testing.T) { splitKeys: [][]byte{ key(202, 2), /*split table key*/ }, - tableIDWithFilesGroups: [][]snapclient.TableIDWithFiles{ + tableIDWithFilesGroups: [][]restore.BackupFileSet{ {files(102, []int{1}, []string{w})}, // {files(202, []int{1, 1}, []string{w, d})}, {files(202, []int{2, 2}, []string{w, d})}, @@ -429,7 +430,7 @@ func TestSortAndValidateFileRanges(t *testing.T) { splitKeys: [][]byte{ key(202, 2), key(302, 2), key(100, 2), }, - tableIDWithFilesGroups: [][]snapclient.TableIDWithFiles{ + tableIDWithFilesGroups: [][]restore.BackupFileSet{ {files(102, []int{1}, []string{w}), files(202, []int{1, 1}, []string{w, d})}, {files(202, []int{2, 2}, []string{w, d}), files(302, []int{1}, []string{w})}, {files(100, []int{1, 1}, []string{w, d})}, @@ -455,7 +456,7 @@ func TestSortAndValidateFileRanges(t *testing.T) { splitKeys: [][]byte{ key(202, 2), key(302, 2), key(100, 2), }, - tableIDWithFilesGroups: [][]snapclient.TableIDWithFiles{ + tableIDWithFilesGroups: [][]restore.BackupFileSet{ {files(102, []int{1}, []string{w})}, {files(202, []int{2, 2}, []string{w, d}), files(302, []int{1}, []string{w})}, }, @@ -475,7 +476,7 @@ func TestSortAndValidateFileRanges(t *testing.T) { splitKeyCount: 450, splitOnTable: true, splitKeys: [][]byte{}, - tableIDWithFilesGroups: [][]snapclient.TableIDWithFiles{ + tableIDWithFilesGroups: [][]restore.BackupFileSet{ {files(102, []int{1}, []string{w})}, {files(202, []int{1, 1, 2, 2}, []string{w, d, w, d})}, {files(302, []int{1}, []string{w})}, @@ -500,7 +501,7 @@ func TestSortAndValidateFileRanges(t *testing.T) { splitKeyCount: 450, splitOnTable: true, splitKeys: [][]byte{}, - tableIDWithFilesGroups: [][]snapclient.TableIDWithFiles{ + tableIDWithFilesGroups: [][]restore.BackupFileSet{ {files(102, []int{1}, []string{w})}, {files(202, []int{2, 2}, []string{w, d})}, {files(302, []int{1}, []string{w})}, @@ -523,7 +524,7 @@ func TestSortAndValidateFileRanges(t *testing.T) { splitKeys: [][]byte{ key(102, 2), key(202, 3), key(100, 2), }, - tableIDWithFilesGroups: [][]snapclient.TableIDWithFiles{ + tableIDWithFilesGroups: [][]restore.BackupFileSet{ {files(102, []int{1}, []string{w})}, {files(202, []int{1, 1, 2, 2}, []string{w, d, w, d})}, {files(302, []int{1}, []string{w}), files(100, []int{1, 1}, []string{w, d})}, @@ -549,7 +550,7 @@ func TestSortAndValidateFileRanges(t *testing.T) { splitKeys: [][]byte{ key(102, 2), key(202, 3), key(100, 2), }, - tableIDWithFilesGroups: [][]snapclient.TableIDWithFiles{ + tableIDWithFilesGroups: [][]restore.BackupFileSet{ {files(102, []int{1}, []string{w})}, {files(202, []int{2, 2}, []string{w, d})}, {files(302, []int{1}, []string{w})}, @@ -572,7 +573,7 @@ func TestSortAndValidateFileRanges(t *testing.T) { splitKeys: [][]byte{ key(202, 3), key(100, 2), }, - tableIDWithFilesGroups: [][]snapclient.TableIDWithFiles{ + tableIDWithFilesGroups: [][]restore.BackupFileSet{ {files(102, []int{1}, []string{w}), files(202, []int{1, 1, 2, 2}, []string{w, d, w, d})}, {files(302, []int{1}, []string{w}), files(100, []int{1, 1}, []string{w, d})}, }, @@ -597,7 +598,7 @@ func TestSortAndValidateFileRanges(t *testing.T) { splitKeys: [][]byte{ key(202, 3), key(100, 2), }, - tableIDWithFilesGroups: [][]snapclient.TableIDWithFiles{ + tableIDWithFilesGroups: [][]restore.BackupFileSet{ {files(102, []int{1}, []string{w}), files(202, []int{2, 2}, []string{w, d, w, d})}, {files(302, []int{1}, []string{w})}, }, @@ -619,7 +620,7 @@ func TestSortAndValidateFileRanges(t *testing.T) { splitKeys: [][]byte{ key(202, 2), key(302, 2), key(100, 2), }, - tableIDWithFilesGroups: [][]snapclient.TableIDWithFiles{ + tableIDWithFilesGroups: [][]restore.BackupFileSet{ {files(102, []int{1}, []string{w}), files(202, []int{1, 1}, []string{w, d})}, {files(202, []int{2, 2}, []string{w, d}), files(302, []int{1}, []string{w})}, {files(100, []int{1, 1}, []string{w, d})}, @@ -645,7 +646,7 @@ func TestSortAndValidateFileRanges(t *testing.T) { splitKeys: [][]byte{ key(202, 2), key(302, 2), key(100, 2), }, - tableIDWithFilesGroups: [][]snapclient.TableIDWithFiles{ + tableIDWithFilesGroups: [][]restore.BackupFileSet{ {files(102, []int{1}, []string{w})}, {files(202, []int{2, 2}, []string{w, d}), files(302, []int{1}, []string{w})}, }, @@ -655,7 +656,8 @@ func TestSortAndValidateFileRanges(t *testing.T) { for i, cs := range cases { t.Log(i) createdTables := generateCreatedTables(t, cs.upstreamTableIDs, cs.upstreamPartitionIDs, downstreamID) - splitKeys, tableIDWithFilesGroups, err := snapclient.SortAndValidateFileRanges(createdTables, cs.files, cs.checkpointSetWithTableID, cs.splitSizeBytes, cs.splitKeyCount, cs.splitOnTable, updateCh) + onProgress := func(i int64) { updateCh.IncBy(i) } + splitKeys, tableIDWithFilesGroups, err := snapclient.SortAndValidateFileRanges(createdTables, cs.files, cs.checkpointSetWithTableID, cs.splitSizeBytes, cs.splitKeyCount, cs.splitOnTable, onProgress) require.NoError(t, err) require.Equal(t, cs.splitKeys, splitKeys) require.Equal(t, len(cs.tableIDWithFilesGroups), len(tableIDWithFilesGroups)) @@ -665,8 +667,8 @@ func TestSortAndValidateFileRanges(t *testing.T) { for j, expectFiles := range expectFilesGroup { actualFiles := actualFilesGroup[j] require.Equal(t, expectFiles.TableID, actualFiles.TableID) - for k, expectFile := range expectFiles.Files { - actualFile := actualFiles.Files[k] + for k, expectFile := range expectFiles.SSTFiles { + actualFile := actualFiles.SSTFiles[k] require.Equal(t, expectFile.Name, actualFile.Name) } } diff --git a/br/pkg/restore/split/BUILD.bazel b/br/pkg/restore/split/BUILD.bazel index e27bb1834d7ac..34eb266023809 100644 --- a/br/pkg/restore/split/BUILD.bazel +++ b/br/pkg/restore/split/BUILD.bazel @@ -7,6 +7,7 @@ go_library( "mock_pd_client.go", "region.go", "split.go", + "splitter.go", "sum_sorted.go", ], importpath = "github.com/pingcap/tidb/br/pkg/restore/split", @@ -32,7 +33,6 @@ go_library( "@com_github_google_btree//:btree", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", - "@com_github_pingcap_kvproto//pkg/brpb", "@com_github_pingcap_kvproto//pkg/errorpb", "@com_github_pingcap_kvproto//pkg/kvrpcpb", "@com_github_pingcap_kvproto//pkg/metapb", @@ -63,7 +63,7 @@ go_test( ], embed = [":split"], flaky = True, - shard_count = 27, + shard_count = 26, deps = [ "//br/pkg/errors", "//br/pkg/restore/utils", @@ -75,10 +75,8 @@ go_test( "//pkg/tablecodec", "//pkg/types", "//pkg/util/codec", - "@com_github_docker_go_units//:go-units", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", - "@com_github_pingcap_kvproto//pkg/brpb", "@com_github_pingcap_kvproto//pkg/import_sstpb", "@com_github_pingcap_kvproto//pkg/kvrpcpb", "@com_github_pingcap_kvproto//pkg/metapb", diff --git a/br/pkg/restore/split/mock_pd_client.go b/br/pkg/restore/split/mock_pd_client.go index 6df6cd56f94b0..535b064700d41 100644 --- a/br/pkg/restore/split/mock_pd_client.go +++ b/br/pkg/restore/split/mock_pd_client.go @@ -23,6 +23,8 @@ import ( "google.golang.org/grpc/status" ) +// TODO consilodate TestClient and MockPDClientForSplit and FakePDClient +// into one test client. type TestClient struct { SplitClient pd.Client diff --git a/br/pkg/restore/split/split.go b/br/pkg/restore/split/split.go index f7df83cd2e3e9..726c4b89794fc 100644 --- a/br/pkg/restore/split/split.go +++ b/br/pkg/restore/split/split.go @@ -7,25 +7,18 @@ import ( "context" "encoding/hex" goerrors "errors" - "sort" - "sync" "time" "github.com/pingcap/errors" "github.com/pingcap/failpoint" - backuppb "github.com/pingcap/kvproto/pkg/brpb" "github.com/pingcap/log" berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/logutil" - restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/codec" "github.com/pingcap/tidb/pkg/util/redact" "go.uber.org/zap" - "golang.org/x/sync/errgroup" ) var ( @@ -44,347 +37,6 @@ const ( ScanRegionPaginationLimit = 128 ) -type rewriteSplitter struct { - rewriteKey []byte - tableID int64 - rule *restoreutils.RewriteRules - splitter *SplitHelper -} - -type splitHelperIterator struct { - tableSplitters []*rewriteSplitter -} - -func (iter *splitHelperIterator) Traverse(fn func(v Valued, endKey []byte, rule *restoreutils.RewriteRules) bool) { - for _, entry := range iter.tableSplitters { - endKey := codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(entry.tableID+1)) - rule := entry.rule - entry.splitter.Traverse(func(v Valued) bool { - return fn(v, endKey, rule) - }) - } -} - -type LogSplitHelper struct { - tableSplitter map[int64]*SplitHelper - rules map[int64]*restoreutils.RewriteRules - client SplitClient - pool *util.WorkerPool - eg *errgroup.Group - regionsCh chan []*RegionInfo - - splitThresholdSize uint64 - splitThresholdKeys int64 -} - -func NewLogSplitHelper(rules map[int64]*restoreutils.RewriteRules, client SplitClient, splitSize uint64, splitKeys int64) *LogSplitHelper { - return &LogSplitHelper{ - tableSplitter: make(map[int64]*SplitHelper), - rules: rules, - client: client, - pool: util.NewWorkerPool(128, "split region"), - eg: nil, - - splitThresholdSize: splitSize, - splitThresholdKeys: splitKeys, - } -} - -func (helper *LogSplitHelper) iterator() *splitHelperIterator { - tableSplitters := make([]*rewriteSplitter, 0, len(helper.tableSplitter)) - for tableID, splitter := range helper.tableSplitter { - delete(helper.tableSplitter, tableID) - rewriteRule, exists := helper.rules[tableID] - if !exists { - log.Info("skip splitting due to no table id matched", zap.Int64("tableID", tableID)) - continue - } - newTableID := restoreutils.GetRewriteTableID(tableID, rewriteRule) - if newTableID == 0 { - log.Warn("failed to get the rewrite table id", zap.Int64("tableID", tableID)) - continue - } - tableSplitters = append(tableSplitters, &rewriteSplitter{ - rewriteKey: codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(newTableID)), - tableID: newTableID, - rule: rewriteRule, - splitter: splitter, - }) - } - sort.Slice(tableSplitters, func(i, j int) bool { - return bytes.Compare(tableSplitters[i].rewriteKey, tableSplitters[j].rewriteKey) < 0 - }) - return &splitHelperIterator{ - tableSplitters: tableSplitters, - } -} - -const splitFileThreshold = 1024 * 1024 // 1 MB - -func (helper *LogSplitHelper) skipFile(file *backuppb.DataFileInfo) bool { - _, exist := helper.rules[file.TableId] - return file.Length < splitFileThreshold || file.IsMeta || !exist -} - -func (helper *LogSplitHelper) Merge(file *backuppb.DataFileInfo) { - if helper.skipFile(file) { - return - } - splitHelper, exist := helper.tableSplitter[file.TableId] - if !exist { - splitHelper = NewSplitHelper() - helper.tableSplitter[file.TableId] = splitHelper - } - - splitHelper.Merge(Valued{ - Key: Span{ - StartKey: file.StartKey, - EndKey: file.EndKey, - }, - Value: Value{ - Size: file.Length, - Number: file.NumberOfEntries, - }, - }) -} - -type splitFunc = func(context.Context, *RegionSplitter, uint64, int64, *RegionInfo, []Valued) error - -func (helper *LogSplitHelper) splitRegionByPoints( - ctx context.Context, - regionSplitter *RegionSplitter, - initialLength uint64, - initialNumber int64, - region *RegionInfo, - valueds []Valued, -) error { - var ( - splitPoints [][]byte = make([][]byte, 0) - lastKey []byte = region.Region.StartKey - length uint64 = initialLength - number int64 = initialNumber - ) - for _, v := range valueds { - // decode will discard ts behind the key, which results in the same key for consecutive ranges - if !bytes.Equal(lastKey, v.GetStartKey()) && (v.Value.Size+length > helper.splitThresholdSize || v.Value.Number+number > helper.splitThresholdKeys) { - _, rawKey, _ := codec.DecodeBytes(v.GetStartKey(), nil) - splitPoints = append(splitPoints, rawKey) - length = 0 - number = 0 - } - lastKey = v.GetStartKey() - length += v.Value.Size - number += v.Value.Number - } - - if len(splitPoints) == 0 { - return nil - } - - helper.pool.ApplyOnErrorGroup(helper.eg, func() error { - newRegions, errSplit := regionSplitter.ExecuteOneRegion(ctx, region, splitPoints) - if errSplit != nil { - log.Warn("failed to split the scaned region", zap.Error(errSplit)) - sort.Slice(splitPoints, func(i, j int) bool { - return bytes.Compare(splitPoints[i], splitPoints[j]) < 0 - }) - return regionSplitter.ExecuteSortedKeys(ctx, splitPoints) - } - select { - case <-ctx.Done(): - return nil - case helper.regionsCh <- newRegions: - } - log.Info("split the region", zap.Uint64("region-id", region.Region.Id), zap.Int("split-point-number", len(splitPoints))) - return nil - }) - return nil -} - -// SplitPoint selects ranges overlapped with each region, and calls `splitF` to split the region -func SplitPoint( - ctx context.Context, - iter *splitHelperIterator, - client SplitClient, - splitF splitFunc, -) (err error) { - // common status - var ( - regionSplitter *RegionSplitter = NewRegionSplitter(client) - ) - // region traverse status - var ( - // the region buffer of each scan - regions []*RegionInfo = nil - regionIndex int = 0 - ) - // region split status - var ( - // range span +----------------+------+---+-------------+ - // region span +------------------------------------+ - // +initial length+ +end valued+ - // regionValueds is the ranges array overlapped with `regionInfo` - regionValueds []Valued = nil - // regionInfo is the region to be split - regionInfo *RegionInfo = nil - // intialLength is the length of the part of the first range overlapped with the region - initialLength uint64 = 0 - initialNumber int64 = 0 - ) - // range status - var ( - // regionOverCount is the number of regions overlapped with the range - regionOverCount uint64 = 0 - ) - - iter.Traverse(func(v Valued, endKey []byte, rule *restoreutils.RewriteRules) bool { - if v.Value.Number == 0 || v.Value.Size == 0 { - return true - } - var ( - vStartKey []byte - vEndKey []byte - ) - // use `vStartKey` and `vEndKey` to compare with region's key - vStartKey, vEndKey, err = restoreutils.GetRewriteEncodedKeys(v, rule) - if err != nil { - return false - } - // traverse to the first region overlapped with the range - for ; regionIndex < len(regions); regionIndex++ { - if bytes.Compare(vStartKey, regions[regionIndex].Region.EndKey) < 0 { - break - } - } - // cannot find any regions overlapped with the range - // need to scan regions again - if regionIndex == len(regions) { - regions = nil - } - regionOverCount = 0 - for { - if regionIndex >= len(regions) { - var startKey []byte - if len(regions) > 0 { - // has traversed over the region buffer, should scan from the last region's end-key of the region buffer - startKey = regions[len(regions)-1].Region.EndKey - } else { - // scan from the range's start-key - startKey = vStartKey - } - // scan at most 64 regions into the region buffer - regions, err = ScanRegionsWithRetry(ctx, client, startKey, endKey, 64) - if err != nil { - return false - } - regionIndex = 0 - } - - region := regions[regionIndex] - // this region must be overlapped with the range - regionOverCount++ - // the region is the last one overlapped with the range, - // should split the last recorded region, - // and then record this region as the region to be split - if bytes.Compare(vEndKey, region.Region.EndKey) < 0 { - endLength := v.Value.Size / regionOverCount - endNumber := v.Value.Number / int64(regionOverCount) - if len(regionValueds) > 0 && regionInfo != region { - // add a part of the range as the end part - if bytes.Compare(vStartKey, regionInfo.Region.EndKey) < 0 { - regionValueds = append(regionValueds, NewValued(vStartKey, regionInfo.Region.EndKey, Value{Size: endLength, Number: endNumber})) - } - // try to split the region - err = splitF(ctx, regionSplitter, initialLength, initialNumber, regionInfo, regionValueds) - if err != nil { - return false - } - regionValueds = make([]Valued, 0) - } - if regionOverCount == 1 { - // the region completely contains the range - regionValueds = append(regionValueds, Valued{ - Key: Span{ - StartKey: vStartKey, - EndKey: vEndKey, - }, - Value: v.Value, - }) - } else { - // the region is overlapped with the last part of the range - initialLength = endLength - initialNumber = endNumber - } - regionInfo = region - // try the next range - return true - } - - // try the next region - regionIndex++ - } - }) - - if err != nil { - return errors.Trace(err) - } - if len(regionValueds) > 0 { - // try to split the region - err = splitF(ctx, regionSplitter, initialLength, initialNumber, regionInfo, regionValueds) - if err != nil { - return errors.Trace(err) - } - } - - return nil -} - -func (helper *LogSplitHelper) Split(ctx context.Context) error { - var ectx context.Context - var wg sync.WaitGroup - helper.eg, ectx = errgroup.WithContext(ctx) - helper.regionsCh = make(chan []*RegionInfo, 1024) - wg.Add(1) - go func() { - defer wg.Done() - scatterRegions := make([]*RegionInfo, 0) - receiveNewRegions: - for { - select { - case <-ctx.Done(): - return - case newRegions, ok := <-helper.regionsCh: - if !ok { - break receiveNewRegions - } - - scatterRegions = append(scatterRegions, newRegions...) - } - } - - regionSplitter := NewRegionSplitter(helper.client) - // It is too expensive to stop recovery and wait for a small number of regions - // to complete scatter, so the maximum waiting time is reduced to 1 minute. - _ = regionSplitter.WaitForScatterRegionsTimeout(ctx, scatterRegions, time.Minute) - }() - - iter := helper.iterator() - if err := SplitPoint(ectx, iter, helper.client, helper.splitRegionByPoints); err != nil { - return errors.Trace(err) - } - - // wait for completion of splitting regions - if err := helper.eg.Wait(); err != nil { - return errors.Trace(err) - } - - // wait for completion of scattering regions - close(helper.regionsCh) - wg.Wait() - - return nil -} - // RegionSplitter is a executor of region split by rules. type RegionSplitter struct { client SplitClient @@ -397,8 +49,8 @@ func NewRegionSplitter(client SplitClient) *RegionSplitter { } } -// ExecuteOneRegion expose the function `SplitWaitAndScatter` of split client. -func (rs *RegionSplitter) ExecuteOneRegion(ctx context.Context, region *RegionInfo, keys [][]byte) ([]*RegionInfo, error) { +// ExecuteSortedKeysOnRegion expose the function `SplitWaitAndScatter` of split client. +func (rs *RegionSplitter) ExecuteSortedKeysOnRegion(ctx context.Context, region *RegionInfo, keys [][]byte) ([]*RegionInfo, error) { return rs.client.SplitWaitAndScatter(ctx, region, keys) } diff --git a/br/pkg/restore/split/split_test.go b/br/pkg/restore/split/split_test.go index 4acef5dac84b7..ee53cce560187 100644 --- a/br/pkg/restore/split/split_test.go +++ b/br/pkg/restore/split/split_test.go @@ -11,10 +11,8 @@ import ( "testing" "time" - "github.com/docker/go-units" "github.com/pingcap/errors" "github.com/pingcap/failpoint" - backuppb "github.com/pingcap/kvproto/pkg/brpb" "github.com/pingcap/kvproto/pkg/import_sstpb" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" @@ -936,8 +934,8 @@ func TestSplitPoint(t *testing.T) { client.AppendRegion(keyWithTablePrefix(tableID, "h"), keyWithTablePrefix(tableID, "j")) client.AppendRegion(keyWithTablePrefix(tableID, "j"), keyWithTablePrefix(tableID+1, "a")) - iter := NewSplitHelperIteratorForTest(splitHelper, tableID, rewriteRules) - err := SplitPoint(ctx, iter, client, func(ctx context.Context, rs *RegionSplitter, u uint64, o int64, ri *RegionInfo, v []Valued) error { + iter := NewSplitHelperIterator([]*RewriteSplitter{{tableID: tableID, rule: rewriteRules, splitter: splitHelper}}) + err := SplitPoint(ctx, iter, client, func(ctx context.Context, u uint64, o int64, ri *RegionInfo, v []Valued) error { require.Equal(t, u, uint64(0)) require.Equal(t, o, int64(0)) require.Equal(t, ri.Region.StartKey, keyWithTablePrefix(tableID, "a")) @@ -994,8 +992,8 @@ func TestSplitPoint2(t *testing.T) { client.AppendRegion(keyWithTablePrefix(tableID, "o"), keyWithTablePrefix(tableID+1, "a")) firstSplit := true - iter := NewSplitHelperIteratorForTest(splitHelper, tableID, rewriteRules) - err := SplitPoint(ctx, iter, client, func(ctx context.Context, rs *RegionSplitter, u uint64, o int64, ri *RegionInfo, v []Valued) error { + iter := NewSplitHelperIterator([]*RewriteSplitter{{tableID: tableID, rule: rewriteRules, splitter: splitHelper}}) + err := SplitPoint(ctx, iter, client, func(ctx context.Context, u uint64, o int64, ri *RegionInfo, v []Valued) error { if firstSplit { require.Equal(t, u, uint64(0)) require.Equal(t, o, int64(0)) @@ -1028,86 +1026,3 @@ func TestSplitPoint2(t *testing.T) { }) require.NoError(t, err) } - -func fakeFile(tableID, rowID int64, length uint64, num int64) *backuppb.DataFileInfo { - return &backuppb.DataFileInfo{ - StartKey: fakeRowKey(tableID, rowID), - EndKey: fakeRowKey(tableID, rowID+1), - TableId: tableID, - Length: length, - NumberOfEntries: num, - } -} - -func fakeRowKey(tableID, rowID int64) kv.Key { - return codec.EncodeBytes(nil, tablecodec.EncodeRecordKey(tablecodec.GenTableRecordPrefix(tableID), kv.IntHandle(rowID))) -} - -func TestLogSplitHelper(t *testing.T) { - ctx := context.Background() - rules := map[int64]*restoreutils.RewriteRules{ - 1: { - Data: []*import_sstpb.RewriteRule{ - { - OldKeyPrefix: tablecodec.GenTableRecordPrefix(1), - NewKeyPrefix: tablecodec.GenTableRecordPrefix(100), - }, - }, - }, - 2: { - Data: []*import_sstpb.RewriteRule{ - { - OldKeyPrefix: tablecodec.GenTableRecordPrefix(2), - NewKeyPrefix: tablecodec.GenTableRecordPrefix(200), - }, - }, - }, - } - oriRegions := [][]byte{ - {}, - codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(100)), - codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(200)), - codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(402)), - } - mockPDCli := NewMockPDClientForSplit() - mockPDCli.SetRegions(oriRegions) - client := NewClient(mockPDCli, nil, nil, 100, 4) - helper := NewLogSplitHelper(rules, client, 4*units.MiB, 400) - - helper.Merge(fakeFile(1, 100, 100, 100)) - helper.Merge(fakeFile(1, 200, 2*units.MiB, 200)) - helper.Merge(fakeFile(2, 100, 3*units.MiB, 300)) - helper.Merge(fakeFile(3, 100, 10*units.MiB, 100000)) - // different regions, no split happens - err := helper.Split(ctx) - require.NoError(t, err) - regions, err := mockPDCli.ScanRegions(ctx, []byte{}, []byte{}, 0) - require.NoError(t, err) - require.Len(t, regions, 3) - require.Equal(t, []byte{}, regions[0].Meta.StartKey) - require.Equal(t, codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(100)), regions[1].Meta.StartKey) - require.Equal(t, codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(200)), regions[2].Meta.StartKey) - require.Equal(t, codec.EncodeBytes(nil, tablecodec.EncodeTablePrefix(402)), regions[2].Meta.EndKey) - - helper.Merge(fakeFile(1, 300, 3*units.MiB, 10)) - helper.Merge(fakeFile(1, 400, 4*units.MiB, 10)) - // trigger to split regions for table 1 - err = helper.Split(ctx) - require.NoError(t, err) - regions, err = mockPDCli.ScanRegions(ctx, []byte{}, []byte{}, 0) - require.NoError(t, err) - require.Len(t, regions, 4) - require.Equal(t, fakeRowKey(100, 400), kv.Key(regions[1].Meta.EndKey)) -} - -func NewSplitHelperIteratorForTest(helper *SplitHelper, tableID int64, rule *restoreutils.RewriteRules) *splitHelperIterator { - return &splitHelperIterator{ - tableSplitters: []*rewriteSplitter{ - { - tableID: tableID, - rule: rule, - splitter: helper, - }, - }, - } -} diff --git a/br/pkg/restore/split/splitter.go b/br/pkg/restore/split/splitter.go new file mode 100644 index 0000000000000..5faec981d4224 --- /dev/null +++ b/br/pkg/restore/split/splitter.go @@ -0,0 +1,400 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package split + +import ( + "bytes" + "context" + "sort" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/log" + restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/codec" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +// Splitter defines the interface for basic splitting strategies. +type Splitter interface { + // ExecuteSortedKeysOnRegion splits the keys within a single region and initiates scattering + // after the region has been successfully split. + ExecuteSortedKeysOnRegion(ctx context.Context, region *RegionInfo, keys [][]byte) ([]*RegionInfo, error) + + // ExecuteSortedKeys splits all provided keys while ensuring that the newly created + // regions are balanced. + ExecuteSortedKeys(ctx context.Context, keys [][]byte) error + + // WaitForScatterRegionsTimeout blocks until all regions have finished scattering, + // or until the specified timeout duration has elapsed. + WaitForScatterRegionsTimeout(ctx context.Context, regionInfos []*RegionInfo, timeout time.Duration) int +} + +// SplitStrategy defines how values should be accumulated and when to trigger a split. +type SplitStrategy[T any] interface { + // Accumulate adds a new value into the split strategy's internal state. + // This method accumulates data points or values, preparing them for potential splitting. + Accumulate(T) + // ShouldSplit checks if the accumulated values meet the criteria for triggering a split. + ShouldSplit() bool + // Skip the file by checkpoints or invalid files + ShouldSkip(T) bool + // GetAccumulations returns an iterator for the accumulated values. + GetAccumulations() *SplitHelperIterator + // Reset the buffer for next round + ResetAccumulations() +} + +type BaseSplitStrategy struct { + AccumulateCount int + TableSplitter map[int64]*SplitHelper + Rules map[int64]*restoreutils.RewriteRules +} + +func NewBaseSplitStrategy(rules map[int64]*restoreutils.RewriteRules) *BaseSplitStrategy { + return &BaseSplitStrategy{ + AccumulateCount: 0, + TableSplitter: make(map[int64]*SplitHelper), + Rules: rules, + } +} + +func (b *BaseSplitStrategy) GetAccumulations() *SplitHelperIterator { + tableSplitters := make([]*RewriteSplitter, 0, len(b.TableSplitter)) + for tableID, splitter := range b.TableSplitter { + rewriteRule, exists := b.Rules[tableID] + if !exists { + log.Fatal("[unreachable] no table id matched", zap.Int64("tableID", tableID)) + } + newTableID := restoreutils.GetRewriteTableID(tableID, rewriteRule) + if newTableID == 0 { + log.Warn("failed to get the rewrite table id", zap.Int64("tableID", tableID)) + continue + } + tableSplitters = append(tableSplitters, NewRewriteSpliter( + // TODO remove this field. sort by newTableID + codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(newTableID)), + newTableID, + rewriteRule, + splitter, + )) + } + sort.Slice(tableSplitters, func(i, j int) bool { + return bytes.Compare(tableSplitters[i].RewriteKey, tableSplitters[j].RewriteKey) < 0 + }) + return NewSplitHelperIterator(tableSplitters) +} + +func (b *BaseSplitStrategy) ResetAccumulations() { + // always assume all previous files has been processed + // so we should reset after handling one batch accumulations + log.Info("reset accumulations") + clear(b.TableSplitter) + b.AccumulateCount = 0 +} + +type RewriteSplitter struct { + RewriteKey []byte + tableID int64 + rule *restoreutils.RewriteRules + splitter *SplitHelper +} + +func NewRewriteSpliter( + rewriteKey []byte, + tableID int64, + rule *restoreutils.RewriteRules, + splitter *SplitHelper, +) *RewriteSplitter { + return &RewriteSplitter{ + RewriteKey: rewriteKey, + tableID: tableID, + rule: rule, + splitter: splitter, + } +} + +type SplitHelperIterator struct { + tableSplitters []*RewriteSplitter +} + +func NewSplitHelperIterator(tableSplitters []*RewriteSplitter) *SplitHelperIterator { + return &SplitHelperIterator{tableSplitters: tableSplitters} +} + +func (iter *SplitHelperIterator) Traverse(fn func(v Valued, endKey []byte, rule *restoreutils.RewriteRules) bool) { + for _, entry := range iter.tableSplitters { + endKey := codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(entry.tableID+1)) + rule := entry.rule + entry.splitter.Traverse(func(v Valued) bool { + return fn(v, endKey, rule) + }) + } +} + +// PipelineRegionsSplitter defines the interface for advanced (pipeline) splitting strategies. +// log / compacted sst files restore need to use this to split after full restore. +// and the splitter must perform with a control. +// so we choose to split and restore in a continuous flow. +type PipelineRegionsSplitter interface { + Splitter + ExecuteRegions(ctx context.Context, splitHelper *SplitHelperIterator) error // Method for executing pipeline-based splitting +} + +type PipelineRegionsSplitterImpl struct { + *RegionSplitter + pool *util.WorkerPool + splitThresholdSize uint64 + splitThresholdKeys int64 + + eg *errgroup.Group + regionsCh chan []*RegionInfo +} + +func NewPipelineRegionsSplitter( + client SplitClient, + splitSize uint64, + splitKeys int64, +) PipelineRegionsSplitter { + pool := util.NewWorkerPool(128, "split") + return &PipelineRegionsSplitterImpl{ + pool: pool, + RegionSplitter: NewRegionSplitter(client), + splitThresholdSize: splitSize, + splitThresholdKeys: splitKeys, + } +} + +func (r *PipelineRegionsSplitterImpl) ExecuteRegions(ctx context.Context, splitHelper *SplitHelperIterator) error { + var ectx context.Context + var wg sync.WaitGroup + r.eg, ectx = errgroup.WithContext(ctx) + r.regionsCh = make(chan []*RegionInfo, 1024) + wg.Add(1) + go func() { + defer wg.Done() + scatterRegions := make([]*RegionInfo, 0) + receiveNewRegions: + for { + select { + case <-ctx.Done(): + return + case newRegions, ok := <-r.regionsCh: + if !ok { + break receiveNewRegions + } + + scatterRegions = append(scatterRegions, newRegions...) + } + } + // It is too expensive to stop recovery and wait for a small number of regions + // to complete scatter, so the maximum waiting time is reduced to 1 minute. + _ = r.WaitForScatterRegionsTimeout(ectx, scatterRegions, time.Minute) + }() + + err := SplitPoint(ectx, splitHelper, r.client, r.splitRegionByPoints) + if err != nil { + return errors.Trace(err) + } + + // wait for completion of splitting regions + if err := r.eg.Wait(); err != nil { + return errors.Trace(err) + } + + // wait for completion of scattering regions + close(r.regionsCh) + wg.Wait() + + return nil +} + +type splitFunc = func(context.Context, uint64, int64, *RegionInfo, []Valued) error + +// SplitPoint selects ranges overlapped with each region, and calls `splitF` to split the region +func SplitPoint( + ctx context.Context, + iter *SplitHelperIterator, + client SplitClient, + splitF splitFunc, +) (err error) { + // region traverse status + var ( + // the region buffer of each scan + regions []*RegionInfo = nil + regionIndex int = 0 + ) + // region split status + var ( + // range span +----------------+------+---+-------------+ + // region span +------------------------------------+ + // +initial length+ +end valued+ + // regionValueds is the ranges array overlapped with `regionInfo` + regionValueds []Valued = nil + // regionInfo is the region to be split + regionInfo *RegionInfo = nil + // initialLength is the length of the part of the first range overlapped with the region + initialLength uint64 = 0 + initialNumber int64 = 0 + ) + // range status + var ( + // regionOverCount is the number of regions overlapped with the range + regionOverCount uint64 = 0 + ) + + iter.Traverse(func(v Valued, endKey []byte, rule *restoreutils.RewriteRules) bool { + if v.Value.Number == 0 || v.Value.Size == 0 { + return true + } + var ( + vStartKey []byte + vEndKey []byte + ) + // use `vStartKey` and `vEndKey` to compare with region's key + vStartKey, vEndKey, err = restoreutils.GetRewriteEncodedKeys(v, rule) + if err != nil { + return false + } + // traverse to the first region overlapped with the range + for ; regionIndex < len(regions); regionIndex++ { + if bytes.Compare(vStartKey, regions[regionIndex].Region.EndKey) < 0 { + break + } + } + // cannot find any regions overlapped with the range + // need to scan regions again + if regionIndex == len(regions) { + regions = nil + } + regionOverCount = 0 + for { + if regionIndex >= len(regions) { + var startKey []byte + if len(regions) > 0 { + // has traversed over the region buffer, should scan from the last region's end-key of the region buffer + startKey = regions[len(regions)-1].Region.EndKey + } else { + // scan from the range's start-key + startKey = vStartKey + } + // scan at most 64 regions into the region buffer + regions, err = ScanRegionsWithRetry(ctx, client, startKey, endKey, 64) + if err != nil { + return false + } + regionIndex = 0 + } + + region := regions[regionIndex] + // this region must be overlapped with the range + regionOverCount++ + // the region is the last one overlapped with the range, + // should split the last recorded region, + // and then record this region as the region to be split + if bytes.Compare(vEndKey, region.Region.EndKey) < 0 { + endLength := v.Value.Size / regionOverCount + endNumber := v.Value.Number / int64(regionOverCount) + if len(regionValueds) > 0 && regionInfo != region { + // add a part of the range as the end part + if bytes.Compare(vStartKey, regionInfo.Region.EndKey) < 0 { + regionValueds = append(regionValueds, NewValued(vStartKey, regionInfo.Region.EndKey, Value{Size: endLength, Number: endNumber})) + } + // try to split the region + err = splitF(ctx, initialLength, initialNumber, regionInfo, regionValueds) + if err != nil { + return false + } + regionValueds = make([]Valued, 0) + } + if regionOverCount == 1 { + // the region completely contains the range + regionValueds = append(regionValueds, Valued{ + Key: Span{ + StartKey: vStartKey, + EndKey: vEndKey, + }, + Value: v.Value, + }) + } else { + // the region is overlapped with the last part of the range + initialLength = endLength + initialNumber = endNumber + } + regionInfo = region + // try the next range + return true + } + + // try the next region + regionIndex++ + } + }) + + if err != nil { + return errors.Trace(err) + } + if len(regionValueds) > 0 { + // try to split the region + err = splitF(ctx, initialLength, initialNumber, regionInfo, regionValueds) + if err != nil { + return errors.Trace(err) + } + } + + return nil +} + +func (r *PipelineRegionsSplitterImpl) splitRegionByPoints( + ctx context.Context, + initialLength uint64, + initialNumber int64, + region *RegionInfo, + valueds []Valued, +) error { + var ( + splitPoints [][]byte = make([][]byte, 0) + lastKey []byte = region.Region.StartKey + length uint64 = initialLength + number int64 = initialNumber + ) + for _, v := range valueds { + // decode will discard ts behind the key, which results in the same key for consecutive ranges + if !bytes.Equal(lastKey, v.GetStartKey()) && (v.Value.Size+length > r.splitThresholdSize || v.Value.Number+number > r.splitThresholdKeys) { + _, rawKey, _ := codec.DecodeBytes(v.GetStartKey(), nil) + splitPoints = append(splitPoints, rawKey) + length = 0 + number = 0 + } + lastKey = v.GetStartKey() + length += v.Value.Size + number += v.Value.Number + } + + if len(splitPoints) == 0 { + return nil + } + + r.pool.ApplyOnErrorGroup(r.eg, func() error { + newRegions, errSplit := r.ExecuteSortedKeysOnRegion(ctx, region, splitPoints) + if errSplit != nil { + log.Warn("failed to split the scaned region", zap.Error(errSplit)) + sort.Slice(splitPoints, func(i, j int) bool { + return bytes.Compare(splitPoints[i], splitPoints[j]) < 0 + }) + return r.ExecuteSortedKeys(ctx, splitPoints) + } + select { + case <-ctx.Done(): + return nil + case r.regionsCh <- newRegions: + } + log.Info("split the region", zap.Uint64("region-id", region.Region.Id), zap.Int("split-point-number", len(splitPoints))) + return nil + }) + return nil +} diff --git a/br/pkg/restore/utils/rewrite_rule.go b/br/pkg/restore/utils/rewrite_rule.go index eca06a58bee6a..a664d97a5f11d 100644 --- a/br/pkg/restore/utils/rewrite_rule.go +++ b/br/pkg/restore/utils/rewrite_rule.go @@ -16,6 +16,7 @@ package utils import ( "bytes" + "strings" "github.com/pingcap/errors" backuppb "github.com/pingcap/kvproto/pkg/brpb" @@ -44,6 +45,8 @@ type RewriteRules struct { Data []*import_sstpb.RewriteRule OldKeyspace []byte NewKeyspace []byte + // used to record checkpoint data + NewTableID int64 } // Append append its argument to this rewrite rules. @@ -171,7 +174,7 @@ func GetRewriteRuleOfTable( }) } - return &RewriteRules{Data: dataRules} + return &RewriteRules{Data: dataRules, NewTableID: newTableID} } // ValidateFileRewriteRule uses rewrite rules to validate the ranges of a file. @@ -282,6 +285,26 @@ func FindMatchedRewriteRule(file AppliedFile, rules *RewriteRules) *import_sstpb return rule } +func (r *RewriteRules) String() string { + var out strings.Builder + out.WriteRune('[') + if len(r.OldKeyspace) != 0 { + out.WriteString(redact.Key(r.OldKeyspace)) + out.WriteString(" =[ks]=> ") + out.WriteString(redact.Key(r.NewKeyspace)) + } + for i, d := range r.Data { + if i > 0 { + out.WriteString(",") + } + out.WriteString(redact.Key(d.OldKeyPrefix)) + out.WriteString(" => ") + out.WriteString(redact.Key(d.NewKeyPrefix)) + } + out.WriteRune(']') + return out.String() +} + // GetRewriteRawKeys rewrites rules to the raw key. func GetRewriteRawKeys(file AppliedFile, rewriteRules *RewriteRules) (startKey, endKey []byte, err error) { startID := tablecodec.DecodeTableID(file.GetStartKey()) @@ -290,7 +313,7 @@ func GetRewriteRawKeys(file AppliedFile, rewriteRules *RewriteRules) (startKey, if startID == endID { startKey, rule = rewriteRawKey(file.GetStartKey(), rewriteRules) if rewriteRules != nil && rule == nil { - err = errors.Annotatef(berrors.ErrRestoreInvalidRewrite, "cannot find raw rewrite rule for start key, startKey: %s", redact.Key(file.GetStartKey())) + err = errors.Annotatef(berrors.ErrRestoreInvalidRewrite, "cannot find raw rewrite rule for start key, startKey: %s; self = %s", redact.Key(file.GetStartKey()), rewriteRules) return } endKey, rule = rewriteRawKey(file.GetEndKey(), rewriteRules) @@ -364,7 +387,7 @@ func RewriteRange(rg *rtree.Range, rewriteRules *RewriteRules) (*rtree.Range, er } rg.StartKey, rule = replacePrefix(rg.StartKey, rewriteRules) if rule == nil { - log.Warn("cannot find rewrite rule", logutil.Key("key", rg.StartKey)) + log.Warn("cannot find rewrite rule", logutil.Key("start key", rg.StartKey)) } else { log.Debug( "rewrite start key", @@ -373,7 +396,7 @@ func RewriteRange(rg *rtree.Range, rewriteRules *RewriteRules) (*rtree.Range, er oldKey := rg.EndKey rg.EndKey, rule = replacePrefix(rg.EndKey, rewriteRules) if rule == nil { - log.Warn("cannot find rewrite rule", logutil.Key("key", rg.EndKey)) + log.Warn("cannot find rewrite rule", logutil.Key("end key", rg.EndKey)) } else { log.Debug( "rewrite end key", diff --git a/br/pkg/stream/stream_metas.go b/br/pkg/stream/stream_metas.go index b51923a9638c1..14c65da097472 100644 --- a/br/pkg/stream/stream_metas.go +++ b/br/pkg/stream/stream_metas.go @@ -652,6 +652,17 @@ func (migs Migrations) MergeToBy(seq int, merge func(m1, m2 *pb.Migration) *pb.M return newBase } +// ListAll returns a slice of all migrations in protobuf format. +// This includes the base migration and any additional layers. +func (migs Migrations) ListAll() []*pb.Migration { + pbMigs := make([]*pb.Migration, 0, len(migs.Layers)+1) + pbMigs = append(pbMigs, migs.Base) + for _, m := range migs.Layers { + pbMigs = append(pbMigs, &m.Content) + } + return pbMigs +} + type mergeAndMigrateToConfig struct { interactiveCheck func(context.Context, *pb.Migration) bool alwaysRunTruncate bool diff --git a/br/pkg/task/restore.go b/br/pkg/task/restore.go index 77893aaa913ec..90f85cf15a7c7 100644 --- a/br/pkg/task/restore.go +++ b/br/pkg/task/restore.go @@ -100,6 +100,8 @@ const ( defaultStatsConcurrency = 12 defaultBatchFlushInterval = 16 * time.Second defaultFlagDdlBatchSize = 128 + maxRestoreBatchSizeLimit = 10240 + pb = 1024 * 1024 * 1024 * 1024 * 1024 resetSpeedLimitRetryTimes = 3 ) @@ -532,6 +534,10 @@ func (cfg *RestoreConfig) adjustRestoreConfigForStreamRestore() { cfg.PitrConcurrency += 1 log.Info("set restore kv files concurrency", zap.Int("concurrency", int(cfg.PitrConcurrency))) cfg.Config.Concurrency = cfg.PitrConcurrency + if cfg.ConcurrencyPerStore.Value > 0 { + log.Info("set restore compacted sst files concurrency per store", + zap.Int("concurrency", int(cfg.ConcurrencyPerStore.Value))) + } } func configureRestoreClient(ctx context.Context, client *snapclient.SnapClient, cfg *RestoreConfig) error { @@ -765,8 +771,10 @@ func runSnapshotRestore(c context.Context, mgr *conn.Mgr, g glue.Glue, cmdName s keepaliveCfg := GetKeepalive(&cfg.Config) keepaliveCfg.PermitWithoutStream = true client := snapclient.NewRestoreClient(mgr.GetPDClient(), mgr.GetPDHTTPClient(), mgr.GetTLSConfig(), keepaliveCfg) + // set to cfg so that restoreStream can use it. + cfg.ConcurrencyPerStore = kvConfigs.ImportGoroutines // using tikv config to set the concurrency-per-store for client. - client.SetConcurrencyPerStore(kvConfigs.ImportGoroutines.Value) + client.SetConcurrencyPerStore(cfg.ConcurrencyPerStore.Value) err := configureRestoreClient(ctx, client, cfg) if err != nil { return errors.Trace(err) @@ -802,7 +810,7 @@ func runSnapshotRestore(c context.Context, mgr *conn.Mgr, g glue.Glue, cmdName s } reader := metautil.NewMetaReader(backupMeta, s, &cfg.CipherInfo) - if err = client.LoadSchemaIfNeededAndInitClient(c, backupMeta, u, reader, cfg.LoadStats); err != nil { + if err = client.LoadSchemaIfNeededAndInitClient(c, backupMeta, u, reader, cfg.LoadStats, nil, nil); err != nil { return errors.Trace(err) } @@ -1093,6 +1101,13 @@ func runSnapshotRestore(c context.Context, mgr *conn.Mgr, g glue.Glue, cmdName s if err != nil { return errors.Trace(err) } + onProgress := func(n int64) { + if n == 0 { + updateCh.Inc() + return + } + updateCh.IncBy(n) + } if err := client.RestoreTables(ctx, placementRuleManager, createdTables, files, checkpointSetWithTableID, kvConfigs.MergeRegionSize.Value, kvConfigs.MergeRegionKeyCount.Value, // If the command is from BR binary, the ddl.EnableSplitTableRegion is always 0, @@ -1100,7 +1115,7 @@ func runSnapshotRestore(c context.Context, mgr *conn.Mgr, g glue.Glue, cmdName s // Notice that `split-region-on-table` configure from TiKV split on the region having data, it may trigger after restore done. // It's recommended to enable TiDB configure `split-table` instead. atomic.LoadUint32(&ddl.EnableSplitTableRegion) == 1, - updateCh, + onProgress, ); err != nil { return errors.Trace(err) } @@ -1126,25 +1141,6 @@ func runSnapshotRestore(c context.Context, mgr *conn.Mgr, g glue.Glue, cmdName s finish := dropToBlackhole(ctx, postHandleCh, errCh) - // Reset speed limit. ResetSpeedLimit must be called after client.LoadSchemaIfNeededAndInitClient has been called. - defer func() { - var resetErr error - // In future we may need a mechanism to set speed limit in ttl. like what we do in switchmode. TODO - for retry := 0; retry < resetSpeedLimitRetryTimes; retry++ { - resetErr = client.ResetSpeedLimit(ctx) - if resetErr != nil { - log.Warn("failed to reset speed limit, retry it", - zap.Int("retry time", retry), logutil.ShortError(resetErr)) - time.Sleep(time.Duration(retry+3) * time.Second) - continue - } - break - } - if resetErr != nil { - log.Error("failed to reset speed limit, please reset it manually", zap.Error(resetErr)) - } - }() - select { case err = <-errCh: err = multierr.Append(err, multierr.Combine(Exhaust(errCh)...)) diff --git a/br/pkg/task/restore_raw.go b/br/pkg/task/restore_raw.go index cbcda5679fb57..acb2e48041e64 100644 --- a/br/pkg/task/restore_raw.go +++ b/br/pkg/task/restore_raw.go @@ -4,6 +4,7 @@ package task import ( "context" + "time" "github.com/pingcap/errors" "github.com/pingcap/log" @@ -12,6 +13,7 @@ import ( berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/glue" "github.com/pingcap/tidb/br/pkg/httputil" + "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/br/pkg/metautil" "github.com/pingcap/tidb/br/pkg/restore" snapclient "github.com/pingcap/tidb/br/pkg/restore/snap_client" @@ -20,6 +22,7 @@ import ( "github.com/pingcap/tidb/br/pkg/summary" "github.com/spf13/cobra" "github.com/spf13/pflag" + "go.uber.org/zap" ) // RestoreRawConfig is the configuration specific for raw kv restore tasks. @@ -109,7 +112,7 @@ func RunRestoreRaw(c context.Context, g glue.Glue, cmdName string, cfg *RestoreR return errors.Trace(err) } reader := metautil.NewMetaReader(backupMeta, s, &cfg.CipherInfo) - if err = client.LoadSchemaIfNeededAndInitClient(c, backupMeta, u, reader, true); err != nil { + if err = client.LoadSchemaIfNeededAndInitClient(c, backupMeta, u, reader, true, cfg.StartKey, cfg.EndKey); err != nil { return errors.Trace(err) } @@ -145,8 +148,9 @@ func RunRestoreRaw(c context.Context, g glue.Glue, cmdName string, cfg *RestoreR int64(len(ranges)+len(files)), !cfg.LogProgress) + onProgress := func(i int64) { updateCh.IncBy(i) } // RawKV restore does not need to rewrite keys. - err = client.SplitPoints(ctx, getEndKeys(ranges), updateCh, true) + err = client.SplitPoints(ctx, getEndKeys(ranges), onProgress, true) if err != nil { return errors.Trace(err) } @@ -158,10 +162,20 @@ func RunRestoreRaw(c context.Context, g glue.Glue, cmdName string, cfg *RestoreR } defer restore.RestorePostWork(ctx, importModeSwitcher, restoreSchedulers, cfg.Online) - err = client.RestoreRaw(ctx, cfg.StartKey, cfg.EndKey, files, updateCh) + start := time.Now() + err = client.GetRestorer().GoRestore(onProgress, restore.CreateUniqueFileSets(files)) if err != nil { return errors.Trace(err) } + err = client.GetRestorer().WaitUntilFinish() + if err != nil { + return errors.Trace(err) + } + elapsed := time.Since(start) + log.Info("Restore Raw", + logutil.Key("startKey", cfg.StartKey), + logutil.Key("endKey", cfg.EndKey), + zap.Duration("take", elapsed)) // Restore has finished. updateCh.Close() diff --git a/br/pkg/task/restore_txn.go b/br/pkg/task/restore_txn.go index d686eb97e2154..2af64a59602cc 100644 --- a/br/pkg/task/restore_txn.go +++ b/br/pkg/task/restore_txn.go @@ -15,6 +15,7 @@ import ( snapclient "github.com/pingcap/tidb/br/pkg/restore/snap_client" restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" "github.com/pingcap/tidb/br/pkg/summary" + "go.uber.org/zap" ) // RunRestoreTxn starts a txn kv restore task inside the current goroutine. @@ -54,7 +55,7 @@ func RunRestoreTxn(c context.Context, g glue.Glue, cmdName string, cfg *Config) return errors.Trace(err) } reader := metautil.NewMetaReader(backupMeta, s, &cfg.CipherInfo) - if err = client.LoadSchemaIfNeededAndInitClient(c, backupMeta, u, reader, true); err != nil { + if err = client.LoadSchemaIfNeededAndInitClient(c, backupMeta, u, reader, true, nil, nil); err != nil { return errors.Trace(err) } @@ -72,6 +73,7 @@ func RunRestoreTxn(c context.Context, g glue.Glue, cmdName string, cfg *Config) } summary.CollectInt("restore files", len(files)) + log.Info("restore files", zap.Int("count", len(files))) ranges, _, err := restoreutils.MergeAndRewriteFileRanges( files, nil, conn.DefaultMergeRegionSizeBytes, conn.DefaultMergeRegionKeyCount) if err != nil { @@ -86,8 +88,9 @@ func RunRestoreTxn(c context.Context, g glue.Glue, cmdName string, cfg *Config) int64(len(ranges)+len(files)), !cfg.LogProgress) + onProgress := func(i int64) { updateCh.IncBy(i) } // RawKV restore does not need to rewrite keys. - err = client.SplitPoints(ctx, getEndKeys(ranges), updateCh, false) + err = client.SplitPoints(ctx, getEndKeys(ranges), onProgress, false) if err != nil { return errors.Trace(err) } @@ -99,11 +102,14 @@ func RunRestoreTxn(c context.Context, g glue.Glue, cmdName string, cfg *Config) } defer restore.RestorePostWork(ctx, importModeSwitcher, restoreSchedulers, false) - err = client.WaitForFilesRestored(ctx, files, updateCh) + err = client.GetRestorer().GoRestore(onProgress, restore.CreateUniqueFileSets(files)) + if err != nil { + return errors.Trace(err) + } + err = client.GetRestorer().WaitUntilFinish() if err != nil { return errors.Trace(err) } - // Restore has finished. updateCh.Close() diff --git a/br/pkg/task/stream.go b/br/pkg/task/stream.go index 8f5f14753eac3..e849da1626e2a 100644 --- a/br/pkg/task/stream.go +++ b/br/pkg/task/stream.go @@ -1294,7 +1294,7 @@ func restoreStream( if err != nil { return errors.Annotate(err, "failed to create restore client") } - defer client.Close() + defer client.Close(ctx) if taskInfo != nil && taskInfo.Metadata != nil { // reuse the task's rewrite ts @@ -1346,24 +1346,18 @@ func restoreStream( log.Info("finish restoring gc") }() - var checkpointRunner *checkpoint.CheckpointRunner[checkpoint.LogRestoreKeyType, checkpoint.LogRestoreValueType] + var sstCheckpointSets map[string]struct{} if cfg.UseCheckpoint { oldRatioFromCheckpoint, err := client.InitCheckpointMetadataForLogRestore(ctx, cfg.StartTS, cfg.RestoreTS, oldRatio, cfg.tiflashRecorder) if err != nil { return errors.Trace(err) } oldRatio = oldRatioFromCheckpoint - - checkpointRunner, err = client.StartCheckpointRunnerForLogRestore(ctx, g, mgr.GetStorage()) + sstCheckpointSets, err = client.InitCheckpointMetadataForCompactedSstRestore(ctx) if err != nil { return errors.Trace(err) } - defer func() { - log.Info("wait for flush checkpoint...") - checkpointRunner.WaitForFinish(ctx, !gcDisabledRestorable) - }() } - encryptionManager, err := encryption.NewManager(&cfg.LogBackupCipherInfo, &cfg.MasterKeyConfig) if err != nil { return errors.Annotate(err, "failed to create encryption manager for log restore") @@ -1373,6 +1367,11 @@ func restoreStream( if err != nil { return err } + migs, err := client.GetMigrations(ctx) + if err != nil { + return errors.Trace(err) + } + client.BuildMigrations(migs) // get full backup meta storage to generate rewrite rules. fullBackupStorage, err := parseFullBackupTablesStorage(cfg) @@ -1413,8 +1412,7 @@ func restoreStream( totalKVCount += kvCount totalSize += size } - dataFileCount := 0 - ddlFiles, err := client.LoadDDLFilesAndCountDMLFiles(ctx, &dataFileCount) + ddlFiles, err := client.LoadDDLFilesAndCountDMLFiles(ctx) if err != nil { return err } @@ -1433,46 +1431,61 @@ func restoreStream( return errors.Trace(err) } - // generate the upstream->downstream id maps for checkpoint - idrules := make(map[int64]int64) - downstreamIdset := make(map[int64]struct{}) - for upstreamId, rule := range rewriteRules { - downstreamId := restoreutils.GetRewriteTableID(upstreamId, rule) - idrules[upstreamId] = downstreamId - downstreamIdset[downstreamId] = struct{}{} + logFilesIter, err := client.LoadDMLFiles(ctx) + if err != nil { + return errors.Trace(err) } - logFilesIter, err := client.LoadDMLFiles(ctx) + compactionIter := client.LogFileManager.GetCompactionIter(ctx) + + se, err := g.CreateSession(mgr.GetStorage()) if err != nil { return errors.Trace(err) } - pd := g.StartProgress(ctx, "Restore KV Files", int64(dataFileCount), !cfg.LogProgress) + execCtx := se.GetSessionCtx().GetRestrictedSQLExecutor() + splitSize, splitKeys := utils.GetRegionSplitInfo(execCtx) + log.Info("[Log Restore] get split threshold from tikv config", zap.Uint64("split-size", splitSize), zap.Int64("split-keys", splitKeys)) + + pd := g.StartProgress(ctx, "Restore Files(SST + KV)", logclient.TotalEntryCount, !cfg.LogProgress) err = withProgress(pd, func(p glue.Progress) (pErr error) { + updateStatsWithCheckpoint := func(kvCount, size uint64) { + mu.Lock() + defer mu.Unlock() + totalKVCount += kvCount + totalSize += size + checkpointTotalKVCount += kvCount + checkpointTotalSize += size + // increase the progress + p.IncBy(int64(kvCount)) + } + compactedSplitIter, err := client.WrapCompactedFilesIterWithSplitHelper( + ctx, compactionIter, rewriteRules, sstCheckpointSets, + updateStatsWithCheckpoint, splitSize, splitKeys, + ) + if err != nil { + return errors.Trace(err) + } + + err = client.RestoreCompactedSstFiles(ctx, compactedSplitIter, rewriteRules, importModeSwitcher, p.IncBy) + if err != nil { + return errors.Trace(err) + } + + logFilesIterWithSplit, err := client.WrapLogFilesIterWithSplitHelper(ctx, logFilesIter, execCtx, rewriteRules, updateStatsWithCheckpoint, splitSize, splitKeys) + if err != nil { + return errors.Trace(err) + } + if cfg.UseCheckpoint { - updateStatsWithCheckpoint := func(kvCount, size uint64) { - mu.Lock() - defer mu.Unlock() - totalKVCount += kvCount - totalSize += size - checkpointTotalKVCount += kvCount - checkpointTotalSize += size - } - logFilesIter, err = client.WrapLogFilesIterWithCheckpoint(ctx, logFilesIter, downstreamIdset, updateStatsWithCheckpoint, p.Inc) - if err != nil { - return errors.Trace(err) - } + // TODO make a failpoint iter inside the logclient. failpoint.Inject("corrupt-files", func(v failpoint.Value) { var retErr error - logFilesIter, retErr = logclient.WrapLogFilesIterWithCheckpointFailpoint(v, logFilesIter, rewriteRules) + logFilesIterWithSplit, retErr = logclient.WrapLogFilesIterWithCheckpointFailpoint(v, logFilesIterWithSplit, rewriteRules) defer func() { pErr = retErr }() }) } - logFilesIterWithSplit, err := client.WrapLogFilesIterWithSplitHelper(logFilesIter, rewriteRules, g, mgr.GetStorage()) - if err != nil { - return errors.Trace(err) - } - return client.RestoreKVFiles(ctx, rewriteRules, idrules, logFilesIterWithSplit, checkpointRunner, + return client.RestoreKVFiles(ctx, rewriteRules, logFilesIterWithSplit, cfg.PitrBatchCount, cfg.PitrBatchSize, updateStats, p.IncBy, &cfg.LogBackupCipherInfo, cfg.MasterKeyConfig.MasterKeys) }) if err != nil { @@ -1512,7 +1525,7 @@ func restoreStream( } failpoint.Inject("do-checksum-with-rewrite-rules", func(_ failpoint.Value) { - if err := client.FailpointDoChecksumForLogRestore(ctx, mgr.GetStorage().GetClient(), mgr.GetPDClient(), idrules, rewriteRules); err != nil { + if err := client.FailpointDoChecksumForLogRestore(ctx, mgr.GetStorage().GetClient(), mgr.GetPDClient(), rewriteRules); err != nil { failpoint.Return(errors.Annotate(err, "failed to do checksum")) } }) @@ -1527,13 +1540,14 @@ func createRestoreClient(ctx context.Context, g glue.Glue, cfg *RestoreConfig, m keepaliveCfg := GetKeepalive(&cfg.Config) keepaliveCfg.PermitWithoutStream = true client := logclient.NewRestoreClient(mgr.GetPDClient(), mgr.GetPDHTTPClient(), mgr.GetTLSConfig(), keepaliveCfg) - err = client.Init(g, mgr.GetStorage()) + + err = client.Init(ctx, g, mgr.GetStorage()) if err != nil { return nil, errors.Trace(err) } defer func() { if err != nil { - client.Close() + client.Close(ctx) } }() @@ -1547,9 +1561,20 @@ func createRestoreClient(ctx context.Context, g glue.Glue, cfg *RestoreConfig, m return nil, errors.Trace(err) } client.SetCrypter(&cfg.CipherInfo) - client.SetConcurrency(uint(cfg.Concurrency)) client.SetUpstreamClusterID(cfg.upstreamClusterID) - client.InitClients(ctx, u) + + createCheckpointSessionFn := func() (glue.Session, error) { + // always create a new session for checkpoint runner + // because session is not thread safe + if cfg.UseCheckpoint { + return g.CreateSession(mgr.GetStorage()) + } + return nil, nil + } + err = client.InitClients(ctx, u, createCheckpointSessionFn, uint(cfg.Concurrency), cfg.ConcurrencyPerStore.Value) + if err != nil { + return nil, errors.Trace(err) + } err = client.SetRawKVBatchClient(ctx, cfg.PD, cfg.TLS.ToKVSecurity()) if err != nil { diff --git a/br/pkg/utils/iter/combinator_types.go b/br/pkg/utils/iter/combinator_types.go index 34288d104236c..af71a83adf5fb 100644 --- a/br/pkg/utils/iter/combinator_types.go +++ b/br/pkg/utils/iter/combinator_types.go @@ -124,6 +124,26 @@ func (f filterMap[T, R]) TryNext(ctx context.Context) IterResult[R] { } } +type tryMap[T, R any] struct { + inner TryNextor[T] + + mapper func(T) (R, error) +} + +func (t tryMap[T, R]) TryNext(ctx context.Context) IterResult[R] { + r := t.inner.TryNext(ctx) + + if r.FinishedOrError() { + return DoneBy[R](r) + } + + res, err := t.mapper(r.Item) + if err != nil { + return Throw[R](err) + } + return Emit(res) +} + type join[T any] struct { inner TryNextor[TryNextor[T]] diff --git a/br/pkg/utils/iter/combinators.go b/br/pkg/utils/iter/combinators.go index 6247237161912..1b7c93f55708c 100644 --- a/br/pkg/utils/iter/combinators.go +++ b/br/pkg/utils/iter/combinators.go @@ -91,6 +91,13 @@ func MapFilter[T, R any](it TryNextor[T], mapper func(T) (R, bool)) TryNextor[R] } } +func TryMap[T, R any](it TryNextor[T], mapper func(T) (R, error)) TryNextor[R] { + return tryMap[T, R]{ + inner: it, + mapper: mapper, + } +} + // ConcatAll concatenates all elements yields by the iterators. // In another word, it 'chains' all the input iterators. func ConcatAll[T any](items ...TryNextor[T]) TryNextor[T] {