Skip to content

Commit

Permalink
Support sharded optimizer state dumping outside of sharded strategies (
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Aug 26, 2022
1 parent 6a999f1 commit e67842d
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 35 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for passing extra init-parameters to the `LightningDataModule.from_datasets` ([#14185](https://github.com/Lightning-AI/lightning/issues/14185))


- Added support for saving sharded optimizer state dict outside of `DDPShardedStrategy` ([#14208](https://github.com/PyTorchLightning/pytorch-lightning/pull/14208))



### Changed

Expand Down
17 changes: 1 addition & 16 deletions src/pytorch_lightning/strategies/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Dict, Generator, List, Optional, Tuple, Union
from typing import Dict, Generator, List, Tuple, Union

from torch import Tensor
from torch.nn import Module
Expand All @@ -27,7 +27,6 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
Expand Down Expand Up @@ -120,20 +119,6 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, Lightnin
del optimizer
return optimizers

def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)

@rank_zero_only
def _optim_state_dict(self, optimizer):
"""
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
:meth:`consolidate_state_dict`.
"""
return optimizer.state_dict()

def pre_backward(self, closure_loss: Tensor) -> None:
pass

Expand Down
16 changes: 1 addition & 15 deletions src/pytorch_lightning/strategies/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Tuple
from typing import Dict, Generator, List, Tuple

from torch import Tensor
from torch.nn import Module
Expand All @@ -25,7 +25,6 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
Expand Down Expand Up @@ -85,11 +84,6 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:

return self._reinit_optimizers_with_oss(optimizers)

def optimizer_state(self, optimizer: "OSS") -> Dict[str, Any]:
if isinstance(optimizer, OSS):
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)

@contextmanager
def block_backward_sync(self) -> Generator:
"""Blocks syncing gradients behaviour on backwards pass.
Expand All @@ -103,14 +97,6 @@ def block_backward_sync(self) -> Generator:
else:
yield None

@rank_zero_only
def _optim_state_dict(self, optimizer: Optimizer) -> Dict[str, Any]:
"""
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
:meth:`consolidate_state_dict`.
"""
return optimizer.state_dict()

def pre_backward(self, closure_loss: Tensor) -> None:
pass

Expand Down
10 changes: 10 additions & 0 deletions src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
Allows for syncing/collating optimizer state from processes in custom plugins.
"""
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer

if hasattr(optimizer, "consolidate_state_dict"):
# there are optimizers like Fairscale's OSS or PyTorch's ZeroRedundancyOptimizer that shard their
# states, and to avoid OOM we consolidate the full state on rank 0 only
optimizer.consolidate_state_dict()
return optimizer.state_dict() if self.is_global_zero else {}

# for optimizers that are not sharded, we return the state dict on all ranks
return optimizer.state_dict()

def backward(
Expand Down
53 changes: 53 additions & 0 deletions tests/tests_pytorch/strategies/test_ddp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,14 @@
from pytorch_lightning.plugins.environments import ClusterEnvironment, LightningEnvironment
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _TORCH_GREATER_EQUAL_1_10
from tests_pytorch.helpers.runif import RunIf

if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
if _TORCH_GREATER_EQUAL_1_10:
from torch.distributed.optim import ZeroRedundancyOptimizer


class BoringModelGPU(BoringModel):
def on_train_start(self) -> None:
Expand Down Expand Up @@ -252,3 +258,50 @@ def test_ddp_strategy_set_timeout(mock_init_process_group):
mock_init_process_group.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
)


class BoringFairScaleOptimizerModel(BoringModel):
def configure_optimizers(self):
base_optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return OSS(params=base_optimizer.param_groups, optim=type(base_optimizer), **base_optimizer.defaults)


@RunIf(min_cuda_gpus=2, skip_windows=True, fairscale=True)
@pytest.mark.parametrize("strategy", (pytest.param("ddp", marks=RunIf(standalone=True)), "ddp_spawn"))
def test_ddp_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, strategy):
"""Test to ensure that checkpoint is saved correctly when using faircale optimizer."""
model = BoringFairScaleOptimizerModel()
trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_steps=1)

trainer.fit(model)

checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)

# Assert model parameters are identical after loading
for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(trained_param.to("cpu"), loaded_param)


class BoringZeroRedundancyOptimizerModel(BoringModel):
def configure_optimizers(self):
return ZeroRedundancyOptimizer(self.layer.parameters(), optimizer_class=torch.optim.Adam, lr=0.1)


@RunIf(min_cuda_gpus=2, skip_windows=True, min_torch="1.10")
@pytest.mark.parametrize("strategy", (pytest.param("ddp", marks=RunIf(standalone=True)), "ddp_spawn"))
def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(tmpdir, strategy):
"""Test to ensure that checkpoint is saved correctly when using zero redundancy optimizer."""
model = BoringZeroRedundancyOptimizerModel()
trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_steps=1)

trainer.fit(model)

checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)

# Assert model parameters are identical after loading
for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(trained_param.to("cpu"), loaded_param)
33 changes: 29 additions & 4 deletions tests/tests_pytorch/strategies/test_sharded_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
from fairscale.optim import OSS


@pytest.mark.parametrize("clip_val", [0, 10])
Expand Down Expand Up @@ -70,8 +71,8 @@ def test_ddp_sharded_strategy_checkpoint_cpu(tmpdir):
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)

# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(ddp_param.to("cpu"), shard_param)
for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(trained_param.to("cpu"), loaded_param)


@RunIf(min_cuda_gpus=2, skip_windows=True, fairscale=True)
Expand All @@ -87,8 +88,8 @@ def test_ddp_sharded_strategy_checkpoint_multi_gpu(tmpdir):
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)

# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(ddp_param.to("cpu"), shard_param)
for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(trained_param.to("cpu"), loaded_param)


@RunIf(min_cuda_gpus=2, skip_windows=True, fairscale=True)
Expand Down Expand Up @@ -314,3 +315,27 @@ def test_block_backward_sync():
def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs):
trainer = Trainer(strategy=strategy_name)
assert trainer.strategy._ddp_kwargs == expected_ddp_kwargs


class BoringFairScaleOptimizerModel(BoringModel):
def configure_optimizers(self):
base_optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return OSS(params=base_optimizer.param_groups, optim=type(base_optimizer), **base_optimizer.defaults)


@RunIf(min_cuda_gpus=2, skip_windows=True, fairscale=True)
@pytest.mark.parametrize("strategy", (pytest.param("ddp_sharded", marks=RunIf(standalone=True)), "ddp_sharded_spawn"))
def test_ddp_sharded_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, strategy):
"""Test to ensure that checkpoint is saved correctly when using fairscale optimizers."""
model = BoringFairScaleOptimizerModel()
trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_steps=1)

trainer.fit(model)

checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)

# Assert model parameters are identical after loading
for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(trained_param.to("cpu"), loaded_param)

0 comments on commit e67842d

Please sign in to comment.