Skip to content

Commit

Permalink
Fix torch.distributed._sharded_tensor DeprecationWarning (#13261)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart authored Jun 21, 2022
1 parent cd44512 commit d24178e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
7 changes: 5 additions & 2 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.parsing import collect_init_args
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
Expand Down Expand Up @@ -1991,7 +1991,10 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
rank_zero_debug("Could not register sharded tensor state dict hooks")
return

from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook
if _TORCH_GREATER_EQUAL_1_11:
from torch.distributed._shard.sharded_tensor import pre_load_state_dict_hook, state_dict_hook
else:
from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook

self._register_state_dict_hook(state_dict_hook)

Expand Down
5 changes: 4 additions & 1 deletion tests/tests_pytorch/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,10 @@ def assert_device(device: torch.device) -> None:

@RunIf(min_torch="1.10", skip_windows=True)
def test_sharded_tensor_state_dict(single_process_pg):
from torch.distributed._sharded_tensor import empty as sharded_tensor_empty
if _TORCH_GREATER_EQUAL_1_11:
from torch.distributed._shard.sharded_tensor import empty as sharded_tensor_empty
else:
from torch.distributed._sharded_tensor import empty as sharded_tensor_empty
from torch.distributed._sharding_spec import ChunkShardingSpec

class BoringModelWithShardedTensor(BoringModel):
Expand Down

0 comments on commit d24178e

Please sign in to comment.