diff --git a/src/lightning_app/components/multi_node.py b/src/lightning_app/components/multi_node.py index 3d308b83c3a5f..66bebb76a492b 100644 --- a/src/lightning_app/components/multi_node.py +++ b/src/lightning_app/components/multi_node.py @@ -58,27 +58,30 @@ def run( self._cloud_compute = cloud_compute self._work_args = work_args self._work_kwargs = work_kwargs - self.has_initialized = False + self.has_started = False def run(self) -> None: - # 1. Create & start the works - if not self.has_initialized: - for node_rank in range(self.nodes): - self.ws.append( - self._work_cls( - *self._work_args, - cloud_compute=self._cloud_compute, - **self._work_kwargs, - parallel=True, + if not self.has_started: + + # 1. Create & start the works + if not self.ws: + for node_rank in range(self.nodes): + self.ws.append( + self._work_cls( + *self._work_args, + cloud_compute=self._cloud_compute, + **self._work_kwargs, + parallel=True, + ) ) - ) - # Starting node `node_rank`` ... - self.ws[-1].start() - self.has_initialized = True + # Starting node `node_rank`` ... + self.ws[-1].start() + + # 2. Wait for all machines to be started ! + if not all(w.status.stage == WorkStageStatus.STARTED for w in self.ws): + return - # 2. Wait for all machines to be started ! - if all(w.status.stage == WorkStageStatus.STARTED for w in self.ws): - return + self.has_started = True # Loop over all node machines for node_rank in range(self.nodes):