Skip to content

Commit

Permalink
Fix device placement when setting up FSDP model in Lite (#15822)
Browse files Browse the repository at this point in the history
* fix
* debug test
* simplify

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
awaelchli and pre-commit-ci[bot] authored Nov 28, 2022
1 parent 3fad651 commit 657bfc5
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/lightning_lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,9 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _LiteM
module = self._strategy.setup_module(module)
module = _LiteModule(module, self._precision, original_module=original_module)

# Update the _DeviceDtypeModuleMixin's device parameter
module.to(self.device if move_to_device else next(module.parameters()).device)
if not isinstance(self._strategy, FSDPStrategy):
# Update the _DeviceDtypeModuleMixin's device parameter
module.to(self.device if move_to_device else next(module.parameters()).device)

self._models_setup += 1
return module
Expand Down
27 changes: 27 additions & 0 deletions tests/tests_lite/strategies/test_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
from unittest import mock

import pytest
import torch
Expand Down Expand Up @@ -116,3 +117,29 @@ def test_fsdp_train_save_load(manual_wrapping, precision):
lite._strategy.save_checkpoint(model.state_dict(), ckpt_path)

_assert_save_equality(lite, model, ckpt_path)


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
@pytest.mark.parametrize("move_to_device", [True, False])
@mock.patch("lightning_lite.wrappers._LiteModule")
def test_setup_module_move_to_device(lite_module_mock, move_to_device):
"""Test that `move_to_device` does nothing, FSDP decides which device parameters get moved to which device
(sharding)."""
strategy = FSDPStrategy(auto_wrap_policy=_custom_auto_wrap_policy)
lite = LightningLite(accelerator="cuda", devices=2, strategy=strategy)
lite.launch()

model = torch.nn.Linear(10, 10, bias=False) # total params: 10 * 10 = 100
lite_model = lite.setup_module(model, move_to_device=move_to_device)
lite_module_mock.assert_not_called()

assert list(param.device for param in model.parameters()) == []
assert len(list(lite_model.parameters())) == 1

# the linear layer got sharded and each part is on the expected device
assert next(lite_model.parameters()).device == torch.device("cuda", lite.local_rank)
assert next(lite_model.parameters()).numel() == 50

# The _DeviceDtypeModuleMixin currently can't represent the device in a meaningful way for sharded models
assert lite_model.device == torch.device("cpu")
assert lite.device == torch.device("cuda", lite.local_rank)

0 comments on commit 657bfc5

Please sign in to comment.