Skip to content

Commit

Permalink
release, instead of disconnect on any error, when fetching connection…
Browse files Browse the repository at this point in the history
…s in cluster pipeline
  • Loading branch information
zach-iee committed Aug 21, 2023
1 parent 8d17920 commit 71677a5
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 8 deletions.
9 changes: 5 additions & 4 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2034,7 +2034,7 @@ def _send_cluster_commands(
redis_node = self.get_redis_connection(node)
try:
connection = get_connection(redis_node, c.args)
except (ConnectionError, TimeoutError) as e:
except BaseException as e:
for n in nodes.values():
n.connection_pool.release(n.connection)
n.connection = None
Expand All @@ -2043,9 +2043,10 @@ def _send_cluster_commands(
backoff = self.retry._backoff.compute(attempts_count)
if backoff > 0:
time.sleep(backoff)
self.nodes_manager.initialize()
if is_default_node:
self.replace_default_node()
if isinstance(e, (ConnectionError, TimeoutError)):
self.nodes_manager.initialize()
if is_default_node:
self.replace_default_node()
raise
nodes[node_name] = NodeCommands(
redis_node.parse_response,
Expand Down
49 changes: 45 additions & 4 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2805,8 +2805,10 @@ def raise_error():

m.side_effect = raise_error

with pytest.raises(Exception, match="unexpected error"):
r.pipeline().get("a").execute()
with patch.object(Connection, "disconnect") as d:
with pytest.raises(Exception, match="unexpected error"):
r.pipeline().get("a").execute()
assert d.call_count == 1

for cluster_node in r.nodes_manager.nodes_cache.values():
connection_pool = cluster_node.redis_connection.connection_pool
Expand Down Expand Up @@ -3127,7 +3129,7 @@ def raise_ask_error():
assert res == ["MOCK_OK"]

@pytest.mark.parametrize("error", [ConnectionError, TimeoutError])
def test_return_previous_acquired_connections(self, r, error):
def test_return_previous_acquired_connections_with_retry(self, r, error):
# in order to ensure that a pipeline will make use of connections
# from different nodes
assert r.keyslot("a") != r.keyslot("b")
Expand All @@ -3143,7 +3145,13 @@ def raise_error(target_node, *args, **kwargs):

get_connection.side_effect = raise_error

r.pipeline().get("a").get("b").execute()
with patch.object(NodesManager, "initialize") as i:
# in order to remove disconnect caused by initialize
i.side_effect = lambda: None

with patch.object(Connection, "disconnect") as d:
r.pipeline().get("a").get("b").execute()
assert d.call_count == 0

# there should have been two get_connections per execution and
# two executions due to exception raised in the first execution
Expand All @@ -3153,6 +3161,39 @@ def raise_error(target_node, *args, **kwargs):
num_of_conns = len(connection_pool._available_connections)
assert num_of_conns == connection_pool._created_connections

@pytest.mark.parametrize("error", [RedisClusterException, BaseException])
def test_return_previous_acquired_connections_without_retry(self, r, error):
# in order to ensure that a pipeline will make use of connections
# from different nodes
assert r.keyslot("a") != r.keyslot("b")

orig_func = redis.cluster.get_connection
with patch("redis.cluster.get_connection") as get_connection:

def raise_error(target_node, *args, **kwargs):
if get_connection.call_count == 2:
raise error("mocked error")
else:
return orig_func(target_node, *args, **kwargs)

get_connection.side_effect = raise_error

with patch.object(Connection, "disconnect") as d:
with pytest.raises(error):
r.pipeline().get("a").get("b").execute()
assert d.call_count == 0

# there should have been two get_connections per execution and
# two executions due to exception raised in the first execution
assert get_connection.call_count == 2
for cluster_node in r.nodes_manager.nodes_cache.values():
connection_pool = cluster_node.redis_connection.connection_pool
num_of_conns = len(connection_pool._available_connections)
assert num_of_conns == connection_pool._created_connections
# connection must remain connected
for conn in connection_pool._available_connections:
assert conn._sock is not None

def test_empty_stack(self, r):
"""
If pipeline is executed with no commands it should
Expand Down

0 comments on commit 71677a5

Please sign in to comment.