Skip to content

Commit

Permalink
add cluster "host_port_remap" feature
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed Apr 5, 2023
1 parent 204d403 commit a030a98
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Optional,
Type,
TypeVar,
Tuple,
Union,
)

Expand Down Expand Up @@ -250,6 +251,7 @@ def __init__(
ssl_certfile: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_keyfile: Optional[str] = None,
host_port_remap: List[Dict[str, Any]] = [],
) -> None:
if db:
raise RedisClusterException(
Expand Down Expand Up @@ -337,7 +339,12 @@ def __init__(
if host and port:
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))

self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs)
self.nodes_manager = NodesManager(
startup_nodes,
require_full_coverage,
kwargs,
host_port_remap=host_port_remap,
)
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self.read_from_replicas = read_from_replicas
self.reinitialize_steps = reinitialize_steps
Expand Down Expand Up @@ -1044,17 +1051,20 @@ class NodesManager:
"require_full_coverage",
"slots_cache",
"startup_nodes",
"host_port_remap",
)

def __init__(
self,
startup_nodes: List["ClusterNode"],
require_full_coverage: bool,
connection_kwargs: Dict[str, Any],
host_port_remap: List[Dict[str, Any]] = [],
) -> None:
self.startup_nodes = {node.name: node for node in startup_nodes}
self.require_full_coverage = require_full_coverage
self.connection_kwargs = connection_kwargs
self.host_port_remap = host_port_remap

self.default_node: "ClusterNode" = None
self.nodes_cache: Dict[str, "ClusterNode"] = {}
Expand Down Expand Up @@ -1213,6 +1223,7 @@ async def initialize(self) -> None:
if host == "":
host = startup_node.host
port = int(primary_node[1])
host, port = self.remap_host_port(host, port)

target_node = tmp_nodes_cache.get(get_node_name(host, port))
if not target_node:
Expand All @@ -1231,6 +1242,7 @@ async def initialize(self) -> None:
for replica_node in replica_nodes:
host = replica_node[0]
port = replica_node[1]
host, port = self.remap_host_port(host, port)

target_replica_node = tmp_nodes_cache.get(
get_node_name(host, port)
Expand Down Expand Up @@ -1304,6 +1316,30 @@ async def close(self, attr: str = "nodes_cache") -> None:
)
)

def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
"""
Remap the host and port returned from the cluster to a different
internal value. Useful if the client is not connecting directly
to the cluster.
"""
for map_entry in self.host_port_remap:
mapped = False
if "from_host" in map_entry:
if host != map_entry["from_host"]:
continue
else:
host = map_entry["to_host"]
mapped = True
if "from_port" in map_entry:
if port != map_entry["from_port"]:
continue
else:
port = map_entry["to_port"]
mapped = True
if mapped:
break
return host, port


class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
"""
Expand Down

0 comments on commit a030a98

Please sign in to comment.