diff --git a/lightning/pkg/importer/get_pre_info.go b/lightning/pkg/importer/get_pre_info.go index db8e569e5f95b..775cf7ec5f3eb 100644 --- a/lightning/pkg/importer/get_pre_info.go +++ b/lightning/pkg/importer/get_pre_info.go @@ -94,7 +94,7 @@ type TargetInfoGetter interface { // FetchRemoteDBModels fetches the database structures from the remote target. FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) // FetchRemoteTableModels fetches the table structures from the remote target. - FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) + FetchRemoteTableModels(ctx context.Context, schemaName string, tableNames []string) (map[string]*model.TableInfo, error) // CheckVersionRequirements performs the check whether the target satisfies the version requirements. CheckVersionRequirements(ctx context.Context) error // IsTableEmpty checks whether the specified table on the target DB contains data or not. @@ -162,8 +162,12 @@ func (g *TargetInfoGetterImpl) FetchRemoteDBModels(ctx context.Context) ([]*mode // FetchRemoteTableModels fetches the table structures from the remote target. // It implements the TargetInfoGetter interface. -func (g *TargetInfoGetterImpl) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { - return g.backend.FetchRemoteTableModels(ctx, schemaName) +func (g *TargetInfoGetterImpl) FetchRemoteTableModels( + ctx context.Context, + schemaName string, + tableNames []string, +) (map[string]*model.TableInfo, error) { + return g.backend.FetchRemoteTableModels(ctx, schemaName, tableNames) } // CheckVersionRequirements performs the check whether the target satisfies the version requirements. @@ -365,6 +369,10 @@ func (p *PreImportInfoGetterImpl) GetAllTableStructures(ctx context.Context, opt func (p *PreImportInfoGetterImpl) getTableStructuresByFileMeta(ctx context.Context, dbSrcFileMeta *mydump.MDDatabaseMeta, getPreInfoCfg *ropts.GetPreInfoConfig) ([]*model.TableInfo, error) { dbName := dbSrcFileMeta.Name + tableNames := make([]string, 0, len(dbSrcFileMeta.Tables)) + for _, tableFileMeta := range dbSrcFileMeta.Tables { + tableNames = append(tableNames, tableFileMeta.Name) + } failpoint.Inject( "getTableStructuresByFileMeta_BeforeFetchRemoteTableModels", func(v failpoint.Value) { @@ -378,7 +386,7 @@ func (p *PreImportInfoGetterImpl) getTableStructuresByFileMeta(ctx context.Conte failpoint.Enable("github.com/pingcap/tidb/pkg/lightning/backend/tidb/FetchRemoteTableModels_BeforeFetchTableAutoIDInfos", fmt.Sprintf("sleep(%d)", sleepMilliSeconds)) }, ) - currentTableInfosFromDB, err := p.targetInfoGetter.FetchRemoteTableModels(ctx, dbName) + currentTableInfosMap, err := p.targetInfoGetter.FetchRemoteTableModels(ctx, dbName, tableNames) if err != nil { if getPreInfoCfg != nil && getPreInfoCfg.IgnoreDBNotExist { dbNotExistErr := dbterror.ClassSchema.NewStd(errno.ErrBadDB).FastGenByArgs(dbName) @@ -394,10 +402,6 @@ func (p *PreImportInfoGetterImpl) getTableStructuresByFileMeta(ctx context.Conte return nil, errors.Trace(err) } get_struct_from_src: - currentTableInfosMap := make(map[string]*model.TableInfo) - for _, tblInfo := range currentTableInfosFromDB { - currentTableInfosMap[tblInfo.Name.L] = tblInfo - } resultInfos := make([]*model.TableInfo, len(dbSrcFileMeta.Tables)) for i, tableFileMeta := range dbSrcFileMeta.Tables { if curTblInfo, ok := currentTableInfosMap[strings.ToLower(tableFileMeta.Name)]; ok { @@ -804,8 +808,12 @@ func (p *PreImportInfoGetterImpl) FetchRemoteDBModels(ctx context.Context) ([]*m // FetchRemoteTableModels fetches the table structures from the remote target. // It implements the PreImportInfoGetter interface. -func (p *PreImportInfoGetterImpl) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { - return p.targetInfoGetter.FetchRemoteTableModels(ctx, schemaName) +func (p *PreImportInfoGetterImpl) FetchRemoteTableModels( + ctx context.Context, + schemaName string, + tableNames []string, +) (map[string]*model.TableInfo, error) { + return p.targetInfoGetter.FetchRemoteTableModels(ctx, schemaName, tableNames) } // CheckVersionRequirements performs the check whether the target satisfies the version requirements. diff --git a/lightning/pkg/importer/mock/mock.go b/lightning/pkg/importer/mock/mock.go index a5979998954cd..3a88b0cd2a56a 100644 --- a/lightning/pkg/importer/mock/mock.go +++ b/lightning/pkg/importer/mock/mock.go @@ -225,17 +225,25 @@ func (t *TargetInfo) FetchRemoteDBModels(_ context.Context) ([]*model.DBInfo, er // FetchRemoteTableModels fetches the table structures from the remote target. // It implements the TargetInfoGetter interface. -func (t *TargetInfo) FetchRemoteTableModels(_ context.Context, schemaName string) ([]*model.TableInfo, error) { - resultInfos := []*model.TableInfo{} +func (t *TargetInfo) FetchRemoteTableModels( + _ context.Context, + schemaName string, + tableNames []string, +) (map[string]*model.TableInfo, error) { tblMap, ok := t.dbTblInfoMap[schemaName] if !ok { dbNotExistErr := dbterror.ClassSchema.NewStd(errno.ErrBadDB).FastGenByArgs(schemaName) return nil, errors.Errorf("get xxxxxx http status code != 200, message %s", dbNotExistErr.Error()) } - for _, tblInfo := range tblMap { - resultInfos = append(resultInfos, tblInfo.TableModel) + ret := make(map[string]*model.TableInfo, len(tableNames)) + for _, tableName := range tableNames { + tblInfo, ok := tblMap[tableName] + if !ok { + continue + } + ret[tableName] = tblInfo.TableModel } - return resultInfos, nil + return ret, nil } // GetTargetSysVariablesForImport gets some important systam variables for importing on the target. diff --git a/lightning/pkg/importer/mock/mock_test.go b/lightning/pkg/importer/mock/mock_test.go index 84cf2a88a4e76..2620583bee966 100644 --- a/lightning/pkg/importer/mock/mock_test.go +++ b/lightning/pkg/importer/mock/mock_test.go @@ -185,7 +185,7 @@ func TestMockTargetInfoBasic(t *testing.T) { RowCount: 100, }, ) - tblInfos, err := ti.FetchRemoteTableModels(ctx, "testdb") + tblInfos, err := ti.FetchRemoteTableModels(ctx, "testdb", []string{"testtbl1", "testtbl2"}) require.NoError(t, err) require.Equal(t, 2, len(tblInfos)) for _, tblInfo := range tblInfos { diff --git a/pkg/lightning/backend/backend.go b/pkg/lightning/backend/backend.go index a9bbd6f33c54f..9dc9231f8e64c 100644 --- a/pkg/lightning/backend/backend.go +++ b/pkg/lightning/backend/backend.go @@ -142,9 +142,13 @@ type TargetInfoGetter interface { // the database name is filled. FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) - // FetchRemoteTableModels obtains the models of all tables given the schema - // name. The returned table info does not need to be precise if the encoder, - // is not requiring them, but must at least fill in the following fields for + // FetchRemoteTableModels obtains the TableInfo of given tables under the schema + // name. It returns a map whose key is the table name in lower case and value is + // the TableInfo. If the table does not exist, it will not be included in the + // map. + // + // The returned table info does not need to be precise if the encoder, is not + // requiring them, but must at least fill in the following fields for // TablesFromMeta to succeed: // - Name // - State (must be model.StatePublic) @@ -154,7 +158,7 @@ type TargetInfoGetter interface { // * State (must be model.StatePublic) // * Offset (must be 0, 1, 2, ...) // - PKIsHandle (true = do not generate _tidb_rowid) - FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) + FetchRemoteTableModels(ctx context.Context, schemaName string, tableNames []string) (map[string]*model.TableInfo, error) // CheckRequirements performs the check whether the backend satisfies the version requirements CheckRequirements(ctx context.Context, checkCtx *CheckCtx) error diff --git a/pkg/lightning/backend/local/local.go b/pkg/lightning/backend/local/local.go index 87f753de2ecd8..8ff151aebc821 100644 --- a/pkg/lightning/backend/local/local.go +++ b/pkg/lightning/backend/local/local.go @@ -286,8 +286,27 @@ func (g *targetInfoGetter) FetchRemoteDBModels(ctx context.Context) ([]*model.DB // FetchRemoteTableModels obtains the models of all tables given the schema name. // It implements the `TargetInfoGetter` interface. -func (g *targetInfoGetter) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { - return tikv.FetchRemoteTableModelsFromTLS(ctx, g.tls, schemaName) +func (g *targetInfoGetter) FetchRemoteTableModels( + ctx context.Context, + schemaName string, + tableNames []string, +) (map[string]*model.TableInfo, error) { + allTablesInDB, err := tikv.FetchRemoteTableModelsFromTLS(ctx, g.tls, schemaName) + if err != nil { + return nil, errors.Trace(err) + } + + tableNamesSet := make(map[string]struct{}, len(tableNames)) + for _, name := range tableNames { + tableNamesSet[strings.ToLower(name)] = struct{}{} + } + ret := make(map[string]*model.TableInfo, len(tableNames)) + for _, tbl := range allTablesInDB { + if _, ok := tableNamesSet[tbl.Name.L]; ok { + ret[tbl.Name.L] = tbl + } + } + return ret, nil } // CheckRequirements performs the check whether the backend satisfies the version requirements. diff --git a/pkg/lightning/backend/tidb/BUILD.bazel b/pkg/lightning/backend/tidb/BUILD.bazel index 74ab3e1586816..1416344efab06 100644 --- a/pkg/lightning/backend/tidb/BUILD.bazel +++ b/pkg/lightning/backend/tidb/BUILD.bazel @@ -6,7 +6,6 @@ go_library( importpath = "github.com/pingcap/tidb/pkg/lightning/backend/tidb", visibility = ["//visibility:public"], deps = [ - "//br/pkg/version", "//pkg/errno", "//pkg/lightning/backend", "//pkg/lightning/backend/encode", @@ -21,6 +20,7 @@ go_library( "//pkg/parser/mysql", "//pkg/table", "//pkg/types", + "//pkg/util", "//pkg/util/dbutil", "//pkg/util/hack", "//pkg/util/kvcache", diff --git a/pkg/lightning/backend/tidb/tidb.go b/pkg/lightning/backend/tidb/tidb.go index caf837523df90..535dc2cea1961 100644 --- a/pkg/lightning/backend/tidb/tidb.go +++ b/pkg/lightning/backend/tidb/tidb.go @@ -28,7 +28,6 @@ import ( "github.com/google/uuid" "github.com/pingcap/errors" "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/version" "github.com/pingcap/tidb/pkg/errno" "github.com/pingcap/tidb/pkg/lightning/backend" "github.com/pingcap/tidb/pkg/lightning/backend/encode" @@ -43,6 +42,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/dbutil" "github.com/pingcap/tidb/pkg/util/hack" "github.com/pingcap/tidb/pkg/util/kvcache" @@ -171,123 +171,154 @@ func (b *targetInfoGetter) FetchRemoteDBModels(ctx context.Context) ([]*model.DB return results, err } -// FetchRemoteTableModels obtains the models of all tables given the schema name. -// It implements the `backend.TargetInfoGetter` interface. -// TODO: refactor -func (b *targetInfoGetter) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { - var err error - results := []*model.TableInfo{} +// exported for test. +var ( + FetchRemoteTableModelsConcurrency = 8 + FetchRemoteTableModelsBatchSize = 32 +) + +// FetchRemoteTableModels implements the `backend.TargetInfoGetter` interface. +func (b *targetInfoGetter) FetchRemoteTableModels( + ctx context.Context, + schemaName string, + tableNames []string, +) (map[string]*model.TableInfo, error) { + tableInfos := make([]*model.TableInfo, len(tableNames)) logger := log.FromContext(ctx) s := common.SQLWithRetry{ DB: b.db, Logger: logger, } - err = s.Transact(ctx, "fetch table columns", func(_ context.Context, tx *sql.Tx) error { - var versionStr string - if versionStr, err = version.FetchVersion(ctx, tx); err != nil { - return err + eg, egCtx := util.NewErrorGroupWithRecoverWithCtx(ctx) + eg.SetLimit(FetchRemoteTableModelsConcurrency) + for i := 0; i < len(tableNames); i += FetchRemoteTableModelsBatchSize { + start := i + end := i + FetchRemoteTableModelsBatchSize + if end > len(tableNames) { + end = len(tableNames) } - serverInfo := version.ParseServerInfo(versionStr) - - rows, e := tx.Query(` - SELECT table_name, column_name, column_type, generation_expression, extra - FROM information_schema.columns - WHERE table_schema = ? - ORDER BY table_name, ordinal_position; - `, schemaName) - if e != nil { - return e - } - defer rows.Close() - - var ( - curTableName string - curColOffset int - curTable *model.TableInfo - ) - tables := []*model.TableInfo{} - for rows.Next() { - var tableName, columnName, columnType, generationExpr, columnExtra string - if e := rows.Scan(&tableName, &columnName, &columnType, &generationExpr, &columnExtra); e != nil { - return e - } - if tableName != curTableName { - curTable = &model.TableInfo{ - Name: pmodel.NewCIStr(tableName), - State: model.StatePublic, - PKIsHandle: true, - } - tables = append(tables, curTable) - curTableName = tableName - curColOffset = 0 - } - - // see: https://github.com/pingcap/parser/blob/3b2fb4b41d73710bc6c4e1f4e8679d8be6a4863e/types/field_type.go#L185-L191 - var flag uint - if strings.HasSuffix(columnType, "unsigned") { - flag |= mysql.UnsignedFlag - } - if strings.Contains(columnExtra, "auto_increment") { - flag |= mysql.AutoIncrementFlag - } + eg.Go(func() error { + return s.Transact( + egCtx, "fetch table columns", + func(_ context.Context, tx *sql.Tx) error { + args := make([]any, 0, 1+end-start) + args = append(args, schemaName) + for _, tableName := range tableNames[start:end] { + args = append(args, tableName) + } + //nolint:gosec + rows, err := tx.Query(` + SELECT table_name, column_name, column_type, generation_expression, extra + FROM information_schema.columns + WHERE table_schema = ? AND table_name IN (?`+strings.Repeat(",?", end-start-1)+`) + ORDER BY table_name, ordinal_position; + `, args...) + if err != nil { + return err + } + defer rows.Close() + + var ( + curTableName string + curColOffset int + curTable *model.TableInfo + tableIdx = start - 1 + ) + for rows.Next() { + var tableName, columnName, columnType, generationExpr, columnExtra string + if err2 := rows.Scan(&tableName, &columnName, &columnType, &generationExpr, &columnExtra); err2 != nil { + return err2 + } + if tableName != curTableName { + tableIdx++ + curTable = &model.TableInfo{ + Name: pmodel.NewCIStr(tableName), + State: model.StatePublic, + PKIsHandle: true, + } + tableInfos[tableIdx] = curTable + curTableName = tableName + curColOffset = 0 + } - ft := types.FieldType{} - ft.SetFlag(flag) - curTable.Columns = append(curTable.Columns, &model.ColumnInfo{ - Name: pmodel.NewCIStr(columnName), - Offset: curColOffset, - State: model.StatePublic, - FieldType: ft, - GeneratedExprString: generationExpr, - }) - curColOffset++ - } - if err := rows.Err(); err != nil { - return err - } - // shard_row_id/auto random is only available after tidb v4.0.0 - // `show table next_row_id` is also not available before tidb v4.0.0 - if serverInfo.ServerType != version.ServerTypeTiDB || serverInfo.ServerVersion.Major < 4 { - results = tables - return nil - } + // see: https://github.com/pingcap/parser/blob/3b2fb4b41d73710bc6c4e1f4e8679d8be6a4863e/types/field_type.go#L185-L191 + var flag uint + if strings.HasSuffix(columnType, "unsigned") { + flag |= mysql.UnsignedFlag + } + if strings.Contains(columnExtra, "auto_increment") { + flag |= mysql.AutoIncrementFlag + } - failpoint.Inject( - "FetchRemoteTableModels_BeforeFetchTableAutoIDInfos", - func() { - fmt.Println("failpoint: FetchRemoteTableModels_BeforeFetchTableAutoIDInfos") - }, - ) + ft := types.FieldType{} + ft.SetFlag(flag) + curTable.Columns = append(curTable.Columns, &model.ColumnInfo{ + Name: pmodel.NewCIStr(columnName), + Offset: curColOffset, + State: model.StatePublic, + FieldType: ft, + GeneratedExprString: generationExpr, + }) + curColOffset++ + } + if err := rows.Err(); err != nil { + return err + } - // init auto id column for each table - for _, tbl := range tables { - tblName := common.UniqueTable(schemaName, tbl.Name.O) - autoIDInfos, err := FetchTableAutoIDInfos(ctx, tx, tblName) - if err != nil { - logger.Warn("fetch table auto ID infos error. Ignore this table and continue.", zap.String("table_name", tblName), zap.Error(err)) - continue - } - for _, info := range autoIDInfos { - for _, col := range tbl.Columns { - if col.Name.O == info.Column { - switch info.Type { - case "AUTO_INCREMENT": - col.AddFlag(mysql.AutoIncrementFlag) - case "AUTO_RANDOM": - col.AddFlag(mysql.PriKeyFlag) - tbl.PKIsHandle = true - // set a stub here, since we don't really need the real value - tbl.AutoRandomBits = 1 + failpoint.Inject( + "FetchRemoteTableModels_BeforeFetchTableAutoIDInfos", + func() { + fmt.Println("failpoint: FetchRemoteTableModels_BeforeFetchTableAutoIDInfos") + }, + ) + + // init auto id column for each table + for idx := start; idx <= tableIdx; idx++ { + tbl := tableInfos[idx] + tblName := common.UniqueTable(schemaName, tbl.Name.O) + autoIDInfos, err := FetchTableAutoIDInfos(ctx, tx, tblName) + if err != nil { + logger.Warn( + "fetch table auto ID infos error. Ignore this table and continue.", + zap.String("table_name", tblName), + zap.Error(err), + ) + tableInfos[idx] = nil + continue + } + for _, info := range autoIDInfos { + for _, col := range tbl.Columns { + if col.Name.O == info.Column { + switch info.Type { + case "AUTO_INCREMENT": + col.AddFlag(mysql.AutoIncrementFlag) + case "AUTO_RANDOM": + col.AddFlag(mysql.PriKeyFlag) + tbl.PKIsHandle = true + // set a stub here, since we don't really need the real value + tbl.AutoRandomBits = 1 + } + } + } } } - } - } - results = append(results, tbl) + return nil + }) + }) + } + if err := eg.Wait(); err != nil { + return nil, errors.Trace(err) + } + + ret := make(map[string]*model.TableInfo, len(tableInfos)) + for _, tbl := range tableInfos { + if tbl != nil { + ret[tbl.Name.L] = tbl } - return nil - }) - return results, err + } + + return ret, nil } // CheckRequirements performs the check whether the backend satisfies the version requirements. diff --git a/pkg/lightning/backend/tidb/tidb_test.go b/pkg/lightning/backend/tidb/tidb_test.go index 1ee7dd0024660..e0344de49d556 100644 --- a/pkg/lightning/backend/tidb/tidb_test.go +++ b/pkg/lightning/backend/tidb/tidb_test.go @@ -334,48 +334,12 @@ func testStrictMode(t *testing.T) { require.Regexp(t, "incorrect ascii value .* for column s1$", err.Error()) } -func TestFetchRemoteTableModels_3_x(t *testing.T) { - s := createMysqlSuite(t) - defer s.TearDownTest(t) - s.mockDB.ExpectBegin() - s.mockDB.ExpectQuery("SELECT version()"). - WillReturnRows(sqlmock.NewRows([]string{"version()"}).AddRow("5.7.25-TiDB-v3.0.18")) - s.mockDB.ExpectQuery("\\QSELECT table_name, column_name, column_type, generation_expression, extra FROM information_schema.columns WHERE table_schema = ? ORDER BY table_name, ordinal_position;\\E"). - WithArgs("test"). - WillReturnRows(sqlmock.NewRows([]string{"table_name", "column_name", "column_type", "generation_expression", "extra"}). - AddRow("t", "id", "int(10)", "", "auto_increment")) - s.mockDB.ExpectCommit() - - targetInfoGetter := tidb.NewTargetInfoGetter(s.dbHandle) - tableInfos, err := targetInfoGetter.FetchRemoteTableModels(context.Background(), "test") - require.NoError(t, err) - ft := types.FieldType{} - ft.SetFlag(mysql.AutoIncrementFlag) - require.Equal(t, []*model.TableInfo{ - { - Name: pmodel.NewCIStr("t"), - State: model.StatePublic, - PKIsHandle: true, - Columns: []*model.ColumnInfo{ - { - Name: pmodel.NewCIStr("id"), - Offset: 0, - State: model.StatePublic, - FieldType: ft, - }, - }, - }, - }, tableInfos) -} - func TestFetchRemoteTableModels_4_0(t *testing.T) { s := createMysqlSuite(t) defer s.TearDownTest(t) s.mockDB.ExpectBegin() - s.mockDB.ExpectQuery("SELECT version()"). - WillReturnRows(sqlmock.NewRows([]string{"version()"}).AddRow("5.7.25-TiDB-v4.0.0")) - s.mockDB.ExpectQuery("\\QSELECT table_name, column_name, column_type, generation_expression, extra FROM information_schema.columns WHERE table_schema = ? ORDER BY table_name, ordinal_position;\\E"). - WithArgs("test"). + s.mockDB.ExpectQuery("\\QSELECT table_name, column_name, column_type, generation_expression, extra FROM information_schema.columns WHERE table_schema = ? AND table_name IN (?) ORDER BY table_name, ordinal_position;\\E"). + WithArgs("test", "t"). WillReturnRows(sqlmock.NewRows([]string{"table_name", "column_name", "column_type", "generation_expression", "extra"}). AddRow("t", "id", "bigint(20) unsigned", "", "auto_increment")) s.mockDB.ExpectQuery("SHOW TABLE `test`.`t` NEXT_ROW_ID"). @@ -384,12 +348,12 @@ func TestFetchRemoteTableModels_4_0(t *testing.T) { s.mockDB.ExpectCommit() targetInfoGetter := tidb.NewTargetInfoGetter(s.dbHandle) - tableInfos, err := targetInfoGetter.FetchRemoteTableModels(context.Background(), "test") + tableInfos, err := targetInfoGetter.FetchRemoteTableModels(context.Background(), "test", []string{"t"}) require.NoError(t, err) ft := types.FieldType{} ft.SetFlag(mysql.AutoIncrementFlag | mysql.UnsignedFlag) - require.Equal(t, []*model.TableInfo{ - { + require.Equal(t, map[string]*model.TableInfo{ + "t": { Name: pmodel.NewCIStr("t"), State: model.StatePublic, PKIsHandle: true, @@ -409,10 +373,8 @@ func TestFetchRemoteTableModels_4_x_auto_increment(t *testing.T) { s := createMysqlSuite(t) defer s.TearDownTest(t) s.mockDB.ExpectBegin() - s.mockDB.ExpectQuery("SELECT version()"). - WillReturnRows(sqlmock.NewRows([]string{"version()"}).AddRow("5.7.25-TiDB-v4.0.7")) - s.mockDB.ExpectQuery("\\QSELECT table_name, column_name, column_type, generation_expression, extra FROM information_schema.columns WHERE table_schema = ? ORDER BY table_name, ordinal_position;\\E"). - WithArgs("test"). + s.mockDB.ExpectQuery("\\QSELECT table_name, column_name, column_type, generation_expression, extra FROM information_schema.columns WHERE table_schema = ? AND table_name IN (?) ORDER BY table_name, ordinal_position;\\E"). + WithArgs("test", "t"). WillReturnRows(sqlmock.NewRows([]string{"table_name", "column_name", "column_type", "generation_expression", "extra"}). AddRow("t", "id", "bigint(20)", "", "")) s.mockDB.ExpectQuery("SHOW TABLE `test`.`t` NEXT_ROW_ID"). @@ -421,12 +383,12 @@ func TestFetchRemoteTableModels_4_x_auto_increment(t *testing.T) { s.mockDB.ExpectCommit() targetInfoGetter := tidb.NewTargetInfoGetter(s.dbHandle) - tableInfos, err := targetInfoGetter.FetchRemoteTableModels(context.Background(), "test") + tableInfos, err := targetInfoGetter.FetchRemoteTableModels(context.Background(), "test", []string{"t"}) require.NoError(t, err) ft := types.FieldType{} ft.SetFlag(mysql.AutoIncrementFlag) - require.Equal(t, []*model.TableInfo{ - { + require.Equal(t, map[string]*model.TableInfo{ + "t": { Name: pmodel.NewCIStr("t"), State: model.StatePublic, PKIsHandle: true, @@ -446,10 +408,8 @@ func TestFetchRemoteTableModels_4_x_auto_random(t *testing.T) { s := createMysqlSuite(t) defer s.TearDownTest(t) s.mockDB.ExpectBegin() - s.mockDB.ExpectQuery("SELECT version()"). - WillReturnRows(sqlmock.NewRows([]string{"version()"}).AddRow("5.7.25-TiDB-v4.0.7")) - s.mockDB.ExpectQuery("\\QSELECT table_name, column_name, column_type, generation_expression, extra FROM information_schema.columns WHERE table_schema = ? ORDER BY table_name, ordinal_position;\\E"). - WithArgs("test"). + s.mockDB.ExpectQuery("\\QSELECT table_name, column_name, column_type, generation_expression, extra FROM information_schema.columns WHERE table_schema = ? AND table_name IN (?) ORDER BY table_name, ordinal_position;\\E"). + WithArgs("test", "t"). WillReturnRows(sqlmock.NewRows([]string{"table_name", "column_name", "column_type", "generation_expression", "extra"}). AddRow("t", "id", "bigint(20)", "1 + 2", "")) s.mockDB.ExpectQuery("SHOW TABLE `test`.`t` NEXT_ROW_ID"). @@ -458,12 +418,12 @@ func TestFetchRemoteTableModels_4_x_auto_random(t *testing.T) { s.mockDB.ExpectCommit() targetInfoGetter := tidb.NewTargetInfoGetter(s.dbHandle) - tableInfos, err := targetInfoGetter.FetchRemoteTableModels(context.Background(), "test") + tableInfos, err := targetInfoGetter.FetchRemoteTableModels(context.Background(), "test", []string{"t"}) require.NoError(t, err) ft := types.FieldType{} ft.SetFlag(mysql.PriKeyFlag) - require.Equal(t, []*model.TableInfo{ - { + require.Equal(t, map[string]*model.TableInfo{ + "t": { Name: pmodel.NewCIStr("t"), State: model.StatePublic, PKIsHandle: true, @@ -485,10 +445,8 @@ func TestFetchRemoteTableModelsDropTableHalfway(t *testing.T) { s := createMysqlSuite(t) defer s.TearDownTest(t) s.mockDB.ExpectBegin() - s.mockDB.ExpectQuery("SELECT tidb_version()"). - WillReturnRows(sqlmock.NewRows([]string{"tidb_version()"}).AddRow(`Release Version: v99.0.0`)) // this is a fake version number - s.mockDB.ExpectQuery("\\QSELECT table_name, column_name, column_type, generation_expression, extra FROM information_schema.columns WHERE table_schema = ? ORDER BY table_name, ordinal_position;\\E"). - WithArgs("test"). + s.mockDB.ExpectQuery("\\QSELECT table_name, column_name, column_type, generation_expression, extra FROM information_schema.columns WHERE table_schema = ? AND table_name IN (?,?) ORDER BY table_name, ordinal_position;\\E"). + WithArgs("test", "tbl01", "tbl02"). WillReturnRows(sqlmock.NewRows([]string{"table_name", "column_name", "column_type", "generation_expression", "extra"}). AddRow("tbl01", "id", "bigint(20)", "", "auto_increment"). AddRow("tbl01", "val", "varchar(255)", "", ""). @@ -503,12 +461,16 @@ func TestFetchRemoteTableModelsDropTableHalfway(t *testing.T) { s.mockDB.ExpectCommit() infoGetter := tidb.NewTargetInfoGetter(s.dbHandle) - tableInfos, err := infoGetter.FetchRemoteTableModels(context.Background(), "test") + tableInfos, err := infoGetter.FetchRemoteTableModels( + context.Background(), + "test", + []string{"tbl01", "tbl02"}, + ) require.NoError(t, err) ft := types.FieldType{} ft.SetFlag(mysql.AutoIncrementFlag) - require.Equal(t, []*model.TableInfo{ - { + require.Equal(t, map[string]*model.TableInfo{ + "tbl01": { Name: pmodel.NewCIStr("tbl01"), State: model.StatePublic, PKIsHandle: true, @@ -529,6 +491,78 @@ func TestFetchRemoteTableModelsDropTableHalfway(t *testing.T) { }, tableInfos) } +func TestFetchRemoteTableModelsConcurrency(t *testing.T) { + backupConcurrency := tidb.FetchRemoteTableModelsConcurrency + tidb.FetchRemoteTableModelsConcurrency = 2 + backupBatchSize := tidb.FetchRemoteTableModelsBatchSize + tidb.FetchRemoteTableModelsBatchSize = 3 + t.Cleanup(func() { + tidb.FetchRemoteTableModelsConcurrency = backupConcurrency + tidb.FetchRemoteTableModelsBatchSize = backupBatchSize + }) + + s := createMysqlSuite(t) + defer s.TearDownTest(t) + s.mockDB.MatchExpectationsInOrder(false) + + s.mockDB.ExpectBegin() + s.mockDB.ExpectQuery("\\QSELECT table_name, column_name, column_type, generation_expression, extra FROM information_schema.columns WHERE table_schema = ? AND table_name IN (?,?,?) ORDER BY table_name, ordinal_position;\\E"). + WithArgs("test", "t1", "t2", "t3"). + WillReturnRows(sqlmock.NewRows([]string{"table_name", "column_name", "column_type", "generation_expression", "extra"}). + AddRow("t1", "id", "bigint(20)", "", ""). + AddRow("t2", "id", "bigint(20)", "", "")) + s.mockDB.ExpectQuery("SHOW TABLE `test`.`t1` NEXT_ROW_ID"). + WillReturnRows(sqlmock.NewRows([]string{"DB_NAME", "TABLE_NAME", "COLUMN_NAME", "NEXT_GLOBAL_ROW_ID", "ID_TYPE"}). + AddRow("test", "t1", "id", int64(1), "_TIDB_ROWID")) + s.mockDB.ExpectQuery("SHOW TABLE `test`.`t2` NEXT_ROW_ID"). + WillReturnRows(sqlmock.NewRows([]string{"DB_NAME", "TABLE_NAME", "COLUMN_NAME", "NEXT_GLOBAL_ROW_ID", "ID_TYPE"}). + AddRow("test", "t2", "id", int64(1), "_TIDB_ROWID")) + s.mockDB.ExpectCommit() + + s.mockDB.ExpectBegin() + s.mockDB.ExpectQuery("\\QSELECT table_name, column_name, column_type, generation_expression, extra FROM information_schema.columns WHERE table_schema = ? AND table_name IN (?,?,?) ORDER BY table_name, ordinal_position;\\E"). + WithArgs("test", "t4", "t5", "t6"). + WillReturnRows(sqlmock.NewRows([]string{"table_name", "column_name", "column_type", "generation_expression", "extra"}). + AddRow("t4", "id", "bigint(20)", "", ""). + AddRow("t6", "id", "bigint(20)", "", "")) + s.mockDB.ExpectQuery("SHOW TABLE `test`.`t4` NEXT_ROW_ID"). + WillReturnRows(sqlmock.NewRows([]string{"DB_NAME", "TABLE_NAME", "COLUMN_NAME", "NEXT_GLOBAL_ROW_ID", "ID_TYPE"}). + AddRow("test", "t4", "id", int64(1), "_TIDB_ROWID")) + s.mockDB.ExpectQuery("SHOW TABLE `test`.`t6` NEXT_ROW_ID"). + WillReturnRows(sqlmock.NewRows([]string{"DB_NAME", "TABLE_NAME", "COLUMN_NAME", "NEXT_GLOBAL_ROW_ID", "ID_TYPE"}). + AddRow("test", "t6", "id", int64(1), "_TIDB_ROWID")) + s.mockDB.ExpectCommit() + + s.mockDB.ExpectBegin() + s.mockDB.ExpectQuery("\\QSELECT table_name, column_name, column_type, generation_expression, extra FROM information_schema.columns WHERE table_schema = ? AND table_name IN (?,?,?) ORDER BY table_name, ordinal_position;\\E"). + WithArgs("test", "t7", "t8", "t9"). + WillReturnRows(sqlmock.NewRows([]string{"table_name", "column_name", "column_type", "generation_expression", "extra"}). + AddRow("t8", "id", "bigint(20)", "", ""). + AddRow("t9", "id", "bigint(20)", "", "")) + s.mockDB.ExpectQuery("SHOW TABLE `test`.`t8` NEXT_ROW_ID"). + WillReturnRows(sqlmock.NewRows([]string{"DB_NAME", "TABLE_NAME", "COLUMN_NAME", "NEXT_GLOBAL_ROW_ID", "ID_TYPE"}). + AddRow("test", "t8", "id", int64(1), "_TIDB_ROWID")) + s.mockDB.ExpectQuery("SHOW TABLE `test`.`t9` NEXT_ROW_ID"). + WillReturnRows(sqlmock.NewRows([]string{"DB_NAME", "TABLE_NAME", "COLUMN_NAME", "NEXT_GLOBAL_ROW_ID", "ID_TYPE"}). + AddRow("test", "t9", "id", int64(1), "_TIDB_ROWID")) + s.mockDB.ExpectCommit() + + infoGetter := tidb.NewTargetInfoGetter(s.dbHandle) + tableInfos, err := infoGetter.FetchRemoteTableModels( + context.Background(), + "test", + []string{"t1", "t2", "t3", "t4", "t5", "t6", "t7", "t8", "t9"}, + ) + require.NoError(t, err) + require.Len(t, tableInfos, 6) + require.Contains(t, tableInfos, "t1") + require.Contains(t, tableInfos, "t2") + require.Contains(t, tableInfos, "t4") + require.Contains(t, tableInfos, "t6") + require.Contains(t, tableInfos, "t8") + require.Contains(t, tableInfos, "t9") +} + func TestWriteRowsErrorNoRetry(t *testing.T) { nonRetryableError := sql.ErrNoRows s := createMysqlSuite(t)