diff --git a/br/pkg/restore/client.go b/br/pkg/restore/client.go index c4d6481a6bd70..180fef74d36f8 100644 --- a/br/pkg/restore/client.go +++ b/br/pkg/restore/client.go @@ -33,6 +33,7 @@ import ( "github.com/pingcap/tidb/br/pkg/pdutil" "github.com/pingcap/tidb/br/pkg/redact" "github.com/pingcap/tidb/br/pkg/restore/split" + "github.com/pingcap/tidb/br/pkg/restore/tiflashrec" "github.com/pingcap/tidb/br/pkg/rtree" "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/br/pkg/stream" @@ -1630,20 +1631,30 @@ func (rc *Client) GetRebasedTables() map[UniqueTableName]bool { return rc.rebasedTablesMap } +func (rc *Client) getTiFlashNodeCount(ctx context.Context) (uint64, error) { + tiFlashStores, err := util.GetAllTiKVStores(ctx, rc.pdClient, util.TiFlashOnly) + if err != nil { + return 0, errors.Trace(err) + } + return uint64(len(tiFlashStores)), nil +} + // PreCheckTableTiFlashReplica checks whether TiFlash replica is less than TiFlash node. func (rc *Client) PreCheckTableTiFlashReplica( ctx context.Context, tables []*metautil.Table, - skipTiflash bool, + recorder *tiflashrec.TiFlashRecorder, ) error { - tiFlashStores, err := util.GetAllTiKVStores(ctx, rc.pdClient, util.TiFlashOnly) + tiFlashStoreCount, err := rc.getTiFlashNodeCount(ctx) if err != nil { - return errors.Trace(err) + return err } - tiFlashStoreCount := len(tiFlashStores) for _, table := range tables { - if skipTiflash || - (table.Info.TiFlashReplica != nil && table.Info.TiFlashReplica.Count > uint64(tiFlashStoreCount)) { + if recorder != nil || + (table.Info.TiFlashReplica != nil && table.Info.TiFlashReplica.Count > tiFlashStoreCount) { + if recorder != nil && table.Info.TiFlashReplica != nil { + recorder.AddTable(table.Info.ID, *table.Info.TiFlashReplica) + } // we cannot satisfy TiFlash replica in restore cluster. so we should // set TiFlashReplica to unavailable in tableInfo, to avoid TiDB cannot sense TiFlash and make plan to TiFlash // see details at https://github.com/pingcap/br/issues/931 @@ -1986,7 +1997,8 @@ func (rc *Client) InitSchemasReplaceForDDL( }()...) } - return stream.NewSchemasReplace(dbMap, rc.currentTS, tableFilter, rc.GenGlobalID, rc.GenGlobalIDs, rc.InsertDeleteRangeForTable, rc.InsertDeleteRangeForIndex), nil + rp := stream.NewSchemasReplace(dbMap, rc.currentTS, tableFilter, rc.GenGlobalID, rc.GenGlobalIDs, rc.InsertDeleteRangeForTable, rc.InsertDeleteRangeForIndex) + return rp, nil } func SortMetaKVFiles(files []*backuppb.DataFileInfo) []*backuppb.DataFileInfo { diff --git a/br/pkg/restore/client_test.go b/br/pkg/restore/client_test.go index be6892463e1a2..1b48bb6050d89 100644 --- a/br/pkg/restore/client_test.go +++ b/br/pkg/restore/client_test.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/tidb/br/pkg/metautil" "github.com/pingcap/tidb/br/pkg/mock" "github.com/pingcap/tidb/br/pkg/restore" + "github.com/pingcap/tidb/br/pkg/restore/tiflashrec" "github.com/pingcap/tidb/br/pkg/stream" "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/parser/model" @@ -383,7 +384,7 @@ func TestPreCheckTableTiFlashReplicas(t *testing.T) { } } ctx := context.Background() - require.Nil(t, client.PreCheckTableTiFlashReplica(ctx, tables, false)) + require.Nil(t, client.PreCheckTableTiFlashReplica(ctx, tables, nil)) for i := 0; i < len(tables); i++ { if i == 0 || i > 2 { @@ -395,7 +396,7 @@ func TestPreCheckTableTiFlashReplicas(t *testing.T) { } } - require.Nil(t, client.PreCheckTableTiFlashReplica(ctx, tables, true)) + require.Nil(t, client.PreCheckTableTiFlashReplica(ctx, tables, tiflashrec.New())) for i := 0; i < len(tables); i++ { require.Nil(t, tables[i].Info.TiFlashReplica) } diff --git a/br/pkg/restore/tiflashrec/tiflash_recorder.go b/br/pkg/restore/tiflashrec/tiflash_recorder.go new file mode 100644 index 0000000000000..d575a26c17a93 --- /dev/null +++ b/br/pkg/restore/tiflashrec/tiflash_recorder.go @@ -0,0 +1,129 @@ +// Copyright 2022-present PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tiflashrec + +import ( + "bytes" + "fmt" + + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/format" + "github.com/pingcap/tidb/parser/model" + "go.uber.org/zap" +) + +// TiFlashRecorder records the information of TiFlash replicas +// during restoring. +// Because the limit of the current implementation, we add serval hooks +// to observe the information we need: +// - Before full restore create tables: +// We record the tiflash replica information and remove the replica info. +// Because during PiTR restore, the transaction model would be broken, which breaks TiFlash too. +// We must make sure they won't be replicated to TiFlash during the whole PiTR procedure. +// - After full restore created tables, generating rewrite rules: +// We perform the rewrite rule over our records. +// We trace table via table ID instead of table name so we can handle `RENAME` DDLs. +// - When doing PiTR restore, after rewriting table info in meta key: +// We update the replica information +type TiFlashRecorder struct { + // Table ID -> TiFlash Count + items map[int64]model.TiFlashReplicaInfo +} + +func New() *TiFlashRecorder { + return &TiFlashRecorder{ + items: map[int64]model.TiFlashReplicaInfo{}, + } +} + +func (r *TiFlashRecorder) AddTable(tableID int64, replica model.TiFlashReplicaInfo) { + log.Info("recording tiflash replica", zap.Int64("table", tableID), zap.Any("replica", replica)) + r.items[tableID] = replica +} + +func (r *TiFlashRecorder) DelTable(tableID int64) { + delete(r.items, tableID) +} + +func (r *TiFlashRecorder) Iterate(f func(tableID int64, replica model.TiFlashReplicaInfo)) { + for k, v := range r.items { + f(k, v) + } +} + +func (r *TiFlashRecorder) Rewrite(oldID int64, newID int64) { + if newID == oldID { + return + } + old, ok := r.items[oldID] + log.Info("rewriting tiflash replica", zap.Int64("old", oldID), zap.Int64("new", newID), zap.Bool("success", ok)) + if ok { + r.items[newID] = old + delete(r.items, oldID) + } +} + +func (r *TiFlashRecorder) GenerateAlterTableDDLs(info infoschema.InfoSchema) []string { + items := make([]string, 0, len(r.items)) + r.Iterate(func(id int64, replica model.TiFlashReplicaInfo) { + table, ok := info.TableByID(id) + if !ok { + log.Warn("Table do not exist, skipping", zap.Int64("id", id)) + return + } + schema, ok := info.SchemaByTable(table.Meta()) + if !ok { + log.Warn("Schema do not exist, skipping", zap.Int64("id", id), zap.Stringer("table", table.Meta().Name)) + return + } + altTableSpec, err := alterTableSpecOf(replica) + if err != nil { + log.Warn("Failed to generate the alter table spec", logutil.ShortError(err), zap.Any("replica", replica)) + return + } + items = append(items, fmt.Sprintf( + "ALTER TABLE %s %s", + utils.EncloseDBAndTable(schema.Name.O, table.Meta().Name.O), + altTableSpec), + ) + }) + return items +} + +func alterTableSpecOf(replica model.TiFlashReplicaInfo) (string, error) { + spec := &ast.AlterTableSpec{ + Tp: ast.AlterTableSetTiFlashReplica, + TiFlashReplica: &ast.TiFlashReplicaSpec{ + Count: replica.Count, + Labels: replica.LocationLabels, + }, + } + + buf := bytes.NewBuffer(make([]byte, 0, 32)) + restoreCx := format.NewRestoreCtx( + format.RestoreKeyWordUppercase| + format.RestoreNameBackQuotes| + format.RestoreStringSingleQuotes| + format.RestoreStringEscapeBackslash, + buf) + if err := spec.Restore(restoreCx); err != nil { + return "", err + } + return buf.String(), nil +} diff --git a/br/pkg/restore/tiflashrec/tiflash_recorder_test.go b/br/pkg/restore/tiflashrec/tiflash_recorder_test.go new file mode 100644 index 0000000000000..b01272caeddc5 --- /dev/null +++ b/br/pkg/restore/tiflashrec/tiflash_recorder_test.go @@ -0,0 +1,172 @@ +// Copyright 2022-present PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tiflashrec_test + +import ( + "fmt" + "testing" + + "github.com/pingcap/tidb/br/pkg/restore/tiflashrec" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/parser/model" + "github.com/stretchr/testify/require" +) + +type op func(*tiflashrec.TiFlashRecorder) + +func add(tableID int64, replica int) op { + return func(tfr *tiflashrec.TiFlashRecorder) { + tfr.AddTable(tableID, model.TiFlashReplicaInfo{ + Count: uint64(replica), + }) + } +} + +func rewrite(tableID, newTableID int64) op { + return func(tfr *tiflashrec.TiFlashRecorder) { + tfr.Rewrite(tableID, newTableID) + } +} + +func del(tableID int64) op { + return func(tfr *tiflashrec.TiFlashRecorder) { + tfr.DelTable(tableID) + } +} + +func ops(ops ...op) op { + return func(tfr *tiflashrec.TiFlashRecorder) { + for _, op := range ops { + op(tfr) + } + } +} + +type table struct { + id int64 + replica int +} + +func t(id int64, replica int) table { + return table{ + id: id, + replica: replica, + } +} + +func TestRecorder(tCtx *testing.T) { + type Case struct { + o op + ts []table + } + cases := []Case{ + { + o: ops( + add(42, 1), + add(43, 2), + ), + ts: []table{ + t(42, 1), + t(43, 2), + }, + }, + { + o: ops( + add(42, 3), + add(43, 1), + del(42), + ), + ts: []table{ + t(43, 1), + }, + }, + { + o: ops( + add(41, 4), + add(42, 8), + rewrite(42, 1890), + rewrite(1890, 43), + rewrite(41, 100), + ), + ts: []table{ + t(43, 8), + t(100, 4), + }, + }, + } + + check := func(t *testing.T, c Case) { + rec := tiflashrec.New() + req := require.New(t) + c.o(rec) + tmap := map[int64]int{} + for _, t := range c.ts { + tmap[t.id] = t.replica + } + + rec.Iterate(func(tableID int64, replicaReal model.TiFlashReplicaInfo) { + replica, ok := tmap[tableID] + req.True(ok, "the key %d not recorded", tableID) + req.EqualValues(replica, replicaReal.Count, "the replica mismatch") + delete(tmap, tableID) + }) + req.Empty(tmap, "not all required are recorded") + } + + for i, c := range cases { + tCtx.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + check(t, c) + }) + } +} + +func TestGenSql(t *testing.T) { + tInfo := func(id int, name string) *model.TableInfo { + return &model.TableInfo{ + ID: int64(id), + Name: model.NewCIStr(name), + } + } + fakeInfo := infoschema.MockInfoSchema([]*model.TableInfo{ + tInfo(1, "fruits"), + tInfo(2, "whisper"), + tInfo(3, "woods"), + tInfo(4, "evils"), + }) + rec := tiflashrec.New() + rec.AddTable(1, model.TiFlashReplicaInfo{ + Count: 1, + }) + rec.AddTable(2, model.TiFlashReplicaInfo{ + Count: 2, + LocationLabels: []string{"climate"}, + }) + rec.AddTable(3, model.TiFlashReplicaInfo{ + Count: 3, + LocationLabels: []string{"leaf", "seed"}, + }) + rec.AddTable(4, model.TiFlashReplicaInfo{ + Count: 1, + LocationLabels: []string{`kIll'; OR DROP DATABASE test --`, `dEaTh with \"quoting\"`}, + }) + + sqls := rec.GenerateAlterTableDDLs(fakeInfo) + require.ElementsMatch(t, sqls, []string{ + "ALTER TABLE `test`.`whisper` SET TIFLASH REPLICA 2 LOCATION LABELS 'climate'", + "ALTER TABLE `test`.`woods` SET TIFLASH REPLICA 3 LOCATION LABELS 'leaf', 'seed'", + "ALTER TABLE `test`.`fruits` SET TIFLASH REPLICA 1", + "ALTER TABLE `test`.`evils` SET TIFLASH REPLICA 1 LOCATION LABELS 'kIll''; OR DROP DATABASE test --', 'dEaTh with " + `\\"quoting\\"` + "'", + }) +} diff --git a/br/pkg/stream/rewrite_meta_rawkv.go b/br/pkg/stream/rewrite_meta_rawkv.go index 93fefff4f8cad..b5da1dafb3013 100644 --- a/br/pkg/stream/rewrite_meta_rawkv.go +++ b/br/pkg/stream/rewrite_meta_rawkv.go @@ -62,6 +62,8 @@ type SchemasReplace struct { genGenGlobalIDs func(ctx context.Context, n int) ([]int64, error) insertDeleteRangeForTable func(jobID int64, tableIDs []int64) insertDeleteRangeForIndex func(jobID int64, elementID *int64, tableID int64, indexIDs []int64) + + AfterTableRewritten func(deleted bool, tableInfo *model.TableInfo) } // NewTableReplace creates a TableReplace struct. @@ -310,9 +312,6 @@ func (sr *SchemasReplace) rewriteTableInfo(value []byte, dbID int64) ([]byte, bo newTableInfo.Partition = tableInfo.Partition.Clone() } newTableInfo.ID = tableReplace.NewTableID - // Do not restore tiflash replica to down-stream. - //After restore meta finished, restore tiflash replica by DDL. - newTableInfo.TiFlashReplica = nil // update partition table ID partitions := newTableInfo.GetPartitionInfo() @@ -338,6 +337,10 @@ func (sr *SchemasReplace) rewriteTableInfo(value []byte, dbID int64) ([]byte, bo } } + if sr.AfterTableRewritten != nil { + sr.AfterTableRewritten(false, newTableInfo) + } + // marshal to json newValue, err := json.Marshal(&newTableInfo) if err != nil { @@ -352,23 +355,34 @@ func (sr *SchemasReplace) rewriteEntryForTable(e *kv.Entry, cf string) (*kv.Entr return nil, errors.Trace(err) } - newValue, needWrite, err := sr.rewriteValue( + result, err := sr.rewriteValueV2( e.Value, cf, func(value []byte) ([]byte, bool, error) { return sr.rewriteTableInfo(value, dbID) }, ) - if err != nil || !needWrite { + if err != nil || !result.NeedRewrite { return nil, errors.Trace(err) } - newKey, needWrite, err := sr.rewriteKeyForTable(e.Key, cf, meta.ParseTableKey, meta.TableKey) + newTableID := 0 + newKey, needWrite, err := sr.rewriteKeyForTable(e.Key, cf, meta.ParseTableKey, func(tableID int64) []byte { + newTableID = int(tableID) + return meta.TableKey(tableID) + }) if err != nil || !needWrite { return nil, errors.Trace(err) } + // NOTE: the normal path is in the `SchemaReplace.rewriteTableInfo` + // for now, we rewrite key and value separately hence we cannot + // get a view of (is_delete, table_id, table_info) at the same time :(. + // Maybe we can extract the rewrite part from rewriteTableInfo. + if result.Deleted && sr.AfterTableRewritten != nil { + sr.AfterTableRewritten(true, &model.TableInfo{ID: int64(newTableID)}) + } - return &kv.Entry{Key: newKey, Value: newValue}, nil + return &kv.Entry{Key: newKey, Value: result.NewValue}, nil } func (sr *SchemasReplace) rewriteEntryForAutoTableIDKey(e *kv.Entry, cf string) (*kv.Entry, error) { @@ -413,36 +427,74 @@ func (sr *SchemasReplace) rewriteEntryForAutoRandomTableIDKey(e *kv.Entry, cf st return &kv.Entry{Key: newKey, Value: e.Value}, nil } -func (sr *SchemasReplace) rewriteValue( - value []byte, - cf string, - cbRewrite func([]byte) ([]byte, bool, error), -) ([]byte, bool, error) { +type rewriteResult struct { + Deleted bool + NeedRewrite bool + NewValue []byte +} + +// rewriteValueV2 likes rewriteValueV1, but provides a richer return value. +func (sr *SchemasReplace) rewriteValueV2(value []byte, cf string, rewrite func([]byte) ([]byte, bool, error)) (rewriteResult, error) { switch cf { case DefaultCF: - return cbRewrite(value) + newValue, needRewrite, err := rewrite(value) + if err != nil { + return rewriteResult{}, errors.Trace(err) + } + return rewriteResult{ + NeedRewrite: needRewrite, + NewValue: newValue, + Deleted: false, + }, nil case WriteCF: rawWriteCFValue := new(RawWriteCFValue) if err := rawWriteCFValue.ParseFrom(value); err != nil { - return nil, false, errors.Trace(err) + return rewriteResult{}, errors.Trace(err) } + if rawWriteCFValue.t == WriteTypeDelete { + return rewriteResult{ + NewValue: value, + NeedRewrite: true, + Deleted: true, + }, nil + } if !rawWriteCFValue.HasShortValue() { - return value, true, nil + return rewriteResult{ + NewValue: value, + NeedRewrite: true, + }, nil } - shortValue, needWrite, err := cbRewrite(rawWriteCFValue.GetShortValue()) - if err != nil || !needWrite { - return nil, needWrite, errors.Trace(err) + shortValue, needWrite, err := rewrite(rawWriteCFValue.GetShortValue()) + if err != nil { + return rewriteResult{}, errors.Trace(err) + } + if !needWrite { + return rewriteResult{ + NeedRewrite: false, + }, nil } rawWriteCFValue.UpdateShortValue(shortValue) - return rawWriteCFValue.EncodeTo(), true, nil + return rewriteResult{NewValue: rawWriteCFValue.EncodeTo(), NeedRewrite: true}, nil default: panic(fmt.Sprintf("not support cf:%s", cf)) } } +func (sr *SchemasReplace) rewriteValue( + value []byte, + cf string, + cbRewrite func([]byte) ([]byte, bool, error), +) ([]byte, bool, error) { + r, err := sr.rewriteValueV2(value, cf, cbRewrite) + if err != nil { + return nil, false, err + } + return r.NewValue, r.NeedRewrite, nil +} + // RewriteKvEntry uses to rewrite tableID/dbID in entry.key and entry.value func (sr *SchemasReplace) RewriteKvEntry(e *kv.Entry, cf string) (*kv.Entry, error) { // skip mDDLJob diff --git a/br/pkg/stream/rewrite_meta_rawkv_test.go b/br/pkg/stream/rewrite_meta_rawkv_test.go index b0dedca7f39af..d2cbe24e8295d 100644 --- a/br/pkg/stream/rewrite_meta_rawkv_test.go +++ b/br/pkg/stream/rewrite_meta_rawkv_test.go @@ -94,6 +94,13 @@ func TestRewriteValueForTable(t *testing.T) { require.Nil(t, err) sr := MockEmptySchemasReplace(nil) + tableCount := 0 + sr.AfterTableRewritten = func(deleted bool, tableInfo *model.TableInfo) { + tableCount++ + tableInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ + Count: 1, + } + } newValue, needRewrite, err := sr.rewriteTableInfo(value, dbId) require.Nil(t, err) require.True(t, needRewrite) @@ -101,6 +108,7 @@ func TestRewriteValueForTable(t *testing.T) { err = json.Unmarshal(newValue, &tableInfo) require.Nil(t, err) require.Equal(t, tableInfo.ID, sr.DbMap[dbId].TableMap[tableID].NewTableID) + require.EqualValues(t, tableInfo.TiFlashReplica.Count, 1) newID := sr.DbMap[dbId].TableMap[tableID].NewTableID newValue, needRewrite, err = sr.rewriteTableInfo(value, dbId) @@ -111,6 +119,7 @@ func TestRewriteValueForTable(t *testing.T) { require.Nil(t, err) require.Equal(t, tableInfo.ID, sr.DbMap[dbId].TableMap[tableID].NewTableID) require.Equal(t, newID, sr.DbMap[dbId].TableMap[tableID].NewTableID) + require.EqualValues(t, tableCount, 2) } func TestRewriteValueForPartitionTable(t *testing.T) { diff --git a/br/pkg/streamhelper/collector.go b/br/pkg/streamhelper/collector.go index 8e30c22804d39..ad53acb03b577 100644 --- a/br/pkg/streamhelper/collector.go +++ b/br/pkg/streamhelper/collector.go @@ -95,6 +95,7 @@ func (c *storeCollector) begin(ctx context.Context) { func (c *storeCollector) recvLoop(ctx context.Context) (err error) { defer utils.PanicToErr(&err) + log.Debug("Begin recv loop", zap.Uint64("store", c.storeID)) for { select { case <-ctx.Done(): @@ -179,6 +180,7 @@ func (c *storeCollector) spawn(ctx context.Context) func(context.Context) (Store func (c *storeCollector) sendPendingRequests(ctx context.Context) error { log.Debug("sending batch", zap.Int("size", len(c.currentRequest.Regions)), zap.Uint64("store", c.storeID)) + defer log.Debug("sending batch done", zap.Uint64("store", c.storeID)) cli, err := c.service.GetLogBackupClient(ctx, c.storeID) if err != nil { return err diff --git a/br/pkg/task/restore.go b/br/pkg/task/restore.go index 4cec79e7acf62..069073b11618d 100644 --- a/br/pkg/task/restore.go +++ b/br/pkg/task/restore.go @@ -20,11 +20,13 @@ import ( "github.com/pingcap/tidb/br/pkg/metautil" "github.com/pingcap/tidb/br/pkg/pdutil" "github.com/pingcap/tidb/br/pkg/restore" + "github.com/pingcap/tidb/br/pkg/restore/tiflashrec" "github.com/pingcap/tidb/br/pkg/summary" "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/br/pkg/version" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/mathutil" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -163,9 +165,9 @@ type RestoreConfig struct { FullBackupStorage string `json:"full-backup-storage" toml:"full-backup-storage"` // [startTs, RestoreTS] is used to `restore log` from StartTS to RestoreTS. - StartTS uint64 `json:"start-ts" toml:"start-ts"` - RestoreTS uint64 `json:"restore-ts" toml:"restore-ts"` - skipTiflash bool `json:"-" toml:"-"` + StartTS uint64 `json:"start-ts" toml:"start-ts"` + RestoreTS uint64 `json:"restore-ts" toml:"restore-ts"` + tiflashRecorder *tiflashrec.TiFlashRecorder `json:"-" toml:"-"` } // DefineRestoreFlags defines common flags for the restore tidb command. @@ -519,7 +521,7 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf ddlJobs := restore.FilterDDLJobs(client.GetDDLJobs(), tables) ddlJobs = restore.FilterDDLJobByRules(ddlJobs, restore.DDLJobBlockListRule) - err = client.PreCheckTableTiFlashReplica(ctx, tables, cfg.skipTiflash) + err = client.PreCheckTableTiFlashReplica(ctx, tables, cfg.tiflashRecorder) if err != nil { return errors.Trace(err) } @@ -580,6 +582,15 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf // don't return immediately, wait all pipeline done. } + if cfg.tiflashRecorder != nil { + tableStream = util.ChanMap(tableStream, func(t restore.CreatedTable) restore.CreatedTable { + if cfg.tiflashRecorder != nil { + cfg.tiflashRecorder.Rewrite(t.OldTable.Info.ID, t.Table.ID) + } + return t + }) + } + tableFileMap := restore.MapTableToFiles(files) log.Debug("mapped table to files", zap.Any("result map", tableFileMap)) diff --git a/br/pkg/task/stream.go b/br/pkg/task/stream.go index e9610e0c7c73a..0e92ec043ff34 100644 --- a/br/pkg/task/stream.go +++ b/br/pkg/task/stream.go @@ -38,6 +38,7 @@ import ( "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/br/pkg/metautil" "github.com/pingcap/tidb/br/pkg/restore" + "github.com/pingcap/tidb/br/pkg/restore/tiflashrec" "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/br/pkg/stream" "github.com/pingcap/tidb/br/pkg/streamhelper" @@ -45,6 +46,7 @@ import ( "github.com/pingcap/tidb/br/pkg/summary" "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/util/mathutil" "github.com/pingcap/tidb/util/sqlexec" "github.com/spf13/pflag" @@ -1001,6 +1003,8 @@ func RunStreamRestore( return errors.Trace(err) } + recorder := tiflashrec.New() + cfg.tiflashRecorder = recorder // restore full snapshot. if len(cfg.FullBackupStorage) > 0 { if err := checkPiTRRequirements(ctx, g, cfg); err != nil { @@ -1009,7 +1013,6 @@ func RunStreamRestore( logStorage := cfg.Config.Storage cfg.Config.Storage = cfg.FullBackupStorage // TiFlash replica is restored to down-stream on 'pitr' currently. - cfg.skipTiflash = true if err = RunRestore(ctx, g, FullRestoreCmd, cfg); err != nil { return errors.Trace(err) } @@ -1125,6 +1128,17 @@ func restoreStream( if err != nil { return errors.Trace(err) } + schemasReplace.AfterTableRewritten = func(deleted bool, tableInfo *model.TableInfo) { + // When the table replica changed to 0, the tiflash replica might be set to `nil`. + // We should remove the table if we meet. + if deleted || tableInfo.TiFlashReplica == nil { + cfg.tiflashRecorder.DelTable(tableInfo.ID) + return + } + cfg.tiflashRecorder.AddTable(tableInfo.ID, *tableInfo.TiFlashReplica) + // Remove the replica firstly. Let's restore them at the end. + tableInfo.TiFlashReplica = nil + } updateStats := func(kvCount uint64, size uint64) { mu.Lock() @@ -1167,6 +1181,26 @@ func restoreStream( return errors.Annotate(err, "failed to insert rows into gc_delete_range") } + if cfg.tiflashRecorder != nil { + sqls := cfg.tiflashRecorder.GenerateAlterTableDDLs(mgr.GetDomain().InfoSchema()) + log.Info("Generating SQLs for restoring TiFlash Replica", + zap.Strings("sqls", sqls)) + err = g.UseOneShotSession(mgr.GetStorage(), false, func(se glue.Session) error { + for _, sql := range sqls { + if errExec := se.ExecuteInternal(ctx, sql); errExec != nil { + logutil.WarnTerm("Failed to restore tiflash replica config, you may execute the sql restore it manually.", + logutil.ShortError(errExec), + zap.String("sql", sql), + ) + } + } + return nil + }) + if err != nil { + return err + } + } + return nil } diff --git a/util/misc_test.go b/util/misc_test.go index a56fee5b4208d..c1510625593f0 100644 --- a/util/misc_test.go +++ b/util/misc_test.go @@ -17,6 +17,7 @@ package util import ( "bytes" "crypto/x509/pkix" + "fmt" "testing" "time" @@ -186,3 +187,29 @@ func TestComposeURL(t *testing.T) { assert.Equal(t, ComposeURL("http://server.example.com", ""), "http://server.example.com") assert.Equal(t, ComposeURL("https://server.example.com", ""), "https://server.example.com") } + +func assertChannel[T any](t *testing.T, ch <-chan T, items ...T) { + for i, item := range items { + assert.Equal(t, <-ch, item, "the %d-th item doesn't match", i) + } + select { + case item, ok := <-ch: + assert.False(t, ok, "channel not closed: more item %v", item) + default: + t.Fatal("channel not closed: blocked") + } +} + +func TestChannelMap(t *testing.T) { + ch := make(chan int, 4) + ch <- 1 + ch <- 2 + ch <- 3 + + tableCh := ChanMap(ch, func(i int) string { + return fmt.Sprintf("table%d", i) + }) + close(ch) + + assertChannel(t, tableCh, "table1", "table2", "table3") +} diff --git a/util/util.go b/util/util.go index 9b1a9b52840fc..9bf6b085882de 100644 --- a/util/util.go +++ b/util/util.go @@ -71,3 +71,18 @@ func GetJSON(client *http.Client, url string, v interface{}) error { return errors.Trace(json.NewDecoder(resp.Body).Decode(v)) } + +// ChanMap creates a channel which applies the function over the input Channel. +// Hint of Resource Leakage: +// In golang, channel isn't an interface so we must create a goroutine for handling the inputs. +// Hence the input channel must be closed properly or this function may leak a goroutine. +func ChanMap[T, R any](c <-chan T, f func(T) R) <-chan R { + outCh := make(chan R) + go func() { + defer close(outCh) + for item := range c { + outCh <- f(item) + } + }() + return outCh +}