diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index 7734d6ef0a29..ea3afa77e404 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -130,14 +130,14 @@ def initialize_optimizer_states(self): return - def zero_grad(self, set_grads_to_None=True): + def zero_grad(self, set_to_none=False): """ Zero FP16 parameter grads. """ # For speed, set model fp16 grad to None by default for group in self.fp16_groups: for p in group: - if set_grads_to_None: + if set_to_none: p.grad = None else: if p.grad is not None: diff --git a/deepspeed/runtime/fp16/unfused_optimizer.py b/deepspeed/runtime/fp16/unfused_optimizer.py index 3bf906404e87..f553c2815444 100755 --- a/deepspeed/runtime/fp16/unfused_optimizer.py +++ b/deepspeed/runtime/fp16/unfused_optimizer.py @@ -110,7 +110,7 @@ def __init__(self, self.initialize_optimizer_states() - def zero_grad(self, set_grads_to_None=True): + def zero_grad(self, set_to_none=False): """ Zero FP16 parameter grads. """ @@ -118,7 +118,7 @@ def zero_grad(self, set_grads_to_None=True): # For speed, set model fp16 grad to None by default for group in self.fp16_groups: for p in group: - if set_grads_to_None: + if set_to_none: p.grad = None else: if p.grad is not None: diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 2a242fcc2ecf..c3f834d0acf8 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1509,7 +1509,7 @@ def get_partition_info(self, tensor_list, partition_size, partition_id): return params_in_partition, params_not_in_partition, first_offset @instrument_w_nvtx - def zero_grad(self, set_grads_to_None=True): + def zero_grad(self, set_to_none=False): """ Zero FP16 parameter grads. """ @@ -1519,7 +1519,7 @@ def zero_grad(self, set_grads_to_None=True): # For speed, set model fp16 grad to None by default for group in self.fp16_groups: for p in group: - if set_grads_to_None: + if set_to_none: if p.grad is not None and p.grad.is_cuda: p.grad.record_stream(torch.cuda.current_stream()) p.grad = None @@ -1708,7 +1708,7 @@ def _prepare_fp32_grad_for_sub_group(self, sub_group_id): self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition # release all the gradient since we have already created a necessary copy in dp_grad_partition - self.zero_grad() + self.zero_grad(set_to_none=True) for grad in filter(lambda g: g.is_cuda, self.averaged_gradients[sub_group_id]): grad.record_stream(torch.cuda.current_stream()) @@ -1816,7 +1816,7 @@ def _unflatten_partitioned_parameters(self, sub_group_id): def _overflow_clean_up(self, prev_scale): see_memory_usage('After overflow before clearing gradients', force=False) - self.zero_grad() + self.zero_grad(set_to_none=True) if self.offload_optimizer: self.reset_cpu_buffers() diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index d1e292caee7b..62cbf5d07fe6 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -753,7 +753,7 @@ def independent_gradient_partition_epilogue(self): # No need to keep the gradients anymore. # All gradients required by the step # are in self.averaged_gradients - self.zero_grad() + self.zero_grad(set_to_none=True) see_memory_usage(f"End ipg_epilogue") # resets all partition to no reduced @@ -1526,7 +1526,7 @@ def get_partition_info(self, tensor_list, partition_size, partition_id): return params_in_partition, params_not_in_partition, first_offset - def zero_grad(self, set_grads_to_None=True): + def zero_grad(self, set_to_none=False): """ Zero FP16 parameter grads. """ @@ -1534,7 +1534,7 @@ def zero_grad(self, set_grads_to_None=True): # For speed, set model fp16 grad to None by default for group in self.bit16_groups: for p in group: - if set_grads_to_None: + if set_to_none: p.grad = None # epilogue and in step else: if p.grad is not None: @@ -1766,7 +1766,7 @@ def step(self, closure=None): self.loss_scale)) see_memory_usage('After overflow before clearing gradients') - self.zero_grad() + self.zero_grad(set_to_none=True) if self.cpu_offload: self.reset_cpu_buffers() else: