Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change zero_grad() argument to match pytorch #2741

Merged
merged 2 commits into from
Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,14 @@ def initialize_optimizer_states(self):

return

def zero_grad(self, set_grads_to_None=True):
def zero_grad(self, set_to_none=False):
"""
Zero FP16 parameter grads.
"""
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
if set_to_none:
p.grad = None
else:
if p.grad is not None:
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/fp16/unfused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,15 @@ def __init__(self,

self.initialize_optimizer_states()

def zero_grad(self, set_grads_to_None=True):
def zero_grad(self, set_to_none=False):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist outside of the step function
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
if set_to_none:
p.grad = None
else:
if p.grad is not None:
Expand Down
8 changes: 4 additions & 4 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,7 +1509,7 @@ def get_partition_info(self, tensor_list, partition_size, partition_id):
return params_in_partition, params_not_in_partition, first_offset

@instrument_w_nvtx
def zero_grad(self, set_grads_to_None=True):
def zero_grad(self, set_to_none=False):
"""
Zero FP16 parameter grads.
"""
Expand All @@ -1519,7 +1519,7 @@ def zero_grad(self, set_grads_to_None=True):
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
if set_to_none:
if p.grad is not None and p.grad.is_cuda:
p.grad.record_stream(torch.cuda.current_stream())
p.grad = None
Expand Down Expand Up @@ -1708,7 +1708,7 @@ def _prepare_fp32_grad_for_sub_group(self, sub_group_id):
self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition

# release all the gradient since we have already created a necessary copy in dp_grad_partition
self.zero_grad()
self.zero_grad(set_to_none=True)

for grad in filter(lambda g: g.is_cuda, self.averaged_gradients[sub_group_id]):
grad.record_stream(torch.cuda.current_stream())
Expand Down Expand Up @@ -1816,7 +1816,7 @@ def _unflatten_partitioned_parameters(self, sub_group_id):

def _overflow_clean_up(self, prev_scale):
see_memory_usage('After overflow before clearing gradients', force=False)
self.zero_grad()
self.zero_grad(set_to_none=True)

if self.offload_optimizer:
self.reset_cpu_buffers()
Expand Down
8 changes: 4 additions & 4 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def independent_gradient_partition_epilogue(self):
# No need to keep the gradients anymore.
# All gradients required by the step
# are in self.averaged_gradients
self.zero_grad()
self.zero_grad(set_to_none=True)
see_memory_usage(f"End ipg_epilogue")

# resets all partition to no reduced
Expand Down Expand Up @@ -1526,15 +1526,15 @@ def get_partition_info(self, tensor_list, partition_size, partition_id):

return params_in_partition, params_not_in_partition, first_offset

def zero_grad(self, set_grads_to_None=True):
def zero_grad(self, set_to_none=False):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self.bit16_groups:
for p in group:
if set_grads_to_None:
if set_to_none:
p.grad = None # epilogue and in step
else:
if p.grad is not None:
Expand Down Expand Up @@ -1766,7 +1766,7 @@ def step(self, closure=None):
self.loss_scale))

see_memory_usage('After overflow before clearing gradients')
self.zero_grad()
self.zero_grad(set_to_none=True)
if self.cpu_offload:
self.reset_cpu_buffers()
else:
Expand Down