Skip to content

Commit

Permalink
Add option to set coordinator lookup timeout for TPU clusters
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617383458
  • Loading branch information
jax authors committed Mar 20, 2024
1 parent 33e1a96 commit c5869fe
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 15 deletions.
21 changes: 12 additions & 9 deletions jax/_src/clusters/cloud_tpu_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def is_env_present(cls) -> bool:
return False

@classmethod
def get_coordinator_address(cls) -> str:
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
if has_megascale_address():
# For both GCE via QueuedResources and GKE via JobSet, the
# Megascale coordinator address is set as the host with process id = 0,
Expand All @@ -103,24 +103,27 @@ def get_coordinator_address(cls) -> str:
coordinator_address = cls._get_worker_list_in_slice()[0]
coordinator_address = coordinator_address.split(':')[0]
logger.debug("TPU Cluster using coordinator address: %s", coordinator_address)
cls.wait_for_coordinator(coordinator_address)
cls.wait_for_coordinator(coordinator_address, timeout_secs)
return f'{coordinator_address}:{coordinator_port}'

@classmethod
def wait_for_coordinator(cls, coordinator_address):
def wait_for_coordinator(cls, coordinator_address, timeout_secs):
# The coordinator may not be up before the other hosts try to
# communicate with it. We check for its existence with retries.
coordinator_found = False
lookup_attempt = 1
max_coordinator_lookups = 50
while not coordinator_found and lookup_attempt <= max_coordinator_lookups:
max_time = time.time() + timeout_secs
coordinator_retry_secs = 5
while not coordinator_found and time.time() < max_time:
try:
ip_address = socket.gethostbyname(coordinator_address)
coordinator_found = True
logger.debug("Found coordinator with address %s", coordinator_address)
except socket.gaierror:
print(f"Failed to recognize coordinator address {coordinator_address} on attempt {lookup_attempt}, retrying...")
lookup_attempt += 1
time.sleep(5)
logger.debug(
"Failed to recognize coordinator address %s"
" retrying...", coordinator_address
)
time.sleep(coordinator_retry_secs)
if not coordinator_found:
raise RuntimeError(f"Failed to recognize coordinator address {coordinator_address}")

Expand Down
7 changes: 4 additions & 3 deletions jax/_src/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def auto_detect_unset_distributed_params(cls,
coordinator_address: str | None,
num_processes: int | None,
process_id: int | None,
local_device_ids: Sequence[int] | None
local_device_ids: Sequence[int] | None,
initialization_timeout: int | None,
) -> tuple[str | None, int | None, int | None,
Sequence[int] | None]:
if all(p is not None for p in (coordinator_address, num_processes,
Expand All @@ -53,7 +54,7 @@ def auto_detect_unset_distributed_params(cls,
if env:
logger.debug('Initializing distributed JAX environment via %s', env.__name__)
if coordinator_address is None:
coordinator_address = env.get_coordinator_address()
coordinator_address = env.get_coordinator_address(timeout_secs=initialization_timeout)
if num_processes is None:
num_processes = env.get_process_count()
if process_id is None:
Expand All @@ -79,7 +80,7 @@ def is_env_present(cls) -> bool:
raise NotImplementedError("ClusterEnv subclasses must implement is_env_present")

@classmethod
def get_coordinator_address(cls) -> str:
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
"""Returns address and port used by JAX to bootstrap.
Process id 0 will open a tcp socket at "hostname:port" where
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/clusters/ompi_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def is_env_present(cls) -> bool:
return _ORTE_URI in os.environ

@classmethod
def get_coordinator_address(cls) -> str:
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
# Examples of orte_uri:
# 1531576320.0;tcp://10.96.0.1,10.148.0.1,10.108.0.1:34911
# 1314521088.0;tcp6://[fe80::b9b:ac5d:9cf0:b858,2620:10d:c083:150e::3000:2]:43370
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/clusters/slurm_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def is_env_present(cls) -> bool:
return _JOBID_PARAM in os.environ

@classmethod
def get_coordinator_address(cls) -> str:
def get_coordinator_address(cls, timeout_secs: int | None) -> str:
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
port = int(os.environ[_JOBID_PARAM]) % 2**12 + (65535 - 2**12 + 1)

Expand Down
6 changes: 5 additions & 1 deletion jax/_src/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ def initialize(self,

(coordinator_address, num_processes, process_id, local_device_ids) = (
clusters.ClusterEnv.auto_detect_unset_distributed_params(
coordinator_address, num_processes, process_id, local_device_ids
coordinator_address,
num_processes,
process_id,
local_device_ids,
initialization_timeout,
)
)

Expand Down

0 comments on commit c5869fe

Please sign in to comment.