From a18b6409d13ded0e12752a9649bdb8dc763d4cbb Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 19 Nov 2021 09:34:23 -0800 Subject: [PATCH] Check torch.distributed availability before sharded tensor state dict hook registration (#10621) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 3 +++ pytorch_lightning/core/lightning.py | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f5f3644b6846..96830878d84ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -155,6 +155,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - When a tensor is logged with `self.log`, run its computation with the same `dtype` ([#10076](https://github.com/PyTorchLightning/pytorch-lightning/pull/10076)) +- Fixed `ShardedTensor` state dict hook registration to check if torch distributed is available ([#10621](https://github.com/PyTorchLightning/pytorch-lightning/pull/10621)) + + - Fixed LigtningLite `_wrap_init` popping unexisting keys from DataLoader signature parameters ([#10613](https://github.com/PyTorchLightning/pytorch-lightning/pull/10613)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index dc3ce5f0f4063..89f46949a525c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -46,7 +46,7 @@ ) from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem -from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp +from pytorch_lightning.utilities.distributed import distributed_available, rank_zero_debug, sync_ddp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import get_model_size_mb from pytorch_lightning.utilities.model_summary import ModelSummary, summarize @@ -1990,7 +1990,8 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None: These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly. """ - if not _TORCH_GREATER_EQUAL_1_10 or _IS_WINDOWS: + if not _TORCH_GREATER_EQUAL_1_10 or _IS_WINDOWS or not torch.distributed.is_available(): + rank_zero_debug("Could not register sharded tensor state dict hooks") return from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook