From ba1b5e8c9ec3eb46c3131ed88dcfa96deba78c11 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 15 Sep 2022 13:22:47 +0200 Subject: [PATCH 1/3] fairscale imports --- src/pytorch_lightning/overrides/fairscale.py | 6 +----- .../plugins/precision/sharded_native_amp.py | 2 +- src/pytorch_lightning/strategies/ddp.py | 2 +- .../strategies/fully_sharded.py | 2 +- src/pytorch_lightning/strategies/sharded.py | 2 +- .../strategies/sharded_spawn.py | 2 +- tests/tests_pytorch/helpers/runif.py | 2 +- .../plugins/precision/test_sharded_precision.py | 2 +- ...est_ddp_fully_sharded_with_full_state_dict.py | 2 +- .../strategies/test_ddp_strategy.py | 2 +- .../strategies/test_sharded_strategy.py | 2 +- tests/tests_pytorch/utilities/test_imports.py | 16 ---------------- 12 files changed, 11 insertions(+), 31 deletions(-) diff --git a/src/pytorch_lightning/overrides/fairscale.py b/src/pytorch_lightning/overrides/fairscale.py index 0a35f9ddd4d8a..6bf3f942ee4f0 100644 --- a/src/pytorch_lightning/overrides/fairscale.py +++ b/src/pytorch_lightning/overrides/fairscale.py @@ -14,20 +14,16 @@ from typing import Optional, Union import torch.nn as nn -from lightning_utilities.core.imports import module_available import pytorch_lightning as pl +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.overrides.base import ( _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase, unwrap_lightning_module, ) -from pytorch_lightning.utilities.imports import _IS_WINDOWS from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation -_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and module_available("fairscale.nn") - - if _FAIRSCALE_AVAILABLE: # pragma: no-cover from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel diff --git a/src/pytorch_lightning/plugins/precision/sharded_native_amp.py b/src/pytorch_lightning/plugins/precision/sharded_native_amp.py index d76db26a76358..30132b291e021 100644 --- a/src/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/src/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Optional, Union -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/src/pytorch_lightning/strategies/ddp.py b/src/pytorch_lightning/strategies/ddp.py index d197aa7979a0a..576e5469a4439 100644 --- a/src/pytorch_lightning/strategies/ddp.py +++ b/src/pytorch_lightning/strategies/ddp.py @@ -31,6 +31,7 @@ import pytorch_lightning as pl from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from lightning_lite.utilities.distributed import ( _get_process_group_backend_from_env, distributed_available, @@ -44,7 +45,6 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.distributed import prepare_for_backward -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy diff --git a/src/pytorch_lightning/strategies/fully_sharded.py b/src/pytorch_lightning/strategies/fully_sharded.py index 6979741d92be7..6c9d47bf61b3c 100644 --- a/src/pytorch_lightning/strategies/fully_sharded.py +++ b/src/pytorch_lightning/strategies/fully_sharded.py @@ -20,10 +20,10 @@ import pytorch_lightning as pl from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from lightning_lite.utilities.enums import PrecisionType from lightning_lite.utilities.optimizer import optimizers_to_device from pytorch_lightning.overrides.base import _LightningModuleWrapperBase -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn diff --git a/src/pytorch_lightning/strategies/sharded.py b/src/pytorch_lightning/strategies/sharded.py index df0d126385f32..8b9ccdc20462b 100644 --- a/src/pytorch_lightning/strategies/sharded.py +++ b/src/pytorch_lightning/strategies/sharded.py @@ -19,11 +19,11 @@ from torch.optim import Optimizer import pytorch_lightning as pl +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from lightning_lite.utilities.enums import PrecisionType from lightning_lite.utilities.optimizer import optimizers_to_device from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index 438f6d5eb6a47..934cf680de0f4 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -19,9 +19,9 @@ from torch.optim import Optimizer import pytorch_lightning as pl +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from lightning_lite.utilities.optimizer import optimizers_to_device from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index 62cb93ed38919..1f369b6c759a4 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -20,9 +20,9 @@ from packaging.version import Version from pkg_resources import get_distribution +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.accelerators.mps import MPSAccelerator from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.imports import ( diff --git a/tests/tests_pytorch/plugins/precision/test_sharded_precision.py b/tests/tests_pytorch/plugins/precision/test_sharded_precision.py index 0c08c8e9540eb..b231455c6cf6f 100644 --- a/tests/tests_pytorch/plugins/precision/test_sharded_precision.py +++ b/tests/tests_pytorch/plugins/precision/test_sharded_precision.py @@ -15,7 +15,7 @@ import pytest import torch -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins import ShardedNativeMixedPrecisionPlugin from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py index bb3b63ea578c6..4a2e60dc39f76 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py @@ -5,10 +5,10 @@ import pytest import torch +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins import FullyShardedNativeMixedPrecisionPlugin from pytorch_lightning.strategies import DDPFullyShardedStrategy from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/tests/tests_pytorch/strategies/test_ddp_strategy.py b/tests/tests_pytorch/strategies/test_ddp_strategy.py index 2665eb7c3e370..7755a31d2129a 100644 --- a/tests/tests_pytorch/strategies/test_ddp_strategy.py +++ b/tests/tests_pytorch/strategies/test_ddp_strategy.py @@ -20,9 +20,9 @@ from torch.nn.parallel import DistributedDataParallel from lightning_lite.plugins.environments import ClusterEnvironment, LightningEnvironment +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10 diff --git a/tests/tests_pytorch/strategies/test_sharded_strategy.py b/tests/tests_pytorch/strategies/test_sharded_strategy.py index 2c0a5579c9933..a2b3775eb6708 100644 --- a/tests/tests_pytorch/strategies/test_sharded_strategy.py +++ b/tests/tests_pytorch/strategies/test_sharded_strategy.py @@ -5,9 +5,9 @@ import pytest import torch +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies import DDPShardedStrategy, DDPSpawnShardedStrategy from pytorch_lightning.trainer.states import TrainerFn from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_imports.py b/tests/tests_pytorch/utilities/test_imports.py index 05845e0b15172..f2b188e6e7895 100644 --- a/tests/tests_pytorch/utilities/test_imports.py +++ b/tests/tests_pytorch/utilities/test_imports.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE -from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities import _APEX_AVAILABLE, _HOROVOD_AVAILABLE, _OMEGACONF_AVAILABLE, _POPTORCH_AVAILABLE @@ -33,20 +31,6 @@ def test_imports(): else: assert _BAGUA_AVAILABLE - try: - import deepspeed # noqa - except ModuleNotFoundError: - assert not _DEEPSPEED_AVAILABLE - else: - assert _DEEPSPEED_AVAILABLE - - try: - import fairscale.nn # noqa - except ModuleNotFoundError: - assert not _FAIRSCALE_AVAILABLE - else: - assert _FAIRSCALE_AVAILABLE - try: import horovod.torch # noqa except ModuleNotFoundError: From 30a36eb5f5f26b969c7373d15e5a36c3750cfbfd Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 18 Sep 2022 00:12:14 +0200 Subject: [PATCH 2/3] refactor to avoid meta package build issue --- src/pytorch_lightning/overrides/fairscale.py | 46 +++++++++----------- src/pytorch_lightning/strategies/ddp.py | 2 +- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/src/pytorch_lightning/overrides/fairscale.py b/src/pytorch_lightning/overrides/fairscale.py index 6bf3f942ee4f0..a0918172dbd1b 100644 --- a/src/pytorch_lightning/overrides/fairscale.py +++ b/src/pytorch_lightning/overrides/fairscale.py @@ -16,7 +16,6 @@ import torch.nn as nn import pytorch_lightning as pl -from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.overrides.base import ( _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase, @@ -24,29 +23,26 @@ ) from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation -if _FAIRSCALE_AVAILABLE: # pragma: no-cover + +class LightningShardedDataParallel(_LightningModuleWrapperBase): + def __init__( + self, + forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + ) -> None: + self._validate_init_arguments(pl_module, forward_module) + super().__init__(forward_module=(pl_module or forward_module)) + + +def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule": from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel - class LightningShardedDataParallel(_LightningModuleWrapperBase): - def __init__( - self, - forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, - pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, - ) -> None: - self._validate_init_arguments(pl_module, forward_module) - super().__init__(forward_module=(pl_module or forward_module)) - - def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule": - rank_zero_deprecation( - "The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0 and will be removed in v1.10.0." - " Access the `LightningModule` directly through the strategy attribute `Strategy.lightning_module`." - ) - model = wrapped_model - if isinstance(model, ShardedDataParallel): - model = model.module - - return unwrap_lightning_module(model, _suppress_warning=True) - -else: - LightningShardedDataParallel = ... # type: ignore[assignment,misc] - unwrap_lightning_module_sharded = ... # type: ignore[assignment] + rank_zero_deprecation( + "The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0 and will be removed in v1.10.0." + " Access the `LightningModule` directly through the strategy attribute `Strategy.lightning_module`." + ) + model = wrapped_model + if isinstance(model, ShardedDataParallel): + model = model.module + + return unwrap_lightning_module(model, _suppress_warning=True) diff --git a/src/pytorch_lightning/strategies/ddp.py b/src/pytorch_lightning/strategies/ddp.py index 2c90b3b82a658..13cbee6f42b60 100644 --- a/src/pytorch_lightning/strategies/ddp.py +++ b/src/pytorch_lightning/strategies/ddp.py @@ -31,8 +31,8 @@ import pytorch_lightning as pl from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO -from lightning_lite.utilities.distributed import distributed_available, get_default_process_group_backend_for_device from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE +from lightning_lite.utilities.distributed import distributed_available, get_default_process_group_backend_for_device from lightning_lite.utilities.distributed import group as _group from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available from lightning_lite.utilities.optimizer import optimizers_to_device From 8d84a80cf0b79a143648b00fa998181eb15c3d1a Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 22 Sep 2022 09:16:09 +0200 Subject: [PATCH 3/3] import --- tests/tests_pytorch/utilities/test_imports.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tests_pytorch/utilities/test_imports.py b/tests/tests_pytorch/utilities/test_imports.py index cf75e47c4a8a0..d770b8307a34a 100644 --- a/tests/tests_pytorch/utilities/test_imports.py +++ b/tests/tests_pytorch/utilities/test_imports.py @@ -20,7 +20,6 @@ from lightning_utilities.core.imports import compare_version, module_available, RequirementCache from torch.distributed import is_available -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE from pytorch_lightning.utilities import _APEX_AVAILABLE, _HOROVOD_AVAILABLE, _OMEGACONF_AVAILABLE, _POPTORCH_AVAILABLE