Skip to content

Commit

Permalink
Merge pull request #20562 from olupton:bind-to-all
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621562851
  • Loading branch information
jax authors committed Apr 3, 2024
2 parents 57ee6b7 + 2dd1b3d commit 85cb169
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
3 changes: 3 additions & 0 deletions docs/multi_process.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ The API {func}`jax.distributed.initialize` takes several arguments, namely:
with a port available on that process. Process 0 will start a JAX service
exposed via that IP address and port, to which the other processes in the
cluster will connect.
* `coordinator_bind_address`: the IP address and port to which the JAX service
on process 0 in your cluster will bind. By default, it will bind to all
available interfaces using the same port as `coordinator_address`.
* `num_processes`: the number of processes in the cluster
* `process_id`: the ID number of this process, in the range `[0 ..
num_processes)`.
Expand Down
24 changes: 20 additions & 4 deletions jax/_src/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def initialize(self,
num_processes: int | None = None,
process_id: int | None = None,
local_device_ids: int | Sequence[int] | None = None,
initialization_timeout: int = 300):
initialization_timeout: int = 300,
coordinator_bind_address: str | None = None):
coordinator_address = (coordinator_address or
os.environ.get('JAX_COORDINATOR_ADDRESS', None))
if isinstance(local_device_ids, int):
Expand All @@ -66,6 +67,15 @@ def initialize(self,

self.coordinator_address = coordinator_address

# The default value of [::]:port tells the coordinator to bind to all
# available addresses on the same port as coordinator_address.
default_coordinator_bind_address = '[::]:' + coordinator_address.rsplit(':', 1)[1]
coordinator_bind_address = (coordinator_bind_address or
os.environ.get('JAX_COORDINATOR_BIND_ADDRESS',
default_coordinator_bind_address))
if coordinator_bind_address is None:
raise ValueError('coordinator_bind_address should be defined.')

if local_device_ids:
visible_devices = ','.join(str(x) for x in local_device_ids) # type: ignore[union-attr]
logger.info('JAX distributed initialized with visible devices: %s', visible_devices)
Expand All @@ -79,7 +89,7 @@ def initialize(self,
raise RuntimeError('distributed.initialize should only be called once.')
logger.info('Starting JAX distributed service on %s', coordinator_address)
self.service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes)
coordinator_bind_address, num_processes)

self.num_processes = num_processes

Expand Down Expand Up @@ -118,7 +128,8 @@ def initialize(coordinator_address: str | None = None,
num_processes: int | None = None,
process_id: int | None = None,
local_device_ids: int | Sequence[int] | None = None,
initialization_timeout: int = 300):
initialization_timeout: int = 300,
coordinator_bind_address: str | None = None):
"""Initializes the JAX distributed system.
Calling :func:`~jax.distributed.initialize` prepares JAX for execution on
Expand Down Expand Up @@ -156,6 +167,11 @@ def initialize(coordinator_address: str | None = None,
initialization_timeout: Time period (in seconds) for which connection will
be retried. If the initialization takes more than the timeout specified,
the initialization will error. Defaults to 300 secs i.e. 5 mins.
coordinator_bind_address: the address and port to which the coordinator service
on process `0` should bind. If this is not specified, the default is to bind to
all available addresses on the same port as ``coordinator_address``. On systems
that have multiple network interfaces per node it may be insufficient to only
have the coordinator service listen on one address/interface.
Raises:
RuntimeError: If :func:`~jax.distributed.initialize` is called more than once.
Expand All @@ -178,7 +194,7 @@ def initialize(coordinator_address: str | None = None,
raise RuntimeError("jax.distributed.initialize() must be called before "
"any JAX computations are executed.")
global_state.initialize(coordinator_address, num_processes, process_id,
local_device_ids, initialization_timeout)
local_device_ids, initialization_timeout, coordinator_bind_address)
atexit.register(shutdown)


Expand Down

0 comments on commit 85cb169

Please sign in to comment.