Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for disconnect issues in #732 #784

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 77 additions & 18 deletions redis/connection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import with_statement
from copy import copy
from distutils.version import StrictVersion
from itertools import chain
import os
Expand Down Expand Up @@ -409,7 +410,8 @@ def __init__(self, host='localhost', port=6379, db=0, password=None,
socket_keepalive=False, socket_keepalive_options=None,
retry_on_timeout=False, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
parser_class=DefaultParser, socket_read_size=65536):
parser_class=DefaultParser, socket_read_size=65536,
**kwargs):
self.pid = os.getpid()
self.host = host
self.port = int(port)
Expand All @@ -432,6 +434,14 @@ def __init__(self, host='localhost', port=6379, db=0, password=None,
}
self._connect_callbacks = []

# If the connection isn't used in a process that forks, there are
# certain optimizations we can make. Default assumes that we have to
# be safe for process forks.
self.fork_safe = kwargs.get('fork_safe', True)

# Connection pool generation id
self.pool_generation = 0

def __repr__(self):
return self.description_format % self._description_args

Expand Down Expand Up @@ -544,7 +554,14 @@ def disconnect(self):
if self._sock is None:
return
try:
self._sock.shutdown(socket.SHUT_RDWR)
# socket.shutdown() kills the underlying TCP connection
# immediately. Good when you can do it, but it ignores
# the ref count on the descriptor. If the process forked
# and the child inherited the socket, it's not safe to
# call .shutdown(). Just close() the socket and let the OS
# close the connection when appropriate.
if not self.fork_safe:
self._sock.shutdown(socket.SHUT_RDWR)
self._sock.close()
except socket.error:
pass
Expand Down Expand Up @@ -716,7 +733,8 @@ def __init__(self, path='', db=0, password=None,
socket_timeout=None, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
retry_on_timeout=False,
parser_class=DefaultParser, socket_read_size=65536):
parser_class=DefaultParser, socket_read_size=65536,
**kwargs):
self.pid = os.getpid()
self.path = path
self.db = db
Expand All @@ -734,6 +752,14 @@ def __init__(self, path='', db=0, password=None,
}
self._connect_callbacks = []

# If the connection isn't used in a process that forks, there are
# certain optimizations we can make. Default assumes that we have to
# be safe for process forks.
self.fork_safe = kwargs.get('fork_safe', True)

# Connection pool generation id
self.pool_generation = 0

def _connect(self):
"Create a Unix domain socket connection"
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
Expand Down Expand Up @@ -919,6 +945,9 @@ def __init__(self, connection_class=Connection, max_connections=None,
self.connection_class = connection_class
self.connection_kwargs = connection_kwargs
self.max_connections = max_connections
self.fork_safe = connection_kwargs.get('fork_safe', True)

self.generation = 0

self.reset()

Expand All @@ -936,20 +965,28 @@ def reset(self):
self._check_lock = threading.Lock()

def _checkpid(self):
if self.pid != os.getpid():
with self._check_lock:
if self.pid == os.getpid():
# another thread already did the work while we waited
# on the lock.
return
self.disconnect()
self.reset()
# No need to check PID if process isn't supposed to fork
if self.fork_safe:
if self.pid != os.getpid():
with self._check_lock:
if self.pid == os.getpid():
# another thread already did the work while we waited
# on the lock.
return
self.disconnect()
self.reset()

def get_connection(self, command_name, *keys, **options):
"Get a connection from the pool"
self._checkpid()
try:
connection = self._available_connections.pop()
if connection.pool_generation != self.generation:
# generation counter mismatch: let this connection go and
# create a new one
connection.disconnect()
self._created_connections -= 1
connection = self.make_connection()
except IndexError:
connection = self.make_connection()
self._in_use_connections.add(connection)
Expand All @@ -959,23 +996,45 @@ def make_connection(self):
"Create a new connection"
if self._created_connections >= self.max_connections:
raise ConnectionError("Too many connections")

new_connection = self.connection_class(**self.connection_kwargs)
new_connection.pool_generation = self.generation
self._created_connections += 1
return self.connection_class(**self.connection_kwargs)
return new_connection

def release(self, connection):
"Releases the connection back to the pool"
self._checkpid()
if connection.pid != self.pid:
return
self._in_use_connections.remove(connection)
self._available_connections.append(connection)

def disconnect(self):
"Disconnects all connections in the pool"
all_conns = chain(self._available_connections,
self._in_use_connections)
for connection in all_conns:
# Verify generation id before putting connection back in the free list
if connection.pool_generation == self.generation:
self._available_connections.append(connection)
else:
# generation id mismatch, kill the connection
connection.disconnect()
self._created_connections -= 1

def disconnect(self, immediate=False):
"""
Disconnects all connections in the pool

By default, the disconnect happens over time as connections
cycle into and out of the pool. This avoids the problem of ripping
connections out from under threads that are using them.

If `immediate` is True, then we forcibly disconnect all connections
and leave it to the owner of the connection to deal with the various
errors that can occur. This option is not recommended.
"""
self.generation += 1
if immediate:
all_conns = chain(copy(self._available_connections),
copy(self._in_use_connections))
for connection in all_conns:
connection.disconnect()


class BlockingConnectionPool(ConnectionPool):
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _get_client(cls, request=None, **kwargs):
if request:
def teardown():
client.flushdb()
client.connection_pool.disconnect()
client.connection_pool.disconnect(immediate=True)
request.addfinalizer(teardown)
return client

Expand Down