Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add find_unused_parameters option to DeepSpeedEngine #945

Merged
merged 32 commits into from
Apr 25, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
976629d
Merge pull request #3 from microsoft/master
ghosthamlet Apr 12, 2021
3c15d98
Add find_unused_parameters option
ghosthamlet Apr 12, 2021
ecc8b11
Add find_unused_parameters option
ghosthamlet Apr 12, 2021
429c0bb
Fix syntax error
ghosthamlet Apr 12, 2021
e32db04
Fix yapf error
ghosthamlet Apr 12, 2021
f27165f
Fix yapf error
ghosthamlet Apr 12, 2021
bcfee31
Fix yapf error
ghosthamlet Apr 12, 2021
afc2da1
Fix yapf error
ghosthamlet Apr 12, 2021
257b5ca
Move stage2 find_unused_parameters to config file
ghosthamlet Apr 24, 2021
f72e2e7
Add stage2 find_unused_parameters
ghosthamlet Apr 24, 2021
26d6a28
Add stage2 find_unused_parameters
ghosthamlet Apr 24, 2021
9249fa7
Add stage2_find_unused_parameters option
ghosthamlet Apr 24, 2021
6605053
Change error msg to reflect zero_optimization config change
ghosthamlet Apr 24, 2021
a197882
Merge branch 'master' into find_unused_parameters
ghosthamlet Apr 24, 2021
10702d1
Fix yapf error
ghosthamlet Apr 24, 2021
8844dea
Fix yapf errors
ghosthamlet Apr 24, 2021
8abcaee
Change find_unused_parameters option name
ghosthamlet Apr 25, 2021
7cedcdf
Change find_unused_parameters option name
ghosthamlet Apr 25, 2021
86464f7
Change find_unused_parameters option name
ghosthamlet Apr 25, 2021
0837674
Change find_unused_parameters option name
ghosthamlet Apr 25, 2021
ec4bdee
Change find_unused_parameters option name
ghosthamlet Apr 25, 2021
8b07db1
Add UnusedParametersModel for test option find_unused_parameters
ghosthamlet Apr 25, 2021
785d910
Add unit test for stage2 find_unused_parameters
ghosthamlet Apr 25, 2021
1430bd5
Add cpu-adam compatible check
ghosthamlet Apr 25, 2021
43ebbbb
Remove dups import
ghosthamlet Apr 25, 2021
42bccaa
Trim spaces
ghosthamlet Apr 25, 2021
d3d2eb4
Fix yapf errors
ghosthamlet Apr 25, 2021
1579123
Trim spaces
ghosthamlet Apr 25, 2021
405aa2f
Add False Positive test check
ghosthamlet Apr 25, 2021
ad3724e
Fix find_unused_parameters test
ghosthamlet Apr 25, 2021
89c9b81
Trim spaces
ghosthamlet Apr 25, 2021
2bb8b66
Fix yapf error
ghosthamlet Apr 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ def zero_cpu_offload(self):
def zero_sub_group_size(self):
return self._config.zero_config.sub_group_size

def zero_find_unused_parameters(self):
return self._config.zero_config.find_unused_parameters

def zero_optimization_stage(self):
return self._config.zero_optimization_stage

Expand Down Expand Up @@ -789,7 +792,8 @@ def _configure_zero_optimizer(self, optimizer):
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps())
gradient_accumulation_steps=self.gradient_accumulation_steps(),
find_unused_parameters=self.zero_find_unused_parameters())
elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS:
print("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
Expand Down
8 changes: 8 additions & 0 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def __init__(self, param_dict):
self.offload_optimizer = None
self.sub_group_size = None

#Stage2 Specific Parameters
self.find_unused_parameters = None

#Stage3 Specific Parameters
self.prefetch_bucket_size = None
self.param_persistence_threshold = None
Expand Down Expand Up @@ -151,6 +154,11 @@ def _initialize(self, zero_config_dict):
ZERO_OPTIMIZATION_SUB_GROUP_SIZE,
ZERO_OPTIMIZATION_SUB_GROUP_SIZE_DEFAULT)

self.find_unused_parameters = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_FIND_UNUSED_PARAMETERS,
ZERO_OPTIMIZATION_FIND_UNUSED_PARAMETERS_DEFAULT)

