diff --git a/pool.go b/pool.go index 396c0ce..abfa912 100644 --- a/pool.go +++ b/pool.go @@ -343,19 +343,14 @@ func (p *pool) runReconnect() func(context.Context) { } } -func (p *pool) checkConn(pc *poolConn) bool { - select { - case <-pc.lastIOErrCh: - default: - return true - } +func (p *pool) discardConn(pc *poolConn, reason trace.PoolConnClosedReason) { // ensure that the discard logic for the conn only occurs once, specifically // buffering a message on reconnectCh. var ok bool pc.once.Do(func() { ok = true }) if !ok { - return false + return } err := p.proc.WithLock(func() error { @@ -363,13 +358,12 @@ func (p *pool) checkConn(pc *poolConn) bool { return nil }) if err != nil { - return false + return } pc.Close() - p.traceConnClosed(trace.PoolConnClosedReasonError) + p.traceConnClosed(reason) p.reconnectCh <- struct{}{} - return false } func (p *pool) getConn(ctx context.Context) (*poolConn, error) { @@ -417,9 +411,15 @@ func (p *pool) useSharedConn(ctx context.Context, a Action) error { pc = p.conns.get(i) return pc.Do(ctx, a) }) + if pc != nil { - p.checkConn(pc) + select { + case <-pc.lastIOErrCh: + p.discardConn(pc, trace.PoolConnClosedReasonError) + default: + } } + if pc != nil || err != nil { return err } @@ -446,7 +446,13 @@ func (p *pool) Do(ctx context.Context, a Action) error { } err = pc.Do(ctx, a) - if p.checkConn(pc) { + if err != nil && !isRespErr(err) { + // Non-shared conns are used for commands which might block. Therefore + // any non-application errors result in closing the connection, because + // it might still have some blocking command holding it up, and we don't + // want to have other connections be blocked by it. + p.discardConn(pc, trace.PoolConnClosedReasonError) + } else { p.putConn(pc) } return err diff --git a/pool_test.go b/pool_test.go index e4cbffe..e1fcc89 100644 --- a/pool_test.go +++ b/pool_test.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" "sync" + "testing" . "testing" "time" @@ -528,3 +529,29 @@ func TestPoolClose(t *T) { assert.NoError(t, h.pool.Close()) assert.Error(t, proc.ErrClosed, h.pool.Do(h.ctx, Cmd(nil, "PING"))) } + +func TestPoolIssue344(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + t.Log("creating pool") + rc, err := (poolConfig{PoolConfig: PoolConfig{ + Size: 1, + }}).new(ctx, "tcp", "127.0.0.1:6379") + assert.NoError(t, err) + + { + t.Log("forcing a timeout") + ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + err := rc.Do(ctx, Cmd(new([]string), "BLPOP", randStr(), "0")) + cancel() + assert.True(t, errors.Is(err, context.DeadlineExceeded)) + } + + { + t.Log("pinging") + var pong string + assert.NoError(t, rc.Do(ctx, Cmd(&pong, "PING"))) + assert.Equal(t, "PONG", pong) + } +}