-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support sharded optimizers outside of DDP sharded strategy #11867
Support sharded optimizers outside of DDP sharded strategy #11867
Conversation
db55ef2
to
180ac6a
Compare
180ac6a
to
57d6ece
Compare
def optimizer_state(self, optimizer: Optimizer) -> Optional[Dict[str, Tensor]]: | ||
"""Returns state of an optimizer. | ||
|
||
Allows for syncing/collating optimizer state from processes in custom plugins. | ||
""" | ||
if (_TORCH_GREATER_EQUAL_1_10 and isinstance(optimizer, ZeroRedundancyOptimizer)) or ( | ||
_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS) | ||
): | ||
optimizer.consolidate_state_dict() | ||
# only call state_dict on the rank where the states were consolidated | ||
return self._rank_zero_only_optim_state_dict(optimizer) | ||
else: | ||
return optimizer.state_dict() | ||
|
||
@rank_zero_only | ||
def _rank_zero_only_optim_state_dict(self, optimizer): | ||
""" | ||
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling | ||
:meth:`consolidate_state_dict`. | ||
""" | ||
return optimizer.state_dict() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optimizers can return a nested dict in their state dict due to the parameter groups. the typehint here isn't correct. let's not change the return type to be optional either
def optimizer_state(self, optimizer: Optimizer) -> Optional[Dict[str, Tensor]]: | |
"""Returns state of an optimizer. | |
Allows for syncing/collating optimizer state from processes in custom plugins. | |
""" | |
if (_TORCH_GREATER_EQUAL_1_10 and isinstance(optimizer, ZeroRedundancyOptimizer)) or ( | |
_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS) | |
): | |
optimizer.consolidate_state_dict() | |
# only call state_dict on the rank where the states were consolidated | |
return self._rank_zero_only_optim_state_dict(optimizer) | |
else: | |
return optimizer.state_dict() | |
@rank_zero_only | |
def _rank_zero_only_optim_state_dict(self, optimizer): | |
""" | |
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling | |
:meth:`consolidate_state_dict`. | |
""" | |
return optimizer.state_dict() | |
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Any]: | |
"""Returns state of an optimizer. | |
Allows for syncing/collating optimizer state from processes in custom plugins. | |
""" | |
if (_TORCH_GREATER_EQUAL_1_10 and isinstance(optimizer, ZeroRedundancyOptimizer)) or ( | |
_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS) | |
): | |
optimizer.consolidate_state_dict() | |
# only call state_dict on the rank where the states were consolidated | |
return optimizer.state_dict() if self.is_global_zero else {} | |
else: | |
return optimizer.state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion!
One question: wouldn't this different from the original behavior of optimizer_state
function of DDPShardedStrategy
, which returns None when rank!=0
https://github.com/PyTorchLightning/pytorch-lightning/blob/79c4e5de60685dbc895641b0139ffc6180d069aa/pytorch_lightning/strategies/sharded.py#L88-L100
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, though this is visible to users as the checkpoint is ultimately only saved from rank 0, which contains all the states
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Addressed the comment. BTW do you mean the return value change in non-rank-0 state dict is invisible to users?
Co-authored-by: ananthsub <[email protected]>
…orch-lightning into bug/optimizer_state_dict
Will redo once #11952 gets merged |
What does this PR do?
Fixes #6387
Does your PR introduce any breaking changes? If yes, please list them.
This check was removed.
https://github.com/PyTorchLightning/pytorch-lightning/blob/6f22b3623c28028026b3cb8bb534c1ebca9c5ac8/pytorch_lightning/strategies/sharded.py#L88-L90
However, the check was also removed in sharded_spawn,py, also the optimizer type hint is "OSS", this might not happen.
https://github.com/PyTorchLightning/pytorch-lightning/blob/6f22b3623c28028026b3cb8bb534c1ebca9c5ac8/pytorch_lightning/strategies/sharded_spawn.py#L72-L75
The original PR link that removed the check
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