diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 6d24ed469f3a..5204d4b09de4 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -148,6 +148,8 @@ def forward(self, inputs): self.world_group = dist.new_group(ranks=range(dist.get_world_size())) self.global_rank = dist.get_rank(group=self.world_group) self.world_size = dist.get_world_size(group=self.world_group) + self.local_rank = int(os.environ.get("LOCAL_RANK", None)) + assert self.local_rank != None if topology: self._topo = topology @@ -186,7 +188,7 @@ def forward(self, inputs): #with torch.random.fork_rng(devices=[torch.cuda.current_device()]): self._build() - self.to('cuda') + self.to(f'cuda:{self.local_rank}') self.tied_comms = self._index_tied_modules() self._synchronize_tied_weights()