diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index 3a57f7b9499bc..ca6d2be30faa5 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -352,6 +352,12 @@ Model layers should be wrapped in FSDP in a nested way to save peak memory and e simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't have to ``wrap`` layers manually as in the case of manual wrapping. +.. note:: + While initializing the optimizers inside ``configure_optimizers`` hook, make sure to use ``self.trainer.model.parameters()``, else + PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your + ``lightning_module.parameters()`` will return a generator with no params. This inconvenience will be addressed in the future. + + .. code-block:: python model = BoringModel() diff --git a/src/lightning_lite/strategies/fairscale.py b/src/lightning_lite/strategies/fairscale.py index f0157edbf336f..12895bcee5466 100644 --- a/src/lightning_lite/strategies/fairscale.py +++ b/src/lightning_lite/strategies/fairscale.py @@ -196,3 +196,9 @@ def no_backward_sync(self, module: Module) -> Generator: ) with module.no_sync(): yield None + + +def _optimizer_has_flat_params(optimizer: Optimizer) -> bool: + from fairscale.nn.misc.flatten_params_wrapper import FlatParameter + + return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"]) diff --git a/src/lightning_lite/strategies/fsdp_native.py b/src/lightning_lite/strategies/fsdp_native.py new file mode 100644 index 0000000000000..9e70400e476dc --- /dev/null +++ b/src/lightning_lite/strategies/fsdp_native.py @@ -0,0 +1,20 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torch.optim import Optimizer + + +def _optimizer_has_flat_params(optimizer: Optimizer) -> bool: + from torch.distributed.fsdp import FlatParameter + + return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"]) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 5c064c52962c4..bb4115995dda9 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -17,6 +17,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - +- Added a check to validate that wrapped FSDP models are used while initializing optimizers ([#15301](https://github.com/Lightning-AI/lightning/pull/15301)) + + ### Changed - From now on, Lightning Trainer and `LightningModule.load_from_checkpoint` automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning ([#15237](https://github.com/Lightning-AI/lightning/pull/15237)) diff --git a/src/pytorch_lightning/strategies/fully_sharded.py b/src/pytorch_lightning/strategies/fully_sharded.py index 4c4deb31433fb..9801cff27f9ae 100644 --- a/src/pytorch_lightning/strategies/fully_sharded.py +++ b/src/pytorch_lightning/strategies/fully_sharded.py @@ -19,7 +19,7 @@ import pytorch_lightning as pl from lightning_lite.plugins import CheckpointIO, ClusterEnvironment -from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE, _optimizer_has_flat_params from lightning_lite.utilities.enums import PrecisionType from lightning_lite.utilities.optimizer import _optimizers_to_device from pytorch_lightning.overrides.base import _LightningModuleWrapperBase @@ -28,7 +28,6 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.rank_zero import rank_zero_info from pytorch_lightning.utilities.types import STEP_OUTPUT if _FAIRSCALE_AVAILABLE: @@ -191,16 +190,27 @@ def setup(self, trainer: "pl.Trainer") -> None: self.setup_precision_plugin() + def setup_optimizers(self, trainer: "pl.Trainer") -> None: + invalid_params_error = False + try: + super().setup_optimizers(trainer) + except ValueError as e: + if "optimizer got an empty parameter list" not in str(e): + raise + invalid_params_error = True + + if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers): + raise ValueError( + "The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the" + " optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the" + " `configure_optimizers()` hook." + ) + def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel: """Wraps the model into a :class:`~fairscale.nn.data_parallel.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" log.detail(f"setting up `Fairscale FSDP` model with device id: {self.root_device.index}.") - rank_zero_info( - "When using FairScale FSDP auto-wrap, make sure to initialize your model using trainer: " - "`torch.optim.Optimizer(self.trainer.model.parameters(), ...)`" - ) - return FullyShardedDataParallel( module=model, process_group=self.process_group, diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 47011f9d4b7e0..c628f2a653a79 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -20,6 +20,7 @@ import pytorch_lightning as pl from lightning_lite.plugins import CheckpointIO, ClusterEnvironment +from lightning_lite.strategies.fsdp_native import _optimizer_has_flat_params from lightning_lite.utilities.distributed import ( _get_default_process_group_backend_for_device, _init_dist_connection, @@ -215,6 +216,7 @@ def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel: del self.kwargs["auto_wrap_policy"] log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}") + return FullyShardedDataParallel( module=model, process_group=self.process_group, @@ -255,6 +257,22 @@ def setup(self, trainer: "pl.Trainer") -> None: self.setup_precision_plugin() + def setup_optimizers(self, trainer: "pl.Trainer") -> None: + invalid_params_error = False + try: + super().setup_optimizers(trainer) + except ValueError as e: + if "optimizer got an empty parameter list" not in str(e): + raise + invalid_params_error = True + + if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers): + raise ValueError( + "The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the" + " optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the" + " `configure_optimizers()` hook." + ) + def model_to_device(self) -> None: pass diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index be8bced2cbf5f..4c1e1a306d540 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -18,46 +18,6 @@ from torch.distributed.fsdp.wrap import wrap -def custom_auto_wrap_policy( - module, - recurse, - unwrapped_params: int, - min_num_params: int = int(1e8), -) -> bool: - return unwrapped_params >= 2 - - -@RunIf(min_torch="1.12") -def test_invalid_on_cpu(tmpdir): - """Test to ensure that we raise Misconfiguration for Native FSDP on CPU.""" - with pytest.raises( - MisconfigurationException, - match=f"You selected strategy to be `{DDPFullyShardedNativeStrategy.strategy_name}`, " - "but GPU accelerator is not used.", - ): - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp_native") - assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy) - trainer.strategy.setup_environment() - - -@RunIf(min_torch="1.12", min_cuda_gpus=1) -@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)]) -def test_precision_plugin_config(precision, expected): - plugin = FullyShardedNativeNativeMixedPrecisionPlugin(precision=precision, device="cuda") - config = plugin.mixed_precision_config - assert config.param_dtype == expected - assert config.buffer_dtype == expected - assert config.reduce_dtype == expected - - -@RunIf(min_torch="1.12") -def test_fsdp_custom_mixed_precision(tmpdir): - """Test to ensure that passing a custom mixed precision config works.""" - config = MixedPrecision() - strategy = DDPFullyShardedNativeStrategy(mixed_precision=config) - assert strategy.mixed_precision_config == config - - class TestFSDPModel(BoringModel): def __init__(self): super().__init__() @@ -154,6 +114,80 @@ def _assert_layer_fsdp_instance(self) -> None: assert self.layer[layer_num].mixed_precision.buffer_dtype == precision +def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): + trainer.fit(model) + model_path = trainer.strategy.broadcast(model_path) + model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path + + trainer.save_checkpoint(model_path, weights_only=True) + + _assert_save_equality(trainer, model_path, cls=model.__class__) + + # Test entry point + trainer.test(model) # model is wrapped, will not call `configure_sharded_model` + + # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap + trainer.test(ckpt_path=model_path) + + # Predict entry point + trainer.predict(model) # model is wrapped, will not call `configure_sharded_model` + + # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap + trainer.predict(ckpt_path=model_path) + + +def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel): + # Use FullySharded to get the state dict for the sake of comparison + model_state_dict = trainer.strategy.lightning_module_state_dict() + + if trainer.is_global_zero: + saved_model = cls.load_from_checkpoint(ckpt_path) + + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()): + assert torch.equal(ddp_param.float().cpu(), shard_param) + + +def custom_auto_wrap_policy( + module, + recurse, + unwrapped_params: int, + min_num_params: int = int(1e8), +) -> bool: + return unwrapped_params >= 2 + + +@RunIf(min_torch="1.12") +def test_invalid_on_cpu(tmpdir): + """Test to ensure that we raise Misconfiguration for Native FSDP on CPU.""" + with pytest.raises( + MisconfigurationException, + match=f"You selected strategy to be `{DDPFullyShardedNativeStrategy.strategy_name}`, " + "but GPU accelerator is not used.", + ): + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp_native") + assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy) + trainer.strategy.setup_environment() + + +@RunIf(min_torch="1.12", min_cuda_gpus=1) +@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)]) +def test_precision_plugin_config(precision, expected): + plugin = FullyShardedNativeNativeMixedPrecisionPlugin(precision=precision, device="cuda") + config = plugin.mixed_precision_config + assert config.param_dtype == expected + assert config.buffer_dtype == expected + assert config.reduce_dtype == expected + + +@RunIf(min_torch="1.12") +def test_fsdp_custom_mixed_precision(tmpdir): + """Test to ensure that passing a custom mixed precision config works.""" + config = MixedPrecision() + strategy = DDPFullyShardedNativeStrategy(mixed_precision=config) + assert strategy.mixed_precision_config == config + + @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir): """Test to ensure that sync_batchnorm works when using fsdp_native and GPU, and all stages can be run.""" @@ -214,35 +248,23 @@ def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir, model, stra _run_multiple_stages(trainer, model) -def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): - trainer.fit(model) - model_path = trainer.strategy.broadcast(model_path) - model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path - - trainer.save_checkpoint(model_path, weights_only=True) - - _assert_save_equality(trainer, model_path, cls=model.__class__) - - # Test entry point - trainer.test(model) # model is wrapped, will not call `configure_sharded_model` - - # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap - trainer.test(ckpt_path=model_path) +@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") +def test_invalid_parameters_in_optimizer(tmpdir): + trainer = Trainer(strategy="fsdp_native", accelerator="cuda", devices=1) - # Predict entry point - trainer.predict(model) # model is wrapped, will not call `configure_sharded_model` + class EmptyParametersModel(BoringModel): + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-2) - # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap - trainer.predict(ckpt_path=model_path) + model = EmptyParametersModel() + with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"): + trainer.fit(model) + class NoFlatParametersModel(BoringModel): + def configure_optimizers(self): + layer = torch.nn.Linear(4, 5) + return torch.optim.Adam(layer.parameters(), lr=1e-2) -def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel): - # Use FullySharded to get the state dict for the sake of comparison - model_state_dict = trainer.strategy.lightning_module_state_dict() - - if trainer.is_global_zero: - saved_model = cls.load_from_checkpoint(ckpt_path) - - # Assert model parameters are identical after loading - for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()): - assert torch.equal(ddp_param.float().cpu(), shard_param) + model = NoFlatParametersModel() + with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"): + trainer.fit(model) diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py index 5043d3a8c4aa3..ba77fc561db1f 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py @@ -18,27 +18,6 @@ from fairscale.nn import FullyShardedDataParallel, wrap -def test_invalid_on_cpu(tmpdir): - """Test to ensure that to raise Misconfiguration for FSDP on CPU.""" - with pytest.raises( - MisconfigurationException, match="You selected strategy to be `ddp_fully_sharded`, but GPU is not available." - ): - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp") - assert isinstance(trainer.strategy, DDPFullyShardedStrategy) - trainer.strategy.setup_environment() - - -@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) -@RunIf(fairscale=True) -def test_fsdp_with_sharded_amp(cuda_count_1, tmpdir): - """Test to ensure that plugin native amp plugin is correctly chosen when using sharded.""" - trainer = Trainer( - default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp", accelerator="gpu", devices=1, precision=16 - ) - assert isinstance(trainer.strategy, DDPFullyShardedStrategy) - assert isinstance(trainer.strategy.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) - - class TestFSDPModelManualWrapped(BoringModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -123,6 +102,72 @@ def _assert_layer_fsdp_instance(self) -> None: assert self.trainer.model.mixed_precision +def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModelManualWrapped): + # Use FullySharded to get the state dict for the sake of comparison + model_state_dict = trainer.strategy.lightning_module_state_dict() + + if trainer.is_global_zero: + saved_model = cls.load_from_checkpoint(ckpt_path) + + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()): + assert torch.equal(ddp_param.float().cpu(), shard_param) + + +def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): + trainer.fit(model) + + model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path + + trainer.save_checkpoint(model_path, weights_only=True) + + _assert_save_equality(trainer, model_path, cls=model.__class__) + + # Test entry point + if model.__class__ is TestFSDPModelAutoWrapped: + model = TestFSDPModelAutoWrapped() + trainer.test(model) # model is wrapped, will not call configure_shared_model + + # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap + if model.__class__ is TestFSDPModelAutoWrapped: + model = TestFSDPModelAutoWrapped() + trainer.test(model, ckpt_path=model_path) + + # Predict entry point + if model.__class__ is TestFSDPModelAutoWrapped: + model = TestFSDPModelAutoWrapped() + + if model.__class__ is TestFSDPModelAutoWrapped: + model = TestFSDPModelAutoWrapped() + trainer.predict(model) # model is wrapped, will not call `configure_sharded_model` + + # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap + if model.__class__ is TestFSDPModelAutoWrapped: + model = TestFSDPModelAutoWrapped() + trainer.predict(model, ckpt_path=model_path) + + +def test_invalid_on_cpu(tmpdir): + """Test to ensure that to raise Misconfiguration for FSDP on CPU.""" + with pytest.raises( + MisconfigurationException, match="You selected strategy to be `ddp_fully_sharded`, but GPU is not available." + ): + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp") + assert isinstance(trainer.strategy, DDPFullyShardedStrategy) + trainer.strategy.setup_environment() + + +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) +@RunIf(fairscale=True) +def test_fsdp_with_sharded_amp(cuda_count_1, tmpdir): + """Test to ensure that plugin native amp plugin is correctly chosen when using sharded.""" + trainer = Trainer( + default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp", accelerator="gpu", devices=1, precision=16 + ) + assert isinstance(trainer.strategy, DDPFullyShardedStrategy) + assert isinstance(trainer.strategy.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) + + @RunIf(min_cuda_gpus=1, standalone=True, fairscale=True) def test_fully_sharded_strategy_checkpoint(tmpdir): """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" @@ -171,51 +216,6 @@ def test_fully_sharded_strategy_checkpoint_multi_gpus(tmpdir, model, strategy): _run_multiple_stages(trainer, model) -def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModelManualWrapped): - # Use FullySharded to get the state dict for the sake of comparison - model_state_dict = trainer.strategy.lightning_module_state_dict() - - if trainer.is_global_zero: - saved_model = cls.load_from_checkpoint(ckpt_path) - - # Assert model parameters are identical after loading - for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()): - assert torch.equal(ddp_param.float().cpu(), shard_param) - - -def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): - trainer.fit(model) - - model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path - - trainer.save_checkpoint(model_path, weights_only=True) - - _assert_save_equality(trainer, model_path, cls=model.__class__) - - # Test entry point - if model.__class__ is TestFSDPModelAutoWrapped: - model = TestFSDPModelAutoWrapped() - trainer.test(model) # model is wrapped, will not call configure_shared_model - - # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap - if model.__class__ is TestFSDPModelAutoWrapped: - model = TestFSDPModelAutoWrapped() - trainer.test(model, ckpt_path=model_path) - - # Predict entry point - if model.__class__ is TestFSDPModelAutoWrapped: - model = TestFSDPModelAutoWrapped() - - if model.__class__ is TestFSDPModelAutoWrapped: - model = TestFSDPModelAutoWrapped() - trainer.predict(model) # model is wrapped, will not call `configure_sharded_model` - - # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap - if model.__class__ is TestFSDPModelAutoWrapped: - model = TestFSDPModelAutoWrapped() - trainer.predict(model, ckpt_path=model_path) - - @RunIf(min_cuda_gpus=1, standalone=True, fairscale=True) def test_fsdp_gradient_clipping_raises(tmpdir): """Test to ensure that an exception is raised when clipping gradients by value with FSDP.""" @@ -254,3 +254,25 @@ def test_fsdp_rewrap_limitation(tmpdir): with pytest.raises(MisconfigurationException, match="Using the same instance of model .* not supported"): trainer.test(model) + + +@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, fairscale_fully_sharded=True) +def test_invalid_parameters_in_optimizer(tmpdir): + trainer = Trainer(strategy="fsdp", accelerator="gpu", devices=1) + + class EmptyParametersModel(BoringModel): + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-2) + + model = EmptyParametersModel() + with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"): + trainer.fit(model) + + class NoFlatParametersModel(BoringModel): + def configure_optimizers(self): + layer = torch.nn.Linear(4, 5) + return torch.optim.Adam(layer.parameters(), lr=1e-2) + + model = NoFlatParametersModel() + with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"): + trainer.fit(model)