Skip to content

Commit

Permalink
Merge branch 'master' into dup-error-table-name
Browse files Browse the repository at this point in the history
  • Loading branch information
ekexium authored Oct 19, 2022
2 parents e8b600e + ff5fd4e commit f8c8500
Show file tree
Hide file tree
Showing 16 changed files with 258 additions and 147 deletions.
20 changes: 18 additions & 2 deletions br/pkg/lightning/checkpoints/checkpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,15 @@ func OpenCheckpointsDB(ctx context.Context, cfg *config.Config) (DB, error) {

switch cfg.Checkpoint.Driver {
case config.CheckpointDriverMySQL:
db, err := common.ConnectMySQL(cfg.Checkpoint.DSN)
var (
db *sql.DB
err error
)
if cfg.Checkpoint.MySQLParam != nil {
db, err = cfg.Checkpoint.MySQLParam.Connect()
} else {
db, err = sql.Open("mysql", cfg.Checkpoint.DSN)
}
if err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -546,7 +554,15 @@ func IsCheckpointsDBExists(ctx context.Context, cfg *config.Config) (bool, error
}
switch cfg.Checkpoint.Driver {
case config.CheckpointDriverMySQL:
db, err := sql.Open("mysql", cfg.Checkpoint.DSN)
var (
db *sql.DB
err error
)
if cfg.Checkpoint.MySQLParam != nil {
db, err = cfg.Checkpoint.MySQLParam.Connect()
} else {
db, err = sql.Open("mysql", cfg.Checkpoint.DSN)
}
if err != nil {
return false, errors.Trace(err)
}
Expand Down
51 changes: 28 additions & 23 deletions br/pkg/lightning/common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"io"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
Expand Down Expand Up @@ -58,28 +57,38 @@ type MySQLConnectParam struct {
Vars map[string]string
}

func (param *MySQLConnectParam) ToDSN() string {
hostPort := net.JoinHostPort(param.Host, strconv.Itoa(param.Port))
dsn := fmt.Sprintf("%s:%s@tcp(%s)/?charset=utf8mb4&sql_mode='%s'&maxAllowedPacket=%d&tls=%s",
param.User, param.Password, hostPort,
param.SQLMode, param.MaxAllowedPacket, param.TLS)
func (param *MySQLConnectParam) ToDriverConfig() *mysql.Config {
cfg := mysql.NewConfig()
cfg.Params = make(map[string]string)

cfg.User = param.User
cfg.Passwd = param.Password
cfg.Net = "tcp"
cfg.Addr = net.JoinHostPort(param.Host, strconv.Itoa(param.Port))
cfg.Params["charset"] = "utf8mb4"
cfg.Params["sql_mode"] = fmt.Sprintf("'%s'", param.SQLMode)
cfg.MaxAllowedPacket = int(param.MaxAllowedPacket)
cfg.TLSConfig = param.TLS

for k, v := range param.Vars {
dsn += fmt.Sprintf("&%s='%s'", k, url.QueryEscape(v))
cfg.Params[k] = fmt.Sprintf("'%s'", v)
}

return dsn
return cfg
}

func tryConnectMySQL(dsn string) (*sql.DB, error) {
driverName := "mysql"
failpoint.Inject("MockMySQLDriver", func(val failpoint.Value) {
driverName = val.(string)
func tryConnectMySQL(cfg *mysql.Config) (*sql.DB, error) {
failpoint.Inject("MustMySQLPassword", func(val failpoint.Value) {
pwd := val.(string)
if cfg.Passwd != pwd {
failpoint.Return(nil, &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"})
}
failpoint.Return(nil, nil)
})
db, err := sql.Open(driverName, dsn)
c, err := mysql.NewConnector(cfg)
if err != nil {
return nil, errors.Trace(err)
}
db := sql.OpenDB(c)
if err = db.Ping(); err != nil {
_ = db.Close()
return nil, errors.Trace(err)
Expand All @@ -89,13 +98,9 @@ func tryConnectMySQL(dsn string) (*sql.DB, error) {

// ConnectMySQL connects MySQL with the dsn. If access is denied and the password is a valid base64 encoding,
// we will try to connect MySQL with the base64 decoding of the password.
func ConnectMySQL(dsn string) (*sql.DB, error) {
cfg, err := mysql.ParseDSN(dsn)
if err != nil {
return nil, errors.Trace(err)
}
func ConnectMySQL(cfg *mysql.Config) (*sql.DB, error) {
// Try plain password first.
db, firstErr := tryConnectMySQL(dsn)
db, firstErr := tryConnectMySQL(cfg)
if firstErr == nil {
return db, nil
}
Expand All @@ -104,9 +109,9 @@ func ConnectMySQL(dsn string) (*sql.DB, error) {
// If password is encoded by base64, try the decoded string as well.
if password, decodeErr := base64.StdEncoding.DecodeString(cfg.Passwd); decodeErr == nil && string(password) != cfg.Passwd {
cfg.Passwd = string(password)
db, err = tryConnectMySQL(cfg.FormatDSN())
db2, err := tryConnectMySQL(cfg)
if err == nil {
return db, nil
return db2, nil
}
}
}
Expand All @@ -115,7 +120,7 @@ func ConnectMySQL(dsn string) (*sql.DB, error) {
}

func (param *MySQLConnectParam) Connect() (*sql.DB, error) {
db, err := ConnectMySQL(param.ToDSN())
db, err := ConnectMySQL(param.ToDriverConfig())
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
69 changes: 5 additions & 64 deletions br/pkg/lightning/common/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,12 @@ package common_test

import (
"context"
"database/sql"
"database/sql/driver"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"

Expand All @@ -35,7 +31,6 @@ import (
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/br/pkg/lightning/common"
"github.com/pingcap/tidb/br/pkg/lightning/log"
tmysql "github.com/pingcap/tidb/errno"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -85,66 +80,14 @@ func TestGetJSON(t *testing.T) {
require.Regexp(t, ".*http status code != 200.*", err.Error())
}

func TestToDSN(t *testing.T) {
param := common.MySQLConnectParam{
Host: "127.0.0.1",
Port: 4000,
User: "root",
Password: "123456",
SQLMode: "strict",
MaxAllowedPacket: 1234,
TLS: "cluster",
Vars: map[string]string{
"tidb_distsql_scan_concurrency": "1",
},
}
require.Equal(t, "root:123456@tcp(127.0.0.1:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'", param.ToDSN())

param.Host = "::1"
require.Equal(t, "root:123456@tcp([::1]:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'", param.ToDSN())
}

type mockDriver struct {
driver.Driver
plainPsw string
}

func (m *mockDriver) Open(dsn string) (driver.Conn, error) {
cfg, err := mysql.ParseDSN(dsn)
if err != nil {
return nil, err
}
accessDenied := cfg.Passwd != m.plainPsw
return &mockConn{accessDenied: accessDenied}, nil
}

type mockConn struct {
driver.Conn
driver.Pinger
accessDenied bool
}

func (c *mockConn) Ping(ctx context.Context) error {
if c.accessDenied {
return &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"}
}
return nil
}

func (c *mockConn) Close() error {
return nil
}

func TestConnect(t *testing.T) {
plainPsw := "dQAUoDiyb1ucWZk7"
driverName := "mysql-mock-" + strconv.Itoa(rand.Int())
sql.Register(driverName, &mockDriver{plainPsw: plainPsw})

require.NoError(t, failpoint.Enable(
"github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver",
fmt.Sprintf("return(\"%s\")", driverName)))
"github.com/pingcap/tidb/br/pkg/lightning/common/MustMySQLPassword",
fmt.Sprintf("return(\"%s\")", plainPsw)))
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver"))
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/MustMySQLPassword"))
}()

param := common.MySQLConnectParam{
Expand All @@ -155,13 +98,11 @@ func TestConnect(t *testing.T) {
SQLMode: "strict",
MaxAllowedPacket: 1234,
}
db, err := param.Connect()
_, err := param.Connect()
require.NoError(t, err)
require.NoError(t, db.Close())
param.Password = base64.StdEncoding.EncodeToString([]byte(plainPsw))
db, err = param.Connect()
_, err = param.Connect()
require.NoError(t, err)
require.NoError(t, db.Close())
}

func TestIsContextCanceledError(t *testing.T) {
Expand Down
13 changes: 7 additions & 6 deletions br/pkg/lightning/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,11 +553,12 @@ type TikvImporter struct {
}

type Checkpoint struct {
Schema string `toml:"schema" json:"schema"`
DSN string `toml:"dsn" json:"-"` // DSN may contain password, don't expose this to JSON.
Driver string `toml:"driver" json:"driver"`
Enable bool `toml:"enable" json:"enable"`
KeepAfterSuccess CheckpointKeepStrategy `toml:"keep-after-success" json:"keep-after-success"`
Schema string `toml:"schema" json:"schema"`
DSN string `toml:"dsn" json:"-"` // DSN may contain password, don't expose this to JSON.
MySQLParam *common.MySQLConnectParam `toml:"-" json:"-"` // For some security reason, we use MySQLParam instead of DSN.
Driver string `toml:"driver" json:"driver"`
Enable bool `toml:"enable" json:"enable"`
KeepAfterSuccess CheckpointKeepStrategy `toml:"keep-after-success" json:"keep-after-success"`
}

type Cron struct {
Expand Down Expand Up @@ -1142,7 +1143,7 @@ func (cfg *Config) AdjustCheckPoint() {
MaxAllowedPacket: defaultMaxAllowedPacket,
TLS: cfg.TiDB.TLS,
}
cfg.Checkpoint.DSN = param.ToDSN()
cfg.Checkpoint.MySQLParam = &param
case CheckpointDriverFile:
cfg.Checkpoint.DSN = "/tmp/" + cfg.Checkpoint.Schema + ".pb"
}
Expand Down
5 changes: 3 additions & 2 deletions br/pkg/lightning/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
"github.com/BurntSushi/toml"
"github.com/pingcap/tidb/br/pkg/lightning/common"
"github.com/pingcap/tidb/br/pkg/lightning/config"
"github.com/pingcap/tidb/parser/mysql"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -626,7 +625,9 @@ func TestLoadConfig(t *testing.T) {
taskCfg.TiDB.DistSQLScanConcurrency = 1
err = taskCfg.Adjust(context.Background())
require.NoError(t, err)
require.Equal(t, "guest:12345@tcp(172.16.30.11:4001)/?charset=utf8mb4&sql_mode='"+mysql.DefaultSQLMode+"'&maxAllowedPacket=67108864&tls=false", taskCfg.Checkpoint.DSN)
equivalentDSN := taskCfg.Checkpoint.MySQLParam.ToDriverConfig().FormatDSN()
expectedDSN := "guest:12345@tcp(172.16.30.11:4001)/?tls=false&maxAllowedPacket=67108864&charset=utf8mb4&sql_mode=%27ONLY_FULL_GROUP_BY%2CSTRICT_TRANS_TABLES%2CNO_ZERO_IN_DATE%2CNO_ZERO_DATE%2CERROR_FOR_DIVISION_BY_ZERO%2CNO_AUTO_CREATE_USER%2CNO_ENGINE_SUBSTITUTION%27"
require.Equal(t, expectedDSN, equivalentDSN)

result := taskCfg.String()
require.Regexp(t, `.*"pd-addr":"172.16.30.11:2379,172.16.30.12:2379".*`, result)
Expand Down
9 changes: 6 additions & 3 deletions br/pkg/lightning/mydump/region.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func MakeTableRegions(
go func() {
defer wg.Done()
for info := range fileChan {
regions, sizes, err := makeSourceFileRegion(execCtx, meta, info, columns, cfg, ioWorkers, store)
regions, sizes, err := MakeSourceFileRegion(execCtx, meta, info, columns, cfg, ioWorkers, store)
select {
case resultChan <- fileRegionRes{info: info, regions: regions, sizes: sizes, err: err}:
case <-ctx.Done():
Expand Down Expand Up @@ -255,7 +255,8 @@ func MakeTableRegions(
return filesRegions, nil
}

func makeSourceFileRegion(
// MakeSourceFileRegion create a new source file region.
func MakeSourceFileRegion(
ctx context.Context,
meta *MDTableMeta,
fi FileInfo,
Expand Down Expand Up @@ -283,7 +284,9 @@ func makeSourceFileRegion(
// We increase the check threshold by 1/10 of the `max-region-size` because the source file size dumped by tools
// like dumpling might be slight exceed the threshold when it is equal `max-region-size`, so we can
// avoid split a lot of small chunks.
if isCsvFile && cfg.Mydumper.StrictFormat && dataFileSize > int64(cfg.Mydumper.MaxRegionSize+cfg.Mydumper.MaxRegionSize/largeCSVLowerThresholdRation) {
// If a csv file is compressed, we can't split it now because we can't get the exact size of a row.
if isCsvFile && cfg.Mydumper.StrictFormat && fi.FileMeta.Compression == CompressionNone &&
dataFileSize > int64(cfg.Mydumper.MaxRegionSize+cfg.Mydumper.MaxRegionSize/largeCSVLowerThresholdRation) {
_, regions, subFileSizes, err := SplitLargeFile(ctx, meta, cfg, fi, divisor, 0, ioWorkers, store)
return regions, subFileSizes, err
}
Expand Down
57 changes: 57 additions & 0 deletions br/pkg/lightning/mydump/region_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,63 @@ func TestAllocateEngineIDs(t *testing.T) {
})
}

func TestMakeSourceFileRegion(t *testing.T) {
meta := &MDTableMeta{
DB: "csv",
Name: "large_csv_file",
}
cfg := &config.Config{
Mydumper: config.MydumperRuntime{
ReadBlockSize: config.ReadBlockSize,
MaxRegionSize: 1,
CSV: config.CSVConfig{
Separator: ",",
Delimiter: "",
Header: true,
TrimLastSep: false,
NotNull: false,
Null: "NULL",
BackslashEscape: true,
},
StrictFormat: true,
Filter: []string{"*.*"},
},
}
filePath := "./csv/split_large_file.csv"
dataFileInfo, err := os.Stat(filePath)
require.NoError(t, err)
fileSize := dataFileInfo.Size()
fileInfo := FileInfo{FileMeta: SourceFileMeta{Path: filePath, Type: SourceTypeCSV, FileSize: fileSize}}
colCnt := 3
columns := []string{"a", "b", "c"}

ctx := context.Background()
ioWorkers := worker.NewPool(ctx, 4, "io")
store, err := storage.NewLocalStorage(".")
assert.NoError(t, err)

// test - no compression
fileInfo.FileMeta.Compression = CompressionNone
regions, _, err := MakeSourceFileRegion(ctx, meta, fileInfo, colCnt, cfg, ioWorkers, store)
assert.NoError(t, err)
offsets := [][]int64{{6, 12}, {12, 18}, {18, 24}, {24, 30}}
assert.Len(t, regions, len(offsets))
for i := range offsets {
assert.Equal(t, offsets[i][0], regions[i].Chunk.Offset)
assert.Equal(t, offsets[i][1], regions[i].Chunk.EndOffset)
assert.Equal(t, columns, regions[i].Chunk.Columns)
}

// test - gzip compression
fileInfo.FileMeta.Compression = CompressionGZ
regions, _, err = MakeSourceFileRegion(ctx, meta, fileInfo, colCnt, cfg, ioWorkers, store)
assert.NoError(t, err)
assert.Len(t, regions, 1)
assert.Equal(t, int64(0), regions[0].Chunk.Offset)
assert.Equal(t, fileInfo.FileMeta.FileSize, regions[0].Chunk.EndOffset)
assert.Len(t, regions[0].Chunk.Columns, 0)
}

func TestSplitLargeFile(t *testing.T) {
meta := &MDTableMeta{
DB: "csv",
Expand Down
Loading

0 comments on commit f8c8500

Please sign in to comment.