Skip to content

Commit

Permalink
Slightly safer multi node (#15538)
Browse files Browse the repository at this point in the history
update

Co-authored-by: Luca Antiga <[email protected]>
  • Loading branch information
tchaton and lantiga authored Nov 5, 2022
1 parent dcfaa06 commit d48aa03
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions src/lightning_app/components/multi_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,27 +58,30 @@ def run(
self._cloud_compute = cloud_compute
self._work_args = work_args
self._work_kwargs = work_kwargs
self.has_initialized = False
self.has_started = False

def run(self) -> None:
# 1. Create & start the works
if not self.has_initialized:
for node_rank in range(self.nodes):
self.ws.append(
self._work_cls(
*self._work_args,
cloud_compute=self._cloud_compute,
**self._work_kwargs,
parallel=True,
if not self.has_started:

# 1. Create & start the works
if not self.ws:
for node_rank in range(self.nodes):
self.ws.append(
self._work_cls(
*self._work_args,
cloud_compute=self._cloud_compute,
**self._work_kwargs,
parallel=True,
)
)
)
# Starting node `node_rank`` ...
self.ws[-1].start()
self.has_initialized = True
# Starting node `node_rank`` ...
self.ws[-1].start()

# 2. Wait for all machines to be started !
if not all(w.status.stage == WorkStageStatus.STARTED for w in self.ws):
return

# 2. Wait for all machines to be started !
if all(w.status.stage == WorkStageStatus.STARTED for w in self.ws):
return
self.has_started = True

# Loop over all node machines
for node_rank in range(self.nodes):
Expand Down

0 comments on commit d48aa03

Please sign in to comment.