diff --git a/pkg/etcd/client_test.go b/pkg/etcd/client_test.go index efa03c6d795..42220d48e9d 100644 --- a/pkg/etcd/client_test.go +++ b/pkg/etcd/client_test.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/errors" "github.com/stretchr/testify/require" "go.etcd.io/etcd/clientv3" + "go.etcd.io/etcd/etcdserver/api/v3rpc/rpctypes" ) type mockClient struct { @@ -43,6 +44,10 @@ func (m *mockClient) Put(ctx context.Context, key, val string, opts ...clientv3. return nil, errors.New("mock error") } +func (m *mockClient) Txn(ctx context.Context) clientv3.Txn { + return &mockTxn{ctx: ctx} +} + type mockWatcher struct { clientv3.Watcher watchCh chan clientv3.WatchResponse @@ -82,6 +87,32 @@ func TestRetry(t *testing.T) { _, err = retrycli.Put(context.TODO(), "", "") require.NotNil(t, err) require.Containsf(t, errors.Cause(err).Error(), "mock error", "err:%v", err.Error()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Test Txn case + // case 0: normal + rsp, err := retrycli.Txn(ctx, nil, nil, nil) + require.Nil(t, err) + require.False(t, rsp.Succeeded) + + // case 1: errors.ErrReachMaxTry + _, err = retrycli.Txn(ctx, TxnEmptyCmps, nil, nil) + require.Regexp(t, ".*CDC:ErrReachMaxTry.*", err) + + // case 2: errors.ErrReachMaxTry + _, err = retrycli.Txn(ctx, nil, TxnEmptyOpsThen, nil) + require.Regexp(t, ".*CDC:ErrReachMaxTry.*", err) + + // case 3: context.DeadlineExceeded + _, err = retrycli.Txn(ctx, TxnEmptyCmps, TxnEmptyOpsThen, nil) + require.Equal(t, context.DeadlineExceeded, err) + + // other case: mock error + _, err = retrycli.Txn(ctx, TxnEmptyCmps, TxnEmptyOpsThen, TxnEmptyOpsElse) + require.Containsf(t, errors.Cause(err).Error(), "mock error", "err:%v", err.Error()) + maxTries = originValue } @@ -276,3 +307,44 @@ func TestRevisionNotFallBack(t *testing.T) { // while WatchCh was reset require.Equal(t, *watcher.rev, revision) } + +type mockTxn struct { + ctx context.Context + mode int +} + +func (txn *mockTxn) If(cs ...clientv3.Cmp) clientv3.Txn { + if cs != nil { + txn.mode += 1 + } + return txn +} + +func (txn *mockTxn) Then(ops ...clientv3.Op) clientv3.Txn { + if ops != nil { + txn.mode += 1 << 1 + } + return txn +} + +func (txn *mockTxn) Else(ops ...clientv3.Op) clientv3.Txn { + if ops != nil { + txn.mode += 1 << 2 + } + return txn +} + +func (txn *mockTxn) Commit() (*clientv3.TxnResponse, error) { + switch txn.mode { + case 0: + return &clientv3.TxnResponse{}, nil + case 1: + return nil, rpctypes.ErrNoSpace + case 2: + return nil, rpctypes.ErrTimeoutDueToLeaderFail + case 3: + return nil, context.DeadlineExceeded + default: + return nil, errors.New("mock error") + } +} diff --git a/pkg/orchestrator/etcd_worker.go b/pkg/orchestrator/etcd_worker.go index 823f7c37608..2dec376abcf 100644 --- a/pkg/orchestrator/etcd_worker.go +++ b/pkg/orchestrator/etcd_worker.go @@ -202,7 +202,7 @@ func (worker *EtcdWorker) Run(ctx context.Context, session *concurrency.Session, if len(pendingPatches) > 0 { // Here we have some patches yet to be uploaded to Etcd. pendingPatches, err = worker.applyPatchGroups(ctx, pendingPatches) - if isRetryableError(err) { + if cerrors.ErrEtcdTryAgain.Equal(err) { continue } if err != nil { @@ -254,13 +254,6 @@ func (worker *EtcdWorker) Run(ctx context.Context, session *concurrency.Session, } } -func isRetryableError(err error) bool { - err = errors.Cause(err) - return cerrors.ErrEtcdTryAgain.Equal(err) || - cerrors.ErrReachMaxTry.Equal(err) || - context.DeadlineExceeded == err -} - func (worker *EtcdWorker) handleEvent(_ context.Context, event *clientv3.Event) { if worker.isDeleteCounterKey(event.Kv.Key) { switch event.Type {