Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standalone Lite: Update LightningLite #14726

Merged
merged 23 commits into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/lightning_lite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@
_logger.addHandler(logging.StreamHandler())
_logger.propagate = False

from lightning_lite.lite import LightningLite # noqa: E402
# TODO(lite): Re-enable this import
# from lightning_lite.lite import LightningLite
from lightning_lite.utilities.seed import seed_everything # noqa: E402

__all__ = ["LightningLite", "seed_everything"]
__all__ = [
# TODO(lite): Re-enable this import
# "LightningLite",
"seed_everything",
]

# for compatibility with namespace packages
__import__("pkg_resources").declare_namespace(__name__)
3 changes: 2 additions & 1 deletion src/lightning_lite/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ def _check_config_and_set_final_flags(
if strategy is not None and strategy not in self._registered_strategies and not isinstance(strategy, Strategy):
raise ValueError(
f"You selected an invalid strategy name: `strategy={strategy!r}`."
f" Available names are: {', '.join(self._registered_strategies)}."
" Example choices: ddp, ddp_spawn, deepspeed, dp, ..."
" Find a complete list of options in our documentation at https://lightning.ai"
carmocca marked this conversation as resolved.
Show resolved Hide resolved
)

if (
Expand Down
3 changes: 0 additions & 3 deletions src/lightning_lite/lite.py

This file was deleted.

7 changes: 1 addition & 6 deletions src/lightning_lite/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,6 @@ def setup_module(self, module: Module) -> Module:
return DistributedDataParallel(module=module, device_ids=self._determine_ddp_device_ids(), **self._ddp_kwargs)

def module_to_device(self, module: Module) -> None:
if self.root_device.type == "cuda":
# TODO(lite): This should be handled outside module_to_device, by a call to accelerator.setup_device()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
# set the device on the spawned subprocesses
torch.cuda.set_device(self.root_device)
module.to(self.root_device)

def reduce(
Expand Down Expand Up @@ -200,8 +196,7 @@ def _setup_distributed(self) -> None:
def _get_process_group_backend(self) -> str:
return self._process_group_backend or get_default_process_group_backend_for_device(self.root_device)

def _set_world_ranks(self, process_idx: int = 0) -> None:
self._local_rank = process_idx
def _set_world_ranks(self) -> None:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
Expand Down
59 changes: 58 additions & 1 deletion src/lightning_lite/utilities/distributed.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import logging
import os
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple, Union

import torch
from lightning_utilities.core.rank_zero import rank_zero_deprecation
from torch import Tensor
from torch.nn import functional as F
from torch.utils.data import Dataset, DistributedSampler, Sampler

from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_lite.utilities.exceptions import MisconfigurationException
from lightning_lite.utilities.imports import _HPU_AVAILABLE, _TPU_AVAILABLE
from lightning_lite.utilities.rank_zero import rank_zero_info as new_rank_zero_info

Expand Down Expand Up @@ -262,3 +264,58 @@ def _get_process_group_backend_from_env() -> Optional[str]:
" Specify `process_group_backend` directly on the strategy constructor."
)
return torch_backend


# TODO(lite): The error messsages refer to 'replace_sampler_ddp' in PL but Lite has it named 'replace_sampler'
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
class _DatasetSamplerWrapper(Dataset):
"""Dataset to create indexes from `Sampler` or `Iterable`"""

def __init__(self, sampler: Union[Sampler, Iterable]) -> None:
if not isinstance(sampler, Sized):
raise TypeError(
"You seem to have configured a sampler in your DataLoader which"
" does not provide `__len__` method. The sampler was about to be"
" replaced by `DistributedSamplerWrapper` since `replace_sampler_ddp`"
" is True and you are using distributed training. Either provide `__len__`"
" method in your sampler, remove it from DataLoader or set `replace_sampler_ddp=False`"
" if you want to handle distributed sampling yourself."
)
if len(sampler) == float("inf"):
raise TypeError(
"You seem to have configured a sampler in your DataLoader which"
" does not provide finite `__len__` method. The sampler was about to be"
" replaced by `DistributedSamplerWrapper` since `replace_sampler_ddp`"
" is True and you are using distributed training. Either provide `__len__`"
" method in your sampler which returns a finite number, remove it from DataLoader"
" or set `replace_sampler_ddp=False` if you want to handle distributed sampling yourself."
)
self._sampler = sampler
# defer materializing an iterator until it is necessary
self._sampler_list: Optional[List[Any]] = None

def __getitem__(self, index: int) -> Any:
if self._sampler_list is None:
self._sampler_list = list(self._sampler)
return self._sampler_list[index]

def __len__(self) -> int:
return len(self._sampler)

def reset(self) -> None:
"""Reset the sampler list in order to get new sampling."""
self._sampler_list = list(self._sampler)


class DistributedSamplerWrapper(DistributedSampler):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Wrapper over ``Sampler`` for distributed training.

Allows you to use any sampler in distributed mode. It will be automatically used by Lightning in distributed mode if
sampler replacement is enabled.
"""

def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any) -> None:
super().__init__(_DatasetSamplerWrapper(sampler), *args, **kwargs)

def __iter__(self) -> Iterator:
self.dataset.reset()
return (self.dataset[index] for index in super().__iter__())
117 changes: 31 additions & 86 deletions src/pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,22 @@
from torch.optim import Optimizer
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler

from lightning_lite.utilities import _AcceleratorType, _StrategyType, move_data_to_device
from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.connector import _Connector, _PLUGIN_INPUT
from lightning_lite.plugins import Precision
from lightning_lite.strategies import DeepSpeedStrategy, Strategy, XLAStrategy
from lightning_lite.strategies.strategy import TBroadcast
from lightning_lite.utilities import move_data_to_device
from lightning_lite.utilities.apply_func import convert_to_tensors
from lightning_lite.utilities.data import (
_auto_add_worker_init_fn,
_replace_dunder_methods,
_update_dataloader,
has_iterable_dataset,
)
from lightning_lite.utilities.distributed import DistributedSamplerWrapper
from lightning_lite.utilities.seed import seed_everything
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from pytorch_lightning.overrides.distributed import DistributedSamplerWrapper
from pytorch_lightning.plugins import PLUGIN_INPUT
from pytorch_lightning.strategies import DeepSpeedStrategy, Strategy, TPUSpawnStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class LightningLite(ABC):
Expand Down Expand Up @@ -76,34 +75,23 @@ def __init__(
devices: Optional[Union[List[int], str, int]] = None,
num_nodes: int = 1,
precision: Union[int, str] = 32,
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
gpus: Optional[Union[List[int], str, int]] = None,
tpu_cores: Optional[Union[List[int], str, int]] = None,
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
self._check_accelerator_support(accelerator)
self._check_strategy_support(strategy)
self._accelerator_connector = AcceleratorConnector(
num_processes=None,
devices=devices,
tpu_cores=tpu_cores,
ipus=None,
self._connector = _Connector(
accelerator=accelerator,
strategy=strategy,
gpus=gpus,
devices=devices,
num_nodes=num_nodes,
sync_batchnorm=False, # TODO: add support?
benchmark=False,
replace_sampler_ddp=True,
deterministic=False,
precision=precision,
amp_type="native",
amp_level=None,
plugins=plugins,
auto_select_gpus=False,
tpu_cores=tpu_cores,
gpus=gpus,
)
self._strategy = self._accelerator_connector.strategy
self._accelerator = self._strategy.accelerator
self._precision_plugin = self._strategy.precision_plugin
self._strategy: Strategy = self._connector.strategy
self._accelerator: Accelerator = self._connector.accelerator
self._precision_plugin: Precision = self._strategy.precision_plugin
self._models_setup: int = 0

# wrap the run method so we can inject setup logic or spawn processes for the user
Expand Down Expand Up @@ -173,7 +161,7 @@ def setup(
model = self._move_model_to_device(model=model, optimizers=list(optimizers))

# Let accelerator/plugin wrap and connect the models and optimizers
model, optimizers = self._strategy._setup_model_and_optimizers(model, list(optimizers))
model, optimizers = self._strategy.setup_module_and_optimizers(model, list(optimizers))
model = _LiteModule(model, self._precision_plugin, original_module=original_model)
optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
self._models_setup += 1
Expand Down Expand Up @@ -234,7 +222,7 @@ def _setup_dataloader(
_auto_add_worker_init_fn(dataloader, self.global_rank)

dataloader = self._strategy.process_dataloader(dataloader)
device = self.device if move_to_device and not isinstance(self._strategy, TPUSpawnStrategy) else None
device = self.device if move_to_device and not isinstance(self._strategy, XLAStrategy) else None
lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=device)
lite_dataloader = cast(DataLoader, lite_dataloader)
return lite_dataloader
Expand All @@ -256,20 +244,18 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = No
if isinstance(self._strategy, DeepSpeedStrategy):
if model is None:
if self._models_setup == 0:
raise MisconfigurationException(
"No models were setup for backward. Did you forget to call `self.setup()`?"
)
raise RuntimeError("No models were set up for backward. Did you forget to call `self.setup()`?")
if self._models_setup > 1:
raise MisconfigurationException(
raise ValueError(
"When using multiple models + deepspeed, please provide the model used to perform"
" the optimization: `self.backward(loss, model=model)`"
)
module = self._strategy.model
else:
# requires to attach the current `DeepSpeedEngine` for the `_LiteOptimizer.step` call.
self._strategy.model = module
self._strategy._deepspeed_engine = module

self._precision_plugin._run_backward(tensor, module, *args, **kwargs)
self._precision_plugin.backward(tensor, module, *args, **kwargs)

@contextmanager
def autocast(self) -> Generator[None, None, None]:
Expand Down Expand Up @@ -305,11 +291,8 @@ def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tens
A reference to the object that was moved to the new device.
"""
if isinstance(obj, nn.Module):
if self.device.type == "cuda":
# need to call this manually here again in case we spawned with DDPSpawnStrategy
# TODO: refactor to let accelerator handle this cleanly (see Accelerator.setup_device)
torch.cuda.set_device(self.device)
return obj.to(self.device)
self._accelerator.setup_device(self.device)
return self._strategy.module_to_device(obj)
return move_data_to_device(obj, device=self.device)

def print(self, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -404,13 +387,13 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:

def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
self._strategy.setup_environment()
with self._strategy.model_sharded_context(), _replace_dunder_methods(
with self._strategy.module_sharded_context(), _replace_dunder_methods(
DataLoader, "dataset"
), _replace_dunder_methods(BatchSampler):
return run_method(*args, **kwargs)

def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:
if isinstance(self._strategy, TPUSpawnStrategy):
if isinstance(self._strategy, XLAStrategy):
# When the user creates the optimizer, they reference the parameters on the CPU.
# However, when running with TPU the parameters get copied and the reference in the optimizer
# remains invalid. We need to update the references to point to the parameter tensors on the device.
Expand All @@ -429,67 +412,29 @@ def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -

def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
return (
self._accelerator_connector.is_distributed
self._connector.is_distributed
and not isinstance(dataloader.sampler, DistributedSampler)
and not has_iterable_dataset(dataloader)
)

@staticmethod
def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> DistributedSampler:
kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0)))
# TODO(lite): Bring the DistributedSamplerWrapper to Lite package
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
return DistributedSamplerWrapper(dataloader.sampler, **kwargs)

def _check_accelerator_support(self, accelerator: Optional[Union[str, Accelerator]]) -> None:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
supported = [t.value.lower() for t in self._supported_device_types()] + ["gpu", "auto"]
valid = accelerator is None or isinstance(accelerator, Accelerator) or accelerator in supported
if not valid:
raise MisconfigurationException(
f"`accelerator={repr(accelerator)}` is not a valid choice."
f" Choose one of {supported} or pass in a `Accelerator` instance."
)

def _check_strategy_support(self, strategy: Optional[Union[str, Strategy]]) -> None:
supported = [t.lower() for t in self._supported_strategy_types()]
valid = strategy is None or isinstance(strategy, Strategy) or strategy in supported
if not valid:
raise MisconfigurationException(
f"`strategy={repr(strategy)}` is not a valid choice."
f" Choose one of {supported} or pass in a `Strategy` instance."
)

@staticmethod
def _supported_device_types() -> Sequence[_AcceleratorType]:
return (
_AcceleratorType.CPU,
_AcceleratorType.CUDA,
_AcceleratorType.TPU,
_AcceleratorType.MPS,
)

@staticmethod
def _supported_strategy_types() -> Sequence[_StrategyType]:
return (
_StrategyType.DP,
_StrategyType.DDP,
_StrategyType.DDP_SPAWN,
_StrategyType.DDP_FORK,
_StrategyType.DEEPSPEED,
_StrategyType.DDP_SHARDED,
_StrategyType.DDP_SHARDED_SPAWN,
)

@staticmethod
def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None:
if isinstance(model, _LiteModule):
raise MisconfigurationException("A model should be passed only once to the `setup` method.")
raise ValueError("A model should be passed only once to the `setup` method.")

if any(isinstance(opt, _LiteOptimizer) for opt in optimizers):
raise MisconfigurationException("An optimizer should be passed only once to the `setup` method.")
raise ValueError("An optimizer should be passed only once to the `setup` method.")

@staticmethod
def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None:
if any(isinstance(dl, _LiteDataLoader) for dl in dataloaders):
raise MisconfigurationException("A dataloader should be passed only once to the `setup_dataloaders` method")
raise ValueError("A dataloader should be passed only once to the `setup_dataloaders` method")

if any(not isinstance(dl, DataLoader) for dl in dataloaders):
raise MisconfigurationException("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")
17 changes: 8 additions & 9 deletions src/pytorch_lightning/lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from lightning_lite.plugins import Precision
from lightning_lite.strategies import Strategy
from lightning_lite.utilities.apply_func import move_data_to_device
from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from pytorch_lightning.plugins import PrecisionPlugin
from pytorch_lightning.strategies import Strategy

T_destination = TypeVar("T_destination", bound=Dict[str, Any])

Expand Down Expand Up @@ -56,21 +56,20 @@ def optimizer(self) -> Optimizer:
return self._optimizer

def state_dict(self) -> Dict[str, Tensor]:
return self._strategy.optimizer_state(self.optimizer)
return self._strategy.get_optimizer_state(self.optimizer)

def step(self, closure: Optional[Callable] = None) -> Any:
closure = closure or _do_nothing_closure
def step(self, closure: Optional[Callable] = None, module: Optional["_LiteModule"] = None) -> Any:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
kwargs = dict(closure=closure) if closure is not None else {}
return self._strategy.optimizer_step(
self.optimizer,
opt_idx=0,
closure=closure,
model=self._strategy.model,
model=(module if module is not None else getattr(self._strategy, "model", None)),
**kwargs,
)


class _LiteModule(_DeviceDtypeModuleMixin):
def __init__(
self, forward_module: nn.Module, precision_plugin: PrecisionPlugin, original_module: Optional[nn.Module] = None
self, forward_module: nn.Module, precision_plugin: Precision, original_module: Optional[nn.Module] = None
) -> None:
"""The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast
automatically for the forward pass.
Expand Down
Loading