From eef4c32f9d70b805adb3c131dec5b373c6199a14 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 18 Nov 2021 12:36:57 -0800 Subject: [PATCH 1/2] Check torch.distributed availability for sharded tensor state dict hook registration --- CHANGELOG.md | 2 +- pytorch_lightning/core/lightning.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d7cd27977ddbf..037dea7032ee9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -157,7 +157,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486)) -- +- Fixed `ShardedTensor` state dict hook registration to check if torch distributed is available ([#]()) - diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index dc3ce5f0f4063..89f46949a525c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -46,7 +46,7 @@ ) from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem -from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp +from pytorch_lightning.utilities.distributed import distributed_available, rank_zero_debug, sync_ddp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import get_model_size_mb from pytorch_lightning.utilities.model_summary import ModelSummary, summarize @@ -1990,7 +1990,8 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None: These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly. """ - if not _TORCH_GREATER_EQUAL_1_10 or _IS_WINDOWS: + if not _TORCH_GREATER_EQUAL_1_10 or _IS_WINDOWS or not torch.distributed.is_available(): + rank_zero_debug("Could not register sharded tensor state dict hooks") return from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook From 3d52916fa56c737fd8a1cee087f4bedf0198b579 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 18 Nov 2021 12:38:00 -0800 Subject: [PATCH 2/2] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 037dea7032ee9..a810ee74cf495 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -157,7 +157,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486)) -- Fixed `ShardedTensor` state dict hook registration to check if torch distributed is available ([#]()) +- Fixed `ShardedTensor` state dict hook registration to check if torch distributed is available ([#10621](https://github.com/PyTorchLightning/pytorch-lightning/pull/10621)) -