Skip to content

Commit

Permalink
Add find_unused_parameters option
Browse files Browse the repository at this point in the history
As unused parameters in modules may not be expected sometimes, 
add an explicit error msg when it occurred and an option to avoid the error: microsoft#707
  • Loading branch information
ghosthamlet authored Apr 12, 2021
1 parent 3c15d98 commit ecc8b11
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def __init__(self,
allreduce_always_fp32=False,
postscale_gradients=True,
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1):
gradient_accumulation_steps=1,
find_unused_parameters=False):

# Load pre-installed or JIT compile (un)flatten ops
util_ops = UtilsBuilder().load()
Expand Down Expand Up @@ -173,6 +174,7 @@ def __init__(self,
self.postscale_gradients = postscale_gradients
self.gradient_accumulation_steps = gradient_accumulation_steps
self.micro_step_id = 0
self.find_unused_parameters = find_unused_parameters

if self.reduce_scatter:
assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled"
Expand Down Expand Up @@ -889,6 +891,17 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
if param_id in self.norm_for_param_grads:
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item()**2
else:
# As unused parameters in modules may not be expected sometimes,
# add an explicit error msg when it occurred and an option to avoid the error
# Error msg adapted from torch.nn.parallel.DistributedDataParallel
assert self.find_unused_parameters, """
This error indicates that your module has parameters that were not used in producing loss.
You can enable unused parameter detection by
(1) passing the keyword argument `find_unused_parameters=True` to `deepspeed.runtime.engine.DeepSpeedEngine`;
(2) making sure all trainable parameters and `forward` function outputs participate in calculating loss.
"""


# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
Expand Down

0 comments on commit ecc8b11

Please sign in to comment.