Skip to content

Commit

Permalink
fixes the various scenarios where ConnectionPool.disconnect() and
Browse files Browse the repository at this point in the history
Connection.disconnect() were ripping connections out from under
the threads or processes that owned them

redis#732
  • Loading branch information
rolette authored and willhug committed Aug 12, 2021
1 parent 03c12da commit d4aaae7
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 19 deletions.
95 changes: 77 additions & 18 deletions redis/connection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import with_statement
from copy import copy
from distutils.version import StrictVersion
from itertools import chain
import os
Expand Down Expand Up @@ -440,7 +441,8 @@ def __init__(self, host='localhost', port=6379, db=0, password=None,
socket_keepalive=False, socket_keepalive_options=None,
retry_on_timeout=False, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
parser_class=DefaultParser, socket_read_size=65536):
parser_class=DefaultParser, socket_read_size=65536,
**kwargs):
self.pid = os.getpid()
self.host = host
self.port = int(port)
Expand All @@ -461,6 +463,14 @@ def __init__(self, host='localhost', port=6379, db=0, password=None,
}
self._connect_callbacks = []

# If the connection isn't used in a process that forks, there are
# certain optimizations we can make. Default assumes that we have to
# be safe for process forks.
self.fork_safe = kwargs.get('fork_safe', True)

# Connection pool generation id
self.pool_generation = 0

def __repr__(self):
return self.description_format % self._description_args

Expand Down Expand Up @@ -573,7 +583,14 @@ def disconnect(self):
if self._sock is None:
return
try:
self._sock.shutdown(socket.SHUT_RDWR)
# socket.shutdown() kills the underlying TCP connection
# immediately. Good when you can do it, but it ignores
# the ref count on the descriptor. If the process forked
# and the child inherited the socket, it's not safe to
# call .shutdown(). Just close() the socket and let the OS
# close the connection when appropriate.
if not self.fork_safe:
self._sock.shutdown(socket.SHUT_RDWR)
self._sock.close()
except socket.error:
pass
Expand Down Expand Up @@ -729,7 +746,8 @@ def __init__(self, path='', db=0, password=None,
socket_timeout=None, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
retry_on_timeout=False,
parser_class=DefaultParser, socket_read_size=65536):
parser_class=DefaultParser, socket_read_size=65536,
**kwargs):
self.pid = os.getpid()
self.path = path
self.db = db
Expand All @@ -745,6 +763,14 @@ def __init__(self, path='', db=0, password=None,
}
self._connect_callbacks = []

# If the connection isn't used in a process that forks, there are
# certain optimizations we can make. Default assumes that we have to
# be safe for process forks.
self.fork_safe = kwargs.get('fork_safe', True)

# Connection pool generation id
self.pool_generation = 0

def _connect(self):
"Create a Unix domain socket connection"
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
Expand Down Expand Up @@ -930,6 +956,9 @@ def __init__(self, connection_class=Connection, max_connections=None,
self.connection_class = connection_class
self.connection_kwargs = connection_kwargs
self.max_connections = max_connections
self.fork_safe = connection_kwargs.get('fork_safe', True)

self.generation = 0

self.reset()

Expand All @@ -947,20 +976,28 @@ def reset(self):
self._check_lock = threading.Lock()

def _checkpid(self):
if self.pid != os.getpid():
with self._check_lock:
if self.pid == os.getpid():
# another thread already did the work while we waited
# on the lock.
return
self.disconnect()
self.reset()
# No need to check PID if process isn't supposed to fork
if self.fork_safe:
if self.pid != os.getpid():
with self._check_lock:
if self.pid == os.getpid():
# another thread already did the work while we waited
# on the lock.
return
self.disconnect()
self.reset()

def get_connection(self, command_name, *keys, **options):
"Get a connection from the pool"
self._checkpid()
try:
connection = self._available_connections.pop()
if connection.pool_generation != self.generation:
# generation counter mismatch: let this connection go and
# create a new one
connection.disconnect()
self._created_connections -= 1
connection = self.make_connection()
except IndexError:
connection = self.make_connection()
self._in_use_connections.add(connection)
Expand All @@ -979,23 +1016,45 @@ def make_connection(self):
"Create a new connection"
if self._created_connections >= self.max_connections:
raise ConnectionError("Too many connections")

new_connection = self.connection_class(**self.connection_kwargs)
new_connection.pool_generation = self.generation
self._created_connections += 1
return self.connection_class(**self.connection_kwargs)
return new_connection

def release(self, connection):
"Releases the connection back to the pool"
self._checkpid()
if connection.pid != self.pid:
return
self._in_use_connections.remove(connection)
self._available_connections.append(connection)

def disconnect(self):
"Disconnects all connections in the pool"
all_conns = chain(self._available_connections,
self._in_use_connections)
for connection in all_conns:
# Verify generation id before putting connection back in the free list
if connection.pool_generation == self.generation:
self._available_connections.append(connection)
else:
# generation id mismatch, kill the connection
connection.disconnect()
self._created_connections -= 1

def disconnect(self, immediate=False):
"""
Disconnects all connections in the pool
By default, the disconnect happens over time as connections
cycle into and out of the pool. This avoids the problem of ripping
connections out from under threads that are using them.
If `immediate` is True, then we forcibly disconnect all connections
and leave it to the owner of the connection to deal with the various
errors that can occur. This option is not recommended.
"""
self.generation += 1
if immediate:
all_conns = chain(copy(self._available_connections),
copy(self._in_use_connections))
for connection in all_conns:
connection.disconnect()


class BlockingConnectionPool(ConnectionPool):
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _get_client(cls, request=None, **kwargs):
if request:
def teardown():
client.flushdb()
client.connection_pool.disconnect()
client.connection_pool.disconnect(immediate=True)
request.addfinalizer(teardown)
return client

Expand Down

0 comments on commit d4aaae7

Please sign in to comment.