diff --git a/redis/connection.py b/redis/connection.py index 004c7a6f78..6b6a8faed1 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -534,6 +534,20 @@ def disconnect(self): pass self._sock = None + def shutdown_socket(self): + """ + Shutdown the socket hold by the current connection, called from + the connection pool class u other manager to singal it that has to be + disconnected in a thread safe way. Later the connection instance + will get an error and will call `disconnect` by it self. + """ + try: + self._sock.shutdown(socket.SHUT_RDWR) + except AttributeError: + # either _sock attribute does not exist or + # connection thread removed it. + pass + def send_packed_command(self, command): "Send an already packed command to the Redis server" if not self._sock: @@ -953,7 +967,7 @@ def disconnect(self): all_conns = chain(self._available_connections, self._in_use_connections) for connection in all_conns: - connection.disconnect() + connection.shutdown_socket() class BlockingConnectionPool(ConnectionPool): @@ -1072,4 +1086,4 @@ def release(self, connection): def disconnect(self): "Disconnects all connections in the pool." for connection in self._connections: - connection.disconnect() + connection.shutdown_socket() diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 11c20080a9..2a764a0df4 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -1,4 +1,6 @@ from __future__ import with_statement +from mock import Mock + import os import pytest import redis @@ -69,6 +71,30 @@ def test_repr_contains_db_info_unix(self): expected = 'ConnectionPool>' assert repr(pool) == expected + def test_disconnect_active_connections(self): + + class MyConnection(redis.Connection): + + connect_calls = 0 + + def __init__(self, *args, **kwargs): + super(MyConnection, self).__init__(*args, **kwargs) + self.register_connect_callback(self.count_connect) + + def count_connect(self, connection): + MyConnection.connect_calls += 1 + + pool = self.get_pool(connection_class=MyConnection) + r = redis.StrictRedis(connection_pool=pool) + r.ping() + pool.disconnect() + r.ping() + + # If the connection is not disconnected by the pool the + # callback belonging to Connection will be called just + # one time. + assert MyConnection.connect_calls == 2 + class TestBlockingConnectionPool(object): def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):