From 2cc9bafad8f0c851ab012635f09ae9b591fb1575 Mon Sep 17 00:00:00 2001 From: Jay Rolette Date: Sat, 10 Sep 2016 12:50:20 -0500 Subject: [PATCH] fixes the various scenarios where ConnectionPool.disconnect() and Connection.disconnect() were ripping connections out from under the threads or processes that owned them https://github.com/andymccurdy/redis-py/issues/732 --- redis/connection.py | 95 ++++++++++++++++++++++++++++++++++++--------- tests/conftest.py | 2 +- 2 files changed, 78 insertions(+), 19 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index ed051a3..cd49023 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,4 +1,5 @@ from __future__ import with_statement +from copy import copy from distutils.version import StrictVersion from itertools import chain import os @@ -409,7 +410,8 @@ def __init__(self, host='localhost', port=6379, db=0, password=None, socket_keepalive=False, socket_keepalive_options=None, retry_on_timeout=False, encoding='utf-8', encoding_errors='strict', decode_responses=False, - parser_class=DefaultParser, socket_read_size=65536): + parser_class=DefaultParser, socket_read_size=65536, + **kwargs): self.pid = os.getpid() self.host = host self.port = int(port) @@ -432,6 +434,14 @@ def __init__(self, host='localhost', port=6379, db=0, password=None, } self._connect_callbacks = [] + # If the connection isn't used in a process that forks, there are + # certain optimizations we can make. Default assumes that we have to + # be safe for process forks. + self.fork_safe = kwargs.get('fork_safe', True) + + # Connection pool generation id + self.pool_generation = 0 + def __repr__(self): return self.description_format % self._description_args @@ -544,7 +554,14 @@ def disconnect(self): if self._sock is None: return try: - self._sock.shutdown(socket.SHUT_RDWR) + # socket.shutdown() kills the underlying TCP connection + # immediately. Good when you can do it, but it ignores + # the ref count on the descriptor. If the process forked + # and the child inherited the socket, it's not safe to + # call .shutdown(). Just close() the socket and let the OS + # close the connection when appropriate. + if not self.fork_safe: + self._sock.shutdown(socket.SHUT_RDWR) self._sock.close() except socket.error: pass @@ -716,7 +733,8 @@ def __init__(self, path='', db=0, password=None, socket_timeout=None, encoding='utf-8', encoding_errors='strict', decode_responses=False, retry_on_timeout=False, - parser_class=DefaultParser, socket_read_size=65536): + parser_class=DefaultParser, socket_read_size=65536, + **kwargs): self.pid = os.getpid() self.path = path self.db = db @@ -734,6 +752,14 @@ def __init__(self, path='', db=0, password=None, } self._connect_callbacks = [] + # If the connection isn't used in a process that forks, there are + # certain optimizations we can make. Default assumes that we have to + # be safe for process forks. + self.fork_safe = kwargs.get('fork_safe', True) + + # Connection pool generation id + self.pool_generation = 0 + def _connect(self): "Create a Unix domain socket connection" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) @@ -919,6 +945,9 @@ def __init__(self, connection_class=Connection, max_connections=None, self.connection_class = connection_class self.connection_kwargs = connection_kwargs self.max_connections = max_connections + self.fork_safe = connection_kwargs.get('fork_safe', True) + + self.generation = 0 self.reset() @@ -936,20 +965,28 @@ def reset(self): self._check_lock = threading.Lock() def _checkpid(self): - if self.pid != os.getpid(): - with self._check_lock: - if self.pid == os.getpid(): - # another thread already did the work while we waited - # on the lock. - return - self.disconnect() - self.reset() + # No need to check PID if process isn't supposed to fork + if self.fork_safe: + if self.pid != os.getpid(): + with self._check_lock: + if self.pid == os.getpid(): + # another thread already did the work while we waited + # on the lock. + return + self.disconnect() + self.reset() def get_connection(self, command_name, *keys, **options): "Get a connection from the pool" self._checkpid() try: connection = self._available_connections.pop() + if connection.pool_generation != self.generation: + # generation counter mismatch: let this connection go and + # create a new one + connection.disconnect() + self._created_connections -= 1 + connection = self.make_connection() except IndexError: connection = self.make_connection() self._in_use_connections.add(connection) @@ -959,8 +996,11 @@ def make_connection(self): "Create a new connection" if self._created_connections >= self.max_connections: raise ConnectionError("Too many connections") + + new_connection = self.connection_class(**self.connection_kwargs) + new_connection.pool_generation = self.generation self._created_connections += 1 - return self.connection_class(**self.connection_kwargs) + return new_connection def release(self, connection): "Releases the connection back to the pool" @@ -968,14 +1008,33 @@ def release(self, connection): if connection.pid != self.pid: return self._in_use_connections.remove(connection) - self._available_connections.append(connection) - def disconnect(self): - "Disconnects all connections in the pool" - all_conns = chain(self._available_connections, - self._in_use_connections) - for connection in all_conns: + # Verify generation id before putting connection back in the free list + if connection.pool_generation == self.generation: + self._available_connections.append(connection) + else: + # generation id mismatch, kill the connection connection.disconnect() + self._created_connections -= 1 + + def disconnect(self, immediate=False): + """ + Disconnects all connections in the pool + + By default, the disconnect happens over time as connections + cycle into and out of the pool. This avoids the problem of ripping + connections out from under threads that are using them. + + If `immediate` is True, then we forcibly disconnect all connections + and leave it to the owner of the connection to deal with the various + errors that can occur. This option is not recommended. + """ + self.generation += 1 + if immediate: + all_conns = chain(copy(self._available_connections), + copy(self._in_use_connections)) + for connection in all_conns: + connection.disconnect() class BlockingConnectionPool(ConnectionPool): diff --git a/tests/conftest.py b/tests/conftest.py index d7b2b14..f26daf4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,7 +27,7 @@ def _get_client(cls, request=None, **kwargs): if request: def teardown(): client.flushdb() - client.connection_pool.disconnect() + client.connection_pool.disconnect(immediate=True) request.addfinalizer(teardown) return client