From 2dd1b3d6c8e52ea507e97f30a0bf577f1dd58bb0 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Wed, 3 Apr 2024 09:28:54 +0000 Subject: [PATCH] jax.distributed.initialize: specify bind address. By default, the coordinator process listens on all interfaces. --- docs/multi_process.md | 3 +++ jax/_src/distributed.py | 24 ++++++++++++++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/docs/multi_process.md b/docs/multi_process.md index 70bead4c42da..ab4beaf8c4e2 100644 --- a/docs/multi_process.md +++ b/docs/multi_process.md @@ -62,6 +62,9 @@ The API {func}`jax.distributed.initialize` takes several arguments, namely: with a port available on that process. Process 0 will start a JAX service exposed via that IP address and port, to which the other processes in the cluster will connect. + * `coordinator_bind_address`: the IP address and port to which the JAX service + on process 0 in your cluster will bind. By default, it will bind to all + available interfaces using the same port as `coordinator_address`. * `num_processes`: the number of processes in the cluster * `process_id`: the ID number of this process, in the range `[0 .. num_processes)`. diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 63464bb0e226..508f124a3b97 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -41,7 +41,8 @@ def initialize(self, num_processes: int | None = None, process_id: int | None = None, local_device_ids: int | Sequence[int] | None = None, - initialization_timeout: int = 300): + initialization_timeout: int = 300, + coordinator_bind_address: str | None = None): coordinator_address = (coordinator_address or os.environ.get('JAX_COORDINATOR_ADDRESS', None)) if isinstance(local_device_ids, int): @@ -66,6 +67,15 @@ def initialize(self, self.coordinator_address = coordinator_address + # The default value of [::]:port tells the coordinator to bind to all + # available addresses on the same port as coordinator_address. + default_coordinator_bind_address = '[::]:' + coordinator_address.rsplit(':', 1)[1] + coordinator_bind_address = (coordinator_bind_address or + os.environ.get('JAX_COORDINATOR_BIND_ADDRESS', + default_coordinator_bind_address)) + if coordinator_bind_address is None: + raise ValueError('coordinator_bind_address should be defined.') + if local_device_ids: visible_devices = ','.join(str(x) for x in local_device_ids) # type: ignore[union-attr] logger.info('JAX distributed initialized with visible devices: %s', visible_devices) @@ -79,7 +89,7 @@ def initialize(self, raise RuntimeError('distributed.initialize should only be called once.') logger.info('Starting JAX distributed service on %s', coordinator_address) self.service = xla_extension.get_distributed_runtime_service( - coordinator_address, num_processes) + coordinator_bind_address, num_processes) self.num_processes = num_processes @@ -118,7 +128,8 @@ def initialize(coordinator_address: str | None = None, num_processes: int | None = None, process_id: int | None = None, local_device_ids: int | Sequence[int] | None = None, - initialization_timeout: int = 300): + initialization_timeout: int = 300, + coordinator_bind_address: str | None = None): """Initializes the JAX distributed system. Calling :func:`~jax.distributed.initialize` prepares JAX for execution on @@ -156,6 +167,11 @@ def initialize(coordinator_address: str | None = None, initialization_timeout: Time period (in seconds) for which connection will be retried. If the initialization takes more than the timeout specified, the initialization will error. Defaults to 300 secs i.e. 5 mins. + coordinator_bind_address: the address and port to which the coordinator service + on process `0` should bind. If this is not specified, the default is to bind to + all available addresses on the same port as ``coordinator_address``. On systems + that have multiple network interfaces per node it may be insufficient to only + have the coordinator service listen on one address/interface. Raises: RuntimeError: If :func:`~jax.distributed.initialize` is called more than once. @@ -178,7 +194,7 @@ def initialize(coordinator_address: str | None = None, raise RuntimeError("jax.distributed.initialize() must be called before " "any JAX computations are executed.") global_state.initialize(coordinator_address, num_processes, process_id, - local_device_ids, initialization_timeout) + local_device_ids, initialization_timeout, coordinator_bind_address) atexit.register(shutdown)