From cd5148f320f0e3154abd6721f338d7687fbf6076 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 25 May 2023 16:35:50 -0400 Subject: [PATCH] DS init should not broadcast or move zero.Init models (#3611) --- deepspeed/runtime/engine.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index fc910f1ac47f..93ab0bdefc91 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1027,20 +1027,22 @@ def _set_client_model(self, model): def _configure_distributed_model(self, model): self._set_client_model(model) + is_zero3_model = self.zero_optimization_partition_weights() and any( + [hasattr(param, "ds_id") for param in self.module.parameters()]) + if self.fp16_enabled(): - if self.zero_optimization_partition_weights() and any( - [hasattr(param, "ds_id") for param in self.module.parameters()]): + if is_zero3_model: self.__check_params(self.module, torch.half) self.module.half() elif self.bfloat16_enabled(): - if self.zero_optimization_partition_weights() and any( - hasattr(param, 'ds_id') for param in self.module.parameters()): + if is_zero3_model: self.__check_params(self.module, torch.bfloat16) self.module.bfloat16() else: self.__check_params(self.module, torch.float) - if not self.dont_change_device: + # zero.Init() handles device placement of model + if not (self.dont_change_device or is_zero3_model): self.module.to(self.device) # MoE related initialization @@ -1076,7 +1078,7 @@ def _configure_distributed_model(self, model): self.expert_parallel_group = groups._get_expert_parallel_group_dict() self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict() - if not self.amp_enabled(): + if not (self.amp_enabled() or is_zero3_model): self._broadcast_model() # check if parameters are duplicated in optimizer param_groups