From 20a442fead53a37e7cda1e3acf6ae2368b34347a Mon Sep 17 00:00:00 2001 From: D3Hunter Date: Wed, 7 Jun 2023 20:19:42 +0800 Subject: [PATCH] import into: precheck and register to pd (#44313) ref pingcap/tidb#42930 --- Makefile | 2 + br/pkg/mock/BUILD.bazel | 1 + br/pkg/mock/task_register.go | 77 +++++++++++ br/pkg/utils/BUILD.bazel | 2 +- br/pkg/utils/register.go | 77 +++++++++-- br/pkg/utils/register_test.go | 35 +++++ disttask/loaddata/BUILD.bazel | 2 + disttask/loaddata/dispatcher.go | 107 ++++++++++++++- disttask/loaddata/subtask_executor.go | 7 + executor/importer/BUILD.bazel | 14 +- executor/importer/precheck.go | 125 ++++++++++-------- executor/importer/precheck_test.go | 30 ----- tests/realtikvtest/loaddatatest/BUILD.bazel | 6 + .../loaddatatest/load_data_test.go | 92 +++++++++++++ .../loaddatatest/precheck_test.go | 58 +++++++- 15 files changed, 521 insertions(+), 114 deletions(-) create mode 100644 br/pkg/mock/task_register.go delete mode 100644 executor/importer/precheck_test.go diff --git a/Makefile b/Makefile index 910def78bca4b..09583ea08bd45 100644 --- a/Makefile +++ b/Makefile @@ -360,10 +360,12 @@ br_compatibility_test: mock_s3iface: @mockgen -package mock github.com/aws/aws-sdk-go/service/s3/s3iface S3API > br/pkg/mock/s3iface.go +# mock interface for lightning and IMPORT INTO mock_lightning: @mockgen -package mock github.com/pingcap/tidb/br/pkg/lightning/backend Backend,EngineWriter,TargetInfoGetter,ChunkFlushStatus > br/pkg/mock/backend.go @mockgen -package mock github.com/pingcap/tidb/br/pkg/lightning/backend/encode Encoder,EncodingBuilder,Rows,Row > br/pkg/mock/encode.go @mockgen -package mocklocal github.com/pingcap/tidb/br/pkg/lightning/backend/local DiskUsage,TiKVModeSwitcher > br/pkg/mock/mocklocal/local.go + @mockgen -package mock github.com/pingcap/tidb/br/pkg/utils TaskRegister > br/pkg/mock/task_register.go # There is no FreeBSD environment for GitHub actions. So cross-compile on Linux # but that doesn't work with CGO_ENABLED=1, so disable cgo. The reason to have diff --git a/br/pkg/mock/BUILD.bazel b/br/pkg/mock/BUILD.bazel index 1a42fe49e7a38..4c842da816705 100644 --- a/br/pkg/mock/BUILD.bazel +++ b/br/pkg/mock/BUILD.bazel @@ -8,6 +8,7 @@ go_library( "importer.go", "mock_cluster.go", "s3iface.go", + "task_register.go", ], importpath = "github.com/pingcap/tidb/br/pkg/mock", visibility = ["//visibility:public"], diff --git a/br/pkg/mock/task_register.go b/br/pkg/mock/task_register.go new file mode 100644 index 0000000000000..609f073fa7417 --- /dev/null +++ b/br/pkg/mock/task_register.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/pingcap/tidb/br/pkg/utils (interfaces: TaskRegister) + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockTaskRegister is a mock of TaskRegister interface. +type MockTaskRegister struct { + ctrl *gomock.Controller + recorder *MockTaskRegisterMockRecorder +} + +// MockTaskRegisterMockRecorder is the mock recorder for MockTaskRegister. +type MockTaskRegisterMockRecorder struct { + mock *MockTaskRegister +} + +// NewMockTaskRegister creates a new mock instance. +func NewMockTaskRegister(ctrl *gomock.Controller) *MockTaskRegister { + mock := &MockTaskRegister{ctrl: ctrl} + mock.recorder = &MockTaskRegisterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTaskRegister) EXPECT() *MockTaskRegisterMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockTaskRegister) Close(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockTaskRegisterMockRecorder) Close(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockTaskRegister)(nil).Close), arg0) +} + +// RegisterTask mocks base method. +func (m *MockTaskRegister) RegisterTask(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterTask", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterTask indicates an expected call of RegisterTask. +func (mr *MockTaskRegisterMockRecorder) RegisterTask(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterTask", reflect.TypeOf((*MockTaskRegister)(nil).RegisterTask), arg0) +} + +// RegisterTaskOnce mocks base method. +func (m *MockTaskRegister) RegisterTaskOnce(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterTaskOnce", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterTaskOnce indicates an expected call of RegisterTaskOnce. +func (mr *MockTaskRegisterMockRecorder) RegisterTaskOnce(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterTaskOnce", reflect.TypeOf((*MockTaskRegister)(nil).RegisterTaskOnce), arg0) +} diff --git a/br/pkg/utils/BUILD.bazel b/br/pkg/utils/BUILD.bazel index 3ca5b6b83554a..fe6ea07bff214 100644 --- a/br/pkg/utils/BUILD.bazel +++ b/br/pkg/utils/BUILD.bazel @@ -90,7 +90,7 @@ go_test( ], embed = [":utils"], flaky = True, - shard_count = 29, + shard_count = 30, deps = [ "//br/pkg/errors", "//br/pkg/metautil", diff --git a/br/pkg/utils/register.go b/br/pkg/utils/register.go index 7116fdfd50a6f..95a102ae68d26 100644 --- a/br/pkg/utils/register.go +++ b/br/pkg/utils/register.go @@ -34,6 +34,7 @@ type RegisterTaskType int const ( RegisterRestore RegisterTaskType = iota RegisterLightning + RegisterImportInto ) func (tp RegisterTaskType) String() string { @@ -42,20 +43,40 @@ func (tp RegisterTaskType) String() string { return "restore" case RegisterLightning: return "lightning" + case RegisterImportInto: + return "import-into" } return "default" } // The key format should be {RegisterImportTaskPrefix}/{RegisterTaskType}/{taskName} const ( + // RegisterImportTaskPrefix is the prefix of the key for task register + // todo: remove "/import" suffix, it's confusing to have a key like "/tidb/brie/import/restore/restore-xxx" RegisterImportTaskPrefix = "/tidb/brie/import" RegisterRetryInternal = 10 * time.Second defaultTaskRegisterTTL = 3 * time.Minute // 3 minutes ) -// TaskRegister can register the task to PD with a lease, and keepalive it in the background -type TaskRegister struct { +// TaskRegister can register the task to PD with a lease. +type TaskRegister interface { + // Close closes the background task if using RegisterTask + // and revoke the lease. + // NOTE: we don't close the etcd client here, call should do it. + Close(ctx context.Context) (err error) + // RegisterTask firstly put its key to PD with a lease, + // and start to keepalive the lease in the background. + // DO NOT mix calls to RegisterTask and RegisterTaskOnce. + RegisterTask(c context.Context) error + // RegisterTaskOnce put its key to PD with a lease if the key does not exist, + // else we refresh the lease. + // you have to call this method periodically to keep the lease alive. + // DO NOT mix calls to RegisterTask and RegisterTaskOnce. + RegisterTaskOnce(ctx context.Context) error +} + +type taskRegister struct { client *clientv3.Client ttl time.Duration secondTTL int64 @@ -68,8 +89,8 @@ type TaskRegister struct { } // NewTaskRegisterWithTTL build a TaskRegister with key format {RegisterTaskPrefix}/{RegisterTaskType}/{taskName} -func NewTaskRegisterWithTTL(client *clientv3.Client, ttl time.Duration, tp RegisterTaskType, taskName string) *TaskRegister { - return &TaskRegister{ +func NewTaskRegisterWithTTL(client *clientv3.Client, ttl time.Duration, tp RegisterTaskType, taskName string) TaskRegister { + return &taskRegister{ client: client, ttl: ttl, secondTTL: int64(ttl / time.Second), @@ -80,13 +101,16 @@ func NewTaskRegisterWithTTL(client *clientv3.Client, ttl time.Duration, tp Regis } // NewTaskRegister build a TaskRegister with key format {RegisterTaskPrefix}/{RegisterTaskType}/{taskName} -func NewTaskRegister(client *clientv3.Client, tp RegisterTaskType, taskName string) *TaskRegister { +func NewTaskRegister(client *clientv3.Client, tp RegisterTaskType, taskName string) TaskRegister { return NewTaskRegisterWithTTL(client, defaultTaskRegisterTTL, tp, taskName) } -// Close closes the background task of taskRegister -func (tr *TaskRegister) Close(ctx context.Context) (err error) { - tr.cancel() +// Close implements the TaskRegister interface +func (tr *taskRegister) Close(ctx context.Context) (err error) { + // not needed if using RegisterTaskOnce + if tr.cancel != nil { + tr.cancel() + } tr.wg.Wait() if tr.curLeaseID != clientv3.NoLease { _, err = tr.client.Lease.Revoke(ctx, tr.curLeaseID) @@ -97,7 +121,7 @@ func (tr *TaskRegister) Close(ctx context.Context) (err error) { return err } -func (tr *TaskRegister) grant(ctx context.Context) (*clientv3.LeaseGrantResponse, error) { +func (tr *taskRegister) grant(ctx context.Context) (*clientv3.LeaseGrantResponse, error) { lease, err := tr.client.Lease.Grant(ctx, tr.secondTTL) if err != nil { return nil, err @@ -108,9 +132,36 @@ func (tr *TaskRegister) grant(ctx context.Context) (*clientv3.LeaseGrantResponse return lease, nil } -// RegisterTask firstly put its key to PD with a lease, -// and start to keepalive the lease in the background. -func (tr *TaskRegister) RegisterTask(c context.Context) error { +// RegisterTaskOnce implements the TaskRegister interface +func (tr *taskRegister) RegisterTaskOnce(ctx context.Context) error { + resp, err := tr.client.Get(ctx, tr.key) + if err != nil { + return errors.Trace(err) + } + if len(resp.Kvs) == 0 { + lease, err2 := tr.grant(ctx) + if err2 != nil { + return errors.Annotatef(err2, "failed grant a lease") + } + tr.curLeaseID = lease.ID + _, err2 = tr.client.KV.Put(ctx, tr.key, "", clientv3.WithLease(lease.ID)) + if err2 != nil { + return errors.Trace(err2) + } + } else { + // if the task is run distributively, like IMPORT INTO, we should refresh the lease ID, + // in case the owner changed during the registration, and the new owner create the key. + tr.curLeaseID = clientv3.LeaseID(resp.Kvs[0].Lease) + _, err2 := tr.client.Lease.KeepAliveOnce(ctx, tr.curLeaseID) + if err2 != nil { + return errors.Trace(err2) + } + } + return nil +} + +// RegisterTask implements the TaskRegister interface +func (tr *taskRegister) RegisterTask(c context.Context) error { cctx, cancel := context.WithCancel(c) tr.cancel = cancel lease, err := tr.grant(cctx) @@ -133,7 +184,7 @@ func (tr *TaskRegister) RegisterTask(c context.Context) error { return nil } -func (tr *TaskRegister) keepaliveLoop(ctx context.Context, ch <-chan *clientv3.LeaseKeepAliveResponse) { +func (tr *taskRegister) keepaliveLoop(ctx context.Context, ch <-chan *clientv3.LeaseKeepAliveResponse) { defer tr.wg.Done() const minTimeLeftThreshold time.Duration = 20 * time.Second var ( diff --git a/br/pkg/utils/register_test.go b/br/pkg/utils/register_test.go index b3a889ed82e8f..aeaef6fac58a5 100644 --- a/br/pkg/utils/register_test.go +++ b/br/pkg/utils/register_test.go @@ -46,6 +46,41 @@ func TestTaskRegister(t *testing.T) { require.NoError(t, register.Close(ctx)) } +func TestTaskRegisterOnce(t *testing.T) { + integration.BeforeTestExternal(t) + testEtcdCluster := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1}) + defer testEtcdCluster.Terminate(t) + + // should not close the client manually, the test will fail, since Terminate will close it too. + client := testEtcdCluster.RandClient() + + ctx := context.Background() + register := NewTaskRegisterWithTTL(client, 10*time.Second, RegisterImportInto, "test") + defer register.Close(ctx) + err := register.RegisterTaskOnce(ctx) + require.NoError(t, err) + + // sleep 3 seconds to make sure the lease TTL is smaller. + time.Sleep(3 * time.Second) + list, err := GetImportTasksFrom(ctx, client) + require.NoError(t, err) + require.Len(t, list.Tasks, 1) + currTask := list.Tasks[0] + t.Log(currTask.MessageToUser()) + require.Equal(t, "/tidb/brie/import/import-into/test", currTask.Key) + + // then register again, this time will only refresh the lease, and left TTL will be larger. + err = register.RegisterTaskOnce(ctx) + require.NoError(t, err) + list, err = GetImportTasksFrom(ctx, client) + require.NoError(t, err) + require.Len(t, list.Tasks, 1) + thisTask := list.Tasks[0] + require.Equal(t, currTask.Key, thisTask.Key) + require.Equal(t, currTask.LeaseID, thisTask.LeaseID) + require.Greater(t, thisTask.TTL, currTask.TTL) +} + func TestTaskRegisterFailedGrant(t *testing.T) { integration.BeforeTestExternal(t) testEtcdCluster := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1, GRPCKeepAliveInterval: time.Second, GRPCKeepAliveTimeout: 10 * time.Second}) diff --git a/disttask/loaddata/BUILD.bazel b/disttask/loaddata/BUILD.bazel index ed1cf2ebbe6ca..b352a6163880e 100644 --- a/disttask/loaddata/BUILD.bazel +++ b/disttask/loaddata/BUILD.bazel @@ -20,6 +20,7 @@ go_library( "//br/pkg/lightning/config", "//br/pkg/lightning/mydump", "//br/pkg/lightning/verification", + "//br/pkg/utils", "//disttask/framework/dispatcher", "//disttask/framework/handle", "//disttask/framework/proto", @@ -32,6 +33,7 @@ go_library( "//parser/mysql", "//sessionctx", "//table/tables", + "//util/etcd", "//util/logutil", "//util/sqlexec", "@com_github_go_sql_driver_mysql//:mysql", diff --git a/disttask/loaddata/dispatcher.go b/disttask/loaddata/dispatcher.go index 8513e6ddec603..b51d4cb5282ae 100644 --- a/disttask/loaddata/dispatcher.go +++ b/disttask/loaddata/dispatcher.go @@ -17,6 +17,7 @@ package loaddata import ( "context" "encoding/json" + "strconv" "strings" "sync" "time" @@ -28,6 +29,7 @@ import ( "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/config" verify "github.com/pingcap/tidb/br/pkg/lightning/verification" + "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/disttask/framework/dispatcher" "github.com/pingcap/tidb/disttask/framework/proto" "github.com/pingcap/tidb/domain/infosync" @@ -35,6 +37,7 @@ import ( "github.com/pingcap/tidb/executor/importer" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table/tables" + "github.com/pingcap/tidb/util/etcd" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/sqlexec" "go.uber.org/atomic" @@ -42,18 +45,102 @@ import ( "go.uber.org/zap" ) +const ( + registerTaskTTL = 10 * time.Minute + refreshTaskTTLInterval = 3 * time.Minute + registerTimeout = 5 * time.Second +) + +// NewTaskRegisterWithTTL is the ctor for TaskRegister. +// It is exported for testing. +var NewTaskRegisterWithTTL = utils.NewTaskRegisterWithTTL + +type taskInfo struct { + taskID int64 + + // operation on taskInfo is run inside detect-task goroutine, so no need to synchronize. + lastRegisterTime time.Time + + // initialized lazily in register() + etcdClient *etcd.Client + taskRegister utils.TaskRegister +} + +func (t *taskInfo) register(ctx context.Context) { + if time.Since(t.lastRegisterTime) < refreshTaskTTLInterval { + return + } + + if time.Since(t.lastRegisterTime) < refreshTaskTTLInterval { + return + } + logger := logutil.BgLogger().With(zap.Int64("task_id", t.taskID)) + if t.taskRegister == nil { + client, err := importer.GetEtcdClient() + if err != nil { + logger.Warn("get etcd client failed", zap.Error(err)) + return + } + t.etcdClient = client + t.taskRegister = NewTaskRegisterWithTTL(client.GetClient(), registerTaskTTL, + utils.RegisterImportInto, strconv.FormatInt(t.taskID, 10)) + } + timeoutCtx, cancel := context.WithTimeout(ctx, registerTimeout) + defer cancel() + if err := t.taskRegister.RegisterTaskOnce(timeoutCtx); err != nil { + logger.Warn("register task failed", zap.Error(err)) + } else { + logger.Info("register task to pd or refresh lease success") + } + // we set it even if register failed, TTL is 10min, refresh interval is 3min, + // we can try 2 times before the lease is expired. + t.lastRegisterTime = time.Now() +} + +func (t *taskInfo) close(ctx context.Context) { + logger := logutil.BgLogger().With(zap.Int64("task_id", t.taskID)) + if t.taskRegister != nil { + timeoutCtx, cancel := context.WithTimeout(ctx, registerTimeout) + defer cancel() + if err := t.taskRegister.Close(timeoutCtx); err != nil { + logger.Warn("unregister task failed", zap.Error(err)) + } else { + logger.Info("unregister task success") + } + t.taskRegister = nil + } + if t.etcdClient != nil { + if err := t.etcdClient.Close(); err != nil { + logger.Warn("close etcd client failed", zap.Error(err)) + } + t.etcdClient = nil + } +} + type flowHandle struct { mu sync.RWMutex + // NOTE: there's no need to sync for below 2 fields actually, since we add a restriction that only one + // task can be running at a time. but we might support task queuing in the future, leave it for now. // the last time we switch TiKV into IMPORT mode, this is a global operation, do it for one task makes // no difference to do it for all tasks. So we do not need to record the switch time for each task. lastSwitchTime atomic.Time + // taskInfoMap is a map from taskID to taskInfo + taskInfoMap sync.Map } var _ dispatcher.TaskFlowHandle = (*flowHandle)(nil) func (h *flowHandle) OnTicker(ctx context.Context, task *proto.Task) { - // only switch TiKV mode when task is running and reach the interval - if task.State != proto.TaskStateRunning || time.Since(h.lastSwitchTime.Load()) < config.DefaultSwitchTiKVModeInterval { + // only switch TiKV mode or register task when task is running + if task.State != proto.TaskStateRunning { + return + } + h.switchTiKVMode(ctx, task) + h.registerTask(ctx, task) +} + +func (h *flowHandle) switchTiKVMode(ctx context.Context, task *proto.Task) { + if time.Since(h.lastSwitchTime.Load()) < config.DefaultSwitchTiKVModeInterval { return } @@ -73,6 +160,19 @@ func (h *flowHandle) OnTicker(ctx context.Context, task *proto.Task) { h.lastSwitchTime.Store(time.Now()) } +func (h *flowHandle) registerTask(ctx context.Context, task *proto.Task) { + val, _ := h.taskInfoMap.LoadOrStore(task.ID, &taskInfo{taskID: task.ID}) + info := val.(*taskInfo) + info.register(ctx) +} + +func (h *flowHandle) unregisterTask(ctx context.Context, task *proto.Task) { + if val, loaded := h.taskInfoMap.LoadAndDelete(task.ID); loaded { + info := val.(*taskInfo) + info.close(ctx) + } +} + func (h *flowHandle) ProcessNormalFlow(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task) ([][]byte, error) { logger := logutil.BgLogger().With(zap.String("component", "dispatcher"), zap.String("type", gTask.Type), zap.Int64("ID", gTask.ID)) taskMeta := &TaskMeta{} @@ -104,6 +204,7 @@ func (h *flowHandle) ProcessNormalFlow(ctx context.Context, handle dispatcher.Ta return metaBytes, nil case Import: h.switchTiKV2NormalMode(ctx, logutil.BgLogger()) + defer h.unregisterTask(ctx, gTask) if err := postProcess(ctx, handle, gTask, taskMeta, logger); err != nil { return nil, err } @@ -118,6 +219,8 @@ func (h *flowHandle) ProcessErrFlow(ctx context.Context, handle dispatcher.TaskH logger := logutil.BgLogger().With(zap.String("component", "dispatcher"), zap.String("type", gTask.Type), zap.Int64("ID", gTask.ID)) logger.Info("process error flow", zap.ByteStrings("error message", receiveErr)) h.switchTiKV2NormalMode(ctx, logger) + h.unregisterTask(ctx, gTask) + gTask.Error = receiveErr[0] errStr := string(receiveErr[0]) diff --git a/disttask/loaddata/subtask_executor.go b/disttask/loaddata/subtask_executor.go index 7df9133de6b33..b9b7e93fd87eb 100644 --- a/disttask/loaddata/subtask_executor.go +++ b/disttask/loaddata/subtask_executor.go @@ -27,6 +27,9 @@ import ( "go.uber.org/zap" ) +// TestSyncChan is used to test. +var TestSyncChan = make(chan struct{}) + // ImportMinimalTaskExecutor is a subtask executor for load data. type ImportMinimalTaskExecutor struct { task *MinimalTaskMeta @@ -40,6 +43,10 @@ func (e *ImportMinimalTaskExecutor) Run(ctx context.Context) error { time.Sleep(3 * time.Second) // wait ToImportMode called failpoint.Return(errors.New("occur an error when sort chunk")) }) + failpoint.Inject("syncBeforeSortChunk", func() { + TestSyncChan <- struct{}{} + <-TestSyncChan + }) chunkCheckpoint := toChunkCheckpoint(e.task.Chunk) sharedVars := e.task.SharedVars if err := importer.ProcessChunk(ctx, &chunkCheckpoint, sharedVars.TableImporter, sharedVars.DataEngine, sharedVars.IndexEngine, sharedVars.Progress, logger); err != nil { diff --git a/executor/importer/BUILD.bazel b/executor/importer/BUILD.bazel index 7e2cbdcc2b849..cd12ba6099875 100644 --- a/executor/importer/BUILD.bazel +++ b/executor/importer/BUILD.bazel @@ -22,9 +22,10 @@ go_library( "//br/pkg/lightning/config", "//br/pkg/lightning/log", "//br/pkg/lightning/mydump", - "//br/pkg/lightning/precheck", "//br/pkg/lightning/verification", "//br/pkg/storage", + "//br/pkg/streamhelper", + "//br/pkg/utils", "//config", "//executor/asyncloaddata", "//expression", @@ -47,14 +48,13 @@ go_library( "//util/chunk", "//util/dbterror", "//util/dbterror/exeerrors", + "//util/etcd", "//util/filter", "//util/intest", "//util/logutil", "//util/sqlexec", "//util/stringutil", "@com_github_docker_go_units//:go-units", - "@com_github_jedib0t_go_pretty_v6//table", - "@com_github_jedib0t_go_pretty_v6//text", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_log//:log", "@com_github_tikv_client_go_v2//config", @@ -70,18 +70,14 @@ go_library( go_test( name = "importer_test", timeout = "short", - srcs = [ - "import_test.go", - "precheck_test.go", - ], + srcs = ["import_test.go"], embed = [":importer"], flaky = True, race = "on", - shard_count = 6, + shard_count = 5, deps = [ "//br/pkg/errors", "//br/pkg/lightning/config", - "//br/pkg/lightning/precheck", "//expression", "//parser", "//parser/ast", diff --git a/executor/importer/precheck.go b/executor/importer/precheck.go index 3d883c9c0e014..56d5fe243534a 100644 --- a/executor/importer/precheck.go +++ b/executor/importer/precheck.go @@ -17,20 +17,41 @@ package importer import ( "context" "fmt" + "time" - "github.com/jedib0t/go-pretty/v6/table" - "github.com/jedib0t/go-pretty/v6/text" + "github.com/pingcap/errors" "github.com/pingcap/tidb/br/pkg/lightning/common" - "github.com/pingcap/tidb/br/pkg/lightning/precheck" + "github.com/pingcap/tidb/br/pkg/streamhelper" + "github.com/pingcap/tidb/br/pkg/utils" + tidb "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/dbterror/exeerrors" + "github.com/pingcap/tidb/util/etcd" "github.com/pingcap/tidb/util/sqlexec" ) -// CheckRequirements checks the requirements for load data. +const ( + etcdDialTimeout = 5 * time.Second +) + +// CheckRequirements checks the requirements for IMPORT INTO. +// we check the following things here: +// 1. target table should be empty +// 2. no CDC or PiTR tasks running +// +// todo: check if there's running lightning tasks? +// we check them one by one, and return the first error we meet. +// todo: check all items and return all errors at once. func (e *LoadDataController) CheckRequirements(ctx context.Context, conn sqlexec.SQLExecutor) error { - collector := newPreCheckCollector() // todo: maybe we can reuse checker in lightning + if err := e.checkTableEmpty(ctx, conn); err != nil { + return err + } + return e.checkCDCPiTRTasks(ctx) +} + +func (e *LoadDataController) checkTableEmpty(ctx context.Context, conn sqlexec.SQLExecutor) error { sql := fmt.Sprintf("SELECT 1 FROM %s USE INDEX() LIMIT 1", common.UniqueTable(e.DBName, e.Table.Meta().Name.L)) rs, err := conn.ExecuteInternal(ctx, sql) if err != nil { @@ -42,66 +63,56 @@ func (e *LoadDataController) CheckRequirements(ctx context.Context, conn sqlexec return err } if len(rows) > 0 { - collector.fail(precheck.CheckTargetTableEmpty, "target table is not empty") - } else { - collector.pass(precheck.CheckTargetTableEmpty) - } - if !collector.success() { - return exeerrors.ErrLoadDataPreCheckFailed.FastGenByArgs("\n" + collector.output()) + return exeerrors.ErrLoadDataPreCheckFailed.FastGenByArgs("target table is not empty") } return nil } -const ( - failed = "failed" - passed = "passed" -) - -type preCheckCollector struct { - failCount int - t table.Writer -} - -func newPreCheckCollector() *preCheckCollector { - t := table.NewWriter() - t.AppendHeader(table.Row{"Check Item", "Result", "Detailed Message"}) - t.SetColumnConfigs([]table.ColumnConfig{ - {Name: "Check Item", WidthMax: 20}, - {Name: "Result", WidthMax: 6}, - {Name: "Detailed Message", WidthMax: 130}, - }) - style := table.StyleDefault - style.Format.Header = text.FormatDefault - t.SetStyle(style) - return &preCheckCollector{ - t: t, +func (e *LoadDataController) checkCDCPiTRTasks(ctx context.Context) error { + cli, err := GetEtcdClient() + if err != nil { + return err } -} + defer terror.Call(cli.Close) -func (c *preCheckCollector) fail(item precheck.CheckItemID, msg string) { - c.failCount++ - c.t.AppendRow(table.Row{item.DisplayName(), failed, msg}) - c.t.AppendSeparator() -} + pitrCli := streamhelper.NewMetaDataClient(cli.GetClient()) + tasks, err := pitrCli.GetAllTasks(ctx) + if err != nil { + return err + } + if len(tasks) > 0 { + names := make([]string, 0, len(tasks)) + for _, task := range tasks { + names = append(names, task.Info.GetName()) + } + return exeerrors.ErrLoadDataPreCheckFailed.FastGenByArgs(fmt.Sprintf("found PiTR log streaming task(s): %v,", names)) + } -func (c *preCheckCollector) success() bool { - return c.failCount == 0 -} + nameSet, err := utils.GetCDCChangefeedNameSet(ctx, cli.GetClient()) + if err != nil { + return errors.Trace(err) + } -func (c *preCheckCollector) output() string { - c.t.SetAllowedRowLength(170) - c.t.SetRowPainter(func(row table.Row) text.Colors { - if result, ok := row[1].(string); ok { - if result == failed { - return text.Colors{text.FgRed} - } - } - return nil - }) - return c.t.Render() + "\n" + if !nameSet.Empty() { + return exeerrors.ErrLoadDataPreCheckFailed.FastGenByArgs(nameSet.MessageToUser()) + } + return nil } -func (c *preCheckCollector) pass(item precheck.CheckItemID) { - c.t.AppendRow(table.Row{item.DisplayName(), passed, ""}) - c.t.AppendSeparator() +// GetEtcdClient returns an etcd client. +// exported for testing. +func GetEtcdClient() (*etcd.Client, error) { + tidbCfg := tidb.GetGlobalConfig() + tls, err := util.NewTLSConfig( + util.WithCAPath(tidbCfg.Security.ClusterSSLCA), + util.WithCertAndKeyPath(tidbCfg.Security.ClusterSSLCert, tidbCfg.Security.ClusterSSLKey), + ) + if err != nil { + return nil, err + } + ectdEndpoints, err := util.ParseHostPortAddr(tidbCfg.Path) + if err != nil { + return nil, err + } + return etcd.NewClientFromCfg(ectdEndpoints, etcdDialTimeout, "", tls) } diff --git a/executor/importer/precheck_test.go b/executor/importer/precheck_test.go deleted file mode 100644 index af8120c62a0af..0000000000000 --- a/executor/importer/precheck_test.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importer - -import ( - "testing" - - "github.com/pingcap/tidb/br/pkg/lightning/precheck" - "github.com/stretchr/testify/require" -) - -func TestPreCheckCollector(t *testing.T) { - c := newPreCheckCollector() - require.True(t, c.success()) - - c.fail(precheck.CheckTargetTableEmpty, "target table is not empty") - require.False(t, c.success()) -} diff --git a/tests/realtikvtest/loaddatatest/BUILD.bazel b/tests/realtikvtest/loaddatatest/BUILD.bazel index ac65954058403..21312959c2744 100644 --- a/tests/realtikvtest/loaddatatest/BUILD.bazel +++ b/tests/realtikvtest/loaddatatest/BUILD.bazel @@ -21,7 +21,10 @@ go_test( "//br/pkg/lightning/backend/local", "//br/pkg/lightning/common", "//br/pkg/lightning/config", + "//br/pkg/mock", "//br/pkg/mock/mocklocal", + "//br/pkg/streamhelper", + "//br/pkg/utils", "//config", "//disttask/framework/storage", "//disttask/loaddata", @@ -41,9 +44,12 @@ go_test( "@com_github_golang_mock//gomock", "@com_github_ngaut_pools//:pools", "@com_github_pingcap_failpoint//:failpoint", + "@com_github_pingcap_kvproto//pkg/brpb", "@com_github_pingcap_log//:log", "@com_github_stretchr_testify//require", "@com_github_stretchr_testify//suite", + "@io_etcd_go_etcd_client_v3//:client", + "@org_uber_go_atomic//:atomic", "@org_uber_go_zap//:zap", ], ) diff --git a/tests/realtikvtest/loaddatatest/load_data_test.go b/tests/realtikvtest/loaddatatest/load_data_test.go index 642bf9a92cf72..627ac07417c15 100644 --- a/tests/realtikvtest/loaddatatest/load_data_test.go +++ b/tests/realtikvtest/loaddatatest/load_data_test.go @@ -23,16 +23,20 @@ import ( "os" "path" "strconv" + "sync" "time" "github.com/fsouza/fake-gcs-server/fakestorage" "github.com/golang/mock/gomock" "github.com/ngaut/pools" + "github.com/pingcap/failpoint" "github.com/pingcap/log" "github.com/pingcap/tidb/br/pkg/lightning/backend/local" "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/config" + "github.com/pingcap/tidb/br/pkg/mock" "github.com/pingcap/tidb/br/pkg/mock/mocklocal" + "github.com/pingcap/tidb/br/pkg/utils" "github.com/pingcap/tidb/disttask/framework/storage" "github.com/pingcap/tidb/disttask/loaddata" "github.com/pingcap/tidb/domain/infosync" @@ -43,6 +47,8 @@ import ( "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util/sem" "github.com/stretchr/testify/require" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/atomic" "go.uber.org/zap" ) @@ -880,6 +886,92 @@ func (s *mockGCSSuite) TestImportMode() { s.Greater(intoNormalTime, intoImportTime) } +func (s *mockGCSSuite) TestRegisterTask() { + var registerTime, unregisterTime time.Time + var taskID atomic.String + controller := gomock.NewController(s.T()) + taskRegister := mock.NewMockTaskRegister(controller) + mockedRegister := func(ctx context.Context) error { + log.L().Info("register task", zap.String("task_id", taskID.Load())) + registerTime = time.Now() + return nil + } + mockedClose := func(ctx context.Context) error { + log.L().Info("unregister task", zap.String("task_id", taskID.Load())) + unregisterTime = time.Now() + return nil + } + taskRegister.EXPECT().RegisterTaskOnce(gomock.Any()).DoAndReturn(mockedRegister).Times(1) + taskRegister.EXPECT().Close(gomock.Any()).DoAndReturn(mockedClose).Times(1) + backup := loaddata.NewTaskRegisterWithTTL + loaddata.NewTaskRegisterWithTTL = func(_ *clientv3.Client, _ time.Duration, _ utils.RegisterTaskType, name string) utils.TaskRegister { + // we use taskID as the task name + taskID.Store(name) + return taskRegister + } + s.T().Cleanup(func() { + loaddata.NewTaskRegisterWithTTL = backup + }) + + s.tk.MustExec("DROP DATABASE IF EXISTS load_data;") + s.tk.MustExec("CREATE DATABASE load_data;") + s.tk.MustExec(`CREATE TABLE load_data.register_task (a INT, b INT, c int);`) + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-load", Name: "register_task-1.tsv"}, + Content: []byte("1,11,111"), + }) + + // NOTE: this case only runs when current instance is TiDB owner, if you run it locally, + // better start a cluster without TiDB instance. + sql := fmt.Sprintf(`IMPORT INTO load_data.register_task FROM 'gs://test-load/register_task-*.tsv?endpoint=%s'`, gcsEndpoint) + s.tk.MustExec(sql) + s.tk.MustQuery("SELECT * FROM load_data.register_task;").Sort().Check(testkit.Rows("1 11 111")) + s.Greater(unregisterTime, registerTime) + + // on error, we should also unregister the task + registerTime, unregisterTime = time.Time{}, time.Time{} + taskRegister.EXPECT().RegisterTaskOnce(gomock.Any()).DoAndReturn(mockedRegister).Times(1) + taskRegister.EXPECT().Close(gomock.Any()).DoAndReturn(mockedClose).Times(1) + s.tk.MustExec("truncate table load_data.register_task;") + s.enableFailpoint("github.com/pingcap/tidb/disttask/loaddata/errorWhenSortChunk", "return(true)") + err := s.tk.ExecToErr(sql) + s.Error(err) + s.Greater(unregisterTime, registerTime) + + loaddata.NewTaskRegisterWithTTL = backup + s.NoError(failpoint.Disable("github.com/pingcap/tidb/disttask/loaddata/errorWhenSortChunk")) + s.enableFailpoint("github.com/pingcap/tidb/disttask/loaddata/syncBeforeSortChunk", "return(true)") + s.enableFailpoint("github.com/pingcap/tidb/disttask/framework/storage/testSetLastTaskID", "return(true)") + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + s.tk.MustExec(sql) + }() + // wait for the task to be registered + <-loaddata.TestSyncChan + client, err := importer.GetEtcdClient() + s.NoError(err) + s.T().Cleanup(func() { + _ = client.Close() + }) + etcdKey := fmt.Sprintf("/tidb/brie/import/import-into/%d", storage.TestLastTaskID.Load()) + s.Eventually(func() bool { + resp, err2 := client.GetClient().Get(context.Background(), etcdKey) + s.NoError(err2) + return len(resp.Kvs) == 1 + }, 5*time.Second, 300*time.Millisecond) + // continue the execution + loaddata.TestSyncChan <- struct{}{} + wg.Wait() + s.tk.MustQuery("SELECT * FROM load_data.register_task;").Sort().Check(testkit.Rows("1 11 111")) + + // the task should be unregistered + resp, err2 := client.GetClient().Get(context.Background(), etcdKey) + s.NoError(err2) + s.Len(resp.Kvs, 0) +} + func (s *mockGCSSuite) TestAddIndexBySQL() { s.tk.MustExec("DROP DATABASE IF EXISTS load_data;") s.tk.MustExec("CREATE DATABASE load_data;") diff --git a/tests/realtikvtest/loaddatatest/precheck_test.go b/tests/realtikvtest/loaddatatest/precheck_test.go index a01112182a158..acb5d6a5a0f37 100644 --- a/tests/realtikvtest/loaddatatest/precheck_test.go +++ b/tests/realtikvtest/loaddatatest/precheck_test.go @@ -15,11 +15,17 @@ package loaddatatest import ( + "context" "fmt" "github.com/fsouza/fake-gcs-server/fakestorage" + brpb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/streamhelper" + "github.com/pingcap/tidb/executor/importer" "github.com/pingcap/tidb/util/dbterror/exeerrors" "github.com/stretchr/testify/require" + "go.uber.org/zap" ) func (s *mockGCSSuite) TestPreCheckTableNotEmpty() { @@ -36,7 +42,55 @@ func (s *mockGCSSuite) TestPreCheckTableNotEmpty() { s.tk.MustExec("drop table if exists t;") s.tk.MustExec("create table t (a bigint primary key, b varchar(100), c int);") s.tk.MustExec("insert into t values(9, 'test9', 99);") - loadDataSQL := fmt.Sprintf(`IMPORT INTO t FROM 'gs://precheck-tbl-empty/file.csv?endpoint=%s'`, gcsEndpoint) - err := s.tk.ExecToErr(loadDataSQL) + sql := fmt.Sprintf(`IMPORT INTO t FROM 'gs://precheck-tbl-empty/file.csv?endpoint=%s'`, gcsEndpoint) + err := s.tk.ExecToErr(sql) require.ErrorIs(s.T(), err, exeerrors.ErrLoadDataPreCheckFailed) } + +func (s *mockGCSSuite) TestPreCheckCDCPiTRTasks() { + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "precheck-cdc-pitr", Name: "file.csv"}, + Content: []byte(`1,test1,11`), + }) + s.prepareAndUseDB("load_data") + s.tk.MustExec("drop table if exists t;") + s.tk.MustExec("create table t (a bigint primary key, b varchar(100), c int);") + + client, err := importer.GetEtcdClient() + s.NoError(err) + s.T().Cleanup(func() { + _ = client.Close() + }) + + // Note: br has a background worker listening on this key, after this test + // there'll be keys like "/tidb/br-stream/checkpoint/dummy-task/store/1" left in etcd + // see OwnerManagerForLogBackup for more details + pitrKey := streamhelper.PrefixOfTask() + "dummy-task" + pitrTaskInfo := brpb.StreamBackupTaskInfo{Name: "dummy-task"} + data, err := pitrTaskInfo.Marshal() + s.NoError(err) + _, err = client.GetClient().Put(context.Background(), pitrKey, string(data)) + s.NoError(err) + s.T().Cleanup(func() { + _, err2 := client.GetClient().Delete(context.Background(), pitrKey) + s.NoError(err2) + }) + sql := fmt.Sprintf(`IMPORT INTO t FROM 'gs://precheck-cdc-pitr/file.csv?endpoint=%s'`, gcsEndpoint) + err = s.tk.ExecToErr(sql) + log.Error("error", zap.Error(err)) + s.ErrorIs(err, exeerrors.ErrLoadDataPreCheckFailed) + s.ErrorContains(err, "found PiTR log streaming task(s): [dummy-task],") + + _, err2 := client.GetClient().Delete(context.Background(), pitrKey) + s.NoError(err2) + cdcKey := "/tidb/cdc/cluster-123/test/changefeed/info/feed-test" + _, err = client.GetClient().Put(context.Background(), cdcKey, `{"state": "normal"}`) + s.NoError(err) + s.T().Cleanup(func() { + _, err2 := client.GetClient().Delete(context.Background(), cdcKey) + s.NoError(err2) + }) + err = s.tk.ExecToErr(sql) + s.ErrorIs(err, exeerrors.ErrLoadDataPreCheckFailed) + s.ErrorContains(err, "found CDC changefeed(s): cluster/namespace: cluster-123/test changefeed(s): [feed-test]") +}