diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 6db54895f4..3d59016bb3 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -158,6 +158,7 @@ def __init__( encoding_errors: str = "strict", decode_responses: bool = False, retry_on_timeout: bool = False, + retry_on_error: Optional[list] = None, ssl: bool = False, ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, @@ -176,8 +177,10 @@ def __init__( ): """ Initialize a new Redis client. - To specify a retry policy, first set `retry_on_timeout` to `True` - then set `retry` to a valid `Retry` object + To specify a retry policy for specific errors, first set + `retry_on_error` to a list of the error/s to retry on, then set + `retry` to a valid `Retry` object. + To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. """ kwargs: Dict[str, Any] # auto_close_connection_pool only has an effect if connection_pool is @@ -188,6 +191,10 @@ def __init__( auto_close_connection_pool if connection_pool is None else False ) if not connection_pool: + if not retry_on_error: + retry_on_error = [] + if retry_on_timeout is True: + retry_on_error.append(TimeoutError) kwargs = { "db": db, "username": username, @@ -197,6 +204,7 @@ def __init__( "encoding_errors": encoding_errors, "decode_responses": decode_responses, "retry_on_timeout": retry_on_timeout, + "retry_on_error": retry_on_error, "retry": copy.deepcopy(retry), "max_connections": max_connections, "health_check_interval": health_check_interval, @@ -461,7 +469,10 @@ async def _disconnect_raise(self, conn: Connection, error: Exception): is not a TimeoutError """ await conn.disconnect() - if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + if ( + conn.retry_on_error is None + or isinstance(error, tuple(conn.retry_on_error)) is False + ): raise error # COMMAND EXECUTION AND PROTOCOL PARSING diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 38465fc0d7..35536fc883 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -578,6 +578,7 @@ class Connection: "socket_type", "redis_connect_func", "retry_on_timeout", + "retry_on_error", "health_check_interval", "next_health_check", "last_active_at", @@ -606,6 +607,7 @@ def __init__( socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, socket_type: int = 0, retry_on_timeout: bool = False, + retry_on_error: Union[list, _Sentinel] = SENTINEL, encoding: str = "utf-8", encoding_errors: str = "strict", decode_responses: bool = False, @@ -631,12 +633,19 @@ def __init__( self.socket_keepalive_options = socket_keepalive_options or {} self.socket_type = socket_type self.retry_on_timeout = retry_on_timeout + if retry_on_error is SENTINEL: + retry_on_error = [] if retry_on_timeout: + retry_on_error.append(TimeoutError) + self.retry_on_error = retry_on_error + if retry_on_error: if not retry: self.retry = Retry(NoBackoff(), 1) else: # deep-copy the Retry object as it is mutable self.retry = copy.deepcopy(retry) + # Update the retry's supported errors with the specified errors + self.retry.update_supported_errors(retry_on_error) else: self.retry = Retry(NoBackoff(), 0) self.health_check_interval = health_check_interval @@ -1169,6 +1178,7 @@ def __init__( encoding_errors: str = "strict", decode_responses: bool = False, retry_on_timeout: bool = False, + retry_on_error: Union[list, _Sentinel] = SENTINEL, parser_class: Type[BaseParser] = DefaultParser, socket_read_size: int = 65536, health_check_interval: float = 0.0, @@ -1190,12 +1200,19 @@ def __init__( self.socket_timeout = socket_timeout self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None self.retry_on_timeout = retry_on_timeout + if retry_on_error is SENTINEL: + retry_on_error = [] if retry_on_timeout: + retry_on_error.append(TimeoutError) + self.retry_on_error = retry_on_error + if retry_on_error: if retry is None: self.retry = Retry(NoBackoff(), 1) else: # deep-copy the Retry object as it is mutable self.retry = copy.deepcopy(retry) + # Update the retry's supported errors with the specified errors + self.retry.update_supported_errors(retry_on_error) else: self.retry = Retry(NoBackoff(), 0) self.health_check_interval = health_check_interval diff --git a/redis/asyncio/retry.py b/redis/asyncio/retry.py index 0934ad0d9f..7c5e3b0e7d 100644 --- a/redis/asyncio/retry.py +++ b/redis/asyncio/retry.py @@ -35,6 +35,14 @@ def __init__( self._retries = retries self._supported_errors = supported_errors + def update_supported_errors(self, specified_errors: list): + """ + Updates the supported errors with the specified error types + """ + self._supported_errors = tuple( + set(self._supported_errors + tuple(specified_errors)) + ) + async def call_with_retry( self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any] ) -> T: diff --git a/redis/client.py b/redis/client.py index fcc2758dae..86061d56d5 100755 --- a/redis/client.py +++ b/redis/client.py @@ -914,7 +914,7 @@ def __init__( errors=None, decode_responses=False, retry_on_timeout=False, - retry_on_error=[], + retry_on_error=None, ssl=False, ssl_keyfile=None, ssl_certfile=None, @@ -958,6 +958,8 @@ def __init__( ) ) encoding_errors = errors + if not retry_on_error: + retry_on_error = [] if retry_on_timeout is True: retry_on_error.append(TimeoutError) kwargs = { diff --git a/redis/connection.py b/redis/connection.py index 1bc2ae1f4e..3438bafe57 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -515,7 +515,7 @@ def __init__( socket_keepalive_options=None, socket_type=0, retry_on_timeout=False, - retry_on_error=[], + retry_on_error=SENTINEL, encoding="utf-8", encoding_errors="strict", decode_responses=False, @@ -547,6 +547,8 @@ def __init__( self.socket_keepalive_options = socket_keepalive_options or {} self.socket_type = socket_type self.retry_on_timeout = retry_on_timeout + if retry_on_error is SENTINEL: + retry_on_error = [] if retry_on_timeout: # Add TimeoutError to the errors list to retry on retry_on_error.append(TimeoutError) @@ -1065,7 +1067,7 @@ def __init__( encoding_errors="strict", decode_responses=False, retry_on_timeout=False, - retry_on_error=[], + retry_on_error=SENTINEL, parser_class=DefaultParser, socket_read_size=65536, health_check_interval=0, @@ -1088,6 +1090,8 @@ def __init__( self.password = password self.socket_timeout = socket_timeout self.retry_on_timeout = retry_on_timeout + if retry_on_error is SENTINEL: + retry_on_error = [] if retry_on_timeout: # Add TimeoutError to the errors list to retry on retry_on_error.append(TimeoutError) diff --git a/tests/test_asyncio/test_retry.py b/tests/test_asyncio/test_retry.py index d696d72d1c..38e353bc36 100644 --- a/tests/test_asyncio/test_retry.py +++ b/tests/test_asyncio/test_retry.py @@ -3,7 +3,7 @@ from redis.asyncio.connection import Connection, UnixDomainSocketConnection from redis.asyncio.retry import Retry from redis.backoff import AbstractBackoff, NoBackoff -from redis.exceptions import ConnectionError +from redis.exceptions import ConnectionError, TimeoutError class BackoffMock(AbstractBackoff): @@ -22,9 +22,28 @@ def compute(self, failures): class TestConnectionConstructorWithRetry: "Test that the Connection constructors properly handles Retry objects" + @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) + def test_retry_on_error_set(self, Class): + class CustomError(Exception): + pass + + retry_on_error = [ConnectionError, TimeoutError, CustomError] + c = Class(retry_on_error=retry_on_error) + assert c.retry_on_error == retry_on_error + assert isinstance(c.retry, Retry) + assert c.retry._retries == 1 + assert set(c.retry._supported_errors) == set(retry_on_error) + + @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) + def test_retry_on_error_not_set(self, Class): + c = Class() + assert c.retry_on_error == [] + assert isinstance(c.retry, Retry) + assert c.retry._retries == 0 + @pytest.mark.parametrize("retry_on_timeout", [False, True]) @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) - def test_retry_on_timeout_boolean(self, Class, retry_on_timeout): + def test_retry_on_timeout(self, Class, retry_on_timeout): c = Class(retry_on_timeout=retry_on_timeout) assert c.retry_on_timeout == retry_on_timeout assert isinstance(c.retry, Retry) @@ -32,13 +51,26 @@ def test_retry_on_timeout_boolean(self, Class, retry_on_timeout): @pytest.mark.parametrize("retries", range(10)) @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) - def test_retry_on_timeout_retry(self, Class, retries: int): + def test_retry_with_retry_on_timeout(self, Class, retries: int): retry_on_timeout = retries > 0 c = Class(retry_on_timeout=retry_on_timeout, retry=Retry(NoBackoff(), retries)) assert c.retry_on_timeout == retry_on_timeout assert isinstance(c.retry, Retry) assert c.retry._retries == retries + @pytest.mark.parametrize("retries", range(10)) + @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) + def test_retry_with_retry_on_error(self, Class, retries: int): + class CustomError(Exception): + pass + + retry_on_error = [ConnectionError, TimeoutError, CustomError] + c = Class(retry_on_error=retry_on_error, retry=Retry(NoBackoff(), retries)) + assert c.retry_on_error == retry_on_error + assert isinstance(c.retry, Retry) + assert c.retry._retries == retries + assert set(c.retry._supported_errors) == set(retry_on_error) + class TestRetry: "Test that Retry calls backoff and retries the expected number of times"