Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fairscale import updates #14721

Merged
merged 6 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 19 additions & 27 deletions src/pytorch_lightning/overrides/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,35 @@
from typing import Optional, Union

import torch.nn as nn
from lightning_utilities.core.imports import module_available

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import (
_LightningModuleWrapperBase,
_LightningPrecisionModuleWrapperBase,
unwrap_lightning_module,
)
from pytorch_lightning.utilities.imports import _IS_WINDOWS
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation

_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and module_available("fairscale.nn")
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

class LightningShardedDataParallel(_LightningModuleWrapperBase):
def __init__(
self,
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
) -> None:
self._validate_init_arguments(pl_module, forward_module)
super().__init__(forward_module=(pl_module or forward_module))

if _FAIRSCALE_AVAILABLE: # pragma: no-cover

def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule":
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel

class LightningShardedDataParallel(_LightningModuleWrapperBase):
def __init__(
self,
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
) -> None:
self._validate_init_arguments(pl_module, forward_module)
super().__init__(forward_module=(pl_module or forward_module))

def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule":
rank_zero_deprecation(
"The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0 and will be removed in v1.10.0."
" Access the `LightningModule` directly through the strategy attribute `Strategy.lightning_module`."
)
model = wrapped_model
if isinstance(model, ShardedDataParallel):
model = model.module

return unwrap_lightning_module(model, _suppress_warning=True)

else:
LightningShardedDataParallel = ... # type: ignore[assignment,misc]
unwrap_lightning_module_sharded = ... # type: ignore[assignment]
rank_zero_deprecation(
"The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0 and will be removed in v1.10.0."
" Access the `LightningModule` directly through the strategy attribute `Strategy.lightning_module`."
)
model = wrapped_model
if isinstance(model, ShardedDataParallel):
model = model.module

return unwrap_lightning_module(model, _suppress_warning=True)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from typing import Optional, Union

from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from lightning_lite.utilities.distributed import distributed_available, get_default_process_group_backend_for_device
from lightning_lite.utilities.distributed import group as _group
from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
Expand All @@ -39,7 +40,6 @@
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from pytorch_lightning.strategies.parallel import ParallelStrategy
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from lightning_lite.utilities.enums import PrecisionType
from lightning_lite.utilities.optimizer import optimizers_to_device
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from torch.optim import Optimizer

import pytorch_lightning as pl
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from lightning_lite.utilities.enums import PrecisionType
from lightning_lite.utilities.optimizer import optimizers_to_device
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from torch.optim import Optimizer

import pytorch_lightning as pl
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from lightning_lite.utilities.optimizer import optimizers_to_device
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from packaging.version import Version
from pkg_resources import get_distribution

from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.accelerators.mps import MPSAccelerator
from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.imports import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest
import torch

from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.plugins import ShardedNativeMixedPrecisionPlugin
from tests_pytorch.helpers.runif import RunIf

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import pytest
import torch

from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.plugins import FullyShardedNativeMixedPrecisionPlugin
from pytorch_lightning.strategies import DDPFullyShardedStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/strategies/test_ddp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from torch.nn.parallel import DistributedDataParallel

from lightning_lite.plugins.environments import ClusterEnvironment, LightningEnvironment
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/strategies/test_sharded_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import pytest
import torch

from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.strategies import DDPShardedStrategy, DDPSpawnShardedStrategy
from pytorch_lightning.trainer.states import TrainerFn
from tests_pytorch.helpers.runif import RunIf
Expand Down
16 changes: 0 additions & 16 deletions tests/tests_pytorch/utilities/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
from lightning_utilities.core.imports import compare_version, module_available, RequirementCache
from torch.distributed import is_available

from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities import _APEX_AVAILABLE, _HOROVOD_AVAILABLE, _OMEGACONF_AVAILABLE, _POPTORCH_AVAILABLE


Expand All @@ -41,20 +39,6 @@ def test_imports():
else:
assert _BAGUA_AVAILABLE

try:
import deepspeed # noqa
except ModuleNotFoundError:
assert not _DEEPSPEED_AVAILABLE
else:
assert _DEEPSPEED_AVAILABLE

try:
import fairscale.nn # noqa
except ModuleNotFoundError:
assert not _FAIRSCALE_AVAILABLE
else:
assert _FAIRSCALE_AVAILABLE

awaelchli marked this conversation as resolved.
Show resolved Hide resolved
try:
import horovod.torch # noqa
except ModuleNotFoundError:
Expand Down