diff --git a/v4/export/dump.go b/v4/export/dump.go index c2f2bcd1..6b70ed62 100755 --- a/v4/export/dump.go +++ b/v4/export/dump.go @@ -461,7 +461,7 @@ func (d *Dumper) concurrentDumpTable(tctx *tcontext.Context, conn *sql.Conn, met return err } - field, err := pickupPossibleField(db, tbl, conn, conf) + field, err := pickupPossibleField(meta, conn, conf) if err != nil || field == "" { // skip split chunk logic if not found proper field tctx.L().Warn("fallback to sequential dump due to no proper field", diff --git a/v4/export/sql.go b/v4/export/sql.go index 586c5842..6cb88f9a 100644 --- a/v4/export/sql.go +++ b/v4/export/sql.go @@ -363,28 +363,34 @@ func GetPrimaryKeyColumns(db *sql.Conn, database, table string) ([]string, error return cols, nil } -// GetPrimaryKeyName try to get a numeric primary index -func GetPrimaryKeyName(db *sql.Conn, database, table string) (string, error) { - return getNumericIndex(db, database, table, "PRI") -} - -// GetUniqueIndexName try to get a numeric unique index -func GetUniqueIndexName(db *sql.Conn, database, table string) (string, error) { - return getNumericIndex(db, database, table, "UNI") -} - -func getNumericIndex(db *sql.Conn, database, table, indexType string) (string, error) { - keyQuery := "SELECT column_name FROM information_schema.columns " + - "WHERE table_schema = ? AND table_name = ? AND column_key = ? AND data_type IN ('int', 'bigint');" - var colName string - row := db.QueryRowContext(context.Background(), keyQuery, database, table, indexType) - err := row.Scan(&colName) - if errors.Cause(err) == sql.ErrNoRows { - return "", nil - } else if err != nil { - return "", errors.Annotatef(err, "sql: %s, indexType: %s", keyQuery, indexType) +func getNumericIndex(db *sql.Conn, meta TableMeta) (string, error) { + database, table := meta.DatabaseName(), meta.TableName() + colNames, colTypes := meta.ColumnNames(), meta.ColumnTypes() + colName2Type := make(map[string]string, len(colNames)) + for i := range colNames { + colName2Type[colNames[i]] = colTypes[i] + } + keyQuery := fmt.Sprintf("SHOW INDEX FROM `%s`.`%s`", escapeString(database), escapeString(table)) + rows, err := db.QueryContext(context.Background(), keyQuery) + if err != nil { + return "", errors.Annotatef(err, "sql: %s", keyQuery) + } + results, err := GetSpecifiedColumnValuesAndClose(rows, "NON_UNIQUE", "KEY_NAME", "COLUMN_NAME") + if err != nil { + return "", errors.Annotatef(err, "sql: %s", keyQuery) + } + uniqueColumnName := "" + // check primary key first, then unique key + for _, oneRow := range results { + var ok bool + if _, ok = dataTypeNum[colName2Type[oneRow[2]]]; ok && oneRow[1] == "PRIMARY" { + return oneRow[2], nil + } + if uniqueColumnName != "" && oneRow[0] == "0" && ok { + uniqueColumnName = oneRow[2] + } } - return colName, nil + return uniqueColumnName, nil } // FlushTableWithReadLock flush tables with read lock @@ -465,6 +471,51 @@ func GetSpecifiedColumnValueAndClose(rows *sql.Rows, columnName string) ([]strin return strs, errors.Trace(rows.Err()) } +// GetSpecifiedColumnValuesAndClose get columns' values whose name is equal to columnName +func GetSpecifiedColumnValuesAndClose(rows *sql.Rows, columnName ...string) ([][]string, error) { + if rows == nil { + return [][]string{}, nil + } + defer rows.Close() + var strs [][]string + columns, err := rows.Columns() + if err != nil { + return strs, errors.Trace(err) + } + addr := make([]interface{}, len(columns)) + oneRow := make([]sql.NullString, len(columns)) + fieldIndexMp := make(map[int]int) + for i, col := range columns { + addr[i] = &oneRow[i] + for j, name := range columnName { + if strings.ToUpper(col) == name { + fieldIndexMp[i] = j + } + } + } + if len(fieldIndexMp) == 0 { + return strs, nil + } + for rows.Next() { + err := rows.Scan(addr...) + if err != nil { + return strs, errors.Trace(err) + } + written := false + tmpStr := make([]string, len(columnName)) + for colPos, namePos := range fieldIndexMp { + if oneRow[colPos].Valid { + written = true + tmpStr[namePos] = oneRow[colPos].String + } + } + if written { + strs = append(strs, tmpStr) + } + } + return strs, errors.Trace(rows.Err()) +} + // GetPdAddrs gets PD address from TiDB func GetPdAddrs(tctx *tcontext.Context, db *sql.DB) ([]string, error) { query := "SELECT * FROM information_schema.cluster_info where type = 'pd';" @@ -822,7 +873,8 @@ func simpleQueryWithArgs(conn *sql.Conn, handleOneRow func(*sql.Rows) error, sql return errors.Annotatef(rows.Err(), "sql: %s", sql) } -func pickupPossibleField(dbName, tableName string, db *sql.Conn, conf *Config) (string, error) { +func pickupPossibleField(meta TableMeta, db *sql.Conn, conf *Config) (string, error) { + dbName, tableName := meta.DatabaseName(), meta.TableName() // If detected server is TiDB, try using _tidb_rowid if conf.ServerInfo.ServerType == ServerTypeTiDB { ok, err := SelectTiDBRowID(db, dbName, tableName) @@ -834,17 +886,10 @@ func pickupPossibleField(dbName, tableName string, db *sql.Conn, conf *Config) ( } } // try to use pk - fieldName, err := GetPrimaryKeyName(db, dbName, tableName) + fieldName, err := getNumericIndex(db, meta) if err != nil { return "", err } - // try to use first uniqueIndex - if fieldName == "" { - fieldName, err = GetUniqueIndexName(db, dbName, tableName) - if err != nil { - return "", err - } - } // if fieldName == "", there is no proper index return fieldName, nil diff --git a/v4/export/status.go b/v4/export/status.go index 909469c0..608b553c 100644 --- a/v4/export/status.go +++ b/v4/export/status.go @@ -67,12 +67,7 @@ func (d *Dumper) getEstimateTotalRowsCount(tctx *tcontext.Context, conn *sql.Con for db, tables := range conf.Tables { for _, m := range tables { if m.Type == TableTypeBase { - // get pk or uk for explain - field, err := pickupPossibleField(db, m.Name, conn, conf) - if err != nil { - return err - } - c := estimateCount(tctx, db, m.Name, conn, field, conf) + c := estimateCount(tctx, db, m.Name, conn, "", conf) totalCount += c } }