Skip to content

Commit

Permalink
fixes the various scenarios where ConnectionPool.disconnect() and
Browse files Browse the repository at this point in the history
Connection.disconnect() were ripping connections out from under
the threads or processes that owned them

redis#732
  • Loading branch information
rolette authored and jdost committed Jun 24, 2019
1 parent da2ffc7 commit 791cab8
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 14 deletions.
84 changes: 71 additions & 13 deletions redis/connection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import unicode_literals
from __future__ import unicode_literals, with_statement
from copy import copy
from distutils.version import StrictVersion
from itertools import chain
import io
Expand Down Expand Up @@ -402,7 +403,8 @@ def __init__(self, host='localhost', port=6379, db=0, password=None,
socket_keepalive=False, socket_keepalive_options=None,
socket_type=0, retry_on_timeout=False, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
parser_class=DefaultParser, socket_read_size=65536):
parser_class=DefaultParser, socket_read_size=65536,
**kwargs):
self.pid = os.getpid()
self.host = host
self.port = int(port)
Expand All @@ -426,6 +428,14 @@ def __init__(self, host='localhost', port=6379, db=0, password=None,
self._connect_callbacks = []
self._buffer_cutoff = 6000

# If the connection isn't used in a process that forks, there are
# certain optimizations we can make. Default assumes that we have to
# be safe for process forks.
self.fork_safe = kwargs.get('fork_safe', True)

# Connection pool generation id
self.pool_generation = 0

def __repr__(self):
return self.description_format % self._description_args

Expand Down Expand Up @@ -542,7 +552,13 @@ def disconnect(self):
self._selector.close()
self._selector = None
try:
if os.getpid() == self.pid:
# socket.shutdown() kills the underlying TCP connection
# immediately. Good when you can do it, but it ignores
# the ref count on the descriptor. If the process forked
# and the child inherited the socket, it's not safe to
# call .shutdown(). Just close() the socket and let the OS
# close the connection when appropriate.
if os.getpid() == self.pid and not self.fork_safe:
self._sock.shutdown(socket.SHUT_RDWR)
self._sock.close()
except socket.error:
Expand Down Expand Up @@ -728,7 +744,8 @@ def __init__(self, path='', db=0, password=None,
socket_timeout=None, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
retry_on_timeout=False,
parser_class=DefaultParser, socket_read_size=65536):
parser_class=DefaultParser, socket_read_size=65536,
**kwargs):
self.pid = os.getpid()
self.path = path
self.db = db
Expand All @@ -745,6 +762,14 @@ def __init__(self, path='', db=0, password=None,
self._connect_callbacks = []
self._buffer_cutoff = 6000

# If the connection isn't used in a process that forks, there are
# certain optimizations we can make. Default assumes that we have to
# be safe for process forks.
self.fork_safe = kwargs.get('fork_safe', True)

# Connection pool generation id
self.pool_generation = 0

def _connect(self):
"Create a Unix domain socket connection"
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
Expand Down Expand Up @@ -923,6 +948,9 @@ def __init__(self, connection_class=Connection, max_connections=None,
self.connection_class = connection_class
self.connection_kwargs = connection_kwargs
self.max_connections = max_connections
self.fork_safe = connection_kwargs.get('fork_safe', True)

self.generation = 0

self.reset()

Expand All @@ -940,19 +968,27 @@ def reset(self):
self._check_lock = threading.Lock()

def _checkpid(self):
if self.pid != os.getpid():
# No need to check PID if process isn't supposed to fork
if self.fork_safe and self.pid != os.getpid():
with self._check_lock:
if self.pid == os.getpid():
# another thread already did the work while we waited
# on the lock.
return
self.disconnect()
self.reset()

def get_connection(self, command_name, *keys, **options):
"Get a connection from the pool"
self._checkpid()
try:
connection = self._available_connections.pop()
if connection.pool_generation != self.generation:
# generation counter mismatch: let this connection go and
# create a new one
connection.disconnect()
self._created_connections -= 1
connection = self.make_connection()
except IndexError:
connection = self.make_connection()
self._in_use_connections.add(connection)
Expand Down Expand Up @@ -988,24 +1024,46 @@ def make_connection(self):
"Create a new connection"
if self._created_connections >= self.max_connections:
raise ConnectionError("Too many connections")

new_connection = self.connection_class(**self.connection_kwargs)
new_connection.pool_generation = self.generation
self._created_connections += 1
return self.connection_class(**self.connection_kwargs)
return new_connection

def release(self, connection):
"Releases the connection back to the pool"
self._checkpid()
if connection.pid != self.pid:
return
self._in_use_connections.remove(connection)
self._available_connections.append(connection)

def disconnect(self):
"Disconnects all connections in the pool"
self._checkpid()
all_conns = chain(self._available_connections,
self._in_use_connections)
for connection in all_conns:
# Verify generation id before putting connection back in the free list
if connection.pool_generation == self.generation:
self._available_connections.append(connection)
else:
# generation id mismatch, kill the connection
connection.disconnect()
self._created_connections -= 1

def disconnect(self, immediate=False):
"""
Disconnects all connections in the pool
By default, the disconnect happens over time as connections
cycle into and out of the pool. This avoids the problem of ripping
connections out from under threads that are using them.
If `immediate` is True, then we forcibly disconnect all connections
and leave it to the owner of the connection to deal with the various
errors that can occur. This option is not recommended.
"""
self._checkpid()
self.generation += 1
if immediate:
all_conns = chain(copy(self._available_connections),
copy(self._in_use_connections))
for connection in all_conns:
connection.disconnect()


class BlockingConnectionPool(ConnectionPool):
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def teardown():
# handle cases where a test disconnected a client
# just manually retry the flushdb
client.flushdb()
client.connection_pool.disconnect()
client.connection_pool.disconnect(immediate=True)
request.addfinalizer(teardown)
return client

Expand Down

0 comments on commit 791cab8

Please sign in to comment.