Skip to content

Commit

Permalink
[GROW-3247] release connection even if an unexpected exception is thr…
Browse files Browse the repository at this point in the history
…own in cluster pipeline (#8)

* [GROW-3247] release connection even if an unexpected exception is thrown in cluster pipeline

* [GROW-3247] fix style issue

* unassign n.connection at every loop
  • Loading branch information
zach-iee authored Jul 20, 2023
1 parent 68c3505 commit ae88892
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 146 deletions.
308 changes: 162 additions & 146 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1996,158 +1996,174 @@ def _send_cluster_commands(
# build a list of node objects based on node names we need to
nodes = {}

# as we move through each command that still needs to be processed,
# we figure out the slot number that command maps to, then from
# the slot determine the node.
for c in attempt:
while True:
# refer to our internal node -> slot table that
# tells us where a given command should route to.
# (it might be possible we have a cached node that no longer
# exists in the cluster, which is why we do this in a loop)
passed_targets = c.options.pop("target_nodes", None)
if passed_targets and not self._is_nodes_flag(passed_targets):
target_nodes = self._parse_target_nodes(passed_targets)
else:
target_nodes = self._determine_nodes(
*c.args, node_flag=passed_targets
)
if not target_nodes:
try:
# as we move through each command that still needs to be processed,
# we figure out the slot number that command maps to, then from
# the slot determine the node.
for c in attempt:
while True:
# refer to our internal node -> slot table that
# tells us where a given command should route to.
# (it might be possible we have a cached node that no longer
# exists in the cluster, which is why we do this in a loop)
passed_targets = c.options.pop("target_nodes", None)
if passed_targets and not self._is_nodes_flag(passed_targets):
target_nodes = self._parse_target_nodes(passed_targets)
else:
target_nodes = self._determine_nodes(
*c.args, node_flag=passed_targets
)
if not target_nodes:
raise RedisClusterException(
f"No targets were found to execute {c.args} command on"
)
if len(target_nodes) > 1:
raise RedisClusterException(
f"No targets were found to execute {c.args} command on"
f"Too many targets for command {c.args}"
)
if len(target_nodes) > 1:
raise RedisClusterException(
f"Too many targets for command {c.args}"
)

node = target_nodes[0]
if node == self.get_default_node():
is_default_node = True
node = target_nodes[0]
if node == self.get_default_node():
is_default_node = True

# now that we know the name of the node
# ( it's just a string in the form of host:port )
# we can build a list of commands for each node.
node_name = node.name
if node_name not in nodes:
redis_node = self.get_redis_connection(node)
# now that we know the name of the node
# ( it's just a string in the form of host:port )
# we can build a list of commands for each node.
node_name = node.name
if node_name not in nodes:
redis_node = self.get_redis_connection(node)
try:
connection = get_connection(redis_node, c.args)
except (ConnectionError, TimeoutError) as e:
for n in nodes.values():
n.connection_pool.release(n.connection)
n.connection = None
nodes = {}
if self.retry and isinstance(
e, self.retry._supported_errors
):
backoff = self.retry._backoff.compute(attempts_count)
if backoff > 0:
time.sleep(backoff)
self.nodes_manager.initialize()
if is_default_node:
self.replace_default_node()
raise
nodes[node_name] = NodeCommands(
redis_node.parse_response,
redis_node.connection_pool,
connection,
)
nodes[node_name].append(c)
break

# send the commands in sequence.
# we write to all the open sockets for each node first,
# before reading anything
# this allows us to flush all the requests out across the
# network essentially in parallel
# so that we can read them all in parallel as they come back.
# we dont' multiplex on the sockets as they come available,
# but that shouldn't make too much difference.
node_commands = nodes.values()
for n in node_commands:
n.write()

for n in node_commands:
n.read()

# release all of the redis connections we allocated earlier
# back into the connection pool.
# we used to do this step as part of a try/finally block,
# but it is really dangerous to
# release connections back into the pool if for some
# reason the socket has data still left in it
# from a previous operation. The write and
# read operations already have try/catch around them for
# all known types of errors including connection
# and socket level errors.
# So if we hit an exception, something really bad
# happened and putting any oF
# these connections back into the pool is a very bad idea.
# the socket might have unread buffer still sitting in it,
# and then the next time we read from it we pass the
# buffered result back from a previous command and
# every single request after to that connection will always get
# a mismatched result.
for n in nodes.values():
n.connection_pool.release(n.connection)
n.connection = None
nodes = {}

# if the response isn't an exception it is a
# valid response from the node
# we're all done with that command, YAY!
# if we have more commands to attempt, we've run into problems.
# collect all the commands we are allowed to retry.
# (MOVED, ASK, or connection errors or timeout errors)
attempt = sorted(
(
c
for c in attempt
if isinstance(c.result, ClusterPipeline.ERRORS_ALLOW_RETRY)
),
key=lambda x: x.position,
)
if attempt and allow_redirections:
# RETRY MAGIC HAPPENS HERE!
# send these remaing commands one at a time using `execute_command`
# in the main client. This keeps our retry logic
# in one place mostly,
# and allows us to be more confident in correctness of behavior.
# at this point any speed gains from pipelining have been lost
# anyway, so we might as well make the best
# attempt to get the correct behavior.
#
# The client command will handle retries for each
# individual command sequentially as we pass each
# one into `execute_command`. Any exceptions
# that bubble out should only appear once all
# retries have been exhausted.
#
# If a lot of commands have failed, we'll be setting the
# flag to rebuild the slots table from scratch.
# So MOVED errors should correct themselves fairly quickly.
self.reinitialize_counter += 1
if self._should_reinitialized():
self.nodes_manager.initialize()
if is_default_node:
self.replace_default_node()
for c in attempt:
try:
connection = get_connection(redis_node, c.args)
except (ConnectionError, TimeoutError) as e:
for n in nodes.values():
n.connection_pool.release(n.connection)
if self.retry and isinstance(e, self.retry._supported_errors):
backoff = self.retry._backoff.compute(attempts_count)
if backoff > 0:
time.sleep(backoff)
self.nodes_manager.initialize()
if is_default_node:
self.replace_default_node()
raise
nodes[node_name] = NodeCommands(
redis_node.parse_response,
redis_node.connection_pool,
connection,
)
nodes[node_name].append(c)
break

# send the commands in sequence.
# we write to all the open sockets for each node first,
# before reading anything
# this allows us to flush all the requests out across the
# network essentially in parallel
# so that we can read them all in parallel as they come back.
# we dont' multiplex on the sockets as they come available,
# but that shouldn't make too much difference.
node_commands = nodes.values()
for n in node_commands:
n.write()

for n in node_commands:
n.read()

# release all of the redis connections we allocated earlier
# back into the connection pool.
# we used to do this step as part of a try/finally block,
# but it is really dangerous to
# release connections back into the pool if for some
# reason the socket has data still left in it
# from a previous operation. The write and
# read operations already have try/catch around them for
# all known types of errors including connection
# and socket level errors.
# So if we hit an exception, something really bad
# happened and putting any oF
# these connections back into the pool is a very bad idea.
# the socket might have unread buffer still sitting in it,
# and then the next time we read from it we pass the
# buffered result back from a previous command and
# every single request after to that connection will always get
# a mismatched result.
for n in nodes.values():
n.connection_pool.release(n.connection)

# if the response isn't an exception it is a
# valid response from the node
# we're all done with that command, YAY!
# if we have more commands to attempt, we've run into problems.
# collect all the commands we are allowed to retry.
# (MOVED, ASK, or connection errors or timeout errors)
attempt = sorted(
(
c
for c in attempt
if isinstance(c.result, ClusterPipeline.ERRORS_ALLOW_RETRY)
),
key=lambda x: x.position,
)
if attempt and allow_redirections:
# RETRY MAGIC HAPPENS HERE!
# send these remaing commands one at a time using `execute_command`
# in the main client. This keeps our retry logic
# in one place mostly,
# and allows us to be more confident in correctness of behavior.
# at this point any speed gains from pipelining have been lost
# anyway, so we might as well make the best
# attempt to get the correct behavior.
#
# The client command will handle retries for each
# individual command sequentially as we pass each
# one into `execute_command`. Any exceptions
# that bubble out should only appear once all
# retries have been exhausted.
#
# If a lot of commands have failed, we'll be setting the
# flag to rebuild the slots table from scratch.
# So MOVED errors should correct themselves fairly quickly.
self.reinitialize_counter += 1
if self._should_reinitialized():
self.nodes_manager.initialize()
if is_default_node:
self.replace_default_node()
for c in attempt:
try:
# send each command individually like we
# do in the main client.
c.result = super().execute_command(*c.args, **c.options)
except RedisError as e:
c.result = e

# turn the response back into a simple flat array that corresponds
# to the sequence of commands issued in the stack in pipeline.execute()
response = []
for c in sorted(stack, key=lambda x: x.position):
if c.args[0] in self.cluster_response_callbacks:
c.result = self.cluster_response_callbacks[c.args[0]](
c.result, **c.options
)
response.append(c.result)

if raise_on_error:
self.raise_first_error(stack)
# send each command individually like we
# do in the main client.
c.result = super().execute_command(*c.args, **c.options)
except RedisError as e:
c.result = e

return response
# turn the response back into a simple flat array that corresponds
# to the sequence of commands issued in the stack in pipeline.execute()
response = []
for c in sorted(stack, key=lambda x: x.position):
if c.args[0] in self.cluster_response_callbacks:
c.result = self.cluster_response_callbacks[c.args[0]](
c.result, **c.options
)
response.append(c.result)

if raise_on_error:
self.raise_first_error(stack)

return response
except BaseException:
# if nodes is not empty, a problem must have occurred
# since we cant guarantee the state of the connections,
# disconnect before returning it to the connection pool
for n in nodes.values():
if n.connection:
n.connection.disconnect()
n.connection_pool.release(n.connection)
raise

def _fail_on_redirect(self, allow_redirections):
""" """
Expand Down
23 changes: 23 additions & 0 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
REPLICA,
ClusterNode,
LoadBalancer,
NodeCommands,
NodesManager,
RedisCluster,
get_node_name,
Expand Down Expand Up @@ -2790,6 +2791,28 @@ class TestClusterPipeline:
Tests for the ClusterPipeline class
"""

@pytest.mark.parametrize("function", ["write", "read"])
def test_connection_release_with_unexpected_error_in_node_commands(
self, r, function
):
"""
Test that connection is released to the pool, even with an unexpected error
"""
with patch.object(NodeCommands, function) as m:

def raise_error():
raise Exception("unexpected error")

m.side_effect = raise_error

with pytest.raises(Exception, match="unexpected error"):
r.pipeline().get("a").execute()

for cluster_node in r.nodes_manager.nodes_cache.values():
connection_pool = cluster_node.redis_connection.connection_pool
num_of_conns = len(connection_pool._available_connections)
assert num_of_conns == connection_pool._created_connections

def test_blocked_methods(self, r):
"""
Currently some method calls on a Cluster pipeline
Expand Down

0 comments on commit ae88892

Please sign in to comment.