From 642c5dc0d68a21c30d87e94bb90b13d0fcaa5705 Mon Sep 17 00:00:00 2001 From: ekexium Date: Thu, 30 Nov 2023 17:54:56 +0800 Subject: [PATCH 1/8] make load data transactional Signed-off-by: ekexium --- pkg/executor/load_data.go | 20 ++++++-------------- pkg/server/conn.go | 7 +++++++ 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/pkg/executor/load_data.go b/pkg/executor/load_data.go index c613324943a72..776a7d1bf1961 100644 --- a/pkg/executor/load_data.go +++ b/pkg/executor/load_data.go @@ -175,9 +175,9 @@ func (e *LoadDataWorker) load(ctx context.Context, readerInfos []importer.LoadDa // commitWork goroutines -> done -> UpdateJobProgress goroutine // TODO: support explicit transaction and non-autocommit - if err = sessiontxn.NewTxn(groupCtx, e.UserSctx); err != nil { - return err - } + // if err = sessiontxn.NewTxn(groupCtx, e.UserSctx); err != nil { + // return err + // } // processOneStream goroutines. group.Go(func() error { @@ -532,16 +532,7 @@ func (w *commitWorker) commitWork(ctx context.Context, inCh <-chan commitTask) ( zap.Stack("stack")) err = util.GetRecoverError(r) } - - if err != nil { - background := context.Background() - w.Ctx().StmtRollback(background, false) - w.Ctx().RollbackTxn(background) - } else { - if err = w.Ctx().CommitTxn(ctx); err != nil { - logutil.Logger(ctx).Error("commit error refresh", zap.Error(err)) - } - } + w.Ctx().StmtCommit(ctx) }() var ( @@ -580,7 +571,8 @@ func (w *commitWorker) commitOneTask(ctx context.Context, task commitTask) error failpoint.Inject("commitOneTaskErr", func() { failpoint.Return(errors.New("mock commit one task error")) }) - w.Ctx().StmtCommit(ctx) + // NOTE: this is not the end of a statement. Should not call StmtCommit here + // w.Ctx().StmtCommit(ctx) return nil } diff --git a/pkg/server/conn.go b/pkg/server/conn.go index 57d9856f01c03..bb6f6aaa0d1c3 100644 --- a/pkg/server/conn.go +++ b/pkg/server/conn.go @@ -1613,6 +1613,13 @@ func (cc *clientConn) handleLoadData(ctx context.Context, loadDataWorker *execut } } } + // if current session is auto-commit and not in a transaction, commit it here + if cc.ctx.GetSessionVars().IsAutocommit() && !cc.ctx.GetSessionVars().InTxn() { + err = cc.ctx.CommitTxn(ctx) + if err != nil { + return err + } + } return err } From a573b27958431aef0ac34bd858341f0914d31a58 Mon Sep 17 00:00:00 2001 From: ekexium Date: Mon, 4 Dec 2023 19:51:59 +0800 Subject: [PATCH 2/8] fix autocommit load data stmt Signed-off-by: ekexium --- pkg/server/conn.go | 22 +++--- .../testserverclient/server_client.go | 74 +++++++++++++++++++ pkg/server/tests/tidb_serial_test.go | 6 ++ 3 files changed, 93 insertions(+), 9 deletions(-) diff --git a/pkg/server/conn.go b/pkg/server/conn.go index bb6f6aaa0d1c3..4bef6a68e76bf 100644 --- a/pkg/server/conn.go +++ b/pkg/server/conn.go @@ -1531,7 +1531,7 @@ func (cc *clientConn) writeReq(ctx context.Context, filePath string) error { // handleLoadData does the additional work after processing the 'load data' query. // It sends client a file path, then reads the file content from client, inserts data into database. -func (cc *clientConn) handleLoadData(ctx context.Context, loadDataWorker *executor.LoadDataWorker) error { +func (cc *clientConn) handleLoadData(ctx context.Context, loadDataWorker *executor.LoadDataWorker) (err error) { // If the server handles the load data request, the client has to set the ClientLocalFiles capability. if cc.capability&mysql.ClientLocalFiles == 0 { return servererr.ErrNotAllowedCommand @@ -1540,7 +1540,7 @@ func (cc *clientConn) handleLoadData(ctx context.Context, loadDataWorker *execut return errors.New("load data info is empty") } infile := loadDataWorker.GetInfilePath() - err := cc.writeReq(ctx, infile) + err = cc.writeReq(ctx, infile) if err != nil { return err } @@ -1585,6 +1585,17 @@ func (cc *clientConn) handleLoadData(ctx context.Context, loadDataWorker *execut }() ctx = kv.WithInternalSourceType(ctx, kv.InternalLoadData) + // if is autocommit and not in txn, begin a new transaction + if cc.ctx.GetSessionVars().IsAutocommit() && !cc.ctx.GetSessionVars().InTxn() { + defer func() { + if err == nil { + err = cc.ctx.TiDBContext.Session.CommitTxn(ctx) + } else { + cc.ctx.TiDBContext.Session.RollbackTxn(ctx) + } + }() + sessiontxn.NewTxn(ctx, cc.ctx) + } err = loadDataWorker.LoadLocal(ctx, r) _ = r.Close() wg.Wait() @@ -1613,13 +1624,6 @@ func (cc *clientConn) handleLoadData(ctx context.Context, loadDataWorker *execut } } } - // if current session is auto-commit and not in a transaction, commit it here - if cc.ctx.GetSessionVars().IsAutocommit() && !cc.ctx.GetSessionVars().InTxn() { - err = cc.ctx.CommitTxn(ctx) - if err != nil { - return err - } - } return err } diff --git a/pkg/server/internal/testserverclient/server_client.go b/pkg/server/internal/testserverclient/server_client.go index a05b24b2b8ccb..90af63073741a 100644 --- a/pkg/server/internal/testserverclient/server_client.go +++ b/pkg/server/internal/testserverclient/server_client.go @@ -1008,6 +1008,80 @@ func columnsAsExpected(t *testing.T, columns []*sql.NullString, expected []strin } } +func (cli *TestServerClient) RunTestLoadDataInTransaction(t *testing.T, server *server.Server) { + fp, err := os.CreateTemp("", "load_data_test.csv") + require.NoError(t, err) + path := fp.Name() + + require.NotNil(t, fp) + defer func() { + err = fp.Close() + require.NoError(t, err) + err = os.Remove(path) + require.NoError(t, err) + }() + + _, err = fp.WriteString("1") + require.NoError(t, err) + + // load file in transaction can be rolled back + cli.RunTestsOnNewDB( + t, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }, "LoadDataInTransaction", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table t (a int)") + txn, err := dbt.GetDB().Begin() + require.NoError(t, err) + _, err = txn.Exec(fmt.Sprintf("load data local infile %q into table t", path)) + require.NoError(t, err) + rows, err := txn.Query("select * from t") + require.NoError(t, err) + cli.CheckRows(t, rows, "1") + err = txn.Rollback() + require.NoError(t, err) + rows = dbt.MustQuery("select * from t") + cli.CheckRows(t, rows) + }, + ) + + // load file in transaction doesn't commit until the transaction is committed + cli.RunTestsOnNewDB( + t, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }, "LoadDataInTransaction", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table t (a int)") + txn, err := dbt.GetDB().Begin() + require.NoError(t, err) + _, err = txn.Exec(fmt.Sprintf("load data local infile %q into table t", path)) + require.NoError(t, err) + rows, err := txn.Query("select * from t") + require.NoError(t, err) + cli.CheckRows(t, rows, "1") + err = txn.Commit() + require.NoError(t, err) + rows = dbt.MustQuery("select * from t") + cli.CheckRows(t, rows, "1") + }, + ) + + // load file in auto commit mode should succeed + cli.RunTestsOnNewDB( + t, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }, "LoadDataInAutoCommit", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table t (a int)") + dbt.MustExec(fmt.Sprintf("load data local infile %q into table t", path)) + txn, err := dbt.GetDB().Begin() + require.NoError(t, err) + rows, _ := txn.Query("select * from t") + cli.CheckRows(t, rows, "1") + }, + ) +} + func (cli *TestServerClient) RunTestLoadData(t *testing.T, server *server.Server) { fp, err := os.CreateTemp("", "load_data_test.csv") require.NoError(t, err) diff --git a/pkg/server/tests/tidb_serial_test.go b/pkg/server/tests/tidb_serial_test.go index 8f7e263100528..fee659d723a7a 100644 --- a/pkg/server/tests/tidb_serial_test.go +++ b/pkg/server/tests/tidb_serial_test.go @@ -70,6 +70,12 @@ func TestLoadData1(t *testing.T) { ts.RunTestLoadDataForSlowLog(t) } +func TestLoadDataInTransaction(t *testing.T) { + ts := createTidbTestSuite(t) + + ts.RunTestLoadDataInTransaction(t, ts.server) +} + func TestConfigDefaultValue(t *testing.T) { ts := createTidbTestSuite(t) From 9e292e1949c997ed232459b00885cdfbb796e8a2 Mon Sep 17 00:00:00 2001 From: ekexium Date: Tue, 5 Dec 2023 17:14:18 +0800 Subject: [PATCH 3/8] refactor: move the actual work of loading local file into LoadDataExec Signed-off-by: ekexium --- pkg/executor/load_data.go | 87 +++++-- pkg/server/conn.go | 232 ++++++++++-------- .../testserverclient/server_client.go | 60 +++++ 3 files changed, 252 insertions(+), 127 deletions(-) diff --git a/pkg/executor/load_data.go b/pkg/executor/load_data.go index 776a7d1bf1961..bcef32fb763d6 100644 --- a/pkg/executor/load_data.go +++ b/pkg/executor/load_data.go @@ -20,6 +20,7 @@ import ( "io" "math" "strings" + "sync" "time" "github.com/pingcap/errors" @@ -50,16 +51,70 @@ import ( "golang.org/x/sync/errgroup" ) +// LoadDataVarKey is a variable key for load data. +const LoadDataVarKey loadDataVarKeyType = 0 + +// LoadDataReaderBuilderKey stores the reader channel that reads from the connection. +const LoadDataReaderBuilderKey loadDataVarKeyType = 1 + +// LoadDataReaderCloseKey stores the close function of the reader +const LoadDataReaderCloseKey loadDataVarKeyType = 2 + var ( taskQueueSize = 16 // the maximum number of pending tasks to commit in queue ) +// LoadDataReaderBuilder is a function type that builds a reader from a file path. +type LoadDataReaderBuilder func(filepath string) ( + r io.ReadCloser, drained *bool, wg *sync.WaitGroup, err error, +) + +// LoadDataReaderCloser is a function type that closes a reader. +type LoadDataReaderCloser func( + r io.ReadCloser, drained *bool, + wg *sync.WaitGroup, err error, +) error + // LoadDataExec represents a load data executor. type LoadDataExec struct { exec.BaseExecutor FileLocRef ast.FileLocRefTp loadDataWorker *LoadDataWorker + + // fields for loading local file + infileReader io.ReadCloser + drained *bool + wg *sync.WaitGroup + readerCloser LoadDataReaderCloser +} + +// Open implements the Executor Next interface. +func (e *LoadDataExec) Open(_ context.Context) error { + if rb, ok := e.Ctx().Value(LoadDataReaderBuilderKey).(LoadDataReaderBuilder); ok { + e.readerCloser = e.Ctx().Value(LoadDataReaderCloseKey).(LoadDataReaderCloser) + var err error + e.infileReader, e.drained, e.wg, err = rb(e.loadDataWorker.GetInfilePath()) + if err != nil { + return err + } + } + return nil +} + +// Close implements the Executor Next interface. +func (e *LoadDataExec) Close() error { + return e.closeLocalReader(nil) +} + +func (e *LoadDataExec) closeLocalReader(originalErr error) error { + var err error + if e.readerCloser != nil { + err = e.readerCloser(e.infileReader, e.drained, e.wg, originalErr) + } + // don't close it twice + e.readerCloser = nil + return err } // Next implements the Executor Next interface. @@ -68,14 +123,12 @@ func (e *LoadDataExec) Next(ctx context.Context, _ *chunk.Chunk) (err error) { case ast.FileLocServerOrRemote: return e.loadDataWorker.loadRemote(ctx) case ast.FileLocClient: - // let caller use handleFileTransInConn to read data in this connection - sctx := e.loadDataWorker.UserSctx - val := sctx.Value(LoadDataVarKey) - if val != nil { - sctx.SetValue(LoadDataVarKey, nil) - return errors.New("previous load data option wasn't closed normally") + err = e.loadDataWorker.loadLocal(ctx, e.infileReader) + if err != nil { + logutil.Logger(ctx).Error("load local data failed", zap.Error(err)) + err = e.closeLocalReader(err) + return err } - sctx.SetValue(LoadDataVarKey, e.loadDataWorker) } return nil } @@ -146,7 +199,11 @@ func (e *LoadDataWorker) loadRemote(ctx context.Context) error { } // LoadLocal reads from client connection and do load data job. -func (e *LoadDataWorker) LoadLocal(ctx context.Context, r io.ReadCloser) error { +func (e *LoadDataWorker) loadLocal(ctx context.Context, r io.ReadCloser) error { + if r == nil { + return errors.New("load local data, reader is nil") + } + compressTp := mydump.ParseCompressionOnFileExtension(e.GetInfilePath()) compressTp2, err := mydump.ToStorageCompressType(compressTp) if err != nil { @@ -174,11 +231,6 @@ func (e *LoadDataWorker) load(ctx context.Context, readerInfos []importer.LoadDa commitTaskCh := make(chan commitTask, taskQueueSize) // commitWork goroutines -> done -> UpdateJobProgress goroutine - // TODO: support explicit transaction and non-autocommit - // if err = sessiontxn.NewTxn(groupCtx, e.UserSctx); err != nil { - // return err - // } - // processOneStream goroutines. group.Go(func() error { err2 := encoder.processStream(groupCtx, readerInfoCh, commitTaskCh) @@ -532,7 +584,8 @@ func (w *commitWorker) commitWork(ctx context.Context, inCh <-chan commitTask) ( zap.Stack("stack")) err = util.GetRecoverError(r) } - w.Ctx().StmtCommit(ctx) + // Why call it here? + // w.Ctx().StmtCommit(ctx) }() var ( @@ -571,8 +624,6 @@ func (w *commitWorker) commitOneTask(ctx context.Context, task commitTask) error failpoint.Inject("commitOneTaskErr", func() { failpoint.Return(errors.New("mock commit one task error")) }) - // NOTE: this is not the end of a statement. Should not call StmtCommit here - // w.Ctx().StmtCommit(ctx) return nil } @@ -729,14 +780,12 @@ func (loadDataVarKeyType) String() string { return "load_data_var" } -// LoadDataVarKey is a variable key for load data. -const LoadDataVarKey loadDataVarKeyType = 0 - var ( _ exec.Executor = (*LoadDataActionExec)(nil) ) // LoadDataActionExec executes LoadDataActionStmt. +// TODO: LoadDataActionExec and its corresponding syntax is not in use and should be deleted. type LoadDataActionExec struct { exec.BaseExecutor diff --git a/pkg/server/conn.go b/pkg/server/conn.go index 4bef6a68e76bf..379f97f7a8ad4 100644 --- a/pkg/server/conn.go +++ b/pkg/server/conn.go @@ -1529,104 +1529,6 @@ func (cc *clientConn) writeReq(ctx context.Context, filePath string) error { return cc.flush(ctx) } -// handleLoadData does the additional work after processing the 'load data' query. -// It sends client a file path, then reads the file content from client, inserts data into database. -func (cc *clientConn) handleLoadData(ctx context.Context, loadDataWorker *executor.LoadDataWorker) (err error) { - // If the server handles the load data request, the client has to set the ClientLocalFiles capability. - if cc.capability&mysql.ClientLocalFiles == 0 { - return servererr.ErrNotAllowedCommand - } - if loadDataWorker == nil { - return errors.New("load data info is empty") - } - infile := loadDataWorker.GetInfilePath() - err = cc.writeReq(ctx, infile) - if err != nil { - return err - } - - var ( - // use Pipe to convert cc.readPacket to io.Reader - r, w = io.Pipe() - drained bool - wg sync.WaitGroup - ) - wg.Add(1) - go func() { - defer wg.Done() - //nolint: errcheck - defer w.Close() - - var ( - data []byte - err2 error - ) - for { - if len(data) == 0 { - data, err2 = cc.readPacket() - if err2 != nil { - w.CloseWithError(err2) - return - } - // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_local_infile_request.html - if len(data) == 0 { - drained = true - return - } - } - - n, err3 := w.Write(data) - if err3 != nil { - logutil.Logger(ctx).Error("write data meet error", zap.Error(err3)) - return - } - data = data[n:] - } - }() - - ctx = kv.WithInternalSourceType(ctx, kv.InternalLoadData) - // if is autocommit and not in txn, begin a new transaction - if cc.ctx.GetSessionVars().IsAutocommit() && !cc.ctx.GetSessionVars().InTxn() { - defer func() { - if err == nil { - err = cc.ctx.TiDBContext.Session.CommitTxn(ctx) - } else { - cc.ctx.TiDBContext.Session.RollbackTxn(ctx) - } - }() - sessiontxn.NewTxn(ctx, cc.ctx) - } - err = loadDataWorker.LoadLocal(ctx, r) - _ = r.Close() - wg.Wait() - - if err != nil { - if !drained { - logutil.Logger(ctx).Info("not drained yet, try reading left data from client connection") - } - // drain the data from client conn util empty packet received, otherwise the connection will be reset - // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_local_infile_request.html - for !drained { - // check kill flag again, let the draining loop could quit if empty packet could not be received - if atomic.CompareAndSwapUint32(&loadDataWorker.UserSctx.GetSessionVars().SQLKiller.Signal, 1, 0) { - logutil.Logger(ctx).Warn("receiving kill, stop draining data, connection may be reset") - return exeerrors.ErrQueryInterrupted - } - curData, err1 := cc.readPacket() - if err1 != nil { - logutil.Logger(ctx).Error("drain reading left data encounter errors", zap.Error(err1)) - break - } - if len(curData) == 0 { - drained = true - logutil.Logger(ctx).Info("draining finished for error", zap.Error(err)) - break - } - } - } - return err -} - // getDataFromPath gets file contents from file path. func (cc *clientConn) getDataFromPath(ctx context.Context, path string) ([]byte, error) { err := cc.writeReq(ctx, path) @@ -2024,11 +1926,24 @@ func (cc *clientConn) prefetchPointPlanKeys(ctx context.Context, stmts []ast.Stm // The first return value indicates whether the call of handleStmt has no side effect and can be retried. // Currently, the first return value is used to fall back to TiKV when TiFlash is down. -func (cc *clientConn) handleStmt(ctx context.Context, stmt ast.StmtNode, warns []stmtctx.SQLWarn, lastStmt bool) (bool, error) { +func (cc *clientConn) handleStmt( + ctx context.Context, stmt ast.StmtNode, + warns []stmtctx.SQLWarn, lastStmt bool, +) (bool, error) { ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) ctx = context.WithValue(ctx, util.ExecDetailsKey, &util.ExecDetails{}) reg := trace.StartRegion(ctx, "ExecuteStmt") cc.audit(plugin.Starting) + + // if stmt is load data stmt, store the channel that reads from the conn + // into the ctx for executor to use + if _, ok := stmt.(*ast.LoadDataStmt); ok { + err := cc.handleLoadData(ctx) + if err != nil { + return false, err + } + } + rs, err := cc.ctx.ExecuteStmt(ctx, stmt) reg.End() // - If rs is not nil, the statement tracker detachment from session tracker @@ -2075,17 +1990,118 @@ func (cc *clientConn) handleStmt(ctx context.Context, stmt ast.StmtNode, warns [ return false, err } -func (cc *clientConn) handleFileTransInConn(ctx context.Context, status uint16) (bool, error) { - handled := false - loadDataInfo := cc.ctx.Value(executor.LoadDataVarKey) - if loadDataInfo != nil { - handled = true - defer cc.ctx.SetValue(executor.LoadDataVarKey, nil) - //nolint:forcetypeassert - if err := cc.handleLoadData(ctx, loadDataInfo.(*executor.LoadDataWorker)); err != nil { - return handled, err +func (cc *clientConn) handleLoadData(ctx context.Context) error { + if cc.capability&mysql.ClientLocalFiles == 0 { + return servererr.ErrNotAllowedCommand + } + var readerBuilder executor.LoadDataReaderBuilder = func(filepath string) ( + r io.ReadCloser, drained *bool, wg *sync.WaitGroup, err error, + ) { + err = cc.writeReq(ctx, filepath) + if err != nil { + return nil, drained, wg, err + } + drained = new(bool) + r, w := io.Pipe() + wg = &sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + defer func() { + err := w.Close() + if err != nil { + logutil.Logger(ctx).Error( + "close pipe meet error in load data", + zap.Error(err), + ) + return + } + }() + + var ( + err2 error + data []byte + ) + for { + if len(data) == 0 { + data, err2 = cc.readPacket() + if err2 != nil { + err4 := w.CloseWithError(err2) + if err4 != nil { + logutil.Logger(ctx).Error( + "close pipe meet error in load data", + zap.Error(err4), + ) + } + return + } + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_local_infile_request.html + if len(data) == 0 { + *drained = true + return + } + } + + n, err3 := w.Write(data) + if err3 != nil { + logutil.Logger(ctx).Error( + "write data meet error in load data", + zap.Error(err3), + ) + return + } + data = data[n:] + } + }() + return r, drained, wg, nil + } + var readerCloser executor.LoadDataReaderCloser = func( + r io.ReadCloser, drained *bool, + wg *sync.WaitGroup, err error, + ) error { + _ = r.Close() + wg.Wait() + if err != nil { + if !*drained { + logutil.Logger(ctx).Info("not drained yet, try reading left data from client connection") + } + // drain the data from client conn util empty packet received, otherwise the connection will be reset + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_local_infile_request.html + for !*drained { + // check kill flag again, let the draining loop could quit if empty packet could not be received + if atomic.CompareAndSwapUint32( + &cc.ctx.GetSessionVars().SQLKiller.Signal, + 1, + 0, + ) { + logutil.Logger(ctx).Warn("receiving kill, stop draining data, connection may be reset") + return exeerrors.ErrQueryInterrupted + } + curData, err1 := cc.readPacket() + if err1 != nil { + logutil.Logger(ctx).Error( + "drain reading left data encounter errors", + zap.Error(err1), + ) + break + } + if len(curData) == 0 { + *drained = true + logutil.Logger(ctx).Info("draining finished for error", zap.Error(err)) + break + } + } } + return err } + // set these functions in the context + cc.ctx.SetValue(executor.LoadDataReaderBuilderKey, readerBuilder) + cc.ctx.SetValue(executor.LoadDataReaderCloseKey, readerCloser) + return nil +} + +func (cc *clientConn) handleFileTransInConn(ctx context.Context, status uint16) (bool, error) { + handled := false loadStats := cc.ctx.Value(executor.LoadStatsVarKey) if loadStats != nil { diff --git a/pkg/server/internal/testserverclient/server_client.go b/pkg/server/internal/testserverclient/server_client.go index 90af63073741a..ebc213d434ade 100644 --- a/pkg/server/internal/testserverclient/server_client.go +++ b/pkg/server/internal/testserverclient/server_client.go @@ -28,6 +28,7 @@ import ( "regexp" "strconv" "strings" + "sync" "testing" "time" @@ -1080,6 +1081,65 @@ func (cli *TestServerClient) RunTestLoadDataInTransaction(t *testing.T, server * cli.CheckRows(t, rows, "1") }, ) + + // load file in a pessimistic transaction, + // should acquire locks when after its execution and before it commits. + // The lock should be observed by another transaction that is attempting to acquire the same + // lock. + dbName := "LoadDataInPessimisticTransaction" + cli.RunTestsOnNewDB( + t, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }, dbName, func(dbt *testkit.DBTestKit) { + dbt.MustExec("set @@global.tidb_general_log = 1") + dbt.MustExec("create table t (a int primary key)") + txn, err := dbt.GetDB().Begin() + require.NoError(t, err) + _, err = txn.Exec(fmt.Sprintf("USE `%s`;", dbName)) + require.NoError(t, err) + _, err = txn.Exec(fmt.Sprintf("load data local infile %q into table t", path)) + require.NoError(t, err) + rows, err := txn.Query("select * from t") + require.NoError(t, err) + cli.CheckRows(t, rows, "1") + + var wg sync.WaitGroup + wg.Add(1) + txn2Locked := make(chan struct{}, 1) + failed := make(chan struct{}, 1) + go func() { + time.Sleep(2 * time.Second) + select { + case <-txn2Locked: + failed <- struct{}{} + default: + } + + err2 := txn.Commit() + require.NoError(t, err2) + wg.Done() + }() + txn2, err := dbt.GetDB().Begin() + require.NoError(t, err) + _, err = txn2.Exec(fmt.Sprintf("USE `%s`;", dbName)) + require.NoError(t, err) + _, err = txn2.Exec("select * from t where a = 1 for update") + require.NoError(t, err) + txn2Locked <- struct{}{} + wg.Wait() + txn2.Rollback() + select { + case <-failed: + require.Fail(t, "txn2 should not be able to acquire the lock") + default: + } + + require.NoError(t, err) + rows = dbt.MustQuery("select * from t") + cli.CheckRows(t, rows, "1") + }, + ) } func (cli *TestServerClient) RunTestLoadData(t *testing.T, server *server.Server) { From 36b7107b151a096cddbf88d0f562e934ccdc99ff Mon Sep 17 00:00:00 2001 From: ekexium Date: Wed, 6 Dec 2023 19:55:55 +0800 Subject: [PATCH 4/8] test: adjust origianl loda_data tests Signed-off-by: ekexium --- pkg/executor/load_data.go | 15 +- pkg/executor/test/loaddatatest/BUILD.bazel | 2 +- .../test/loaddatatest/load_data_test.go | 150 ++++++++---------- pkg/server/conn.go | 27 +++- .../testserverclient/server_client.go | 37 ++++- pkg/server/tests/tidb_serial_test.go | 2 +- 6 files changed, 137 insertions(+), 96 deletions(-) diff --git a/pkg/executor/load_data.go b/pkg/executor/load_data.go index bcef32fb763d6..5212fa4f1b420 100644 --- a/pkg/executor/load_data.go +++ b/pkg/executor/load_data.go @@ -89,7 +89,7 @@ type LoadDataExec struct { readerCloser LoadDataReaderCloser } -// Open implements the Executor Next interface. +// Open implements the Executor interface. func (e *LoadDataExec) Open(_ context.Context) error { if rb, ok := e.Ctx().Value(LoadDataReaderBuilderKey).(LoadDataReaderBuilder); ok { e.readerCloser = e.Ctx().Value(LoadDataReaderCloseKey).(LoadDataReaderCloser) @@ -102,7 +102,7 @@ func (e *LoadDataExec) Open(_ context.Context) error { return nil } -// Close implements the Executor Next interface. +// Close implements the Executor interface. func (e *LoadDataExec) Close() error { return e.closeLocalReader(nil) } @@ -123,7 +123,12 @@ func (e *LoadDataExec) Next(ctx context.Context, _ *chunk.Chunk) (err error) { case ast.FileLocServerOrRemote: return e.loadDataWorker.loadRemote(ctx) case ast.FileLocClient: - err = e.loadDataWorker.loadLocal(ctx, e.infileReader) + // This is for legacy test only + // TODO: adjust tests to remove LoadDataVarKey + sctx := e.loadDataWorker.UserSctx + sctx.SetValue(LoadDataVarKey, e.loadDataWorker) + + err = e.loadDataWorker.LoadLocal(ctx, e.infileReader) if err != nil { logutil.Logger(ctx).Error("load local data failed", zap.Error(err)) err = e.closeLocalReader(err) @@ -199,7 +204,7 @@ func (e *LoadDataWorker) loadRemote(ctx context.Context) error { } // LoadLocal reads from client connection and do load data job. -func (e *LoadDataWorker) loadLocal(ctx context.Context, r io.ReadCloser) error { +func (e *LoadDataWorker) LoadLocal(ctx context.Context, r io.ReadCloser) error { if r == nil { return errors.New("load local data, reader is nil") } @@ -584,8 +589,6 @@ func (w *commitWorker) commitWork(ctx context.Context, inCh <-chan commitTask) ( zap.Stack("stack")) err = util.GetRecoverError(r) } - // Why call it here? - // w.Ctx().StmtCommit(ctx) }() var ( diff --git a/pkg/executor/test/loaddatatest/BUILD.bazel b/pkg/executor/test/loaddatatest/BUILD.bazel index 4f0f27f54363e..7c42b26c4e066 100644 --- a/pkg/executor/test/loaddatatest/BUILD.bazel +++ b/pkg/executor/test/loaddatatest/BUILD.bazel @@ -9,7 +9,7 @@ go_test( ], flaky = True, race = "on", - shard_count = 10, + shard_count = 11, deps = [ "//br/pkg/lightning/mydump", "//pkg/config", diff --git a/pkg/executor/test/loaddatatest/load_data_test.go b/pkg/executor/test/loaddatatest/load_data_test.go index 1284475340815..a2a2daa960053 100644 --- a/pkg/executor/test/loaddatatest/load_data_test.go +++ b/pkg/executor/test/loaddatatest/load_data_test.go @@ -15,7 +15,9 @@ package loaddatatest import ( - "context" + "fmt" + "io" + "sync" "testing" "github.com/pingcap/tidb/br/pkg/lightning/mydump" @@ -34,25 +36,34 @@ type testCase struct { func checkCases( tests []testCase, - ld *executor.LoadDataWorker, + loadSQL string, t *testing.T, tk *testkit.TestKit, ctx sessionctx.Context, selectSQL, deleteSQL string, ) { for _, tt := range tests { - parser, err := mydump.NewCSVParser( - context.Background(), - ld.GetController().GenerateCSVConfig(), - mydump.NewStringReader(string(tt.data)), - 1, - nil, - false, - nil) - require.NoError(t, err) - - err = ld.TestLoadLocal(parser) - require.NoError(t, err) + var reader io.ReadCloser = mydump.NewStringReader(string(tt.data)) + var readerBuilder executor.LoadDataReaderBuilder = func(_ string) ( + r io.ReadCloser, drained *bool, wg *sync.WaitGroup, err error, + ) { + return reader, new(bool), new(sync.WaitGroup), nil + } + var readerCloser executor.LoadDataReaderCloser = func( + r io.ReadCloser, drained *bool, + wg *sync.WaitGroup, err error, + ) error { + r.Close() + return nil + } + + ctx.SetValue(executor.LoadDataReaderBuilderKey, readerBuilder) + ctx.SetValue(executor.LoadDataReaderCloseKey, readerCloser) + tk.MustExec(loadSQL) + warnings := tk.Session().GetSessionVars().StmtCtx.GetWarnings() + for _, w := range warnings { + fmt.Printf("warnnig: %#v\n", w.Err.Error()) + } require.Equal(t, tt.expectedMsg, tk.Session().LastMessage(), tt.expected) tk.MustQuery(selectSQL).Check(testkit.RowsWithSep("|", tt.expected...)) tk.MustExec(deleteSQL) @@ -130,12 +141,8 @@ func TestLoadData(t *testing.T) { tk.MustExec(createSQL) err = tk.ExecToErr("load data infile '/tmp/nonexistence.csv' into table load_data_test") require.Error(t, err) - tk.MustExec("load data local infile '/tmp/nonexistence.csv' ignore into table load_data_test") + loadSQL := "load data local infile '/tmp/nonexistence.csv' ignore into table load_data_test" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) deleteSQL := "delete from load_data_test" selectSQL := "select * from load_data_test;" @@ -164,10 +171,11 @@ func TestLoadData(t *testing.T) { {[]byte("\t2\t3\t4\t5\n"), []string{"10|2|3|4"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 2"}, {[]byte("\t2\t34\t5\n"), []string{"11|2|34|5"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, } - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) // lines starting symbol is "" and terminated symbol length is 2, ReadOneBatchRows returns data is nil - ld.GetController().LinesTerminatedBy = "||" + loadSQL = "load data local infile '/tmp/nonexistence." + + "csv' ignore into table load_data_test lines terminated by '||'" tests = []testCase{ {[]byte("0\t2\t3\t4\t5||"), []string{"12|2|3|4"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, {[]byte("1\t2\t3\t4\t5||"), []string{"1|2|3|4"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, @@ -179,12 +187,11 @@ func TestLoadData(t *testing.T) { []string{"4|2|3|4", "5|22|33|", "6|222||"}, "Records: 3 Deleted: 0 Skipped: 0 Warnings: 3"}, {[]byte("6\t2\t34\t5||"), []string{"6|2|34|5"}, trivialMsg}, } - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) // fields and lines aren't default, ReadOneBatchRows returns data is nil - ld.GetController().FieldsTerminatedBy = "\\" - ld.GetController().LinesStartingBy = "xxx" - ld.GetController().LinesTerminatedBy = "|!#^" + loadSQL = "load data local infile '/tmp/nonexistence.csv' " + + `ignore into table load_data_test fields terminated by '\\' lines starting by 'xxx' terminated by '|!#^'` tests = []testCase{ {[]byte("xxx|!#^"), []string{"13|||"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 2"}, {[]byte("xxx\\|!#^"), []string{"14|0||"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 3"}, @@ -219,7 +226,7 @@ func TestLoadData(t *testing.T) { []string{"25|2|3|4", "27|222||"}, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 2"}, {[]byte("xxx\\2\\34\\5|!#^"), []string{"28|2|34|5"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, } - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) // TODO: not support it now // lines starting symbol is the same as terminated symbol, ReadOneBatchRows returns data is nil @@ -258,21 +265,25 @@ func TestLoadData(t *testing.T) { //checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) // test line terminator in field quoter - ld.GetController().LinesTerminatedBy = "\n" - ld.GetController().FieldsEnclosedBy = `"` + loadSQL = "load data local infile '/tmp/nonexistence.csv' " + + "ignore into table load_data_test " + + "fields terminated by '\\\\' enclosed by '\\\"' " + + "lines starting by 'xxx' terminated by '\\n'" tests = []testCase{ {[]byte("xxx1\\1\\\"2\n\"\\3\nxxx4\\4\\\"5\n5\"\\6"), []string{"1|1|2\n|3", "4|4|5\n5|6"}, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 0"}, } - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) - ld.GetController().LinesTerminatedBy = "#\n" - ld.GetController().FieldsTerminatedBy = "#" + loadSQL = "load data local infile '/tmp/nonexistence.csv' " + + "ignore into table load_data_test " + + "fields terminated by '#' enclosed by '\\\"' " + + "lines starting by 'xxx' terminated by '#\\n'" tests = []testCase{ {[]byte("xxx1#\nxxx2#\n"), []string{"1|||", "2|||"}, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 2"}, {[]byte("xxx1#2#3#4#\nnxxx2#3#4#5#\n"), []string{"1|2|3|4", "2|3|4|5"}, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 0"}, {[]byte("xxx1#2#\"3#\"#\"4\n\"#\nxxx2#3#\"#4#\n\"#5#\n"), []string{"1|2|3#|4", "2|3|#4#\n|5"}, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 0"}, } - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) // TODO: now support it now //ld.LinesInfo.Terminated = "#" @@ -293,12 +304,8 @@ func TestLoadDataEscape(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test; drop table if exists load_data_test;") tk.MustExec("CREATE TABLE load_data_test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8") - tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test") + loadSQL := "load data local infile '/tmp/nonexistence.csv' into table load_data_test" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) // test escape tests := []testCase{ // data1 = nil, data2 != nil @@ -314,7 +321,7 @@ func TestLoadDataEscape(t *testing.T) { } deleteSQL := "delete from load_data_test" selectSQL := "select * from load_data_test;" - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) } // TestLoadDataSpecifiedColumns reuse TestLoadDataEscape's test case :-) @@ -324,12 +331,8 @@ func TestLoadDataSpecifiedColumns(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test; drop table if exists load_data_test;") tk.MustExec(`create table load_data_test (id int PRIMARY KEY AUTO_INCREMENT, c1 int, c2 varchar(255) default "def", c3 int default 0);`) - tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test (c1, c2)") + loadSQL := "load data local infile '/tmp/nonexistence.csv' into table load_data_test (c1, c2)" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) // test tests := []testCase{ {[]byte("7\ta string\n"), []string{"1|7|a string|0"}, trivialMsg}, @@ -342,7 +345,7 @@ func TestLoadDataSpecifiedColumns(t *testing.T) { } deleteSQL := "delete from load_data_test" selectSQL := "select * from load_data_test;" - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) } func TestLoadDataIgnoreLines(t *testing.T) { @@ -350,19 +353,15 @@ func TestLoadDataIgnoreLines(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test; drop table if exists load_data_test;") tk.MustExec("CREATE TABLE load_data_test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8") - tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test ignore 1 lines") + loadSQL := "load data local infile '/tmp/nonexistence.csv' into table load_data_test ignore 1 lines" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) tests := []testCase{ {[]byte("1\tline1\n2\tline2\n"), []string{"2|line2"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 0"}, {[]byte("1\tline1\n2\tline2\n3\tline3\n"), []string{"2|line2", "3|line3"}, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 0"}, } deleteSQL := "delete from load_data_test" selectSQL := "select * from load_data_test;" - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) } func TestLoadDataNULL(t *testing.T) { @@ -374,13 +373,9 @@ func TestLoadDataNULL(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test; drop table if exists load_data_test;") tk.MustExec("CREATE TABLE load_data_test (id VARCHAR(20), value VARCHAR(20)) CHARACTER SET utf8") - tk.MustExec(`load data local infile '/tmp/nonexistence.csv' into table load_data_test -FIELDS TERMINATED BY ',' ENCLOSED BY '"' LINES TERMINATED BY '\n';`) + loadSQL := `load data local infile '/tmp/nonexistence.csv' into table load_data_test +FIELDS TERMINATED BY ',' ENCLOSED BY '"' LINES TERMINATED BY '\n';` ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) tests := []testCase{ { []byte(`NULL,"NULL" @@ -392,7 +387,7 @@ FIELDS TERMINATED BY ',' ENCLOSED BY '"' LINES TERMINATED BY '\n';`) } deleteSQL := "delete from load_data_test" selectSQL := "select * from load_data_test;" - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) } func TestLoadDataReplace(t *testing.T) { @@ -401,19 +396,15 @@ func TestLoadDataReplace(t *testing.T) { tk.MustExec("USE test; DROP TABLE IF EXISTS load_data_replace;") tk.MustExec("CREATE TABLE load_data_replace (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL)") tk.MustExec("INSERT INTO load_data_replace VALUES(1,'val 1'),(2,'val 2')") - tk.MustExec("LOAD DATA LOCAL INFILE '/tmp/nonexistence.csv' REPLACE INTO TABLE load_data_replace") + loadSQL := "LOAD DATA LOCAL INFILE '/tmp/nonexistence.csv' REPLACE INTO TABLE load_data_replace" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) tests := []testCase{ {[]byte("1\tline1\n2\tline2\n"), []string{"1|line1", "2|line2"}, "Records: 2 Deleted: 2 Skipped: 0 Warnings: 0"}, {[]byte("2\tnew line2\n3\tnew line3\n"), []string{"1|line1", "2|new line2", "3|new line3"}, "Records: 2 Deleted: 1 Skipped: 0 Warnings: 0"}, } deleteSQL := "DO 1" selectSQL := "TABLE load_data_replace;" - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) } // TestLoadDataOverflowBigintUnsigned related to issue 6360 @@ -422,19 +413,15 @@ func TestLoadDataOverflowBigintUnsigned(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test; drop table if exists load_data_test;") tk.MustExec("CREATE TABLE load_data_test (a bigint unsigned);") - tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test") + loadSQL := "load data local infile '/tmp/nonexistence.csv' into table load_data_test" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) tests := []testCase{ {[]byte("-1\n-18446744073709551615\n-18446744073709551616\n"), []string{"0", "0", "0"}, "Records: 3 Deleted: 0 Skipped: 0 Warnings: 3"}, {[]byte("-9223372036854775809\n18446744073709551616\n"), []string{"0", "18446744073709551615"}, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 2"}, } deleteSQL := "delete from load_data_test" selectSQL := "select * from load_data_test;" - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) } func TestLoadDataWithUppercaseUserVars(t *testing.T) { @@ -442,19 +429,15 @@ func TestLoadDataWithUppercaseUserVars(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test; drop table if exists load_data_test;") tk.MustExec("CREATE TABLE load_data_test (a int, b int);") - tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test (@V1)" + - " set a = @V1, b = @V1*100") + loadSQL := "load data local infile '/tmp/nonexistence.csv' into table load_data_test (@V1)" + + " set a = @V1, b = @V1*100" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) tests := []testCase{ {[]byte("1\n2\n"), []string{"1|100", "2|200"}, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 0"}, } deleteSQL := "delete from load_data_test" selectSQL := "select * from load_data_test;" - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) } func TestLoadDataIntoPartitionedTable(t *testing.T) { @@ -465,14 +448,21 @@ func TestLoadDataIntoPartitionedTable(t *testing.T) { "partition p0 values less than (4)," + "partition p1 values less than (7)," + "partition p2 values less than (11))") - tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table range_t fields terminated by ','") ctx := tk.Session().(sessionctx.Context) - ld := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - + loadSQL := "load data local infile '/tmp/nonexistence.csv' into table range_t fields terminated by ','" tests := []testCase{ {[]byte("1,2\n3,4\n5,6\n7,8\n9,10\n"), []string{"1|2", "3|4", "5|6", "7|8", "9|10"}, "Records: 5 Deleted: 0 Skipped: 0 Warnings: 0"}, } deleteSQL := "delete from range_t" selectSQL := "select * from range_t order by a;" - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) +} + +func TestLoadDataFromServerFile(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table load_data_test (a int)") + err := tk.ExecToErr("load data infile 'remote.csv' into table load_data_test") + require.ErrorContains(t, err, "[executor:8154]Don't support load data from tidb-server's disk.") } diff --git a/pkg/server/conn.go b/pkg/server/conn.go index 379f97f7a8ad4..821a7d08eec96 100644 --- a/pkg/server/conn.go +++ b/pkg/server/conn.go @@ -1937,10 +1937,12 @@ func (cc *clientConn) handleStmt( // if stmt is load data stmt, store the channel that reads from the conn // into the ctx for executor to use - if _, ok := stmt.(*ast.LoadDataStmt); ok { - err := cc.handleLoadData(ctx) - if err != nil { - return false, err + if s, ok := stmt.(*ast.LoadDataStmt); ok { + if s.FileLocRef == ast.FileLocClient { + err := cc.preprocessLoadDataLocal(ctx) + if err != nil { + return false, err + } } } @@ -1953,6 +1955,12 @@ func (cc *clientConn) handleStmt( if rs != nil { defer terror.Call(rs.Close) } + if s, ok := stmt.(*ast.LoadDataStmt); ok { + if s.FileLocRef == ast.FileLocClient { + cc.postprocessLoadDataLocal() + } + } + if err != nil { // If error is returned during the planner phase or the executor.Open // phase, the rs will be nil, and StmtCtx.MemTracker StmtCtx.DiskTracker @@ -1990,7 +1998,10 @@ func (cc *clientConn) handleStmt( return false, err } -func (cc *clientConn) handleLoadData(ctx context.Context) error { +// Preprocess LOAD DATA. Load data from a local file requires reading from the connection. +// The function pass a builder to build the connection reader to the context, +// which will be used in LoadDataExec. +func (cc *clientConn) preprocessLoadDataLocal(ctx context.Context) error { if cc.capability&mysql.ClientLocalFiles == 0 { return servererr.ErrNotAllowedCommand } @@ -2094,12 +2105,16 @@ func (cc *clientConn) handleLoadData(ctx context.Context) error { } return err } - // set these functions in the context cc.ctx.SetValue(executor.LoadDataReaderBuilderKey, readerBuilder) cc.ctx.SetValue(executor.LoadDataReaderCloseKey, readerCloser) return nil } +func (cc *clientConn) postprocessLoadDataLocal() { + cc.ctx.ClearValue(executor.LoadDataReaderBuilderKey) + cc.ctx.ClearValue(executor.LoadDataReaderCloseKey) +} + func (cc *clientConn) handleFileTransInConn(ctx context.Context, status uint16) (bool, error) { handled := false diff --git a/pkg/server/internal/testserverclient/server_client.go b/pkg/server/internal/testserverclient/server_client.go index ebc213d434ade..5d87e5684117d 100644 --- a/pkg/server/internal/testserverclient/server_client.go +++ b/pkg/server/internal/testserverclient/server_client.go @@ -1009,7 +1009,7 @@ func columnsAsExpected(t *testing.T, columns []*sql.NullString, expected []strin } } -func (cli *TestServerClient) RunTestLoadDataInTransaction(t *testing.T, server *server.Server) { +func (cli *TestServerClient) RunTestLoadDataInTransaction(t *testing.T) { fp, err := os.CreateTemp("", "load_data_test.csv") require.NoError(t, err) path := fp.Name() @@ -1092,7 +1092,6 @@ func (cli *TestServerClient) RunTestLoadDataInTransaction(t *testing.T, server * config.AllowAllFiles = true config.Params["sql_mode"] = "''" }, dbName, func(dbt *testkit.DBTestKit) { - dbt.MustExec("set @@global.tidb_general_log = 1") dbt.MustExec("create table t (a int primary key)") txn, err := dbt.GetDB().Begin() require.NoError(t, err) @@ -1140,6 +1139,40 @@ func (cli *TestServerClient) RunTestLoadDataInTransaction(t *testing.T, server * cli.CheckRows(t, rows, "1") }, ) + + cli.RunTestsOnNewDB( + t, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }, "LoadDataFromServerFile", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table t (a int)") + _, err = dbt.GetDB().Exec(fmt.Sprintf("load data infile %q into table t", path)) + require.ErrorContains(t, err, "Don't support load data from tidb-server's disk.") + }, + ) + + // The test is intended to test if the load data statement correctly cleans up its + // resources after execution, and does not affect following statements. + // For example, the 1st load data builds the reader and finishes. + // The 2nd load data should not be able to access the reader, especially when it should fail + cli.RunTestsOnNewDB( + t, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }, "LoadDataCleanup", func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table t (a int)") + txn, err := dbt.GetDB().Begin() + require.NoError(t, err) + _, err = txn.Exec(fmt.Sprintf("load data local infile %q into table t", path)) + require.NoError(t, err) + _, err = txn.Exec("load data local infile '/tmp/does_not_exist' into table t") + require.ErrorContains(t, err, "no such file or directory") + err = txn.Commit() + require.NoError(t, err) + rows := dbt.MustQuery("select * from t") + cli.CheckRows(t, rows, "1") + }, + ) } func (cli *TestServerClient) RunTestLoadData(t *testing.T, server *server.Server) { diff --git a/pkg/server/tests/tidb_serial_test.go b/pkg/server/tests/tidb_serial_test.go index fee659d723a7a..132703e96fd18 100644 --- a/pkg/server/tests/tidb_serial_test.go +++ b/pkg/server/tests/tidb_serial_test.go @@ -73,7 +73,7 @@ func TestLoadData1(t *testing.T) { func TestLoadDataInTransaction(t *testing.T) { ts := createTidbTestSuite(t) - ts.RunTestLoadDataInTransaction(t, ts.server) + ts.RunTestLoadDataInTransaction(t) } func TestConfigDefaultValue(t *testing.T) { From 09e23102b8c876aaa9ea6a912811aa444cfe604e Mon Sep 17 00:00:00 2001 From: ekexium Date: Thu, 7 Dec 2023 17:19:47 +0800 Subject: [PATCH 5/8] refactor: defer postprocessLoadDataLocal Signed-off-by: ekexium --- pkg/server/conn.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pkg/server/conn.go b/pkg/server/conn.go index 821a7d08eec96..b77d4df25d7c8 100644 --- a/pkg/server/conn.go +++ b/pkg/server/conn.go @@ -1940,6 +1940,7 @@ func (cc *clientConn) handleStmt( if s, ok := stmt.(*ast.LoadDataStmt); ok { if s.FileLocRef == ast.FileLocClient { err := cc.preprocessLoadDataLocal(ctx) + defer cc.postprocessLoadDataLocal() if err != nil { return false, err } @@ -1955,11 +1956,6 @@ func (cc *clientConn) handleStmt( if rs != nil { defer terror.Call(rs.Close) } - if s, ok := stmt.(*ast.LoadDataStmt); ok { - if s.FileLocRef == ast.FileLocClient { - cc.postprocessLoadDataLocal() - } - } if err != nil { // If error is returned during the planner phase or the executor.Open From ed9f3249e8bf7cbb0c4eec2ad170b2a9ae795749 Mon Sep 17 00:00:00 2001 From: ekexium Date: Mon, 11 Dec 2023 16:39:40 +0800 Subject: [PATCH 6/8] test: load remote Signed-off-by: ekexium --- .../test/loadremotetest/one_csv_test.go | 41 +++++++++++++++++++ .../testserverclient/server_client.go | 3 +- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/pkg/executor/test/loadremotetest/one_csv_test.go b/pkg/executor/test/loadremotetest/one_csv_test.go index 4a021ecebbe56..2fefa018d63fa 100644 --- a/pkg/executor/test/loadremotetest/one_csv_test.go +++ b/pkg/executor/test/loadremotetest/one_csv_test.go @@ -85,6 +85,47 @@ func (s *mockGCSSuite) TestLoadCSV() { s.tk.MustContainErrMsg(sql, "Don't support load data from tidb-server's disk. Or if you want to load local data via client, the path of INFILE '/etc/passwd' needs to specify the clause of LOCAL first") } +func (s *mockGCSSuite) TestLoadCsvInTransaction() { + s.tk.MustExec("DROP DATABASE IF EXISTS load_csv;") + s.tk.MustExec("CREATE DATABASE load_csv;") + s.tk.MustExec("CREATE TABLE load_csv.t (i INT, s varchar(32));") + + s.server.CreateObject( + fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{ + BucketName: "test-load-csv", + Name: "data.csv", + }, + Content: []byte("100, test100\n101, hello\n102, 😄😄😄😄😄\n104, bye"), + }, + ) + + s.tk.MustExec("begin pessimistic") + sql := fmt.Sprintf( + `LOAD DATA INFILE 'gs://test-load-csv/data.csv?endpoint=%s' INTO TABLE load_csv.t `+ + "FIELDS TERMINATED BY ','", + gcsEndpoint, + ) + // test: load data stmt doesn't commit it + s.tk.MustExec("insert into load_csv.t values (1, 'a')") + s.tk.MustExec(sql) + s.tk.MustQuery("select i from load_csv.t order by i").Check( + testkit.Rows( + "1", "100", "101", + "102", "104", + ), + ) + // load data can be rolled back + s.tk.MustExec("rollback") + s.tk.MustQuery("select * from load_csv.t").Check(testkit.Rows()) + + // load data commit + s.tk.MustExec("begin pessimistic") + s.tk.MustExec(sql) + s.tk.MustExec("commit") + s.tk.MustQuery("select i from load_csv.t").Check(testkit.Rows("100", "101", "102", "104")) +} + func (s *mockGCSSuite) TestIgnoreNLines() { s.tk.MustExec("DROP DATABASE IF EXISTS load_csv;") s.tk.MustExec("CREATE DATABASE load_csv;") diff --git a/pkg/server/internal/testserverclient/server_client.go b/pkg/server/internal/testserverclient/server_client.go index 5d87e5684117d..6e375c247b28e 100644 --- a/pkg/server/internal/testserverclient/server_client.go +++ b/pkg/server/internal/testserverclient/server_client.go @@ -1034,11 +1034,12 @@ func (cli *TestServerClient) RunTestLoadDataInTransaction(t *testing.T) { dbt.MustExec("create table t (a int)") txn, err := dbt.GetDB().Begin() require.NoError(t, err) + txn.Exec("insert into t values (100)") // `load data` doesn't commit current txn _, err = txn.Exec(fmt.Sprintf("load data local infile %q into table t", path)) require.NoError(t, err) rows, err := txn.Query("select * from t") require.NoError(t, err) - cli.CheckRows(t, rows, "1") + cli.CheckRows(t, rows, "100\n1") err = txn.Rollback() require.NoError(t, err) rows = dbt.MustQuery("select * from t") From f974e43b494583b1f4aa06b37979e9e76498e3b6 Mon Sep 17 00:00:00 2001 From: ekexium Date: Thu, 14 Dec 2023 20:44:17 +0800 Subject: [PATCH 7/8] refactor: remove the closer function; let the goroutine in reader builder drain the connection when error occurs Signed-off-by: ekexium --- pkg/executor/load_data.go | 35 ++--- .../test/loaddatatest/load_data_test.go | 36 ++--- pkg/executor/test/writetest/write_test.go | 53 ++++---- pkg/privilege/privileges/privileges_test.go | 3 +- pkg/server/conn.go | 127 +++++++----------- 5 files changed, 107 insertions(+), 147 deletions(-) diff --git a/pkg/executor/load_data.go b/pkg/executor/load_data.go index 5212fa4f1b420..62d06fbeb0840 100644 --- a/pkg/executor/load_data.go +++ b/pkg/executor/load_data.go @@ -20,7 +20,6 @@ import ( "io" "math" "strings" - "sync" "time" "github.com/pingcap/errors" @@ -57,24 +56,15 @@ const LoadDataVarKey loadDataVarKeyType = 0 // LoadDataReaderBuilderKey stores the reader channel that reads from the connection. const LoadDataReaderBuilderKey loadDataVarKeyType = 1 -// LoadDataReaderCloseKey stores the close function of the reader -const LoadDataReaderCloseKey loadDataVarKeyType = 2 - var ( taskQueueSize = 16 // the maximum number of pending tasks to commit in queue ) // LoadDataReaderBuilder is a function type that builds a reader from a file path. type LoadDataReaderBuilder func(filepath string) ( - r io.ReadCloser, drained *bool, wg *sync.WaitGroup, err error, + r io.ReadCloser, err error, ) -// LoadDataReaderCloser is a function type that closes a reader. -type LoadDataReaderCloser func( - r io.ReadCloser, drained *bool, - wg *sync.WaitGroup, err error, -) error - // LoadDataExec represents a load data executor. type LoadDataExec struct { exec.BaseExecutor @@ -84,17 +74,13 @@ type LoadDataExec struct { // fields for loading local file infileReader io.ReadCloser - drained *bool - wg *sync.WaitGroup - readerCloser LoadDataReaderCloser } // Open implements the Executor interface. func (e *LoadDataExec) Open(_ context.Context) error { if rb, ok := e.Ctx().Value(LoadDataReaderBuilderKey).(LoadDataReaderBuilder); ok { - e.readerCloser = e.Ctx().Value(LoadDataReaderCloseKey).(LoadDataReaderCloser) var err error - e.infileReader, e.drained, e.wg, err = rb(e.loadDataWorker.GetInfilePath()) + e.infileReader, err = rb(e.loadDataWorker.GetInfilePath()) if err != nil { return err } @@ -108,12 +94,19 @@ func (e *LoadDataExec) Close() error { } func (e *LoadDataExec) closeLocalReader(originalErr error) error { - var err error - if e.readerCloser != nil { - err = e.readerCloser(e.infileReader, e.drained, e.wg, originalErr) + err := originalErr + if e.infileReader != nil { + if err2 := e.infileReader.Close(); err2 != nil { + logutil.BgLogger().Error( + "close local reader failed", zap.Error(err2), + zap.NamedError("original error", originalErr), + ) + if err == nil { + err = err2 + } + } + e.infileReader = nil } - // don't close it twice - e.readerCloser = nil return err } diff --git a/pkg/executor/test/loaddatatest/load_data_test.go b/pkg/executor/test/loaddatatest/load_data_test.go index a2a2daa960053..bb9f03d7a32d5 100644 --- a/pkg/executor/test/loaddatatest/load_data_test.go +++ b/pkg/executor/test/loaddatatest/load_data_test.go @@ -17,7 +17,6 @@ package loaddatatest import ( "fmt" "io" - "sync" "testing" "github.com/pingcap/tidb/br/pkg/lightning/mydump" @@ -45,20 +44,12 @@ func checkCases( for _, tt := range tests { var reader io.ReadCloser = mydump.NewStringReader(string(tt.data)) var readerBuilder executor.LoadDataReaderBuilder = func(_ string) ( - r io.ReadCloser, drained *bool, wg *sync.WaitGroup, err error, + r io.ReadCloser, err error, ) { - return reader, new(bool), new(sync.WaitGroup), nil - } - var readerCloser executor.LoadDataReaderCloser = func( - r io.ReadCloser, drained *bool, - wg *sync.WaitGroup, err error, - ) error { - r.Close() - return nil + return reader, nil } ctx.SetValue(executor.LoadDataReaderBuilderKey, readerBuilder) - ctx.SetValue(executor.LoadDataReaderCloseKey, readerCloser) tk.MustExec(loadSQL) warnings := tk.Session().GetSessionVars().StmtCtx.GetWarnings() for _, w := range warnings { @@ -91,7 +82,7 @@ func TestLoadDataInitParam(t *testing.T) { // null def values testFunc := func(sql string, expectedNullDef []string, expectedNullOptEnclosed bool) { - require.NoError(t, tk.ExecToErr(sql)) + require.ErrorContains(t, tk.ExecToErr(sql), "reader is nil") defer ctx.SetValue(executor.LoadDataVarKey, nil) ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) require.True(t, ok) @@ -113,11 +104,26 @@ func TestLoadDataInitParam(t *testing.T) { []string{"NULL"}, false) // positive case - require.NoError(t, tk.ExecToErr("load data local infile '/a' format 'sql file' into table load_data_test")) + require.ErrorContains( + t, tk.ExecToErr( + "load data local infile '/a' format 'sql file' into table"+ + " load_data_test", + ), "reader is nil", + ) ctx.SetValue(executor.LoadDataVarKey, nil) - require.NoError(t, tk.ExecToErr("load data local infile '/a' into table load_data_test fields terminated by 'a'")) + require.ErrorContains( + t, tk.ExecToErr( + "load data local infile '/a' into table load_data_test fields"+ + " terminated by 'a'", + ), "reader is nil", + ) ctx.SetValue(executor.LoadDataVarKey, nil) - require.NoError(t, tk.ExecToErr("load data local infile '/a' format 'delimited data' into table load_data_test fields terminated by 'a'")) + require.ErrorContains( + t, tk.ExecToErr( + "load data local infile '/a' format 'delimited data' into"+ + " table load_data_test fields terminated by 'a'", + ), "reader is nil", + ) ctx.SetValue(executor.LoadDataVarKey, nil) // According to https://dev.mysql.com/doc/refman/8.0/en/load-data.html , fixed-row format should be used when fields diff --git a/pkg/executor/test/writetest/write_test.go b/pkg/executor/test/writetest/write_test.go index d279259fc0636..727e1b5c3431d 100644 --- a/pkg/executor/test/writetest/write_test.go +++ b/pkg/executor/test/writetest/write_test.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "io" "testing" "github.com/pingcap/failpoint" @@ -167,25 +168,26 @@ type testCase struct { func checkCases( tests []testCase, - ld *executor.LoadDataWorker, + loadSQL string, t *testing.T, tk *testkit.TestKit, ctx sessionctx.Context, selectSQL, deleteSQL string, ) { for _, tt := range tests { - parser, err := mydump.NewCSVParser( - context.Background(), - ld.GetController().GenerateCSVConfig(), - mydump.NewStringReader(string(tt.data)), - 1, - nil, - false, - nil) - require.NoError(t, err) + var reader io.ReadCloser = mydump.NewStringReader(string(tt.data)) + var readerBuilder executor.LoadDataReaderBuilder = func(_ string) ( + r io.ReadCloser, err error, + ) { + return reader, nil + } - err = ld.TestLoadLocal(parser) - require.NoError(t, err) + ctx.SetValue(executor.LoadDataReaderBuilderKey, readerBuilder) + tk.MustExec(loadSQL) + warnings := tk.Session().GetSessionVars().StmtCtx.GetWarnings() + for _, w := range warnings { + fmt.Printf("warnnig: %#v\n", w.Err.Error()) + } require.Equal(t, tt.expectedMsg, tk.Session().LastMessage(), tt.expected) tk.MustQuery(selectSQL).Check(testkit.RowsWithSep("|", tt.expected...)) tk.MustExec(deleteSQL) @@ -198,12 +200,8 @@ func TestLoadDataMissingColumn(t *testing.T) { tk.MustExec("use test") createSQL := `create table load_data_missing (id int, t timestamp not null)` tk.MustExec(createSQL) - tk.MustExec("load data local infile '/tmp/nonexistence.csv' ignore into table load_data_missing") + loadSQL := "load data local infile '/tmp/nonexistence.csv' ignore into table load_data_missing" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) deleteSQL := "delete from load_data_missing" selectSQL := "select id, hour(t), minute(t) from load_data_missing;" @@ -215,7 +213,7 @@ func TestLoadDataMissingColumn(t *testing.T) { {[]byte(""), nil, "Records: 0 Deleted: 0 Skipped: 0 Warnings: 0"}, {[]byte("12\n"), []string{fmt.Sprintf("12|%v|%v", timeHour, timeMinute)}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, } - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) tk.MustExec("alter table load_data_missing add column t2 timestamp null") curTime = types.CurrentTime(mysql.TypeTimestamp) @@ -225,7 +223,7 @@ func TestLoadDataMissingColumn(t *testing.T) { tests = []testCase{ {[]byte("12\n"), []string{fmt.Sprintf("12|%v|%v|", timeHour, timeMinute)}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, } - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) } func TestIssue18681(t *testing.T) { @@ -235,12 +233,8 @@ func TestIssue18681(t *testing.T) { createSQL := `drop table if exists load_data_test; create table load_data_test (a bit(1),b bit(1),c bit(1),d bit(1));` tk.MustExec(createSQL) - tk.MustExec("load data local infile '/tmp/nonexistence.csv' ignore into table load_data_test") + loadSQL := "load data local infile '/tmp/nonexistence.csv' ignore into table load_data_test" ctx := tk.Session().(sessionctx.Context) - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - defer ctx.SetValue(executor.LoadDataVarKey, nil) - require.NotNil(t, ld) deleteSQL := "delete from load_data_test" selectSQL := "select bin(a), bin(b), bin(c), bin(d) from load_data_test;" @@ -256,7 +250,7 @@ func TestIssue18681(t *testing.T) { tests := []testCase{ {[]byte("true\tfalse\t0\t1\n"), []string{"1|0|0|1"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 0"}, } - checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) + checkCases(tests, loadSQL, t, tk, ctx, selectSQL, deleteSQL) require.Equal(t, uint16(0), sc.WarningCount()) } @@ -270,13 +264,12 @@ func TestIssue34358(t *testing.T) { tk.MustExec("drop table if exists load_data_test") tk.MustExec("create table load_data_test (a varchar(10), b varchar(10))") - tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test ( @v1, @v2 ) set a = @v1, b = @v2") - ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataWorker) - require.True(t, ok) - require.NotNil(t, ld) + loadSQL := "load data local infile '/tmp/nonexistence.csv' into table load_data_test ( @v1, " + + "@v2 ) set a = @v1, b = @v2" checkCases([]testCase{ {[]byte("\\N\n"), []string{"|"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 1"}, - }, ld, t, tk, ctx, "select * from load_data_test", "delete from load_data_test") + }, loadSQL, t, tk, ctx, "select * from load_data_test", "delete from load_data_test", + ) } func TestLatch(t *testing.T) { diff --git a/pkg/privilege/privileges/privileges_test.go b/pkg/privilege/privileges/privileges_test.go index a8c39ceda572a..043ae41ca52fe 100644 --- a/pkg/privilege/privileges/privileges_test.go +++ b/pkg/privilege/privileges/privileges_test.go @@ -1003,7 +1003,8 @@ func TestLoadDataPrivilege(t *testing.T) { require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) tk.MustExec(`GRANT INSERT on *.* to 'test_load'@'localhost'`) require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "test_load", Hostname: "localhost"}, nil, nil, nil)) - tk.MustExec("LOAD DATA LOCAL INFILE '/tmp/load_data_priv.csv' INTO TABLE t_load") + err = tk.ExecToErr("LOAD DATA LOCAL INFILE '/tmp/load_data_priv.csv' INTO TABLE t_load") + require.ErrorContains(t, err, "reader is nil") require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) tk.MustExec(`GRANT INSERT on *.* to 'test_load'@'localhost'`) diff --git a/pkg/server/conn.go b/pkg/server/conn.go index b77d4df25d7c8..2f82f08416baa 100644 --- a/pkg/server/conn.go +++ b/pkg/server/conn.go @@ -2001,114 +2001,81 @@ func (cc *clientConn) preprocessLoadDataLocal(ctx context.Context) error { if cc.capability&mysql.ClientLocalFiles == 0 { return servererr.ErrNotAllowedCommand } + var readerBuilder executor.LoadDataReaderBuilder = func(filepath string) ( - r io.ReadCloser, drained *bool, wg *sync.WaitGroup, err error, + io.ReadCloser, error, ) { - err = cc.writeReq(ctx, filepath) + err := cc.writeReq(ctx, filepath) if err != nil { - return nil, drained, wg, err + return nil, err } - drained = new(bool) + + drained := false r, w := io.Pipe() - wg = &sync.WaitGroup{} - wg.Add(1) + go func() { - defer wg.Done() + var errOccurred error + defer func() { - err := w.Close() + if errOccurred != nil { + // Continue reading packets to drain the connection + for !drained { + data, err := cc.readPacket() + if err != nil { + logutil.Logger(ctx).Error( + "drain connection failed in load data", + zap.Error(err), + ) + break + } + if len(data) == 0 { + drained = true + } + } + } + err := w.CloseWithError(errOccurred) if err != nil { logutil.Logger(ctx).Error( - "close pipe meet error in load data", + "close pipe failed in `load data`", zap.Error(err), ) - return } }() - var ( - err2 error - data []byte - ) for { - if len(data) == 0 { - data, err2 = cc.readPacket() - if err2 != nil { - err4 := w.CloseWithError(err2) - if err4 != nil { - logutil.Logger(ctx).Error( - "close pipe meet error in load data", - zap.Error(err4), - ) - } - return - } - // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_local_infile_request.html - if len(data) == 0 { - *drained = true - return - } + data, err := cc.readPacket() + if err != nil { + errOccurred = err + return } - n, err3 := w.Write(data) - if err3 != nil { - logutil.Logger(ctx).Error( - "write data meet error in load data", - zap.Error(err3), - ) + if len(data) == 0 { + drained = true return } - data = data[n:] - } - }() - return r, drained, wg, nil - } - var readerCloser executor.LoadDataReaderCloser = func( - r io.ReadCloser, drained *bool, - wg *sync.WaitGroup, err error, - ) error { - _ = r.Close() - wg.Wait() - if err != nil { - if !*drained { - logutil.Logger(ctx).Info("not drained yet, try reading left data from client connection") - } - // drain the data from client conn util empty packet received, otherwise the connection will be reset - // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_local_infile_request.html - for !*drained { - // check kill flag again, let the draining loop could quit if empty packet could not be received - if atomic.CompareAndSwapUint32( - &cc.ctx.GetSessionVars().SQLKiller.Signal, - 1, - 0, - ) { - logutil.Logger(ctx).Warn("receiving kill, stop draining data, connection may be reset") - return exeerrors.ErrQueryInterrupted - } - curData, err1 := cc.readPacket() - if err1 != nil { - logutil.Logger(ctx).Error( - "drain reading left data encounter errors", - zap.Error(err1), - ) - break - } - if len(curData) == 0 { - *drained = true - logutil.Logger(ctx).Info("draining finished for error", zap.Error(err)) - break + + // Write all content in `data` + for len(data) > 0 { + n, err := w.Write(data) + if err != nil { + errOccurred = err + return + } + data = data[n:] } } - } - return err + }() + + return r, nil } + cc.ctx.SetValue(executor.LoadDataReaderBuilderKey, readerBuilder) - cc.ctx.SetValue(executor.LoadDataReaderCloseKey, readerCloser) + return nil } func (cc *clientConn) postprocessLoadDataLocal() { cc.ctx.ClearValue(executor.LoadDataReaderBuilderKey) - cc.ctx.ClearValue(executor.LoadDataReaderCloseKey) } func (cc *clientConn) handleFileTransInConn(ctx context.Context, status uint16) (bool, error) { From 06bcc1354b2416fa23f583d97e96925edc85dd4c Mon Sep 17 00:00:00 2001 From: ekexium Date: Thu, 14 Dec 2023 23:22:43 +0800 Subject: [PATCH 8/8] test: optimistic transaction write conflict Signed-off-by: ekexium --- .../testserverclient/server_client.go | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/pkg/server/internal/testserverclient/server_client.go b/pkg/server/internal/testserverclient/server_client.go index 6e375c247b28e..0222217ef025e 100644 --- a/pkg/server/internal/testserverclient/server_client.go +++ b/pkg/server/internal/testserverclient/server_client.go @@ -1093,6 +1093,7 @@ func (cli *TestServerClient) RunTestLoadDataInTransaction(t *testing.T) { config.AllowAllFiles = true config.Params["sql_mode"] = "''" }, dbName, func(dbt *testkit.DBTestKit) { + dbt.MustExec("set @@global.tidb_txn_mode = 'pessimistic'") dbt.MustExec("create table t (a int primary key)") txn, err := dbt.GetDB().Begin() require.NoError(t, err) @@ -1141,6 +1142,36 @@ func (cli *TestServerClient) RunTestLoadDataInTransaction(t *testing.T) { }, ) + dbName = "LoadDataInExplicitTransaction" + cli.RunTestsOnNewDB( + t, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }, dbName, func(dbt *testkit.DBTestKit) { + // in optimistic txn, one should not block another + dbt.MustExec("set @@global.tidb_txn_mode = 'optimistic'") + dbt.MustExec("create table t (a int primary key)") + txn1, err := dbt.GetDB().Begin() + require.NoError(t, err) + txn2, err := dbt.GetDB().Begin() + require.NoError(t, err) + _, err = txn1.Exec(fmt.Sprintf("USE `%s`;", dbName)) + require.NoError(t, err) + _, err = txn2.Exec(fmt.Sprintf("USE `%s`;", dbName)) + require.NoError(t, err) + _, err = txn1.Exec(fmt.Sprintf("load data local infile %q into table t", path)) + require.NoError(t, err) + _, err = txn2.Exec(fmt.Sprintf("load data local infile %q into table t", path)) + require.NoError(t, err) + err = txn1.Commit() + require.NoError(t, err) + err = txn2.Commit() + require.ErrorContains(t, err, "Write conflict") + rows := dbt.MustQuery("select * from t") + cli.CheckRows(t, rows, "1") + }, + ) + cli.RunTestsOnNewDB( t, func(config *mysql.Config) { config.AllowAllFiles = true