From feb4b14ebfa46f3ea49306f84106d6f32c5141b8 Mon Sep 17 00:00:00 2001 From: glorv Date: Fri, 5 Nov 2021 19:29:03 +0800 Subject: [PATCH] br/pkg/lightning: make tidb version check more compitible (#29505) --- br/pkg/lightning/backend/local/local.go | 12 +-- br/pkg/lightning/backend/tidb/tidb.go | 9 +- br/pkg/lightning/restore/restore.go | 11 +- br/pkg/lightning/restore/restore_test.go | 8 +- br/pkg/version/version.go | 129 ++++++++++++++++++++++- 5 files changed, 142 insertions(+), 27 deletions(-) diff --git a/br/pkg/lightning/backend/local/local.go b/br/pkg/lightning/backend/local/local.go index c04a0b9278282..39b8cb60f899e 100644 --- a/br/pkg/lightning/backend/local/local.go +++ b/br/pkg/lightning/backend/local/local.go @@ -2314,11 +2314,9 @@ func (local *local) CleanupEngine(ctx context.Context, engineUUID uuid.UUID) err } func (local *local) CheckRequirements(ctx context.Context, checkCtx *backend.CheckCtx) error { - versionStr, err := local.g.GetSQLExecutor().ObtainStringWithLog( - ctx, - "SELECT version();", - "check TiDB version", - log.L()) + // TODO: support lightning via SQL + db, _ := local.g.GetDB() + versionStr, err := version.FetchVersion(ctx, db) if err != nil { return errors.Trace(err) } @@ -2332,8 +2330,8 @@ func (local *local) CheckRequirements(ctx context.Context, checkCtx *backend.Che return err } - tidbVersion, _ := version.ExtractTiDBVersion(versionStr) - return checkTiFlashVersion(ctx, local.g, checkCtx, *tidbVersion) + serverInfo := version.ParseServerInfo(versionStr) + return checkTiFlashVersion(ctx, local.g, checkCtx, *serverInfo.ServerVersion) } func checkTiDBVersion(_ context.Context, versionStr string, requiredMinVersion, requiredMaxVersion semver.Version) error { diff --git a/br/pkg/lightning/backend/tidb/tidb.go b/br/pkg/lightning/backend/tidb/tidb.go index db6c4c47b8544..5a60e7f8fd589 100644 --- a/br/pkg/lightning/backend/tidb/tidb.go +++ b/br/pkg/lightning/backend/tidb/tidb.go @@ -450,13 +450,10 @@ func (be *tidbBackend) FetchRemoteTableModels(ctx context.Context, schemaName st err = s.Transact(ctx, "fetch table columns", func(c context.Context, tx *sql.Tx) error { var versionStr string - if err = tx.QueryRowContext(ctx, "SELECT version()").Scan(&versionStr); err != nil { - return err - } - tidbVersion, err := version.ExtractTiDBVersion(versionStr) - if err != nil { + if versionStr, err = version.FetchVersion(ctx, tx); err != nil { return err } + serverInfo := version.ParseServerInfo(versionStr) rows, e := tx.Query(` SELECT table_name, column_name, column_type, extra @@ -513,7 +510,7 @@ func (be *tidbBackend) FetchRemoteTableModels(ctx context.Context, schemaName st } // shard_row_id/auto random is only available after tidb v4.0.0 // `show table next_row_id` is also not available before tidb v4.0.0 - if tidbVersion.Major < 4 { + if serverInfo.ServerType != version.ServerTypeTiDB || serverInfo.ServerVersion.Major < 4 { return nil } diff --git a/br/pkg/lightning/restore/restore.go b/br/pkg/lightning/restore/restore.go index 0478dc99ad127..b323c5025d9c9 100644 --- a/br/pkg/lightning/restore/restore.go +++ b/br/pkg/lightning/restore/restore.go @@ -1393,16 +1393,13 @@ func (tr *TableRestore) restoreTable( zap.Int("filesCnt", cp.CountChunks()), ) } else if cp.Status < checkpoints.CheckpointStatusAllWritten { - versionStr, err := rc.tidbGlue.GetSQLExecutor().ObtainStringWithLog( - ctx, "SELECT version()", "fetch tidb version", log.L()) + db, _ := rc.tidbGlue.GetDB() + versionStr, err := version.FetchVersion(ctx, db) if err != nil { return false, errors.Trace(err) } - tidbVersion, err := version.ExtractTiDBVersion(versionStr) - if err != nil { - return false, errors.Trace(err) - } + versionInfo := version.ParseServerInfo(versionStr) if err := tr.populateChunks(ctx, rc, cp); err != nil { return false, errors.Trace(err) @@ -1417,7 +1414,7 @@ func (tr *TableRestore) restoreTable( } // "show table next_row_id" is only available after v4.0.0 - if tidbVersion.Major >= 4 && (rc.cfg.TikvImporter.Backend == config.BackendLocal || rc.cfg.TikvImporter.Backend == config.BackendImporter) { + if versionInfo.ServerVersion.Major >= 4 && (rc.cfg.TikvImporter.Backend == config.BackendLocal || rc.cfg.TikvImporter.Backend == config.BackendImporter) { // first, insert a new-line into meta table if err = metaMgr.InitTableMeta(ctx); err != nil { return false, err diff --git a/br/pkg/lightning/restore/restore_test.go b/br/pkg/lightning/restore/restore_test.go index da3d38a15477c..6c3cddb5e9fc0 100644 --- a/br/pkg/lightning/restore/restore_test.go +++ b/br/pkg/lightning/restore/restore_test.go @@ -969,8 +969,12 @@ func (s *tableRestoreSuite) TestTableRestoreMetrics(c *C) { }() exec := mock.NewMockSQLExecutor(controller) g.EXPECT().GetSQLExecutor().Return(exec).AnyTimes() - exec.EXPECT().ObtainStringWithLog(gomock.Any(), "SELECT version()", gomock.Any(), gomock.Any()). - Return("5.7.25-TiDB-v5.0.1", nil).AnyTimes() + db, mock, err := sqlmock.New() + c.Assert(err, IsNil) + g.EXPECT().GetDB().Return(db, nil).AnyTimes() + mock.ExpectQuery("SELECT tidb_version\\(\\);"). + WillReturnRows(sqlmock.NewRows([]string{"tidb_version"}). + AddRow("Release Version: v5.2.1\nEdition: Community\n")) web.BroadcastInitProgress(rc.dbMetas) diff --git a/br/pkg/version/version.go b/br/pkg/version/version.go index 10d24d5f4bb22..70f990d17eb12 100644 --- a/br/pkg/version/version.go +++ b/br/pkg/version/version.go @@ -14,6 +14,8 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/log" berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/lightning/common" + "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/br/pkg/version/build" pd "github.com/tikv/pd/client" "go.uber.org/zap" @@ -213,13 +215,13 @@ func ExtractTiDBVersion(version string) (*semver.Version, error) { return semver.NewVersion(rawVersion) } -// CheckTiDBVersion is equals to ExtractTiDBVersion followed by CheckVersion. +// CheckTiDBVersion is equals to ParseServerInfo followed by CheckVersion. func CheckTiDBVersion(versionStr string, requiredMinVersion, requiredMaxVersion semver.Version) error { - version, err := ExtractTiDBVersion(versionStr) - if err != nil { - return errors.Trace(err) + serverInfo := ParseServerInfo(versionStr) + if serverInfo.ServerType != ServerTypeTiDB { + return errors.Errorf("server with version '%s' is not TiDB", versionStr) } - return CheckVersion("TiDB", *version, requiredMinVersion, requiredMaxVersion) + return CheckVersion("TiDB", *serverInfo.ServerVersion, requiredMinVersion, requiredMaxVersion) } // NormalizeBackupVersion normalizes the version string from backupmeta. @@ -238,3 +240,120 @@ func NormalizeBackupVersion(version string) *semver.Version { } return ver } + +// FetchVersion gets the version information from the database server +// +// NOTE: the executed query will be: +// - `select tidb_version()` if target db is tidb +// - `select version()` if target db is not tidb +func FetchVersion(ctx context.Context, db common.QueryExecutor) (string, error) { + var versionInfo string + const queryTiDB = "SELECT tidb_version();" + tidbRow := db.QueryRowContext(ctx, queryTiDB) + err := tidbRow.Scan(&versionInfo) + if err == nil { + return versionInfo, nil + } + log.L().Warn("select tidb_version() failed, will fallback to 'select version();'", logutil.ShortError(err)) + const query = "SELECT version();" + row := db.QueryRowContext(ctx, query) + err = row.Scan(&versionInfo) + if err != nil { + return "", errors.Annotatef(err, "sql: %s", query) + } + return versionInfo, nil +} + +type ServerType int + +const ( + // ServerTypeUnknown represents unknown server type + ServerTypeUnknown = iota + // ServerTypeMySQL represents MySQL server type + ServerTypeMySQL + // ServerTypeMariaDB represents MariaDB server type + ServerTypeMariaDB + // ServerTypeTiDB represents TiDB server type + ServerTypeTiDB + + // ServerTypeAll represents All server types + ServerTypeAll +) + +var serverTypeString = []string{ + ServerTypeUnknown: "Unknown", + ServerTypeMySQL: "MySQL", + ServerTypeMariaDB: "MariaDB", + ServerTypeTiDB: "TiDB", +} + +// String implements Stringer.String +func (s ServerType) String() string { + if s >= ServerTypeAll { + return "" + } + return serverTypeString[s] +} + +// ServerInfo is the combination of ServerType and ServerInfo +type ServerInfo struct { + ServerType ServerType + ServerVersion *semver.Version +} + +var ( + mysqlVersionRegex = regexp.MustCompile(`^\d+\.\d+\.\d+([0-9A-Za-z-]+(\.[0-9A-Za-z-]+)*)?`) + // `select version()` result + tidbVersionRegex = regexp.MustCompile(`-[v]?\d+\.\d+\.\d+([0-9A-Za-z-]+(\.[0-9A-Za-z-]+)*)?`) + // `select tidb_version()` result + tidbReleaseVersionRegex = regexp.MustCompile(`v\d+\.\d+\.\d+([0-9A-Za-z-]+(\.[0-9A-Za-z-]+)*)?`) +) + +// ParseServerInfo parses exported server type and version info from version string +func ParseServerInfo(src string) ServerInfo { + lowerCase := strings.ToLower(src) + serverInfo := ServerInfo{} + isReleaseVersion := false + switch { + case strings.Contains(lowerCase, "release version:"): + // this version string is tidb release version + serverInfo.ServerType = ServerTypeTiDB + isReleaseVersion = true + case strings.Contains(lowerCase, "tidb"): + serverInfo.ServerType = ServerTypeTiDB + case strings.Contains(lowerCase, "mariadb"): + serverInfo.ServerType = ServerTypeMariaDB + case mysqlVersionRegex.MatchString(lowerCase): + serverInfo.ServerType = ServerTypeMySQL + default: + serverInfo.ServerType = ServerTypeUnknown + } + + var versionStr string + if serverInfo.ServerType == ServerTypeTiDB { + if isReleaseVersion { + versionStr = tidbReleaseVersionRegex.FindString(src) + } else { + versionStr = tidbVersionRegex.FindString(src)[1:] + } + versionStr = strings.TrimPrefix(versionStr, "v") + } else { + versionStr = mysqlVersionRegex.FindString(src) + } + + var err error + serverInfo.ServerVersion, err = semver.NewVersion(versionStr) + if err != nil { + log.L().Warn("fail to parse version", + zap.String("version", versionStr)) + } + var version string + if serverInfo.ServerVersion != nil { + version = serverInfo.ServerVersion.String() + } + log.L().Info("detect server version", + zap.String("type", serverInfo.ServerType.String()), + zap.String("version", version)) + + return serverInfo +}