diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 0e7a6115b091..d84c9b988597 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1665,18 +1665,21 @@ def step(self, closure=None): return # Step 1:- Calculate gradient norm using fp-16 grads - see_memory_usage('Before norm calculation') - scaled_global_grad_norm = self.scaled_global_norm() - self._global_grad_norm = scaled_global_grad_norm / prev_scale + if self.dtype == torch.float16: + see_memory_usage('Before norm calculation') + scaled_global_grad_norm = self.scaled_global_norm() + self._global_grad_norm = scaled_global_grad_norm / prev_scale + see_memory_usage('After norm before optimizer') - see_memory_usage('After norm before optimizer') # Step 2:- run optimizer and upscaling simultaneously for i, group in enumerate(self.bit16_groups): self.start_timers([OPTIMIZER_GRADIENTS]) partition_id = dist.get_rank(group=self.real_dp_process_group[i]) if self.cpu_offload: single_grad_partition = self.single_partition_of_fp32_groups[i].grad - self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) + if self.dtype == torch.float16: + self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) + self.stop_timers([OPTIMIZER_GRADIENTS]) self.start_timers([OPTIMIZER_STEP]) self._optimizer_step(i) @@ -1715,7 +1718,9 @@ def step(self, closure=None): self.averaged_gradients[i] = None - self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) + if self.dtype == torch.float16: + self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm) + self.stop_timers([OPTIMIZER_GRADIENTS]) # Step 3:- run the optimizer if no offloading