diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index a198efd1ceab8..5cd33f7c1baea 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -218,8 +218,9 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _LiteM module = self._strategy.setup_module(module) module = _LiteModule(module, self._precision, original_module=original_module) - # Update the _DeviceDtypeModuleMixin's device parameter - module.to(self.device if move_to_device else next(module.parameters()).device) + if not isinstance(self._strategy, FSDPStrategy): + # Update the _DeviceDtypeModuleMixin's device parameter + module.to(self.device if move_to_device else next(module.parameters()).device) self._models_setup += 1 return module diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 052133e265e4c..963918d076d77 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import tempfile +from unittest import mock import pytest import torch @@ -116,3 +117,29 @@ def test_fsdp_train_save_load(manual_wrapping, precision): lite._strategy.save_checkpoint(model.state_dict(), ckpt_path) _assert_save_equality(lite, model, ckpt_path) + + +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") +@pytest.mark.parametrize("move_to_device", [True, False]) +@mock.patch("lightning_lite.wrappers._LiteModule") +def test_setup_module_move_to_device(lite_module_mock, move_to_device): + """Test that `move_to_device` does nothing, FSDP decides which device parameters get moved to which device + (sharding).""" + strategy = FSDPStrategy(auto_wrap_policy=_custom_auto_wrap_policy) + lite = LightningLite(accelerator="cuda", devices=2, strategy=strategy) + lite.launch() + + model = torch.nn.Linear(10, 10, bias=False) # total params: 10 * 10 = 100 + lite_model = lite.setup_module(model, move_to_device=move_to_device) + lite_module_mock.assert_not_called() + + assert list(param.device for param in model.parameters()) == [] + assert len(list(lite_model.parameters())) == 1 + + # the linear layer got sharded and each part is on the expected device + assert next(lite_model.parameters()).device == torch.device("cuda", lite.local_rank) + assert next(lite_model.parameters()).numel() == 50 + + # The _DeviceDtypeModuleMixin currently can't represent the device in a meaningful way for sharded models + assert lite_model.device == torch.device("cpu") + assert lite.device == torch.device("cuda", lite.local_rank)