Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sharded optimizer state dumping outside of sharded strategies #14208

Merged
merged 17 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added prefix to log message in `seed_everything` with rank info ([#13290](https://github.com/Lightning-AI/lightning/issues/13290))


-
- Added support for saving sharded optimizer state dict outside of `DDPShardedStrategy` ([#14208](https://github.com/PyTorchLightning/pytorch-lightning/pull/14208))



### Changed
Expand Down
17 changes: 1 addition & 16 deletions src/pytorch_lightning/strategies/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Dict, Generator, List, Optional, Tuple, Union
from typing import Dict, Generator, List, Tuple, Union

from torch import Tensor
from torch.nn import Module
Expand All @@ -27,7 +27,6 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
Expand Down Expand Up @@ -120,20 +119,6 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, Lightnin
del optimizer
return optimizers

def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)

@rank_zero_only
def _optim_state_dict(self, optimizer):
"""
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
:meth:`consolidate_state_dict`.
"""
return optimizer.state_dict()

def pre_backward(self, closure_loss: Tensor) -> None:
pass

Expand Down
16 changes: 1 addition & 15 deletions src/pytorch_lightning/strategies/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Tuple
from typing import Dict, Generator, List, Tuple

from torch import Tensor
from torch.nn import Module
Expand All @@ -25,7 +25,6 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
Expand Down Expand Up @@ -85,11 +84,6 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:

return self._reinit_optimizers_with_oss(optimizers)

def optimizer_state(self, optimizer: "OSS") -> Dict[str, Any]:
if isinstance(optimizer, OSS):
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)

@contextmanager
def block_backward_sync(self) -> Generator:
"""Blocks syncing gradients behaviour on backwards pass.
Expand All @@ -103,14 +97,6 @@ def block_backward_sync(self) -> Generator:
else:
yield None

@rank_zero_only
def _optim_state_dict(self, optimizer: Optimizer) -> Dict[str, Any]:
"""
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
:meth:`consolidate_state_dict`.
"""
return optimizer.state_dict()

def pre_backward(self, closure_loss: Tensor) -> None:
pass

Expand Down
10 changes: 10 additions & 0 deletions src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:

Allows for syncing/collating optimizer state from processes in custom plugins.
"""
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer

if hasattr(optimizer, "consolidate_state_dict"):
# there are optimizers like Fairscale's OSS or PyTorch's ZeroRedundancyOptimizer that shard their
# states, and to avoid OOM we consolidate the full state on rank 0 only
optimizer.consolidate_state_dict()
return optimizer.state_dict() if self.is_global_zero else {}
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# for optimizers that are not sharded, we return the state dict on all ranks
return optimizer.state_dict()

def backward(
Expand Down
53 changes: 53 additions & 0 deletions tests/tests_pytorch/strategies/test_ddp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,14 @@
from pytorch_lightning.plugins.environments import ClusterEnvironment, LightningEnvironment
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _TORCH_GREATER_EQUAL_1_10
from tests_pytorch.helpers.runif import RunIf

if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
if _TORCH_GREATER_EQUAL_1_10:
from torch.distributed.optim import ZeroRedundancyOptimizer


class BoringModelGPU(BoringModel):
def on_train_start(self) -> None:
Expand Down Expand Up @@ -252,3 +258,50 @@ def test_ddp_strategy_set_timeout(mock_init_process_group):
mock_init_process_group.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
)


class BoringFairScaleOptimizerModel(BoringModel):
def configure_optimizers(self):
base_optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return OSS(params=base_optimizer.param_groups, optim=type(base_optimizer), **base_optimizer.defaults)


@RunIf(min_cuda_gpus=2, skip_windows=True, fairscale=True)
@pytest.mark.parametrize("strategy", (pytest.param("ddp", marks=RunIf(standalone=True)), "ddp_spawn"))
def test_ddp_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, strategy):
"""Test to ensure that checkpoint is saved correctly when using faircale optimizer."""
model = BoringFairScaleOptimizerModel()
trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_epochs=2)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

trainer.fit(model)

checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)

# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
assert torch.equal(ddp_param.to("cpu"), shard_param)


class BoringZeroRedundancyOptimizerModel(BoringModel):
def configure_optimizers(self):
return ZeroRedundancyOptimizer(self.layer.parameters(), optimizer_class=torch.optim.Adam, lr=0.1)


@RunIf(min_cuda_gpus=2, skip_windows=True, min_torch="1.10")
@pytest.mark.parametrize("strategy", (pytest.param("ddp", marks=RunIf(standalone=True)), "ddp_spawn"))
def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(tmpdir, strategy):
"""Test to ensure that checkpoint is saved correctly when using zero redundancy optimizer."""
model = BoringZeroRedundancyOptimizerModel()
trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_epochs=2)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

trainer.fit(model)

checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)

# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(ddp_param.to("cpu"), shard_param)
25 changes: 25 additions & 0 deletions tests/tests_pytorch/strategies/test_sharded_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
from fairscale.optim import OSS


@pytest.mark.parametrize("clip_val", [0, 10])
Expand Down Expand Up @@ -314,3 +315,27 @@ def test_block_backward_sync():
def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs):
trainer = Trainer(strategy=strategy_name)
assert trainer.strategy._ddp_kwargs == expected_ddp_kwargs


class BoringFairScaleOptimizerModel(BoringModel):
def configure_optimizers(self):
base_optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return OSS(params=base_optimizer.param_groups, optim=type(base_optimizer), **base_optimizer.defaults)


@RunIf(min_cuda_gpus=2, skip_windows=True, fairscale=True)
@pytest.mark.parametrize("strategy", (pytest.param("ddp_sharded", marks=RunIf(standalone=True)), "ddp_sharded_spawn"))
def test_ddp_sharded_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, strategy):
"""Test to ensure that checkpoint is saved correctly when using fairscale optimizers."""
model = BoringFairScaleOptimizerModel()
trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_epochs=2)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

trainer.fit(model)

checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)

# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(ddp_param.to("cpu"), shard_param)