diff --git a/pipe.go b/pipe.go index f1fe9930..e72b3671 100644 --- a/pipe.go +++ b/pipe.go @@ -880,8 +880,8 @@ func (p *pipe) syncDo(dl time.Time, dlOk bool, cmd cmds.Completed) (resp RedisRe err = context.DeadlineExceeded } p.error.CompareAndSwap(nil, &errs{error: err}) - atomic.CompareAndSwapInt32(&p.state, 1, 3) // stopping the worker and let it do the cleaning - p.background() // start the background worker + p.conn.Close() + p.background() // start the background worker to clean up goroutines } return newResult(msg, err) } @@ -921,8 +921,8 @@ abort: err = context.DeadlineExceeded } p.error.CompareAndSwap(nil, &errs{error: err}) - atomic.CompareAndSwapInt32(&p.state, 1, 3) // stopping the worker and let it do the cleaning - p.background() // start the background worker + p.conn.Close() + p.background() // start the background worker to clean up goroutines for i := 0; i < len(resp); i++ { resp[i] = newErrResult(err) } diff --git a/pipe_test.go b/pipe_test.go index ae8587c3..f8bd36e7 100644 --- a/pipe_test.go +++ b/pipe_test.go @@ -2841,6 +2841,56 @@ func TestForceClose_DoMulti_Canceled_Block(t *testing.T) { p.Close() } +func TestSyncModeSwitchingWithDeadlineExceed_Do(t *testing.T) { + p, mock, _, closeConn := setup(t, ClientOption{}) + defer closeConn() + + ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond*100) + defer cancel() + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + if err := p.Do(ctx, cmds.NewCompleted([]string{"GET", "a"})).NonRedisError(); !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("unexpected err %v", err) + } + wg.Done() + }() + } + + mock.Expect("GET", "a") + time.Sleep(time.Microsecond * 200) + mock.Expect().ReplyString("OK") + wg.Wait() + p.Close() +} + +func TestSyncModeSwitchingWithDeadlineExceed_DoMulti(t *testing.T) { + p, mock, _, closeConn := setup(t, ClientOption{}) + defer closeConn() + + ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond*100) + defer cancel() + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + if err := p.DoMulti(ctx, cmds.NewCompleted([]string{"GET", "a"}))[0].NonRedisError(); !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("unexpected err %v", err) + } + wg.Done() + }() + } + + mock.Expect("GET", "a") + time.Sleep(time.Microsecond * 200) + mock.Expect().ReplyString("OK") + wg.Wait() + p.Close() +} + func TestOngoingDeadlineContextInSyncMode_Do(t *testing.T) { p, _, _, closeConn := setup(t, ClientOption{}) defer closeConn()