Skip to content
This repository has been archived by the owner on Sep 28, 2022. It is now read-only.

Commit

Permalink
Check torch.distributed availability before sharded tensor state dict…
Browse files Browse the repository at this point in the history
… hook registration (Lightning-AI#10621)

Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
2 people authored and Raalsky committed Nov 23, 2021
1 parent d56d9a5 commit e3e2a9e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- When a tensor is logged with `self.log`, run its computation with the same `dtype` ([#10076](https://github.com/PyTorchLightning/pytorch-lightning/pull/10076))


- Fixed `ShardedTensor` state dict hook registration to check if torch distributed is available ([#10621](https://github.com/PyTorchLightning/pytorch-lightning/pull/10621))


- Fixed LigtningLite `_wrap_init` popping unexisting keys from DataLoader signature parameters ([#10613](https://github.com/PyTorchLightning/pytorch-lightning/pull/10613))


Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
from pytorch_lightning.utilities.distributed import distributed_available, rank_zero_debug, sync_ddp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import get_model_size_mb
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
Expand Down Expand Up @@ -1990,7 +1990,8 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
"""
if not _TORCH_GREATER_EQUAL_1_10 or _IS_WINDOWS:
if not _TORCH_GREATER_EQUAL_1_10 or _IS_WINDOWS or not torch.distributed.is_available():
rank_zero_debug("Could not register sharded tensor state dict hooks")
return

from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook
Expand Down

0 comments on commit e3e2a9e

Please sign in to comment.