Skip to content

Commit

Permalink
Check torch.distributed availability before sharded tensor state dict…
Browse files Browse the repository at this point in the history
… hook registration (#10621)

Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
2 people authored and awaelchli committed Nov 24, 2021
1 parent 0098f71 commit 130f0a5
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
26 changes: 26 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,32 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [1.5.3] - 2021-11-23

### Fixed

- When a tensor is logged with `self.log`, run its computation with the same `dtype` ([#10076](https://github.com/PyTorchLightning/pytorch-lightning/pull/10076))


- Fixed `ShardedTensor` state dict hook registration to check if torch distributed is available ([#10621](https://github.com/PyTorchLightning/pytorch-lightning/pull/10621))


- Fixed LigtningLite `_wrap_init` popping unexisting keys from DataLoader signature parameters ([#10613](https://github.com/PyTorchLightning/pytorch-lightning/pull/10613))


- Fixed signals being registered within threads ([#10610](https://github.com/PyTorchLightning/pytorch-lightning/pull/10610))


- Fixed an issue that caused Lightning to extract the batch size even though it was set by the user in `LightningModule.log` ([#10408](https://github.com/PyTorchLightning/pytorch-lightning/pull/10408))


-


-



## [1.5.2] - 2021-11-16

### Fixed
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
from pytorch_lightning.utilities.distributed import distributed_available, rank_zero_debug, sync_ddp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import get_model_size_mb
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
Expand Down Expand Up @@ -2059,7 +2059,8 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
"""
if not _TORCH_GREATER_EQUAL_1_10 or _IS_WINDOWS:
if not _TORCH_GREATER_EQUAL_1_10 or _IS_WINDOWS or not torch.distributed.is_available():
rank_zero_debug("Could not register sharded tensor state dict hooks")
return

from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook
Expand Down

0 comments on commit 130f0a5

Please sign in to comment.