diff --git a/redis/cluster.py b/redis/cluster.py index d7a30cce9a..40e1546abc 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2034,7 +2034,7 @@ def _send_cluster_commands( redis_node = self.get_redis_connection(node) try: connection = get_connection(redis_node, c.args) - except (ConnectionError, TimeoutError) as e: + except BaseException as e: for n in nodes.values(): n.connection_pool.release(n.connection) n.connection = None @@ -2043,9 +2043,10 @@ def _send_cluster_commands( backoff = self.retry._backoff.compute(attempts_count) if backoff > 0: time.sleep(backoff) - self.nodes_manager.initialize() - if is_default_node: - self.replace_default_node() + if isinstance(e, (ConnectionError, TimeoutError)): + self.nodes_manager.initialize() + if is_default_node: + self.replace_default_node() raise nodes[node_name] = NodeCommands( redis_node.parse_response, diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 70dc509f2f..625f194911 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -2805,8 +2805,10 @@ def raise_error(): m.side_effect = raise_error - with pytest.raises(Exception, match="unexpected error"): - r.pipeline().get("a").execute() + with patch.object(Connection, "disconnect") as d: + with pytest.raises(Exception, match="unexpected error"): + r.pipeline().get("a").execute() + assert d.call_count == 1 for cluster_node in r.nodes_manager.nodes_cache.values(): connection_pool = cluster_node.redis_connection.connection_pool @@ -3127,7 +3129,7 @@ def raise_ask_error(): assert res == ["MOCK_OK"] @pytest.mark.parametrize("error", [ConnectionError, TimeoutError]) - def test_return_previous_acquired_connections(self, r, error): + def test_return_previous_acquired_connections_with_retry(self, r, error): # in order to ensure that a pipeline will make use of connections # from different nodes assert r.keyslot("a") != r.keyslot("b") @@ -3143,7 +3145,13 @@ def raise_error(target_node, *args, **kwargs): get_connection.side_effect = raise_error - r.pipeline().get("a").get("b").execute() + with patch.object(NodesManager, "initialize") as i: + # in order to remove disconnect caused by initialize + i.side_effect = lambda: None + + with patch.object(Connection, "disconnect") as d: + r.pipeline().get("a").get("b").execute() + assert d.call_count == 0 # there should have been two get_connections per execution and # two executions due to exception raised in the first execution @@ -3153,6 +3161,39 @@ def raise_error(target_node, *args, **kwargs): num_of_conns = len(connection_pool._available_connections) assert num_of_conns == connection_pool._created_connections + @pytest.mark.parametrize("error", [RedisClusterException, BaseException]) + def test_return_previous_acquired_connections_without_retry(self, r, error): + # in order to ensure that a pipeline will make use of connections + # from different nodes + assert r.keyslot("a") != r.keyslot("b") + + orig_func = redis.cluster.get_connection + with patch("redis.cluster.get_connection") as get_connection: + + def raise_error(target_node, *args, **kwargs): + if get_connection.call_count == 2: + raise error("mocked error") + else: + return orig_func(target_node, *args, **kwargs) + + get_connection.side_effect = raise_error + + with patch.object(Connection, "disconnect") as d: + with pytest.raises(error): + r.pipeline().get("a").get("b").execute() + assert d.call_count == 0 + + # there should have been two get_connections per execution and + # two executions due to exception raised in the first execution + assert get_connection.call_count == 2 + for cluster_node in r.nodes_manager.nodes_cache.values(): + connection_pool = cluster_node.redis_connection.connection_pool + num_of_conns = len(connection_pool._available_connections) + assert num_of_conns == connection_pool._created_connections + # connection must remain connected + for conn in connection_pool._available_connections: + assert conn._sock is not None + def test_empty_stack(self, r): """ If pipeline is executed with no commands it should