diff --git a/src/pytorch_lightning/overrides/fairscale.py b/src/pytorch_lightning/overrides/fairscale.py index 0a35f9ddd4d8a..a0918172dbd1b 100644 --- a/src/pytorch_lightning/overrides/fairscale.py +++ b/src/pytorch_lightning/overrides/fairscale.py @@ -14,7 +14,6 @@ from typing import Optional, Union import torch.nn as nn -from lightning_utilities.core.imports import module_available import pytorch_lightning as pl from pytorch_lightning.overrides.base import ( @@ -22,35 +21,28 @@ _LightningPrecisionModuleWrapperBase, unwrap_lightning_module, ) -from pytorch_lightning.utilities.imports import _IS_WINDOWS from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation -_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and module_available("fairscale.nn") +class LightningShardedDataParallel(_LightningModuleWrapperBase): + def __init__( + self, + forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + ) -> None: + self._validate_init_arguments(pl_module, forward_module) + super().__init__(forward_module=(pl_module or forward_module)) -if _FAIRSCALE_AVAILABLE: # pragma: no-cover + +def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule": from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel - class LightningShardedDataParallel(_LightningModuleWrapperBase): - def __init__( - self, - forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, - pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, - ) -> None: - self._validate_init_arguments(pl_module, forward_module) - super().__init__(forward_module=(pl_module or forward_module)) - - def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule": - rank_zero_deprecation( - "The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0 and will be removed in v1.10.0." - " Access the `LightningModule` directly through the strategy attribute `Strategy.lightning_module`." - ) - model = wrapped_model - if isinstance(model, ShardedDataParallel): - model = model.module - - return unwrap_lightning_module(model, _suppress_warning=True) - -else: - LightningShardedDataParallel = ... # type: ignore[assignment,misc] - unwrap_lightning_module_sharded = ... # type: ignore[assignment] + rank_zero_deprecation( + "The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0 and will be removed in v1.10.0." + " Access the `LightningModule` directly through the strategy attribute `Strategy.lightning_module`." + ) + model = wrapped_model + if isinstance(model, ShardedDataParallel): + model = model.module + + return unwrap_lightning_module(model, _suppress_warning=True) diff --git a/src/pytorch_lightning/plugins/precision/sharded_native_amp.py b/src/pytorch_lightning/plugins/precision/sharded_native_amp.py index d76db26a76358..30132b291e021 100644 --- a/src/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/src/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Optional, Union -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/src/pytorch_lightning/strategies/ddp.py b/src/pytorch_lightning/strategies/ddp.py index 15dba4c98877b..de2d167c63a60 100644 --- a/src/pytorch_lightning/strategies/ddp.py +++ b/src/pytorch_lightning/strategies/ddp.py @@ -30,6 +30,7 @@ import pytorch_lightning as pl from lightning_lite.plugins import CheckpointIO, ClusterEnvironment +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from lightning_lite.utilities.distributed import distributed_available, get_default_process_group_backend_for_device from lightning_lite.utilities.distributed import group as _group from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available @@ -39,7 +40,6 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.distributed import prepare_for_backward -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy diff --git a/src/pytorch_lightning/strategies/fully_sharded.py b/src/pytorch_lightning/strategies/fully_sharded.py index 26256af600d8e..59c82400d68ce 100644 --- a/src/pytorch_lightning/strategies/fully_sharded.py +++ b/src/pytorch_lightning/strategies/fully_sharded.py @@ -19,10 +19,10 @@ import pytorch_lightning as pl from lightning_lite.plugins import CheckpointIO, ClusterEnvironment +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from lightning_lite.utilities.enums import PrecisionType from lightning_lite.utilities.optimizer import optimizers_to_device from pytorch_lightning.overrides.base import _LightningModuleWrapperBase -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn diff --git a/src/pytorch_lightning/strategies/sharded.py b/src/pytorch_lightning/strategies/sharded.py index df0d126385f32..8b9ccdc20462b 100644 --- a/src/pytorch_lightning/strategies/sharded.py +++ b/src/pytorch_lightning/strategies/sharded.py @@ -19,11 +19,11 @@ from torch.optim import Optimizer import pytorch_lightning as pl +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from lightning_lite.utilities.enums import PrecisionType from lightning_lite.utilities.optimizer import optimizers_to_device from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index 438f6d5eb6a47..934cf680de0f4 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -19,9 +19,9 @@ from torch.optim import Optimizer import pytorch_lightning as pl +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from lightning_lite.utilities.optimizer import optimizers_to_device from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index 62cb93ed38919..1f369b6c759a4 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -20,9 +20,9 @@ from packaging.version import Version from pkg_resources import get_distribution +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.accelerators.mps import MPSAccelerator from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.imports import ( diff --git a/tests/tests_pytorch/plugins/precision/test_sharded_precision.py b/tests/tests_pytorch/plugins/precision/test_sharded_precision.py index 0c08c8e9540eb..b231455c6cf6f 100644 --- a/tests/tests_pytorch/plugins/precision/test_sharded_precision.py +++ b/tests/tests_pytorch/plugins/precision/test_sharded_precision.py @@ -15,7 +15,7 @@ import pytest import torch -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins import ShardedNativeMixedPrecisionPlugin from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py index e7b12bd7c7e6b..5043d3a8c4aa3 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_with_full_state_dict.py @@ -5,10 +5,10 @@ import pytest import torch +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins import FullyShardedNativeMixedPrecisionPlugin from pytorch_lightning.strategies import DDPFullyShardedStrategy from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/tests/tests_pytorch/strategies/test_ddp_strategy.py b/tests/tests_pytorch/strategies/test_ddp_strategy.py index 2665eb7c3e370..7755a31d2129a 100644 --- a/tests/tests_pytorch/strategies/test_ddp_strategy.py +++ b/tests/tests_pytorch/strategies/test_ddp_strategy.py @@ -20,9 +20,9 @@ from torch.nn.parallel import DistributedDataParallel from lightning_lite.plugins.environments import ClusterEnvironment, LightningEnvironment +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10 diff --git a/tests/tests_pytorch/strategies/test_sharded_strategy.py b/tests/tests_pytorch/strategies/test_sharded_strategy.py index 2c0a5579c9933..a2b3775eb6708 100644 --- a/tests/tests_pytorch/strategies/test_sharded_strategy.py +++ b/tests/tests_pytorch/strategies/test_sharded_strategy.py @@ -5,9 +5,9 @@ import pytest import torch +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies import DDPShardedStrategy, DDPSpawnShardedStrategy from pytorch_lightning.trainer.states import TrainerFn from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_imports.py b/tests/tests_pytorch/utilities/test_imports.py index 27c81306e6480..d770b8307a34a 100644 --- a/tests/tests_pytorch/utilities/test_imports.py +++ b/tests/tests_pytorch/utilities/test_imports.py @@ -20,9 +20,7 @@ from lightning_utilities.core.imports import compare_version, module_available, RequirementCache from torch.distributed import is_available -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE -from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities import _APEX_AVAILABLE, _HOROVOD_AVAILABLE, _OMEGACONF_AVAILABLE, _POPTORCH_AVAILABLE @@ -41,20 +39,6 @@ def test_imports(): else: assert _BAGUA_AVAILABLE - try: - import deepspeed # noqa - except ModuleNotFoundError: - assert not _DEEPSPEED_AVAILABLE - else: - assert _DEEPSPEED_AVAILABLE - - try: - import fairscale.nn # noqa - except ModuleNotFoundError: - assert not _FAIRSCALE_AVAILABLE - else: - assert _FAIRSCALE_AVAILABLE - try: import horovod.torch # noqa except ModuleNotFoundError: