diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 0ab5f68c43..eb09b0585e 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -5,6 +5,7 @@ import warnings from typing import ( Any, + Callable, Deque, Dict, Generator, @@ -251,7 +252,7 @@ def __init__( ssl_certfile: Optional[str] = None, ssl_check_hostname: bool = False, ssl_keyfile: Optional[str] = None, - host_port_remap: List[Dict[str, Any]] = [], + host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: if db: raise RedisClusterException( @@ -1059,7 +1060,7 @@ def __init__( startup_nodes: List["ClusterNode"], require_full_coverage: bool, connection_kwargs: Dict[str, Any], - host_port_remap: List[Dict[str, Any]] = [], + host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: self.startup_nodes = {node.name: node for node in startup_nodes} self.require_full_coverage = require_full_coverage @@ -1322,22 +1323,8 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: internal value. Useful if the client is not connecting directly to the cluster. """ - for map_entry in self.host_port_remap: - mapped = False - if "from_host" in map_entry: - if host != map_entry["from_host"]: - continue - else: - host = map_entry["to_host"] - mapped = True - if "from_port" in map_entry: - if port != map_entry["from_port"]: - continue - else: - port = map_entry["to_port"] - mapped = True - if mapped: - break + if self.host_port_remap: + return self.host_port_remap(host, port) return host, port diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py index 872fb503ae..66060a0668 100644 --- a/tests/test_asyncio/test_cwe_404.py +++ b/tests/test_asyncio/test_cwe_404.py @@ -181,37 +181,39 @@ async def test_cluster(request, redis_addr): remap_base = 7372 n_nodes = 6 - remap = [] + def remap(host, port): + return host, remap_base + port - cluster_port + proxies = [] for i in range(n_nodes): port = cluster_port + i remapped = remap_base + i - remap.append({"from_port": port, "to_port": remapped}) forward_addr = redis_addr[0], port proxy = DelayProxy( addr=("127.0.0.1", remapped), redis_addr=forward_addr, delay=0 ) proxies.append(proxy) - # start proxies - await asyncio.gather(*[p.start() for p in proxies]) - + # helpers to work with all or any proxy def all_clear(): for p in proxies: p.send_event.clear() - async def wait_for_send(): + async def any_wait(): asyncio.wait( [p.send_event.wait() for p in proxies], return_when=asyncio.FIRST_COMPLETED ) @contextlib.contextmanager - def override(delay: int = 0): + def all_override(delay: int = 0): with contextlib.ExitStack() as stack: for p in proxies: stack.enter_context(p.override(delay=delay)) yield + # start proxies + await asyncio.gather(*[p.start() for p in proxies]) + with contextlib.closing( RedisCluster.from_url(f"redis://127.0.0.1:{remap_base}", host_port_remap=remap) ) as r: @@ -220,10 +222,10 @@ def override(delay: int = 0): await r.set("bar", "bar") all_clear() - with override(delay=delay): + with all_override(delay=delay): t = asyncio.create_task(r.get("foo")) # cannot wait on the send event, we don't know which node will be used - await wait_for_send() + await any_wait() await asyncio.sleep(delay) t.cancel() with pytest.raises(asyncio.CancelledError): @@ -237,4 +239,5 @@ async def doit(): await asyncio.gather(*[doit() for _ in range(10)]) + # stop proxies await asyncio.gather(*(p.stop() for p in proxies))