Skip to content

Commit

Permalink
Fix zero stage2 cpu_offload when some model trainable parameters skip…
Browse files Browse the repository at this point in the history
…ped in training, as in microsoft#707

As some model trainable parameters skipped in training,
their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run, 
so they have no norm_for_param_grads
  • Loading branch information
ghosthamlet authored Mar 15, 2021
1 parent 517357e commit d8f1dcd
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,8 +878,12 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
for p in params:
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
param_id = self.get_param_id(p)
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item()**2
# as some model have trainable parameters but skipped in training,
# their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run,
# so they have no norm_for_param_grads
if param_id in self.norm_for_param_grads:
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item()**2

# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
Expand Down

0 comments on commit d8f1dcd

Please sign in to comment.