Skip to content

Commit

Permalink
Add distributed config "dataclasses"
Browse files Browse the repository at this point in the history
Can't use the actual `dataclasses` as those were introduced in 3.7 and `valohai-utils` currently supports 3.6
  • Loading branch information
ruksi committed Jun 1, 2022
1 parent 5e62bb7 commit 83b0ae8
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 0 deletions.
7 changes: 7 additions & 0 deletions valohai/internals/distributed_config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from valohai.internals.distributed_config.distributed_config import DistributedConfig
from valohai.internals.distributed_config.member import Member

__all__ = [
"DistributedConfig",
"Member",
]
30 changes: 30 additions & 0 deletions valohai/internals/distributed_config/distributed_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Any, Dict, List

from valohai.internals.distributed_config.member import Member
from valohai.internals.distributed_config.utils import rank_members


class DistributedConfig:
def __init__(
self,
*,
group_name: str,
member_id: str,
required_count: int,
members: List[Member],
):
self.group_name = group_name
self.member_id = member_id
self.required_count = required_count
self.members = members

@classmethod
def from_json_data(cls, json_data: Dict[str, Any]) -> "DistributedConfig":
members = [Member.from_json_data(m) for m in json_data["members"]]
rank_members(members)
return cls(
group_name=json_data["config"]["group_name"],
member_id=json_data["config"]["member_id"],
required_count=json_data["config"]["required_count"],
members=members,
)
58 changes: 58 additions & 0 deletions valohai/internals/distributed_config/member.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Any, Dict, List, Optional


class Member:
def __init__(
self,
*,
announce_time: str,
identity: str,
job_id: str,
member_id: str,
exposed_ports: Dict[str, str],
local_ips: List[str],
public_ips: List[str],
rank: Optional[int] = None,
):
self.announce_time = announce_time
self.identity = identity
self.job_id = job_id
self.member_id = member_id
self.exposed_ports = exposed_ports
self.local_ips = local_ips
self.public_ips = public_ips
self.rank = rank # populated by `DistributedConfig.from_json_data()`

@property
def is_master(self) -> bool:
return self.rank == 0

@property
def primary_local_ip(self) -> str:
try:
return self.local_ips[0]
except IndexError as ie:
raise RuntimeError(
"There are no local IPs in the distributed worker network configuration"
) from ie

@property
def primary_public_ip(self) -> str:
try:
return self.public_ips[0]
except IndexError as ie:
raise RuntimeError(
"There are no public IPs in the distributed worker network configuration"
) from ie

@classmethod
def from_json_data(cls, json_data: Dict[str, Any]) -> "Member":
return cls(
announce_time=json_data["announce_time"],
identity=json_data["identity"],
job_id=json_data["job_id"],
member_id=json_data["member_id"],
exposed_ports=json_data["network"]["exposed_ports"],
local_ips=json_data["network"]["local_ips"],
public_ips=json_data["network"]["public_ips"],
)
27 changes: 27 additions & 0 deletions valohai/internals/distributed_config/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import TYPE_CHECKING, Dict, Iterable, Union

if TYPE_CHECKING:
from valohai.internals.distributed_config import Member


def rank_members(members: Iterable["Member"]) -> None:
"""Add ranks to members in-place."""
mapping = compute_member_id_ranks([m.member_id for m in members])
for m in members:
m.rank = mapping[m.member_id]


def compute_member_id_ranks(member_ids: Iterable[str]) -> Dict[str, int]:
"""Given member ids, return member id to rank mapping."""
id_to_sortable: Dict[str, Union[str, int]]
try:
id_to_sortable = {member_id: int(member_id) for member_id in member_ids}
except ValueError:
# if we fail to parse all member ids as an integer, just sort as string
id_to_sortable = {member_id: member_id for member_id in member_ids}
return {
member_id: index
for index, (member_id, sort_value) in enumerate(
sorted(id_to_sortable.items(), key=lambda kv: kv[1])
)
}

0 comments on commit 83b0ae8

Please sign in to comment.