Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

txn: make load data transactional #49079

Merged
merged 9 commits into from
Dec 18, 2023
86 changes: 61 additions & 25 deletions pkg/executor/load_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,64 @@ import (
"golang.org/x/sync/errgroup"
)

// LoadDataVarKey is a variable key for load data.
const LoadDataVarKey loadDataVarKeyType = 0

// LoadDataReaderBuilderKey stores the reader channel that reads from the connection.
const LoadDataReaderBuilderKey loadDataVarKeyType = 1

var (
taskQueueSize = 16 // the maximum number of pending tasks to commit in queue
)

// LoadDataReaderBuilder is a function type that builds a reader from a file path.
type LoadDataReaderBuilder func(filepath string) (
r io.ReadCloser, err error,
)

// LoadDataExec represents a load data executor.
type LoadDataExec struct {
exec.BaseExecutor

FileLocRef ast.FileLocRefTp
loadDataWorker *LoadDataWorker

// fields for loading local file
infileReader io.ReadCloser
}

// Open implements the Executor interface.
func (e *LoadDataExec) Open(_ context.Context) error {
if rb, ok := e.Ctx().Value(LoadDataReaderBuilderKey).(LoadDataReaderBuilder); ok {
var err error
e.infileReader, err = rb(e.loadDataWorker.GetInfilePath())
if err != nil {
return err
}
}
return nil
}

// Close implements the Executor interface.
func (e *LoadDataExec) Close() error {
return e.closeLocalReader(nil)
}

func (e *LoadDataExec) closeLocalReader(originalErr error) error {
err := originalErr
if e.infileReader != nil {
if err2 := e.infileReader.Close(); err2 != nil {
logutil.BgLogger().Error(
"close local reader failed", zap.Error(err2),
zap.NamedError("original error", originalErr),
)
if err == nil {
err = err2
}
}
e.infileReader = nil
}
return err
}

// Next implements the Executor Next interface.
Expand All @@ -66,14 +114,17 @@ func (e *LoadDataExec) Next(ctx context.Context, _ *chunk.Chunk) (err error) {
case ast.FileLocServerOrRemote:
return e.loadDataWorker.loadRemote(ctx)
case ast.FileLocClient:
// let caller use handleFileTransInConn to read data in this connection
// This is for legacy test only
// TODO: adjust tests to remove LoadDataVarKey
sctx := e.loadDataWorker.UserSctx
val := sctx.Value(LoadDataVarKey)
if val != nil {
sctx.SetValue(LoadDataVarKey, nil)
return errors.New("previous load data option wasn't closed normally")
}
sctx.SetValue(LoadDataVarKey, e.loadDataWorker)

err = e.loadDataWorker.LoadLocal(ctx, e.infileReader)
if err != nil {
logutil.Logger(ctx).Error("load local data failed", zap.Error(err))
err = e.closeLocalReader(err)
return err
}
}
return nil
}
Expand Down Expand Up @@ -145,6 +196,10 @@ func (e *LoadDataWorker) loadRemote(ctx context.Context) error {

// LoadLocal reads from client connection and do load data job.
func (e *LoadDataWorker) LoadLocal(ctx context.Context, r io.ReadCloser) error {
if r == nil {
return errors.New("load local data, reader is nil")
}

compressTp := mydump.ParseCompressionOnFileExtension(e.GetInfilePath())
compressTp2, err := mydump.ToStorageCompressType(compressTp)
if err != nil {
Expand Down Expand Up @@ -172,11 +227,6 @@ func (e *LoadDataWorker) load(ctx context.Context, readerInfos []importer.LoadDa
commitTaskCh := make(chan commitTask, taskQueueSize)
// commitWork goroutines -> done -> UpdateJobProgress goroutine

// TODO: support explicit transaction and non-autocommit
if err = sessiontxn.NewTxn(groupCtx, e.UserSctx); err != nil {
return err
}

// processOneStream goroutines.
group.Go(func() error {
err2 := encoder.processStream(groupCtx, readerInfoCh, commitTaskCh)
Expand Down Expand Up @@ -530,16 +580,6 @@ func (w *commitWorker) commitWork(ctx context.Context, inCh <-chan commitTask) (
zap.Stack("stack"))
err = util.GetRecoverError(r)
}

if err != nil {
background := context.Background()
w.Ctx().StmtRollback(background, false)
w.Ctx().RollbackTxn(background)
} else {
if err = w.Ctx().CommitTxn(ctx); err != nil {
logutil.Logger(ctx).Error("commit error refresh", zap.Error(err))
}
}
}()

var (
Expand Down Expand Up @@ -578,7 +618,6 @@ func (w *commitWorker) commitOneTask(ctx context.Context, task commitTask) error
failpoint.Inject("commitOneTaskErr", func() {
failpoint.Return(errors.New("mock commit one task error"))
})
w.Ctx().StmtCommit(ctx)
return nil
}

Expand Down Expand Up @@ -734,6 +773,3 @@ type loadDataVarKeyType int
func (loadDataVarKeyType) String() string {
return "load_data_var"
}

// LoadDataVarKey is a variable key for load data.
const LoadDataVarKey loadDataVarKeyType = 0
2 changes: 1 addition & 1 deletion pkg/executor/test/loaddatatest/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ go_test(
],
flaky = True,
race = "on",
shard_count = 10,
shard_count = 11,
deps = [
"//br/pkg/lightning/mydump",
"//pkg/config",
Expand Down
Loading