Skip to content

Commit

Permalink
Fix Task.launch_multi_node() not supported when used via pytorch ligh…
Browse files Browse the repository at this point in the history
…tning
  • Loading branch information
allegroai committed Jul 4, 2024
1 parent aa227a0 commit e27d277
Showing 1 changed file with 57 additions and 10 deletions.
67 changes: 57 additions & 10 deletions clearml/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ class Task(_Task):
__detect_repo_async = deferred_config('development.vcs_repo_detect_async', False)
__default_output_uri = DEV_DEFAULT_OUTPUT_URI.get() or deferred_config('development.default_output_uri', None)

__hidden_tag = "hidden"

_launch_multi_node_section = "launch_multi_node"
_launch_multi_node_instance_tag = "multi_node_instance"

Expand Down Expand Up @@ -1921,8 +1923,16 @@ def get_logger(self):
"""
return self._get_logger(auto_connect_streams=self._log_to_backend)

def launch_multi_node(self, total_num_nodes, port=29500, queue=None, wait=False, addr=None):
# type: (int, Optional[int], Optional[str], bool, Optional[str]) -> dict
def launch_multi_node(
self,
total_num_nodes, # type: int
port=29500, # type: Optional[int]
queue=None, # type: Optional[str]
wait=False, # type: bool
addr=None, # type: Optional[str]
devices=None, # type: Optional[Union[int, Sequence[int]]]
hide_children=False # bool
):
"""
Enqueue multiple clones of the current task to a queue, allowing the task
to be ran by multiple workers in parallel. Each task running this way is called a node.
Expand Down Expand Up @@ -1996,6 +2006,9 @@ def run(rank, size):
parameter will be set to the one defined in ``MASTER_ADDR``. If neither environment variables exist,
the value passed to the parameter will be used. If this value is None (default), the private IP of
the machine the master node is running on will be used.
:param devices: The devices to use. This can be a positive number indicating the number of devices to use,
a sequence of indices or the value ``-1`` to indicate all available devices should be used.
:param hide_children: If True, the children tasks will be hidden. Otherwise, they will be visible in the UI
:return: A dictionary containing relevant information regarding the multi node run. This dictionary has the following entries:
Expand All @@ -2006,9 +2019,12 @@ def run(rank, size):
- `node_rank` - the rank of the current node (master has rank 0)
- `wait` - if True, the master node will wait for the other nodes to start
"""

def set_launch_multi_node_runtime_props(task, conf):
# noinspection PyProtectedMember
task._set_runtime_properties({"{}/{}".format(self._launch_multi_node_section, k): v for k, v in conf.items()})
task._set_runtime_properties(
{"{}/{}".format(self._launch_multi_node_section, k): v for k, v in conf.items()}
)

if total_num_nodes < 1:
raise UsageError("total_num_nodes needs to be at least 1")
Expand All @@ -2024,6 +2040,7 @@ def set_launch_multi_node_runtime_props(task, conf):
),
"node_rank": 0,
"wait": wait,
"devices": devices
}
editable_conf = {"total_num_nodes": total_num_nodes, "queue": queue}
editable_conf = self.connect(editable_conf, name=self._launch_multi_node_section)
Expand All @@ -2033,23 +2050,27 @@ def set_launch_multi_node_runtime_props(task, conf):
runtime_properties = self._get_runtime_properties()
remote_node_rank = runtime_properties.get("{}/node_rank".format(self._launch_multi_node_section))

current_conf = master_conf
if remote_node_rank:
# self is a child node, build the conf from the runtime proprerties
current_conf = {
entry: runtime_properties.get("{}/{}".format(self._launch_multi_node_section, entry))
for entry in master_conf.keys()
}
else:
elif os.environ.get("CLEARML_MULTI_NODE_MASTER") is None:
nodes_to_wait = []
# self is the master node, enqueue the other nodes
set_launch_multi_node_runtime_props(self, master_conf)
current_conf = master_conf
for node_rank in range(1, master_conf.get("total_num_nodes", total_num_nodes)):
node = self.clone(source_task=self, parent=self.id)
node_conf = copy.deepcopy(master_conf)
node_conf["node_rank"] = node_rank
set_launch_multi_node_runtime_props(node, node_conf)
node.set_system_tags(node.get_system_tags() + [self._launch_multi_node_instance_tag])
node.set_system_tags(
node.get_system_tags()
+ [self._launch_multi_node_instance_tag]
+ ([self.__hidden_tag] if hide_children else [])
)
if master_conf.get("queue"):
Task.enqueue(node, queue_name=master_conf["queue"])
else:
Expand All @@ -2064,16 +2085,42 @@ def set_launch_multi_node_runtime_props(task, conf):
Task.TaskStatusEnum.stopped,
Task.TaskStatusEnum.closed,
Task.TaskStatusEnum.failed,
Task.TaskStatusEnum.in_progress
Task.TaskStatusEnum.in_progress,
),
check_interval_sec=10
check_interval_sec=10,
)
self.log.info("Node with task ID {} and rank {} detected".format(node_to_wait.id, rank))
os.environ["CLEARML_MULTI_NODE_MASTER"] = "1"

num_devices = 1
if devices is not None:
try:
num_devices = int(devices)
except TypeError:
try:
num_devices = len(devices)
except Exception as ex:
raise ValueError("Failed parsing number of devices: {}".format(ex))
except ValueError as ex:
raise ValueError("Failed parsing number of devices: {}".format(ex))
if num_devices < 0:
try:
import torch

num_devices = torch.cuda.device_count()
except ImportError:
raise ImportError(
"Could not import `torch` while finding the number of devices. "
"Please install it or set `devices` to a value different than -1"
)

os.environ["MASTER_ADDR"] = current_conf.get("master_addr", "")
os.environ["MASTER_PORT"] = str(current_conf.get("master_port", ""))
os.environ["WORLD_SIZE"] = str(current_conf.get("total_num_nodes", ""))
os.environ["RANK"] = str(current_conf.get("node_rank", ""))
os.environ["RANK"] = str(
current_conf.get("node_rank", 0) * num_devices + int(os.environ.get("LOCAL_RANK", "0"))
)
os.environ["NODE_RANK"] = str(current_conf.get("node_rank", ""))
os.environ["WORLD_SIZE"] = str(current_conf.get("total_num_nodes", total_num_nodes) * num_devices)

return current_conf

Expand Down

0 comments on commit e27d277

Please sign in to comment.