self.max_live_parameters = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_MAX_LIVE_PARAMETERS,
Expand Down
15 changes: 13 additions & 2 deletions deepspeed/runtime/zero/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
"cpu_offload_use_pin_memory": [true|false] (deprecated),
"sub_group_size" : 1000000000000,
"offload_param": {...},
"offload_optimizer": {...}
"offload_optimizer": {...},
"stage2_find_unused_parameters": [true|false]
ghosthamlet marked this conversation as resolved.
Show resolved Hide resolved
}
}
'''
Expand Down Expand Up @@ -113,6 +114,14 @@
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE = 'stage3_gather_fp16_weights_on_model_save'
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT = False

# Used in stage2 complete_grad_norm_calculation_for_cpu_offload
# Enable this option to avoid:
# https://github.com/microsoft/DeepSpeed/issues/707
# torch.nn.parallel.DistributedDataParallel has the same option with
# similar usage
ZERO_OPTIMIZATION_FIND_UNUSED_PARAMETERS = 'stage2_find_unused_parameters'
ZERO_OPTIMIZATION_FIND_UNUSED_PARAMETERS_DEFAULT = False

ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE:
ZERO_OPTIMIZATION_STAGE_DEFAULT,
Expand Down Expand Up @@ -145,5 +154,7 @@
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD:
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE:
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT,
ZERO_OPTIMIZATION_FIND_UNUSED_PARAMETERS:
ZERO_OPTIMIZATION_FIND_UNUSED_PARAMETERS_DEFAULT
}
17 changes: 16 additions & 1 deletion deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def __init__(self,
allreduce_always_fp32=False,
postscale_gradients=True,
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1):
gradient_accumulation_steps=1,
find_unused_parameters=False):

if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}")
Expand Down Expand Up @@ -149,6 +150,7 @@ def __init__(self,
self.postscale_gradients = postscale_gradients
self.gradient_accumulation_steps = gradient_accumulation_steps
self.micro_step_id = 0
self.find_unused_parameters = find_unused_parameters

if self.reduce_scatter:
assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled"
Expand Down Expand Up @@ -886,6 +888,19 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
if param_id in self.norm_for_param_grads:
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item()**2
else:
# As unused parameters in modules may not be expected sometimes,
# add an explicit error msg when it occurred and an option to
# avoid the error
# Error msg adapted from torch.nn.parallel.DistributedDataParallel
assert self.find_unused_parameters, """
This error indicates that your module has parameters that
were not used in producing loss.
You can avoid this error by
(1) enable stage2_find_unused_parameters option in zero_optimization config;
(2) making sure all trainable parameters and `forward` function
outputs participate in calculating loss.
"""

# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
Expand Down
8 changes: 7 additions & 1 deletion docs/_pages/config-json.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ Enabling and configuring ZeRO memory optimizations
"stage3_param_persistence_threshold" : 1e6,
"sub_group_size" : 1e12,
"elastic_checkpoint" : [true|false],
"stage3_gather_fp16_weights_on_model_save": [true|false]
"stage3_gather_fp16_weights_on_model_save": [true|false],
"stage2_find_unused_parameters": [true|false]
}
```

Expand Down Expand Up @@ -396,6 +397,7 @@ Enabling and configuring ZeRO memory optimizations
| -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Consolidate the weights before saving the model by `save_fp16_model()`. Since the weights are partitioned across GPUs, they aren't part of `state_dict`, so this function automatically gather the weights when this option is enabled and then saves the fp16 model weights. | `False` |


***cpu_offload***: [boolean]

**Deprecated:** **cpu_offload** is disabled and will be removed in future, please use `offload_optimizer` instead.
Expand Down Expand Up @@ -538,6 +540,10 @@ Configuring the asynchronous I/O module for offloading parameter and optimizer s
| ------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Submit requests to storage device in an overlapped fashion without waiting for completion of earlier requests. | `true` |

***stage2_find_unused_parameters***: [boolean]
| Description | Default |
| -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| As unused parameters in modules may not be expected sometimes, it will cause an explicit error msg when it occurred and enable this option to avoid the error, `torch.nn.parallel.DistributedDataParallel` has the same `find_unused_parameters` option with similar usage. | `False` |

### Logging

Expand Down