diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index 6fd233adc8..38b6c2abe9 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -23,6 +23,54 @@ class SlaveNotFoundError(ConnectionError): pass +class AsyncSentinelConnectionPoolProxy: + def __init__( + self, + connection_pool, + is_master, + check_connection, + service_name, + sentinel_manager, + ): + self.connection_pool_ref = weakref.ref(connection_pool) + self.is_master = is_master + self.check_connection = check_connection + self.service_name = service_name + self.sentinel_manager = sentinel_manager + self.reset() + + def reset(self): + self.master_address = None + self.slave_rr_counter = None + + async def get_master_address(self): + master_address = await self.sentinel_manager.discover_master(self.service_name) + if self.is_master and self.master_address != master_address: + self.master_address = master_address + # disconnect any idle connections so that they reconnect + # to the new master the next time that they are used. + connection_pool = self.connection_pool_ref() + if connection_pool is not None: + await connection_pool.disconnect(inuse_connections=False) + return master_address + + async def rotate_slaves(self) -> AsyncIterator: + slaves = await self.sentinel_manager.discover_slaves(self.service_name) + if slaves: + if self.slave_rr_counter is None: + self.slave_rr_counter = random.randint(0, len(slaves) - 1) + for _ in range(len(slaves)): + self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves) + slave = slaves[self.slave_rr_counter] + yield slave + # Fallback to the master connection + try: + yield await self.get_master_address() + except MasterNotFoundError: + pass + raise SlaveNotFoundError(f"No slave found for {self.service_name!r}") + + class SentinelManagedConnection(Connection): def __init__(self, **kwargs): self.connection_pool = kwargs.pop("connection_pool") @@ -116,12 +164,17 @@ def __init__(self, service_name, sentinel_manager, **kwargs): ) self.is_master = kwargs.pop("is_master", True) self.check_connection = kwargs.pop("check_connection", False) + self.proxy = AsyncSentinelConnectionPoolProxy( + connection_pool=self, + is_master=self.is_master, + check_connection=self.check_connection, + service_name=service_name, + sentinel_manager=sentinel_manager, + ) super().__init__(**kwargs) - self.connection_kwargs["connection_pool"] = weakref.proxy(self) + self.connection_kwargs["connection_pool"] = self.proxy self.service_name = service_name self.sentinel_manager = sentinel_manager - self.master_address = None - self.slave_rr_counter = None def __repr__(self): return ( @@ -131,8 +184,11 @@ def __repr__(self): def reset(self): super().reset() - self.master_address = None - self.slave_rr_counter = None + self.proxy.reset() + + @property + def master_address(self): + return self.proxy.master_address def owns_connection(self, connection: Connection): check = not self.is_master or ( @@ -141,31 +197,12 @@ def owns_connection(self, connection: Connection): return check and super().owns_connection(connection) async def get_master_address(self): - master_address = await self.sentinel_manager.discover_master(self.service_name) - if self.is_master: - if self.master_address != master_address: - self.master_address = master_address - # disconnect any idle connections so that they reconnect - # to the new master the next time that they are used. - await self.disconnect(inuse_connections=False) - return master_address + return await self.proxy.get_master_address() async def rotate_slaves(self) -> AsyncIterator: """Round-robin slave balancer""" - slaves = await self.sentinel_manager.discover_slaves(self.service_name) - if slaves: - if self.slave_rr_counter is None: - self.slave_rr_counter = random.randint(0, len(slaves) - 1) - for _ in range(len(slaves)): - self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves) - slave = slaves[self.slave_rr_counter] - yield slave - # Fallback to the master connection - try: - yield await self.get_master_address() - except MasterNotFoundError: - pass - raise SlaveNotFoundError(f"No slave found for {self.service_name!r}") + async for x in self.proxy.rotate_slaves(): + yield x class Sentinel(AsyncSentinelCommands): diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py index 51e59d69d0..b9c6848c7d 100644 --- a/tests/test_asyncio/test_sentinel.py +++ b/tests/test_asyncio/test_sentinel.py @@ -96,6 +96,15 @@ async def test_discover_master_error(sentinel): await sentinel.discover_master("xxx") +@pytest.mark.onlynoncluster +async def test_dead_pool(sentinel): + master = sentinel.master_for("mymaster", db=9) + conn = await master.connection_pool.get_connection("_") + await conn.disconnect() + del master + await conn.connect() + + @pytest.mark.onlynoncluster async def test_discover_master_sentinel_down(cluster, sentinel, master_ip): # Put first sentinel 'foo' down