diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index ac7e68d177fbe..32303d6babb5d 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for passing extra init-parameters to the `LightningDataModule.from_datasets` ([#14185](https://github.com/Lightning-AI/lightning/issues/14185)) +- Added support for saving sharded optimizer state dict outside of `DDPShardedStrategy` ([#14208](https://github.com/PyTorchLightning/pytorch-lightning/pull/14208)) + + ### Changed diff --git a/src/pytorch_lightning/strategies/sharded.py b/src/pytorch_lightning/strategies/sharded.py index ce1e4cd96b961..3b77bc6ceeb70 100644 --- a/src/pytorch_lightning/strategies/sharded.py +++ b/src/pytorch_lightning/strategies/sharded.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Dict, Generator, List, Optional, Tuple, Union +from typing import Dict, Generator, List, Tuple, Union from torch import Tensor from torch.nn import Module @@ -27,7 +27,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE from pytorch_lightning.utilities.optimizer import optimizers_to_device -from pytorch_lightning.utilities.rank_zero import rank_zero_only if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel @@ -120,20 +119,6 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, Lightnin del optimizer return optimizers - def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: - if isinstance(optimizer, LightningOptimizer): - optimizer = optimizer._optimizer - optimizer.consolidate_state_dict() - return self._optim_state_dict(optimizer) - - @rank_zero_only - def _optim_state_dict(self, optimizer): - """ - Retrieves state dict only on rank 0, which contains the entire optimizer state after calling - :meth:`consolidate_state_dict`. - """ - return optimizer.state_dict() - def pre_backward(self, closure_loss: Tensor) -> None: pass diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index f19aae7302eea..01ccb75677544 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Tuple +from typing import Dict, Generator, List, Tuple from torch import Tensor from torch.nn import Module @@ -25,7 +25,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE from pytorch_lightning.utilities.optimizer import optimizers_to_device -from pytorch_lightning.utilities.rank_zero import rank_zero_only if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel @@ -85,11 +84,6 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: return self._reinit_optimizers_with_oss(optimizers) - def optimizer_state(self, optimizer: "OSS") -> Dict[str, Any]: - if isinstance(optimizer, OSS): - optimizer.consolidate_state_dict() - return self._optim_state_dict(optimizer) - @contextmanager def block_backward_sync(self) -> Generator: """Blocks syncing gradients behaviour on backwards pass. @@ -103,14 +97,6 @@ def block_backward_sync(self) -> Generator: else: yield None - @rank_zero_only - def _optim_state_dict(self, optimizer: Optimizer) -> Dict[str, Any]: - """ - Retrieves state dict only on rank 0, which contains the entire optimizer state after calling - :meth:`consolidate_state_dict`. - """ - return optimizer.state_dict() - def pre_backward(self, closure_loss: Tensor) -> None: pass diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index c09e7eae8c586..0abc5fe516273 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -170,6 +170,16 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: Allows for syncing/collating optimizer state from processes in custom plugins. """ + if isinstance(optimizer, LightningOptimizer): + optimizer = optimizer._optimizer + + if hasattr(optimizer, "consolidate_state_dict"): + # there are optimizers like Fairscale's OSS or PyTorch's ZeroRedundancyOptimizer that shard their + # states, and to avoid OOM we consolidate the full state on rank 0 only + optimizer.consolidate_state_dict() + return optimizer.state_dict() if self.is_global_zero else {} + + # for optimizers that are not sharded, we return the state dict on all ranks return optimizer.state_dict() def backward( diff --git a/tests/tests_pytorch/strategies/test_ddp_strategy.py b/tests/tests_pytorch/strategies/test_ddp_strategy.py index 8d2d965f1d4c6..318505a984216 100644 --- a/tests/tests_pytorch/strategies/test_ddp_strategy.py +++ b/tests/tests_pytorch/strategies/test_ddp_strategy.py @@ -24,8 +24,14 @@ from pytorch_lightning.plugins.environments import ClusterEnvironment, LightningEnvironment from pytorch_lightning.strategies import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _TORCH_GREATER_EQUAL_1_10 from tests_pytorch.helpers.runif import RunIf +if _FAIRSCALE_AVAILABLE: + from fairscale.optim import OSS +if _TORCH_GREATER_EQUAL_1_10: + from torch.distributed.optim import ZeroRedundancyOptimizer + class BoringModelGPU(BoringModel): def on_train_start(self) -> None: @@ -252,3 +258,50 @@ def test_ddp_strategy_set_timeout(mock_init_process_group): mock_init_process_group.assert_called_with( process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta ) + + +class BoringFairScaleOptimizerModel(BoringModel): + def configure_optimizers(self): + base_optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return OSS(params=base_optimizer.param_groups, optim=type(base_optimizer), **base_optimizer.defaults) + + +@RunIf(min_cuda_gpus=2, skip_windows=True, fairscale=True) +@pytest.mark.parametrize("strategy", (pytest.param("ddp", marks=RunIf(standalone=True)), "ddp_spawn")) +def test_ddp_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, strategy): + """Test to ensure that checkpoint is saved correctly when using faircale optimizer.""" + model = BoringFairScaleOptimizerModel() + trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_steps=1) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, "model.pt") + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Assert model parameters are identical after loading + for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(trained_param.to("cpu"), loaded_param) + + +class BoringZeroRedundancyOptimizerModel(BoringModel): + def configure_optimizers(self): + return ZeroRedundancyOptimizer(self.layer.parameters(), optimizer_class=torch.optim.Adam, lr=0.1) + + +@RunIf(min_cuda_gpus=2, skip_windows=True, min_torch="1.10") +@pytest.mark.parametrize("strategy", (pytest.param("ddp", marks=RunIf(standalone=True)), "ddp_spawn")) +def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(tmpdir, strategy): + """Test to ensure that checkpoint is saved correctly when using zero redundancy optimizer.""" + model = BoringZeroRedundancyOptimizerModel() + trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_steps=1) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, "model.pt") + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Assert model parameters are identical after loading + for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(trained_param.to("cpu"), loaded_param) diff --git a/tests/tests_pytorch/strategies/test_sharded_strategy.py b/tests/tests_pytorch/strategies/test_sharded_strategy.py index a0abfb3f73ec0..acefecbf4d2a7 100644 --- a/tests/tests_pytorch/strategies/test_sharded_strategy.py +++ b/tests/tests_pytorch/strategies/test_sharded_strategy.py @@ -14,6 +14,7 @@ if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel + from fairscale.optim import OSS @pytest.mark.parametrize("clip_val", [0, 10]) @@ -70,8 +71,8 @@ def test_ddp_sharded_strategy_checkpoint_cpu(tmpdir): saved_model = BoringModel.load_from_checkpoint(checkpoint_path) # Assert model parameters are identical after loading - for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()): - assert torch.equal(ddp_param.to("cpu"), shard_param) + for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(trained_param.to("cpu"), loaded_param) @RunIf(min_cuda_gpus=2, skip_windows=True, fairscale=True) @@ -87,8 +88,8 @@ def test_ddp_sharded_strategy_checkpoint_multi_gpu(tmpdir): saved_model = BoringModel.load_from_checkpoint(checkpoint_path) # Assert model parameters are identical after loading - for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()): - assert torch.equal(ddp_param.to("cpu"), shard_param) + for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(trained_param.to("cpu"), loaded_param) @RunIf(min_cuda_gpus=2, skip_windows=True, fairscale=True) @@ -314,3 +315,27 @@ def test_block_backward_sync(): def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs): trainer = Trainer(strategy=strategy_name) assert trainer.strategy._ddp_kwargs == expected_ddp_kwargs + + +class BoringFairScaleOptimizerModel(BoringModel): + def configure_optimizers(self): + base_optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return OSS(params=base_optimizer.param_groups, optim=type(base_optimizer), **base_optimizer.defaults) + + +@RunIf(min_cuda_gpus=2, skip_windows=True, fairscale=True) +@pytest.mark.parametrize("strategy", (pytest.param("ddp_sharded", marks=RunIf(standalone=True)), "ddp_sharded_spawn")) +def test_ddp_sharded_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, strategy): + """Test to ensure that checkpoint is saved correctly when using fairscale optimizers.""" + model = BoringFairScaleOptimizerModel() + trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_steps=1) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, "model.pt") + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Assert model parameters are identical after loading + for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(trained_param.to("cpu"), loaded_param)