diff --git a/valohai/internals/distributed_config/__init__.py b/valohai/internals/distributed_config/__init__.py new file mode 100644 index 0000000..fd6ff6b --- /dev/null +++ b/valohai/internals/distributed_config/__init__.py @@ -0,0 +1,7 @@ +from valohai.internals.distributed_config.distributed_config import DistributedConfig +from valohai.internals.distributed_config.member import Member + +__all__ = [ + "DistributedConfig", + "Member", +] diff --git a/valohai/internals/distributed_config/distributed_config.py b/valohai/internals/distributed_config/distributed_config.py new file mode 100644 index 0000000..e340c64 --- /dev/null +++ b/valohai/internals/distributed_config/distributed_config.py @@ -0,0 +1,30 @@ +from typing import Any, Dict, List + +from valohai.internals.distributed_config.member import Member +from valohai.internals.distributed_config.utils import rank_members + + +class DistributedConfig: + def __init__( + self, + *, + group_name: str, + member_id: str, + required_count: int, + members: List[Member], + ): + self.group_name = group_name + self.member_id = member_id + self.required_count = required_count + self.members = members + + @classmethod + def from_json_data(cls, json_data: Dict[str, Any]) -> "DistributedConfig": + members = [Member.from_json_data(m) for m in json_data["members"]] + rank_members(members) + return cls( + group_name=json_data["config"]["group_name"], + member_id=json_data["config"]["member_id"], + required_count=json_data["config"]["required_count"], + members=members, + ) diff --git a/valohai/internals/distributed_config/member.py b/valohai/internals/distributed_config/member.py new file mode 100644 index 0000000..2c85f06 --- /dev/null +++ b/valohai/internals/distributed_config/member.py @@ -0,0 +1,58 @@ +from typing import Any, Dict, List, Optional + + +class Member: + def __init__( + self, + *, + announce_time: str, + identity: str, + job_id: str, + member_id: str, + exposed_ports: Dict[str, str], + local_ips: List[str], + public_ips: List[str], + rank: Optional[int] = None, + ): + self.announce_time = announce_time + self.identity = identity + self.job_id = job_id + self.member_id = member_id + self.exposed_ports = exposed_ports + self.local_ips = local_ips + self.public_ips = public_ips + self.rank = rank # populated by `DistributedConfig.from_json_data()` + + @property + def is_master(self) -> bool: + return self.rank == 0 + + @property + def primary_local_ip(self) -> str: + try: + return self.local_ips[0] + except IndexError as ie: + raise RuntimeError( + "There are no local IPs in the distributed worker network configuration" + ) from ie + + @property + def primary_public_ip(self) -> str: + try: + return self.public_ips[0] + except IndexError as ie: + raise RuntimeError( + "There are no public IPs in the distributed worker network configuration" + ) from ie + + @classmethod + def from_json_data(cls, json_data: Dict[str, Any]) -> "Member": + return cls( + announce_time=json_data["announce_time"], + identity=json_data["identity"], + job_id=json_data["job_id"], + member_id=json_data["member_id"], + exposed_ports=json_data["network"]["exposed_ports"], + local_ips=json_data["network"]["local_ips"], + public_ips=json_data["network"]["public_ips"], + ) diff --git a/valohai/internals/distributed_config/utils.py b/valohai/internals/distributed_config/utils.py new file mode 100644 index 0000000..21cad25 --- /dev/null +++ b/valohai/internals/distributed_config/utils.py @@ -0,0 +1,27 @@ +from typing import TYPE_CHECKING, Dict, Iterable, Union + +if TYPE_CHECKING: + from valohai.internals.distributed_config import Member + + +def rank_members(members: Iterable["Member"]) -> None: + """Add ranks to members in-place.""" + mapping = compute_member_id_ranks([m.member_id for m in members]) + for m in members: + m.rank = mapping[m.member_id] + + +def compute_member_id_ranks(member_ids: Iterable[str]) -> Dict[str, int]: + """Given member ids, return member id to rank mapping.""" + id_to_sortable: Dict[str, Union[str, int]] + try: + id_to_sortable = {member_id: int(member_id) for member_id in member_ids} + except ValueError: + # if we fail to parse all member ids as an integer, just sort as string + id_to_sortable = {member_id: member_id for member_id in member_ids} + return { + member_id: index + for index, (member_id, sort_value) in enumerate( + sorted(id_to_sortable.items(), key=lambda kv: kv[1]) + ) + }