From 1c2d9ebb83ef5e556b7727cc2cf4ee7daf7c4699 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 22 Aug 2022 11:39:21 -0400 Subject: [PATCH 01/11] Add autowrapping support for Fairscale FSDP --- .../strategies/fully_sharded.py | 58 ++++++++++------ .../tuner/batch_size_scaling.py | 8 +++ ..._ddp_fully_sharded_with_full_state_dict.py | 67 ++++++++++++++++--- 3 files changed, 104 insertions(+), 29 deletions(-) diff --git a/src/pytorch_lightning/strategies/fully_sharded.py b/src/pytorch_lightning/strategies/fully_sharded.py index 239e4844b146e..b87fcef0fbcf1 100644 --- a/src/pytorch_lightning/strategies/fully_sharded.py +++ b/src/pytorch_lightning/strategies/fully_sharded.py @@ -18,6 +18,7 @@ import torch import pytorch_lightning as pl +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -27,7 +28,7 @@ from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.optimizer import optimizers_to_device -from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, TrainingStep, ValidationStep +from pytorch_lightning.utilities.types import STEP_OUTPUT if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: from fairscale.nn import default_auto_wrap_policy, enable_wrap @@ -144,17 +145,38 @@ def setup(self, trainer: "pl.Trainer") -> None: self.accelerator.setup(trainer) if trainer.state.fn == TrainerFn.FITTING: - self.setup_optimizers(trainer) - optimizers_to_device(self.optimizers, self.root_device) - if self._layer_sync: assert self.model self.model = self._layer_sync.apply(self.model) - self.setup_precision_plugin() self.configure_ddp() + self.model = _LightningModuleWrapperBase(self.model) + self.model = self._setup_model(self.model) + self.setup_optimizers(self.lightning_module.trainer) + optimizers_to_device(self.optimizers, self.root_device) self.barrier() + self.setup_precision_plugin() + + def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel: + """Wraps the model into a + :class:`~fairscale.nn.data_parallel.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" + log.detail(f"setting up `Fairscale FSDP` model with device id: {self.root_device.index}.") + + return FullyShardedDataParallel( + module=model, + process_group=self.process_group, + cpu_offload=self.cpu_offload, + move_grads_to_cpu=self.move_grads_to_cpu, + flatten_parameters=self.flatten_parameters, + mixed_precision=(self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF)), + reshard_after_forward=self.reshard_after_forward, + fp32_reduce_scatter=self.fp32_reduce_scatter, + compute_dtype=self.compute_dtype, + bucket_cap_mb=self.bucket_cap_mb, + state_dict_device=self.state_dict_device, + ) + @contextlib.contextmanager def model_sharded_context(self) -> Generator: log.detail(f"{self.__class__.__name__}: entered model_sharded_context.") @@ -190,10 +212,6 @@ def configure_ddp(self) -> None: # (TODO: need to figure out solution) self.model_to_device() - # setup optimizers after fully sharded has wrapped the lightning module - assert self.lightning_module - self.setup_optimizers(self.lightning_module.trainer) - def model_to_device(self) -> None: log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...") # ensure we update the device type in the lightning module @@ -201,24 +219,22 @@ def model_to_device(self) -> None: self.lightning_module.to(self.root_device) def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: - with self.precision_plugin.train_step_context(): - assert isinstance(self.model, TrainingStep) - return self.model.training_step(*args, **kwargs) + # we don't need precision context since casting is done by FSDP + # read `mixed_precision` docstring here: https://pytorch.org/docs/stable/fsdp.html + assert self.model is not None + return self.model(*args, **kwargs) def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: - with self.precision_plugin.val_step_context(): - assert isinstance(self.model, ValidationStep) - return self.model.validation_step(*args, **kwargs) + assert self.model is not None + return self.model(*args, **kwargs) def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: - with self.precision_plugin.test_step_context(): - assert isinstance(self.model, TestStep) - return self.model.test_step(*args, **kwargs) + assert self.model is not None + return self.model(*args, **kwargs) def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: - with self.precision_plugin.predict_step_context(): - assert isinstance(self.model, PredictStep) - return self.model.predict_step(*args, **kwargs) + assert self.model is not None + return self.model(*args, **kwargs) def post_training_step(self) -> None: pass diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index a1f8a2de4b9d8..a8c845f162184 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -143,6 +143,9 @@ def _run_power_scaling( # If we fail in power mode, half the size and return garbage_collection_cuda() new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc="failed") + # Force the train dataloader to reset as the batch size has changed + trainer.reset_train_dataloader(model) + trainer.reset_val_dataloader(model) break else: raise # some other error not memory related @@ -203,7 +206,12 @@ def _run_binsearch_scaling( garbage_collection_cuda() high = new_size midval = (high + low) // 2 + new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="failed") + # Force the train dataloader to reset as the batch size has changed + trainer.reset_train_dataloader(model) + trainer.reset_val_dataloader(model) + if high - low <= 1: break else: diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py index 2790f014c7212..1490878700ff0 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py @@ -41,7 +41,7 @@ def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): assert isinstance(trainer.strategy.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) -class TestFSDPModel(BoringModel): +class TestFSDPModelManualWrapped(BoringModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.layer: Optional[torch.nn.Module] = None @@ -96,11 +96,46 @@ def _assert_layer_fsdp_instance(self) -> None: assert self.layer.module[2].mixed_precision +class TestFSDPModelAutoWrapped(BoringModel): + def __init__(self): + super().__init__() + self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + + def configure_optimizers(self): + return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1) + + def on_train_start(self) -> None: + self._assert_layer_fsdp_instance() + + def on_test_start(self) -> None: + self._assert_layer_fsdp_instance() + + def on_validation_start(self) -> None: + self._assert_layer_fsdp_instance() + + def on_prediction_start(self) -> None: + self._assert_layer_fsdp_instance() + + def _assert_layer_fsdp_instance(self) -> None: + assert isinstance(self.layer, FullyShardedDataParallel) + assert isinstance(self.layer.module[0], FullyShardedDataParallel) + assert isinstance(self.layer.module[2], FullyShardedDataParallel) + + # Assert that the nested layers are set reshard_after_forward to True + assert self.layer.module[0].reshard_after_forward is True + assert self.layer.module[2].reshard_after_forward is True + + if isinstance(self.trainer.precision_plugin, FullyShardedNativeMixedPrecisionPlugin): + assert self.layer.mixed_precision + assert self.layer.module[0].mixed_precision + assert self.layer.module[2].mixed_precision + + @RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, fairscale_fully_sharded=True) def test_fully_sharded_strategy_checkpoint(tmpdir): """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" - model = TestFSDPModel() + model = TestFSDPModelManualWrapped() trainer = Trainer( default_root_dir=tmpdir, accelerator="gpu", @@ -115,18 +150,28 @@ def test_fully_sharded_strategy_checkpoint(tmpdir): @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, fairscale_fully_sharded=True) -def test_fully_sharded_strategy_checkpoint_multi_gpus(tmpdir): +@pytest.mark.parametrize( + "model, strategy", + [ + (TestFSDPModelManualWrapped(), DDPFullyShardedStrategy(min_num_params=2)), + (TestFSDPModelAutoWrapped(), "fsdp"), + ], +) +def test_fully_sharded_strategy_checkpoint_multi_gpus(tmpdir, model, strategy): """Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run.""" - model = TestFSDPModel() ck = ModelCheckpoint(save_last=True) trainer = Trainer( default_root_dir=tmpdir, accelerator="gpu", devices=2, - strategy="fsdp", + strategy=strategy, precision=16, max_epochs=1, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + limit_predict_batches=2, callbacks=[ck], enable_progress_bar=False, enable_model_summary=False, @@ -134,7 +179,7 @@ def test_fully_sharded_strategy_checkpoint_multi_gpus(tmpdir): _run_multiple_stages(trainer, model) -def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel): +def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModelManualWrapped): # Use FullySharded to get the state dict for the sake of comparison model_state_dict = trainer.strategy.lightning_module_state_dict() @@ -153,14 +198,20 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): trainer.save_checkpoint(model_path, weights_only=True) - _assert_save_equality(trainer, model_path, cls=TestFSDPModel) + _assert_save_equality(trainer, model_path, cls=model.__class__) # Test entry point trainer.test(model) # model is wrapped, will not call configure_shared_model - # provide model path, will create a new unwrapped model and load and then call configure_shared_model to wrap + # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap trainer.test(ckpt_path=model_path) + # Predict entry point + trainer.predict(model) # model is wrapped, will not call `configure_sharded_model` + + # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap + trainer.predict(ckpt_path=model_path) + @RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, fairscale_fully_sharded=True) def test_fsdp_gradient_clipping_raises(tmpdir): From d51c53b418db7223d5fddaa9287e7e762076e5b6 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Wed, 24 Aug 2022 17:46:36 -0400 Subject: [PATCH 02/11] update --- .../strategies/fully_sharded.py | 11 +++++- src/pytorch_lightning/strategies/strategy.py | 2 +- ..._ddp_fully_sharded_with_full_state_dict.py | 34 ++++++++----------- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/src/pytorch_lightning/strategies/fully_sharded.py b/src/pytorch_lightning/strategies/fully_sharded.py index b87fcef0fbcf1..a803be7f2f1aa 100644 --- a/src/pytorch_lightning/strategies/fully_sharded.py +++ b/src/pytorch_lightning/strategies/fully_sharded.py @@ -37,6 +37,11 @@ log = logging.getLogger(__name__) +class _DDPFullyShardedStrategyModuleWrapper(_LightningModuleWrapperBase): + def state_dict(self, *args: Any, **kwargs: Any): + return self._forward_module.state_dict(*args, **kwargs) + + class DDPFullyShardedStrategy(DDPStrategy): strategy_name = "ddp_fully_sharded" @@ -133,6 +138,10 @@ def process_group(self) -> Any: self._process_group = torch.distributed.new_group() return self._process_group + def lightning_module_state_dict(self) -> Dict[str, Any]: + """Returns model state.""" + return self.model.module.state_dict() + def setup_distributed(self) -> None: if not self.root_device.type == "cuda": raise MisconfigurationException( @@ -150,7 +159,7 @@ def setup(self, trainer: "pl.Trainer") -> None: self.model = self._layer_sync.apply(self.model) self.configure_ddp() - self.model = _LightningModuleWrapperBase(self.model) + self.model = _DDPFullyShardedStrategyModuleWrapper(self.model) self.model = self._setup_model(self.model) self.setup_optimizers(self.lightning_module.trainer) optimizers_to_device(self.optimizers, self.root_device) diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index 0abc5fe516273..a3f78ee98c415 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -443,7 +443,7 @@ def handles_gradient_accumulation(self) -> bool: """Whether the plugin handles gradient accumulation internally.""" return False - def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: + def lightning_module_state_dict(self) -> Dict[str, Any]: """Returns model state.""" assert self.lightning_module is not None return self.lightning_module.state_dict() diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py index 1490878700ff0..615602dd0a6bf 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py @@ -69,16 +69,16 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: def configure_optimizers(self): return torch.optim.SGD(self.layer.parameters(), lr=0.1) - def on_train_start(self) -> None: + def on_train_batch_end(self, *_, **__) -> None: self._assert_layer_fsdp_instance() - def on_test_start(self) -> None: + def on_test_batch_end(self, *_, **__) -> None: self._assert_layer_fsdp_instance() - def on_validation_start(self) -> None: + def on_validation_batch_end(self, *_, **__) -> None: self._assert_layer_fsdp_instance() - def on_prediction_start(self) -> None: + def on_prediction_batch_end(self, *_, **__) -> None: self._assert_layer_fsdp_instance() def _assert_layer_fsdp_instance(self) -> None: @@ -87,8 +87,8 @@ def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.layer.module[2], FullyShardedDataParallel) # Assert that the nested layers are set reshard_after_forward to True - assert self.layer.module[0].reshard_after_forward is True - assert self.layer.module[2].reshard_after_forward is True + assert self.layer.module[0].reshard_after_forward + assert self.layer.module[2].reshard_after_forward if isinstance(self.trainer.precision_plugin, FullyShardedNativeMixedPrecisionPlugin): assert self.layer.mixed_precision @@ -104,31 +104,25 @@ def __init__(self): def configure_optimizers(self): return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1) - def on_train_start(self) -> None: + def on_train_batch_end(self, *_, **__) -> None: self._assert_layer_fsdp_instance() - def on_test_start(self) -> None: + def on_test_batch_end(self, *_, **__) -> None: self._assert_layer_fsdp_instance() - def on_validation_start(self) -> None: + def on_validation_batch_end(self, *_, **__) -> None: self._assert_layer_fsdp_instance() - def on_prediction_start(self) -> None: + def on_prediction_batch_end(self, *_, **__) -> None: self._assert_layer_fsdp_instance() def _assert_layer_fsdp_instance(self) -> None: - assert isinstance(self.layer, FullyShardedDataParallel) - assert isinstance(self.layer.module[0], FullyShardedDataParallel) - assert isinstance(self.layer.module[2], FullyShardedDataParallel) - - # Assert that the nested layers are set reshard_after_forward to True - assert self.layer.module[0].reshard_after_forward is True - assert self.layer.module[2].reshard_after_forward is True + assert isinstance(self.trainer.model, FullyShardedDataParallel) + # `disable_reshard_on_root=True` (default) in FSDP which turns-off resharding + assert not self.trainer.model.reshard_after_forward if isinstance(self.trainer.precision_plugin, FullyShardedNativeMixedPrecisionPlugin): - assert self.layer.mixed_precision - assert self.layer.module[0].mixed_precision - assert self.layer.module[2].mixed_precision + assert self.trainer.model.mixed_precision @RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, fairscale_fully_sharded=True) From 7985f64a0632754797061825d75fdfbc3bc560b5 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 25 Aug 2022 11:55:27 -0400 Subject: [PATCH 03/11] fix auto wrap due to limitations --- .../strategies/fully_sharded.py | 31 +++++++++++++++-- ..._ddp_fully_sharded_with_full_state_dict.py | 33 +++++++++++++++++-- 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/src/pytorch_lightning/strategies/fully_sharded.py b/src/pytorch_lightning/strategies/fully_sharded.py index a803be7f2f1aa..e57bbf418d063 100644 --- a/src/pytorch_lightning/strategies/fully_sharded.py +++ b/src/pytorch_lightning/strategies/fully_sharded.py @@ -27,7 +27,9 @@ from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.optimizer import optimizers_to_device +from pytorch_lightning.utilities.rank_zero import rank_zero_info from pytorch_lightning.utilities.types import STEP_OUTPUT if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: @@ -38,7 +40,7 @@ class _DDPFullyShardedStrategyModuleWrapper(_LightningModuleWrapperBase): - def state_dict(self, *args: Any, **kwargs: Any): + def state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: # type: ignore[override] return self._forward_module.state_dict(*args, **kwargs) @@ -140,7 +142,22 @@ def process_group(self) -> Any: def lightning_module_state_dict(self) -> Dict[str, Any]: """Returns model state.""" - return self.model.module.state_dict() + assert self.model is not None + return self.model.state_dict() + + def connect(self, model: "pl.LightningModule") -> None: + """Called by the accelerator to connect the accelerator and the model with this plugin.""" + # TODO: Wait for this issue to resolve and remove this blocker + # https://github.com/facebookresearch/fairscale/issues/648 + # Also make sure to update the tests + if not is_overridden("configure_sharded_model", self.lightning_module) and len(list(model.parameters())) == 0: + assert self.lightning_module is not None + raise MisconfigurationException( + f"Using the same instance of model with `trainer.{self.lightning_module.trainer.state.fn}()` is not" + " supported with Fairscale FSDP auto-wrap. Please reinitialize your `LightningModule` and pass that." + ) + + super().connect(model) def setup_distributed(self) -> None: if not self.root_device.type == "cuda": @@ -159,8 +176,11 @@ def setup(self, trainer: "pl.Trainer") -> None: self.model = self._layer_sync.apply(self.model) self.configure_ddp() + assert isinstance(self.model, pl.LightningModule) self.model = _DDPFullyShardedStrategyModuleWrapper(self.model) - self.model = self._setup_model(self.model) + assert self.lightning_module is not None + if not is_overridden("configure_sharded_model", self.lightning_module): + self.model = self._setup_model(self.model) self.setup_optimizers(self.lightning_module.trainer) optimizers_to_device(self.optimizers, self.root_device) self.barrier() @@ -172,6 +192,11 @@ def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel: :class:`~fairscale.nn.data_parallel.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" log.detail(f"setting up `Fairscale FSDP` model with device id: {self.root_device.index}.") + rank_zero_info( + "When using FairScale FSDP auto-wrap, make sure to initalize your model using trainer else" + " you will get an error.\ntorch.optim.Optimizer(self.trainer.model.parameters(), ...)" + ) + return FullyShardedDataParallel( module=model, process_group=self.process_group, diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py index 615602dd0a6bf..0b825cfe013c6 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py @@ -195,16 +195,27 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): _assert_save_equality(trainer, model_path, cls=model.__class__) # Test entry point + if model.__class__ is TestFSDPModelAutoWrapped: + model = TestFSDPModelAutoWrapped() trainer.test(model) # model is wrapped, will not call configure_shared_model # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap - trainer.test(ckpt_path=model_path) + if model.__class__ is TestFSDPModelAutoWrapped: + model = TestFSDPModelAutoWrapped() + trainer.test(model, ckpt_path=model_path) # Predict entry point + if model.__class__ is TestFSDPModelAutoWrapped: + model = TestFSDPModelAutoWrapped() + + if model.__class__ is TestFSDPModelAutoWrapped: + model = TestFSDPModelAutoWrapped() trainer.predict(model) # model is wrapped, will not call `configure_sharded_model` # provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap - trainer.predict(ckpt_path=model_path) + if model.__class__ is TestFSDPModelAutoWrapped: + model = TestFSDPModelAutoWrapped() + trainer.predict(model, ckpt_path=model_path) @RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, fairscale_fully_sharded=True) @@ -227,3 +238,21 @@ def test_fsdp_gradient_clipping_raises(tmpdir): MisconfigurationException, match="gradient_clip_algorithm='norm'` is currently not supported for `FullySharded" ): trainer.fit(model) + + +@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, fairscale_fully_sharded=True) +def test_fsdp_rewrap_limitation(tmpdir): + trainer = Trainer( + default_root_dir=tmpdir, + accelerator="gpu", + devices=1, + max_steps=1, + limit_val_batches=0, + limit_test_batches=1, + strategy="fsdp", + ) + model = TestFSDPModelAutoWrapped() + trainer.fit(model) + + with pytest.raises(MisconfigurationException, match="Using the same instance of model .* not supported"): + trainer.test(model) From 9754f3247113a99e24385873f410ff7336a0fd84 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 25 Aug 2022 11:55:42 -0400 Subject: [PATCH 04/11] update docs --- .../advanced/model_parallel.rst | 85 ++++++++++++------- 1 file changed, 56 insertions(+), 29 deletions(-) diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index 50ae2cd2827d0..022d2a7fbe9fe 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -1,7 +1,8 @@ .. _model-parallel: +################################## Train 1 trillion+ parameter models -================================== +################################## When training large models, fitting larger batch sizes, or trying to increase throughput using multi-GPU compute, Lightning provides advanced optimized distributed training strategies to support these cases and offer substantial improvements in memory usage. @@ -19,8 +20,9 @@ Check out this amazing video explaining model parallelism and how it works behin allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen> +********************************************* Choosing an Advanced Distributed GPU Strategy -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +********************************************* If you would like to stick with PyTorch DDP, see :ref:`ddp-optimizations`. @@ -29,7 +31,7 @@ Unlike :class:`~torch.nn.parallel.DistributedDataParallel` (DDP) where the maxim There are many considerations when choosing a strategy as described below. In addition, check out the visualization of various strategy benchmarks using `minGPT `__ `here `__. Pre-training vs Fine-tuning -""""""""""""""""""""""""""" +=========================== When fine-tuning, we often use a magnitude less data compared to pre-training a model. This is important when choosing a distributed strategy as usually for pre-training, **we are compute-bound**. This means we cannot sacrifice throughput as much as if we were fine-tuning, because in fine-tuning the data requirement is smaller. @@ -45,7 +47,7 @@ For example when using 128 GPUs, you can **pre-train** large 10 to 20 Billion pa But for **fine-tuning** a model, you can reach 10 to 20 Billion parameter models using :ref:`deepspeed-zero-stage-3-offload` on a **single GPU**. This does come with a significant throughput hit, which needs to be weighed accordingly. When Shouldn't I use an Optimized Distributed Strategy? -""""""""""""""""""""""""""""""""""""""""""""""""""""""" +======================================================= Sharding techniques help when model sizes are fairly large; roughly 500M+ parameters is where we've seen benefits. However, in the following cases, we recommend sticking to ordinary distributed strategies * When your model is small (ResNet50 of around 80M Parameters), unless you are using unusually large batch sizes or inputs. @@ -55,8 +57,10 @@ Sharding techniques help when model sizes are fairly large; roughly 500M+ parame .. _sharded-training: +**************** Sharded Training -^^^^^^^^^^^^^^^^ +**************** + Lightning integration of optimizer sharded training provided by `FairScale `_. The technique can be found within `DeepSpeed ZeRO `_ and `ZeRO-2 `_, @@ -93,8 +97,9 @@ Internally we re-initialize your optimizers and shard them across your machines .. _fully-sharded-training: +******************************** FairScale Fully Sharded Training -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +******************************** .. warning:: FairScale Fully Sharded Training is in BETA and the API is subject to change. Please create an `issue `_ if you run into any problems. @@ -104,7 +109,7 @@ FairScale Fully Sharded Training Fully Sharded Training alleviates the need to worry about balancing layers onto specific devices using some form of pipe parallelism, and optimizes for distributed communication with minimal effort. Shard Parameters to Reach 10+ Billion Parameters -"""""""""""""""""""""""""""""""""""""""""""""""" +================================================ To reach larger parameter sizes and to be memory efficient, we have to shard parameters. There are various ways to enable this. @@ -114,9 +119,27 @@ To reach larger parameter sizes and to be memory efficient, we have to shard par This is a limitation of Fully Sharded Training that will be resolved in the future. Enabling Module Sharding for Maximum Memory Efficiency -"""""""""""""""""""""""""""""""""""""""""""""""""""""" +====================================================== -To activate parameter sharding, you must wrap your model using the ``wrap`` or ``auto_wrap`` functions. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` and ``auto_wrap`` parameters are passed correctly. +Auto Wrapping +------------- + +Model layers should be wrapped in FSDP in a nested way to save peak memory and enable communication and computation overlapping. The +simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't +have to ``wrap`` layers manually as in the case of manual wrapping. + +.. code-block:: python + + model = BoringModel() + trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp", precision=16) + trainer.fit(model) + + +Manual Wrapping +--------------- + +Manual wrapping can be useful to explore complex sharding strategies by applying ``wrap`` selectively to some parts of the model. To activate +parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly. When not using Fully Sharded Training these wrap functions are a no-op. That means once the changes have been made, there is no need to remove the changes for other strategies. @@ -179,7 +202,7 @@ Here's an example using both ``wrap`` and ``auto_wrap`` to create your model: .. _fairscale-activation-checkpointing: FairScale Activation Checkpointing -"""""""""""""""""""""""""""""""""" +================================== Activation checkpointing frees activations from memory as soon as they are not needed during the forward pass. They are then re-computed for the backwards pass as needed. Activation checkpointing is very useful when you have intermediate layers that produce large activations. @@ -208,8 +231,9 @@ This saves memory when training larger models, however it requires wrapping modu .. _fully-sharded-native-training: +****************************** PyTorch Fully Sharded Training -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +****************************** PyTorch has it's own version of `FSDP `_ which is upstreamed from their `fairscale `__ project. It was introduced in their `v1.11.0 release `_ but it is recommended to use it with PyTorch v1.12 or more and that's what @@ -301,8 +325,9 @@ Check out `this tutorial `_ if you run into any issues. @@ -343,7 +368,7 @@ If you run into an issue with the install or later in training, ensure that the .. _deepspeed-zero-stage-1: DeepSpeed ZeRO Stage 1 -"""""""""""""""""""""" +====================== `DeepSpeed ZeRO Stage 1 `_ partitions your optimizer states (Stage 1) across your GPUs to reduce memory. @@ -361,7 +386,7 @@ It is recommended to skip Stage 1 and use Stage 2, which comes with larger memor .. _deepspeed-zero-stage-2: DeepSpeed ZeRO Stage 2 -"""""""""""""""""""""" +====================== `DeepSpeed ZeRO Stage 2 `_ partitions your optimizer states (Stage 1) and your gradients (Stage 2) across your GPUs to reduce memory. In most cases, this is more efficient or at parity with DDP, primarily due to the optimized custom communications written by the DeepSpeed team. As a result, benefits can also be seen on a single GPU. Do note that the default bucket sizes allocate around ``3.6GB`` of VRAM to use during distributed communications, which can be tweaked when instantiating the strategy described in a few sections below. @@ -382,7 +407,7 @@ As a result, benefits can also be seen on a single GPU. Do note that the default .. _deepspeed-zero-stage-2-offload: DeepSpeed ZeRO Stage 2 Offload -"""""""""""""""""""""""""""""" +------------------------------ Below we show an example of running `ZeRO-Offload `_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption. @@ -452,7 +477,7 @@ For even more speed benefit, DeepSpeed offers an optimized CPU version of ADAM c .. _deepspeed-zero-stage-3: DeepSpeed ZeRO Stage 3 -"""""""""""""""""""""" +====================== DeepSpeed ZeRO Stage 3 shards the optimizer states, gradients and the model parameters (also optionally activations). Sharding model parameters and activations comes with an increase in distributed communication, however allows you to scale your models massively from one GPU to multiple GPUs. **The DeepSpeed team report the ability to fine-tune models with over 40B parameters on a single GPU and over 2 Trillion parameters on 512 GPUs.** For more information we suggest checking the `DeepSpeed ZeRO-3 Offload documentation `__. @@ -511,7 +536,7 @@ You can also use the Lightning Trainer to run predict or evaluate with DeepSpeed Shard Model Instantly to Reduce Initialization Time/Memory -"""""""""""""""""""""""""""""""""""""""""""""""""""""""""" +---------------------------------------------------------- When instantiating really large models, it is sometimes necessary to shard the model layers instantly. @@ -550,7 +575,7 @@ This reduces the time taken to initialize very large models, as well as ensure w .. _deepspeed-zero-stage-3-offload: DeepSpeed ZeRO Stage 3 Offload -"""""""""""""""""""""""""""""" +------------------------------ DeepSpeed ZeRO Stage 3 Offloads optimizer state, gradients to the host CPU to reduce memory usage as ZeRO Stage 2 does, however additionally allows you to offload the parameters as well for even more memory saving. @@ -584,7 +609,7 @@ DeepSpeed ZeRO Stage 3 Offloads optimizer state, gradients to the host CPU to re DeepSpeed Infinity (NVMe Offloading) -"""""""""""""""""""""""""""""""""""" +------------------------------------ Additionally, DeepSpeed supports offloading to NVMe drives for even larger models, utilizing the large memory space found in NVMes. DeepSpeed `reports `__ the ability to fine-tune 1 Trillion+ parameters using NVMe Offloading on one 8 GPU machine. Below shows how to enable this, assuming the NVMe drive is mounted in a directory called ``/local_nvme``. @@ -621,7 +646,7 @@ When offloading to NVMe you may notice that the speed is slow. There are paramet .. _deepspeed-activation-checkpointing: DeepSpeed Activation Checkpointing -"""""""""""""""""""""""""""""""""" +---------------------------------- Activation checkpointing frees activations from memory as soon as they are not needed during the forward pass. They are then re-computed for the backwards pass as needed. @@ -697,7 +722,7 @@ This saves memory when training larger models, however requires using a checkpoi .. _deepspeed-zero-stage-3-tips: DeepSpeed ZeRO Stage 3 Tips -""""""""""""""""""""""""""" +--------------------------- Here is some helpful information when setting up DeepSpeed ZeRO Stage 3 with Lightning. @@ -709,7 +734,7 @@ Here is some helpful information when setting up DeepSpeed ZeRO Stage 3 with Lig .. _deepspeed-zero-stage-3-single-file: Collating Single File Checkpoint for DeepSpeed ZeRO Stage 3 -""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" +----------------------------------------------------------- After training using ZeRO Stage 3, you'll notice that your checkpoints are a directory of sharded model and optimizer states. If you'd like to collate a single file from the checkpoint directory please use the below command, which handles all the Lightning states additionally when collating the file. @@ -728,7 +753,7 @@ After training using ZeRO Stage 3, you'll notice that your checkpoints are a dir This single file checkpoint does not include the optimizer/lr-scheduler states. This means we cannot restore training via the ``trainer.fit(ckpt_path=)`` call. Ensure to keep the sharded checkpoint directory if this is required. Custom DeepSpeed Config -""""""""""""""""""""""" +======================= In some cases you may want to define your own DeepSpeed Config, to access all parameters defined. We've exposed most of the important parameters, however, there may be debugging parameters to enable. Also, DeepSpeed allows the use of custom DeepSpeed optimizers and schedulers defined within a config file that is supported. @@ -801,12 +826,13 @@ You can use also use an environment variable via your PyTorch Lightning script: .. _ddp-optimizations: +***************** DDP Optimizations -^^^^^^^^^^^^^^^^^ +***************** When Using DDP Strategies, Set find_unused_parameters=False -""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" +=========================================================== By default, we have set ``find_unused_parameters=True`` for compatibility reasons that have been observed in the past (refer to the `discussion `_ for more details). When enabled, it can result in a performance hit and can be disabled in most cases. Read more about it `here `_. @@ -836,7 +862,7 @@ When enabled, it can result in a performance hit and can be disabled in most cas DDP Static Graph -"""""""""""""""" +================ `DDP static graph `__ assumes that your model employs the same set of used/unused parameters in every iteration, so that it can deterministically know the flow of @@ -854,7 +880,7 @@ training and apply special optimizations during runtime. When Using DDP on a Multi-node Cluster, Set NCCL Parameters -""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" +=========================================================== `NCCL `__ is the NVIDIA Collective Communications Library that is used by PyTorch to handle communication across nodes and GPUs. There are reported benefits in terms of speedups when adjusting NCCL parameters as seen in this `issue `__. In the issue, we see a 30% speed improvement when training the Transformer XLM-RoBERTa and a 15% improvement in training with Detectron2. @@ -875,7 +901,7 @@ NCCL parameters can be adjusted via environment variables. Gradients as Bucket View -"""""""""""""""""""""""" +======================== Enabling ``gradient_as_bucket_view=True`` in the ``DDPStrategy`` will make gradients views point to different offsets of the ``allreduce`` communication buckets. See :class:`~torch.nn.parallel.DistributedDataParallel` for more information. @@ -894,8 +920,9 @@ This can reduce peak memory usage and throughput as saved memory will be equal t trainer = Trainer(accelerator="gpu", devices=4, strategy=DDPStrategy(gradient_as_bucket_view=True)) trainer.fit(model) + DDP Communication Hooks -""""""""""""""""""""""" +======================= DDP Communication hooks is an interface to control how gradients are communicated across workers, overriding the standard allreduce in DistributedDataParallel. This allows you to enable performance improving communication hooks when using multiple nodes. From 4fd39ca276c6ae77c61244f9502c0a90e70fd317 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 25 Aug 2022 11:56:52 -0400 Subject: [PATCH 05/11] chlog --- src/pytorch_lightning/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 1776b917c761a..a843c627cc532 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for saving sharded optimizer state dict outside of `DDPShardedStrategy` ([#14208](https://github.com/PyTorchLightning/pytorch-lightning/pull/14208)) +Added support for auto wrapping for `DDPFullyShardedStrategy` ([#14383](https://github.com/Lightning-AI/lightning/issues/14383)) + + ### Changed From dcb601bb85977e98137cd07379116469fe7db0de Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 25 Aug 2022 23:09:54 +0530 Subject: [PATCH 06/11] fix --- .../advanced/model_parallel.rst | 23 +++++++++---------- .../strategies/fully_sharded.py | 2 ++ ..._ddp_fully_sharded_with_full_state_dict.py | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index 022d2a7fbe9fe..23116b7812a38 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -57,9 +57,9 @@ Sharding techniques help when model sizes are fairly large; roughly 500M+ parame .. _sharded-training: -**************** -Sharded Training -**************** +************************** +FairScale Sharded Training +************************** Lightning integration of optimizer sharded training provided by `FairScale `_. The technique can be found within `DeepSpeed ZeRO `_ and @@ -97,9 +97,8 @@ Internally we re-initialize your optimizers and shard them across your machines .. _fully-sharded-training: -******************************** -FairScale Fully Sharded Training -******************************** +Fully Sharded Training +====================== .. warning:: FairScale Fully Sharded Training is in BETA and the API is subject to change. Please create an `issue `_ if you run into any problems. @@ -109,7 +108,7 @@ FairScale Fully Sharded Training Fully Sharded Training alleviates the need to worry about balancing layers onto specific devices using some form of pipe parallelism, and optimizes for distributed communication with minimal effort. Shard Parameters to Reach 10+ Billion Parameters -================================================ +------------------------------------------------ To reach larger parameter sizes and to be memory efficient, we have to shard parameters. There are various ways to enable this. @@ -119,10 +118,10 @@ To reach larger parameter sizes and to be memory efficient, we have to shard par This is a limitation of Fully Sharded Training that will be resolved in the future. Enabling Module Sharding for Maximum Memory Efficiency -====================================================== +------------------------------------------------------ Auto Wrapping -------------- +^^^^^^^^^^^^^ Model layers should be wrapped in FSDP in a nested way to save peak memory and enable communication and computation overlapping. The simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't @@ -136,7 +135,7 @@ have to ``wrap`` layers manually as in the case of manual wrapping. Manual Wrapping ---------------- +^^^^^^^^^^^^^^^ Manual wrapping can be useful to explore complex sharding strategies by applying ``wrap`` selectively to some parts of the model. To activate parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly. @@ -201,8 +200,8 @@ Here's an example using both ``wrap`` and ``auto_wrap`` to create your model: .. _fairscale-activation-checkpointing: -FairScale Activation Checkpointing -================================== +Activation Checkpointing +------------------------ Activation checkpointing frees activations from memory as soon as they are not needed during the forward pass. They are then re-computed for the backwards pass as needed. Activation checkpointing is very useful when you have intermediate layers that produce large activations. diff --git a/src/pytorch_lightning/strategies/fully_sharded.py b/src/pytorch_lightning/strategies/fully_sharded.py index e57bbf418d063..4ceaace05b64e 100644 --- a/src/pytorch_lightning/strategies/fully_sharded.py +++ b/src/pytorch_lightning/strategies/fully_sharded.py @@ -35,6 +35,8 @@ if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: from fairscale.nn import default_auto_wrap_policy, enable_wrap from fairscale.nn.data_parallel import FullyShardedDataParallel +else: + FullyShardedDataParallel = None log = logging.getLogger(__name__) diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py index 0b825cfe013c6..46491707dd4fc 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py @@ -221,7 +221,7 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): @RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, fairscale_fully_sharded=True) def test_fsdp_gradient_clipping_raises(tmpdir): """Test to ensure that an exception is raised when clipping gradients by value with FSDP.""" - model = BoringModel() + model = TestFSDPModelManualWrapped() trainer = Trainer( default_root_dir=tmpdir, strategy="fsdp", From 23148912c163b20ea5b443a259fc998a780b497f Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 26 Aug 2022 00:48:03 +0530 Subject: [PATCH 07/11] fix test --- src/pytorch_lightning/callbacks/stochastic_weight_avg.py | 7 ++++--- .../tests_pytorch/callbacks/test_stochastic_weight_avg.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 6650bb3f0c479..ecb2c2b7155ae 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -25,6 +25,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.callback import Callback from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy +from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.types import _LRScheduler, LRSchedulerConfig @@ -144,6 +145,9 @@ def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool: return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules()) def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + if isinstance(trainer.strategy, (DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy, DeepSpeedStrategy)): + raise MisconfigurationException("SWA does not currently support sharded models.") + # copy the model before moving it to accelerator device. with pl_module._prevent_trainer_and_dataloaders_deepcopy(): self._average_model = deepcopy(pl_module) @@ -155,9 +159,6 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - if len(trainer.lr_scheduler_configs) > 1: raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.") - if isinstance(trainer.strategy, (DDPFullyShardedStrategy, DeepSpeedStrategy)): - raise MisconfigurationException("SWA does not currently support sharded models.") - if isinstance(self._swa_epoch_start, float): self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start) diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index 7f1692e30a3f2..a58d3baf3c6f0 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -348,6 +348,7 @@ def test_swa_resume_training_from_checkpoint_ddp(tmpdir): [ pytest.param("fsdp", marks=RunIf(fairscale_fully_sharded=True, min_cuda_gpus=1)), pytest.param("deepspeed", marks=RunIf(deepspeed=True, min_cuda_gpus=1)), + pytest.param("fsdp_native", marks=RunIf(min_cuda_gpus=1, skip_windows=True, min_torch="1.12")), ], ) def test_misconfiguration_error_with_sharded_model(tmpdir, strategy: str): From 42db8a0e825f464e170fcb0dff29b681cb79b04c Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 26 Aug 2022 19:59:06 +0530 Subject: [PATCH 08/11] update docs --- .../source-pytorch/advanced/model_parallel.rst | 18 ++++++++++++++---- docs/source-pytorch/extensions/strategy.rst | 4 ++-- .../tuner/batch_size_scaling.py | 8 -------- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index 23116b7812a38..d72bd8ad358af 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -97,8 +97,8 @@ Internally we re-initialize your optimizers and shard them across your machines .. _fully-sharded-training: -Fully Sharded Training -====================== +FairScale Fully Sharded Training +================================ .. warning:: FairScale Fully Sharded Training is in BETA and the API is subject to change. Please create an `issue `_ if you run into any problems. @@ -127,9 +127,19 @@ Model layers should be wrapped in FSDP in a nested way to save peak memory and e simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't have to ``wrap`` layers manually as in the case of manual wrapping. +.. note:: + While initializing the optimizers inside ``configure_optimizers`` hook, make sure to use ``self.trainer.model.parameters()``, else + PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your + ``lightning_module.parameters()`` will return a generator with no params. + .. code-block:: python - model = BoringModel() + class MyModel(BoringModel): + def configure_optimizers(self): + return torch.optim.AdamW(self.trainer.model.parameters(), lr=1e-2) + + + model = MyModel() trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp", precision=16) trainer.fit(model) @@ -186,7 +196,7 @@ Here's an example using both ``wrap`` and ``auto_wrap`` to create your model: self.model = nn.Sequential(linear_layer, nn.ReLU(), block, final_block) def configure_optimizers(self): - return torch.optim.AdamW(self.model.parameters()) + return torch.optim.AdamW(self.model.parameters(), lr=1e-2) model = MyModel() diff --git a/docs/source-pytorch/extensions/strategy.rst b/docs/source-pytorch/extensions/strategy.rst index ed39f68d45e23..21a6e8a8814b2 100644 --- a/docs/source-pytorch/extensions/strategy.rst +++ b/docs/source-pytorch/extensions/strategy.rst @@ -83,10 +83,10 @@ The below table lists all relevant strategies available in Lightning with their - Strategy for Fully Sharded Data Parallel provided by FairScale. :ref:`Learn more. ` * - ddp_sharded - :class:`~pytorch_lightning.strategies.DDPShardedStrategy` - - Optimizer and gradient sharded training provided by FairScale. :ref:`Learn more. ` + - Optimizer and gradient sharded training provided by FairScale. :ref:`Learn more. ` * - ddp_sharded_spawn - :class:`~pytorch_lightning.strategies.DDPSpawnShardedStrategy` - - Optimizer sharded training provided by FairScale. :ref:`Learn more. ` + - Optimizer sharded training provided by FairScale. :ref:`Learn more. ` * - ddp_spawn - :class:`~pytorch_lightning.strategies.DDPSpawnStrategy` - Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training finishes. :ref:`Learn more. ` diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index a8c845f162184..a1f8a2de4b9d8 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -143,9 +143,6 @@ def _run_power_scaling( # If we fail in power mode, half the size and return garbage_collection_cuda() new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc="failed") - # Force the train dataloader to reset as the batch size has changed - trainer.reset_train_dataloader(model) - trainer.reset_val_dataloader(model) break else: raise # some other error not memory related @@ -206,12 +203,7 @@ def _run_binsearch_scaling( garbage_collection_cuda() high = new_size midval = (high + low) // 2 - new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="failed") - # Force the train dataloader to reset as the batch size has changed - trainer.reset_train_dataloader(model) - trainer.reset_val_dataloader(model) - if high - low <= 1: break else: From 84adb6e1a042eace58cc340663dbb97fa4256ab5 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 26 Aug 2022 20:00:58 +0530 Subject: [PATCH 09/11] update docs --- docs/source-pytorch/advanced/model_parallel.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index d72bd8ad358af..49d5158dc899e 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -250,7 +250,8 @@ Lightning supports. The API is pretty similar to that of FairScale. Auto Wrapping -""""""""""""" +============= + Model layers should be wrapped in FSDP in a nested way to save peak memory and enable communication and computation overlapping. The simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't have to ``wrap`` layers manually as in the case of manual wrapping. @@ -266,7 +267,7 @@ Read more `here Date: Mon, 5 Sep 2022 16:09:19 +0530 Subject: [PATCH 10/11] update --- docs/source-pytorch/advanced/model_parallel.rst | 2 +- src/pytorch_lightning/strategies/fully_sharded.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index 49d5158dc899e..757b7dffa4580 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -130,7 +130,7 @@ have to ``wrap`` layers manually as in the case of manual wrapping. .. note:: While initializing the optimizers inside ``configure_optimizers`` hook, make sure to use ``self.trainer.model.parameters()``, else PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your - ``lightning_module.parameters()`` will return a generator with no params. + ``lightning_module.parameters()`` will return a generator with no params. This inconvenience will be addressed in the future. .. code-block:: python diff --git a/src/pytorch_lightning/strategies/fully_sharded.py b/src/pytorch_lightning/strategies/fully_sharded.py index 3ef1244f89397..55bbca95edfa4 100644 --- a/src/pytorch_lightning/strategies/fully_sharded.py +++ b/src/pytorch_lightning/strategies/fully_sharded.py @@ -43,6 +43,9 @@ class _DDPFullyShardedStrategyModuleWrapper(_LightningModuleWrapperBase): def state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: # type: ignore[override] + # this is required because with FSDP lightning_module is empty because weights are sharded. + # So we need to call self.trainer.model.state_dict (wrapped version) and use this wraper to + # avoid extra keys `_forward_module.layer.weight.` since we want `layer.weight.` in state_dict. return self._forward_module.state_dict(*args, **kwargs) @@ -152,10 +155,9 @@ def connect(self, model: "pl.LightningModule") -> None: # TODO: Wait for this issue to resolve and remove this blocker # https://github.com/facebookresearch/fairscale/issues/648 # Also make sure to update the tests - if not is_overridden("configure_sharded_model", self.lightning_module) and len(list(model.parameters())) == 0: - assert self.lightning_module is not None + if not is_overridden("configure_sharded_model", self.model) and len(list(model.parameters())) == 0: raise MisconfigurationException( - f"Using the same instance of model with `trainer.{self.lightning_module.trainer.state.fn}()` is not" + f"Using the same instance of model with `trainer.{self.model.trainer.state.fn}()` is not" " supported with Fairscale FSDP auto-wrap. Please reinitialize your `LightningModule` and pass that." ) From d7d263e627914553526d30cbf6bcd1c00a4d00db Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 5 Sep 2022 23:48:57 +0530 Subject: [PATCH 11/11] rev --- src/pytorch_lightning/strategies/fully_sharded.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/strategies/fully_sharded.py b/src/pytorch_lightning/strategies/fully_sharded.py index 55bbca95edfa4..a364d7d19a679 100644 --- a/src/pytorch_lightning/strategies/fully_sharded.py +++ b/src/pytorch_lightning/strategies/fully_sharded.py @@ -155,9 +155,10 @@ def connect(self, model: "pl.LightningModule") -> None: # TODO: Wait for this issue to resolve and remove this blocker # https://github.com/facebookresearch/fairscale/issues/648 # Also make sure to update the tests - if not is_overridden("configure_sharded_model", self.model) and len(list(model.parameters())) == 0: + if not is_overridden("configure_sharded_model", self.lightning_module) and len(list(model.parameters())) == 0: + assert self.lightning_module is not None raise MisconfigurationException( - f"Using the same instance of model with `trainer.{self.model.trainer.state.fn}()` is not" + f"Using the same instance of model with `trainer.{self.lightning_module.trainer.state.fn}()` is not" " supported with Fairscale FSDP auto-wrap. Please reinitialize your `LightningModule` and pass that." )