Skip to content

Commit

Permalink
[DSD] Add a test to verify FSDP lazy initialization case (#127069) (#…
Browse files Browse the repository at this point in the history
…127130)

* [DSD] Add a test to verify FSDP lazy initialization case (#127069)

Summary:
Distributed state_dict should not error out because the `model.state_dict()` will trigger FSDP to initialize.

Pull Request resolved: #127069
Approved by: https://github.com/wz337

* Add missing import get_optimizer_state_dict

---------

Co-authored-by: Andrey Talman <[email protected]>
  • Loading branch information
fegin and atalman authored May 27, 2024
1 parent e63004b commit 81b8854
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions test/distributed/checkpoint/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_patch_model_state_dict,
_patch_optimizer_state_dict,
get_model_state_dict,
get_optimizer_state_dict,
get_state_dict,
set_model_state_dict,
set_state_dict,
Expand Down Expand Up @@ -555,6 +556,19 @@ def _test_activation_ckpt_fqns_fsdp1(self, use_orig_params: bool) -> None:

self.assertEqual(original_keys, new_keys)

@with_comms
@skip_if_lt_x_gpu(2)
def test_fsdp_root_not_initialized(self) -> None:
# This test verifies that FSDP root is not initialized but we should
# still be able to get the state_dict without errors because
# fsdp_model.state_dict() will trigger the FSDP initialization.
device_mesh = init_device_mesh("cuda", (self.world_size,))
model = CompositeParamModel(device=torch.device("cuda"))
fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh)
fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
get_model_state_dict(fsdp_model)
get_optimizer_state_dict(fsdp_model, fsdp_optim)


if __name__ == "__main__":
run_tests()

0 comments on commit 81b8854

Please sign in to comment.