Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sharded optimizers outside of DDP sharded strategy #11867

Conversation

DuYicong515
Copy link
Contributor

@DuYicong515 DuYicong515 commented Feb 11, 2022

What does this PR do?

Fixes #6387

Does your PR introduce any breaking changes? If yes, please list them.

This check was removed.
https://github.com/PyTorchLightning/pytorch-lightning/blob/6f22b3623c28028026b3cb8bb534c1ebca9c5ac8/pytorch_lightning/strategies/sharded.py#L88-L90

However, the check was also removed in sharded_spawn,py, also the optimizer type hint is "OSS", this might not happen.
https://github.com/PyTorchLightning/pytorch-lightning/blob/6f22b3623c28028026b3cb8bb534c1ebca9c5ac8/pytorch_lightning/strategies/sharded_spawn.py#L72-L75

The original PR link that removed the check

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • [N/A] Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • [N/A] Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

@DuYicong515 DuYicong515 force-pushed the bug/optimizer_state_dict branch from 180ac6a to 57d6ece Compare February 11, 2022 22:46
@DuYicong515 DuYicong515 marked this pull request as ready for review February 11, 2022 22:58
@carmocca carmocca changed the base branch from master to ananthsub-patch-1 February 12, 2022 04:12
CHANGELOG.md Outdated Show resolved Hide resolved
Comment on lines 158 to 179
def optimizer_state(self, optimizer: Optimizer) -> Optional[Dict[str, Tensor]]:
"""Returns state of an optimizer.

Allows for syncing/collating optimizer state from processes in custom plugins.
"""
if (_TORCH_GREATER_EQUAL_1_10 and isinstance(optimizer, ZeroRedundancyOptimizer)) or (
_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS)
):
optimizer.consolidate_state_dict()
# only call state_dict on the rank where the states were consolidated
return self._rank_zero_only_optim_state_dict(optimizer)
else:
return optimizer.state_dict()

@rank_zero_only
def _rank_zero_only_optim_state_dict(self, optimizer):
"""
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
:meth:`consolidate_state_dict`.
"""
return optimizer.state_dict()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimizers can return a nested dict in their state dict due to the parameter groups. the typehint here isn't correct. let's not change the return type to be optional either

Suggested change
def optimizer_state(self, optimizer: Optimizer) -> Optional[Dict[str, Tensor]]:
"""Returns state of an optimizer.
Allows for syncing/collating optimizer state from processes in custom plugins.
"""
if (_TORCH_GREATER_EQUAL_1_10 and isinstance(optimizer, ZeroRedundancyOptimizer)) or (
_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS)
):
optimizer.consolidate_state_dict()
# only call state_dict on the rank where the states were consolidated
return self._rank_zero_only_optim_state_dict(optimizer)
else:
return optimizer.state_dict()
@rank_zero_only
def _rank_zero_only_optim_state_dict(self, optimizer):
"""
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
:meth:`consolidate_state_dict`.
"""
return optimizer.state_dict()
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Any]:
"""Returns state of an optimizer.
Allows for syncing/collating optimizer state from processes in custom plugins.
"""
if (_TORCH_GREATER_EQUAL_1_10 and isinstance(optimizer, ZeroRedundancyOptimizer)) or (
_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS)
):
optimizer.consolidate_state_dict()
# only call state_dict on the rank where the states were consolidated
return optimizer.state_dict() if self.is_global_zero else {}
else:
return optimizer.state_dict()

Copy link
Contributor Author

@DuYicong515 DuYicong515 Feb 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion!

One question: wouldn't this different from the original behavior of optimizer_state function of DDPShardedStrategy, which returns None when rank!=0
https://github.com/PyTorchLightning/pytorch-lightning/blob/79c4e5de60685dbc895641b0139ffc6180d069aa/pytorch_lightning/strategies/sharded.py#L88-L100

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, though this is visible to users as the checkpoint is ultimately only saved from rank 0, which contains all the states

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Addressed the comment. BTW do you mean the return value change in non-rank-0 state dict is invisible to users?

@carmocca carmocca added this to the 1.5.x milestone Feb 16, 2022
@ananthsub ananthsub deleted the branch Lightning-AI:ananthsub-patch-1 February 17, 2022 06:11
@DuYicong515
Copy link
Contributor Author

Will redo once #11952 gets merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for sharded optimizers when dumping checkpoints outside of the DDP sharded training type plugin
3 participants