From 7e499a09a8b50135b575230dda6a78c2dbe0fde1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Sep 2022 15:45:55 +0200 Subject: [PATCH 01/20] squash all commits --- src/lightning_lite/connector.py | 603 ++++++++++++++++++ src/pytorch_lightning/lite/lite.py | 98 +-- src/pytorch_lightning/lite/wrappers.py | 13 +- tests/tests_lite/test_connector.py | 720 ++++++++++++++++++++++ tests/tests_pytorch/lite/test_lite.py | 30 +- tests/tests_pytorch/lite/test_wrappers.py | 8 +- 6 files changed, 1366 insertions(+), 106 deletions(-) create mode 100644 src/lightning_lite/connector.py create mode 100644 tests/tests_lite/test_connector.py diff --git a/src/lightning_lite/connector.py b/src/lightning_lite/connector.py new file mode 100644 index 0000000000000..c1243d2b41926 --- /dev/null +++ b/src/lightning_lite/connector.py @@ -0,0 +1,603 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections import Counter +from typing import Dict, List, Optional, Union + +import torch + +from lightning_lite.accelerators import ACCELERATOR_REGISTRY +from lightning_lite.accelerators.accelerator import Accelerator +from lightning_lite.accelerators.cuda import CUDAAccelerator +from lightning_lite.accelerators.mps import MPSAccelerator +from lightning_lite.accelerators.tpu import TPUAccelerator +from lightning_lite.plugins import ( + CheckpointIO, + DeepSpeedPrecision, + NativeMixedPrecision, + Precision, + TPUBf16Precision, + TPUPrecision, +) +from lightning_lite.plugins.environments import ( + ClusterEnvironment, + KubeflowEnvironment, + LightningEnvironment, + LSFEnvironment, + SLURMEnvironment, + TorchElasticEnvironment, +) +from lightning_lite.plugins.precision.double import DoublePrecision +from lightning_lite.strategies import ( + DDPShardedStrategy, + DDPSpawnShardedStrategy, + DDPSpawnStrategy, + DDPStrategy, + DeepSpeedStrategy, + SingleDeviceStrategy, + SingleTPUStrategy, + Strategy, + STRATEGY_REGISTRY, + XLAStrategy, +) +from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES +from lightning_lite.utilities import _StrategyType, device_parser, rank_zero_deprecation, rank_zero_info, rank_zero_warn +from lightning_lite.utilities.imports import _HPU_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE, _TPU_AVAILABLE + +_PLUGIN = Union[Strategy, Precision, ClusterEnvironment, CheckpointIO] +_PLUGIN_INPUT = Union[_PLUGIN, str] + + +class _Connector: + """The Connector parses several Lite arguments and instantiates the Strategy including other components such as + the Accelerator and Precision plugins. + + A. accelerator flag could be: + 1. accelerator class + 2. accelerator str + 3. accelerator auto + + B. strategy flag could be : + 1. strategy class + 2. strategy str registered with STRATEGY_REGISTRY + 3. strategy str in _strategy_type enum which listed in each strategy as + backend (registed these too, and _strategy_type could be deprecated) + + C. plugins flag could be: + 1. List of str, which could contain: + i. precision str (Not supported in the old accelerator_connector version) + ii. checkpoint_io str (Not supported in the old accelerator_connector version) + iii. cluster_environment str (Not supported in the old accelerator_connector version) + 2. List of class, which could contains: + i. precision class (should be removed, and precision flag should allow user pass classes) + ii. checkpoint_io class + iii. cluster_environment class + + + priorities which to take when: + A. Class > str + B. Strategy > Accelerator/precision/plugins + """ + + def __init__( + self, + accelerator: Optional[Union[str, Accelerator]] = None, + strategy: Optional[Union[str, Strategy]] = None, + devices: Optional[Union[List[int], str, int]] = None, + num_nodes: int = 1, + precision: Union[int, str] = 32, + plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, + tpu_cores: Optional[Union[List[int], str, int]] = None, # deprecated + gpus: Optional[Union[List[int], str, int]] = None, # deprecated + ) -> None: + # 1. Parsing flags + # Get registered strategies, built-in accelerators and precision plugins + self._registered_strategies = STRATEGY_REGISTRY.available_strategies() + self._registered_accelerators = ACCELERATOR_REGISTRY.available_accelerators() + self._precision_types = ("16", "32", "64", "bf16", "mixed") + + # Raise an exception if there are conflicts between flags + # Set each valid flag to `self._x_flag` after validation + # For devices: Assign gpus, ipus, etc. to the accelerator flag and devices flag + self._strategy_flag: Optional[Union[Strategy, str]] = None + self._accelerator_flag: Optional[Union[Accelerator, str]] = None + self._precision_flag: Optional[Union[int, str]] = None + self._precision_plugin_flag: Optional[Precision] = None + self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None + self._parallel_devices: List[Union[int, torch.device, str]] = [] + self.checkpoint_io: Optional[CheckpointIO] = None + + self._check_config_and_set_final_flags( + strategy=strategy, + accelerator=accelerator, + precision=precision, + plugins=plugins, + ) + self._check_device_config_and_set_final_flags( + devices=devices, num_nodes=num_nodes, gpus=gpus, tpu_cores=tpu_cores + ) + + # 2. Instantiate Accelerator + # handle `auto`, `None` and `gpu` + if self._accelerator_flag == "auto" or self._accelerator_flag is None: + self._accelerator_flag = self._choose_auto_accelerator() + elif self._accelerator_flag == "gpu": + self._accelerator_flag = self._choose_gpu_accelerator_backend() + + self._set_parallel_devices_and_init_accelerator() + + # 3. Instantiate ClusterEnvironment + self.cluster_environment: ClusterEnvironment = self._choose_and_init_cluster_environment() + + # 4. Instantiate Strategy - Part 1 + if self._strategy_flag is None: + self._strategy_flag = self._choose_strategy() + # In specific cases, ignore user selection and fall back to a different strategy + self._check_strategy_and_fallback() + self._init_strategy() + + # 5. Instantiate Precision Plugin + self.precision_plugin = self._check_and_init_precision() + + # 6. Instantiate Strategy - Part 2 + self._lazy_init_strategy() + + def _check_config_and_set_final_flags( + self, + strategy: Optional[Union[str, Strategy]], + accelerator: Optional[Union[str, Accelerator]], + precision: Union[int, str], + plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]], + ) -> None: + """This method checks: + + 1. strategy: whether the strategy name is valid, and sets the internal flags if it is. + 2. accelerator: if the value of the accelerator argument is a type of accelerator (instance or string), + set self._accelerator_flag accordingly. + 3. precision: The final value of the precision flag may be determined either by the precision argument or + by a plugin instance. + 4. plugins: The list of plugins may contain a Precision plugin, CheckpointIO, ClusterEnvironment and others. + Additionally, other flags such as `precision` can populate the list with the + corresponding plugin instances. + """ + if plugins is not None: + plugins = [plugins] if not isinstance(plugins, list) else plugins + + if isinstance(strategy, str): + strategy = strategy.lower() + + if strategy is not None: + self._strategy_flag = strategy + + if strategy is not None and strategy not in self._registered_strategies and not isinstance(strategy, Strategy): + raise ValueError( + f"You selected an invalid strategy name: `strategy={strategy!r}`." + " Example choices: ddp, ddp_spawn, deepspeed, dp, ..." + " Find a complete list of options in our documentation at https://lightning.ai" + ) + + if ( + accelerator is not None + and accelerator not in self._registered_accelerators + and accelerator not in ("auto", "gpu") + and not isinstance(accelerator, Accelerator) + ): + raise ValueError( + f"You selected an invalid accelerator name: `accelerator={accelerator!r}`." + f" Available names are: {', '.join(self._registered_accelerators)}." + ) + + self._accelerator_flag = accelerator + + if precision is not None: + if str(precision) not in self._precision_types: + raise ValueError( + f"Precision {repr(precision)} is invalid. Allowed precision values: {self._precision_types}" + ) + self._precision_flag = precision + + if plugins: + plugins_flags_types: Dict[str, int] = Counter() + for plugin in plugins: + if isinstance(plugin, Precision): + self._precision_plugin_flag = plugin + plugins_flags_types[Precision.__name__] += 1 + elif isinstance(plugin, CheckpointIO): + self.checkpoint_io = plugin + plugins_flags_types[CheckpointIO.__name__] += 1 + elif isinstance(plugin, ClusterEnvironment): + self._cluster_environment_flag = plugin + plugins_flags_types[ClusterEnvironment.__name__] += 1 + else: + raise TypeError( + f"Found invalid type for plugin {plugin}. Expected one of: Precision, " + "CheckpointIO, ClusterEnviroment." + ) + + duplicated_plugin_key = [k for k, v in plugins_flags_types.items() if v > 1] + if duplicated_plugin_key: + raise ValueError( + f"Received multiple values for {', '.join(duplicated_plugin_key)} flags in `plugins`." + " Expected one value for each type at most." + ) + + # handle the case when the user passes in a strategy instance which has an accelerator, precision, + # checkpoint io or cluster env set up + # TODO: improve the error messages below + if self._strategy_flag and isinstance(self._strategy_flag, Strategy): + if self._strategy_flag._accelerator: + if self._accelerator_flag: + raise ValueError("accelerator set through both strategy class and accelerator flag, choose one") + else: + self._accelerator_flag = self._strategy_flag._accelerator + if self._strategy_flag._precision_plugin: + # [RFC] handle precision plugin set up conflict? + if self._precision_plugin_flag: + raise ValueError("precision set through both strategy class and plugins, choose one") + else: + self._precision_plugin_flag = self._strategy_flag._precision_plugin + if self._strategy_flag._checkpoint_io: + if self.checkpoint_io: + raise ValueError("checkpoint_io set through both strategy class and plugins, choose one") + else: + self.checkpoint_io = self._strategy_flag._checkpoint_io + if getattr(self._strategy_flag, "cluster_environment", None): + if self._cluster_environment_flag: + raise ValueError("cluster_environment set through both strategy class and plugins, choose one") + else: + self._cluster_environment_flag = getattr(self._strategy_flag, "cluster_environment") + + if hasattr(self._strategy_flag, "parallel_devices"): + if self._strategy_flag.parallel_devices: + if self._strategy_flag.parallel_devices[0].type == "cpu": + if self._accelerator_flag and self._accelerator_flag not in ("auto", "cpu"): + raise ValueError( + f"CPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," + f" but accelerator set to {self._accelerator_flag}, please choose one device type" + ) + self._accelerator_flag = "cpu" + if self._strategy_flag.parallel_devices[0].type == "cuda": + if self._accelerator_flag and self._accelerator_flag not in ("auto", "cuda", "gpu"): + raise ValueError( + f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," + f" but accelerator set to {self._accelerator_flag}, please choose one device type" + ) + self._accelerator_flag = "cuda" + self._parallel_devices = self._strategy_flag.parallel_devices + + def _check_device_config_and_set_final_flags( + self, + devices: Optional[Union[List[int], str, int]], + num_nodes: int, + gpus: Optional[Union[List[int], str, int]], + tpu_cores: Optional[Union[List[int], str, int]], + ) -> None: + self._num_nodes_flag = int(num_nodes) if num_nodes is not None else 1 + self._devices_flag = devices + + if self._devices_flag in ([], 0, "0"): + accelerator_name = ( + self._accelerator_flag.__class__.__qualname__ + if isinstance(self._accelerator_flag, Accelerator) + else self._accelerator_flag + ) + raise ValueError( + f"`Lite(devices={self._devices_flag!r})` value is not a valid input" + f" using {accelerator_name} accelerator." + ) + + # TODO: Delete this method when num_processes, gpus, ipus and tpu_cores gets removed + self._map_deprecated_devices_specific_info_to_accelerator_and_device_flag(devices, gpus, tpu_cores) + + if self._devices_flag == "auto" and self._accelerator_flag is None: + raise ValueError( + f"You passed `devices={devices}` but haven't specified" + " `accelerator=('auto'|'tpu'|'gpu'|'cpu'|'mps')` for the devices mapping." + ) + + def _map_deprecated_devices_specific_info_to_accelerator_and_device_flag( + self, + devices: Optional[Union[List[int], str, int]], + gpus: Optional[Union[List[int], str, int]], + tpu_cores: Optional[Union[List[int], str, int]], + ) -> None: + """Emit deprecation warnings for num_processes, gpus, ipus, tpu_cores and set the `devices_flag` and + `accelerator_flag`.""" + if gpus is not None: + rank_zero_deprecation( + f"Setting `Lite(gpus={gpus!r})` is deprecated in v1.7 and will be removed" + f" in v2.0. Please use `Lite(accelerator='gpu', devices={gpus!r})` instead." + ) + if tpu_cores is not None: + rank_zero_deprecation( + f"Setting `Lite(tpu_cores={tpu_cores!r})` is deprecated in v1.7 and will be removed" + f" in v2.0. Please use `Lite(accelerator='tpu', devices={tpu_cores!r})` instead." + ) + self._gpus: Optional[Union[List[int], str, int]] = gpus + self._tpu_cores: Optional[Union[List[int], str, int]] = tpu_cores + deprecated_devices_specific_flag = gpus or tpu_cores + if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in ([], 0, "0"): + if devices: + # TODO: improve error message + rank_zero_warn( + f"The flag `devices={devices}` will be ignored, " + f"instead the device specific number {deprecated_devices_specific_flag} will be used" + ) + + if [(gpus is not None), (tpu_cores is not None)].count(True) > 1: + # TODO: improve error message + rank_zero_warn("more than one device specific flag has been set") + self._devices_flag = deprecated_devices_specific_flag + + if self._accelerator_flag is None: + # set accelerator type based on num_processes, gpus, ipus, tpu_cores + if tpu_cores: + self._accelerator_flag = "tpu" + if gpus: + self._accelerator_flag = "cuda" + + def _choose_auto_accelerator(self) -> str: + """Choose the accelerator type (str) based on availability when ``accelerator='auto'``.""" + if self._accelerator_flag == "auto": + if _TPU_AVAILABLE: + return "tpu" + if _IPU_AVAILABLE: + return "ipu" + if _HPU_AVAILABLE: + return "hpu" + if MPSAccelerator.is_available(): + return "mps" + if CUDAAccelerator.is_available(): + return "cuda" + return "cpu" + + @staticmethod + def _choose_gpu_accelerator_backend() -> str: + if MPSAccelerator.is_available(): + return "mps" + if CUDAAccelerator.is_available(): + return "cuda" + + raise RuntimeError("No supported gpu backend found!") + + def _set_parallel_devices_and_init_accelerator(self) -> None: + if isinstance(self._accelerator_flag, Accelerator): + self.accelerator: Accelerator = self._accelerator_flag + else: + assert self._accelerator_flag is not None + self.accelerator = ACCELERATOR_REGISTRY.get(self._accelerator_flag) + + if not self.accelerator.is_available(): + available_accelerator = [ + acc_str for acc_str in self._registered_accelerators if ACCELERATOR_REGISTRY.get(acc_str).is_available() + ] + raise RuntimeError( + f"{self.accelerator.__class__.__qualname__} can not run on your system" + " since the accelerator is not available. The following accelerator(s)" + " is available and can be passed into `accelerator` argument of" + f" `Lite`: {available_accelerator}." + ) + + self._set_devices_flag_if_auto_passed() + + self._gpus = self._devices_flag if not self._gpus else self._gpus + self._tpu_cores = self._devices_flag if not self._tpu_cores else self._tpu_cores + + self._devices_flag = self.accelerator.parse_devices(self._devices_flag) + if not self._parallel_devices: + self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag) + + def _set_devices_flag_if_auto_passed(self) -> None: + if self._devices_flag == "auto" or self._devices_flag is None: + self._devices_flag = self.accelerator.auto_device_count() + + def _choose_and_init_cluster_environment(self) -> ClusterEnvironment: + if isinstance(self._cluster_environment_flag, ClusterEnvironment): + return self._cluster_environment_flag + if self._is_slurm_managing_tasks(): + rank_zero_info("Multiprocessing is handled by SLURM.") + return SLURMEnvironment() + for env_type in (TorchElasticEnvironment, KubeflowEnvironment, LSFEnvironment): + if env_type.detect(): + # Ignore type error because it is a false positive: https://github.com/python/mypy/issues/13044 + return env_type() # type: ignore[abstract] + return LightningEnvironment() + + def _is_slurm_managing_tasks(self) -> bool: + """used by choosing cluster enviroment.""" + # TODO(lite): Remove this, see: https://github.com/Lightning-AI/lightning/pull/14300 + if not SLURMEnvironment.detect() or SLURMEnvironment.job_name() == "bash": + return False + + total_requested_devices = len(self._parallel_devices) * self._num_nodes_flag + num_slurm_tasks = int(os.environ["SLURM_NTASKS"], 0) + return num_slurm_tasks == total_requested_devices + + def _choose_strategy(self) -> Union[Strategy, str]: + if self._accelerator_flag == "tpu": + if self._parallel_devices and len(self._parallel_devices) > 1: + return "tpu_spawn" + else: + # TODO: lazy initialized device, then here could be self._strategy_flag = "single_tpu_device" + return SingleTPUStrategy(device=self._parallel_devices[0]) # type: ignore + if self._num_nodes_flag > 1: + return "ddp" + if len(self._parallel_devices) <= 1: + # TODO: Change this once gpu accelerator was renamed to cuda accelerator + if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or ( + isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps") + ): + device = device_parser.determine_root_gpu_device(self._parallel_devices) + else: + device = "cpu" + # TODO: lazy initialized device, then here could be self._strategy_flag = "single_device" + return SingleDeviceStrategy(device=device) # type: ignore + if len(self._parallel_devices) > 1: + if _IS_INTERACTIVE: + return "ddp_fork" + return "ddp_spawn" + + return "ddp" + + def _check_strategy_and_fallback(self) -> None: + """Checks edge cases when the strategy selection was a string input, and we need to fall back to a + different choice depending on other parameters or the environment.""" + # current fallback and check logic only apply to user pass in str config and object config + # TODO this logic should apply to both str and object config + strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag + + if strategy_flag in ("ddp_spawn", "ddp_spawn_find_unused_parameters_false") and ( + TorchElasticEnvironment.detect() or KubeflowEnvironment.detect() or self._is_slurm_managing_tasks() + ): + strategy_flag = "ddp" + if strategy_flag == "dp" and self._accelerator_flag == "cpu": + rank_zero_warn(f"{strategy_flag!r} is not supported on CPUs, hence setting `strategy='ddp'`.") + strategy_flag = "ddp" + if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods(): + raise ValueError( + f"You selected `Lite(strategy='{strategy_flag}')` but process forking is not supported on this" + f" platform. We recommed `Lite(strategy='ddp_spawn')` instead." + ) + if strategy_flag: + self._strategy_flag = strategy_flag + + def _init_strategy(self) -> None: + """Instantiate the Strategy given depending on the setting of ``_strategy_flag``.""" + if isinstance(self._strategy_flag, str): + self.strategy = STRATEGY_REGISTRY.get(self._strategy_flag) + elif isinstance(self._strategy_flag, Strategy): + self.strategy = self._strategy_flag + else: + raise RuntimeError(f"{self.strategy} is not valid type: {self.strategy}") + + def _check_and_init_precision(self) -> Precision: + self._validate_precision_choice() + if isinstance(self._precision_plugin_flag, Precision): + return self._precision_plugin_flag + + if isinstance(self.accelerator, TPUAccelerator): + if self._precision_flag == 32: + return TPUPrecision() + elif self._precision_flag in (16, "bf16"): + if self._precision_flag == 16: + rank_zero_warn( + "You passed `Lite(accelerator='tpu', precision=16)` but AMP" + " is not supported with TPUs. Using `precision='bf16'` instead." + ) + return TPUBf16Precision() + if isinstance(self.strategy, DeepSpeedStrategy): + return DeepSpeedPrecision(self._precision_flag, amp_type="native", amp_level=None) # type: ignore + + if self._precision_flag == 32: + return Precision() + if self._precision_flag == 64: + return DoublePrecision() + + if self._precision_flag == 16 and self._accelerator_flag == "cpu": + rank_zero_warn( + "You passed `Lite(accelerator='cpu', precision=16)` but native AMP is not supported on CPU." + " Using `precision='bf16'` instead." + ) + self._precision_flag = "bf16" + + if self._precision_flag in (16, "bf16"): + rank_zero_info( + "Using 16-bit Automatic Mixed Precision (AMP)" + if self._precision_flag == 16 + else "Using bfloat16 Automatic Mixed Precision (AMP)" + ) + + device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + return NativeMixedPrecision(self._precision_flag, device) + + raise RuntimeError("No precision set") + + def _validate_precision_choice(self) -> None: + """Validate the combination of choices for precision, and accelerator.""" + if isinstance(self.accelerator, TPUAccelerator): + if self._precision_flag == 64: + raise NotImplementedError( + "`Lite(accelerator='tpu', precision=64)` is not implemented." + " Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`" + " requesting this feature." + ) + if self._precision_plugin_flag and not isinstance( + self._precision_plugin_flag, (TPUPrecision, TPUBf16Precision) + ): + raise ValueError( + f"The `TPUAccelerator` can only be used with a `TPUPrecision` plugin," + f" found: {self._precision_plugin_flag}." + ) + + def _lazy_init_strategy(self) -> None: + """Lazily set missing attributes on the previously instantiated strategy.""" + self.strategy.accelerator = self.accelerator + if self.precision_plugin: + self.strategy.precision_plugin = self.precision_plugin + if self.checkpoint_io: + self.strategy.checkpoint_io = self.checkpoint_io + if hasattr(self.strategy, "cluster_environment"): + self.strategy.cluster_environment = self.cluster_environment + if hasattr(self.strategy, "parallel_devices"): + if self.strategy.parallel_devices: + self._parallel_devices = self.strategy.parallel_devices + else: + self.strategy.parallel_devices = self._parallel_devices + if hasattr(self.strategy, "num_nodes"): + self.strategy._num_nodes = self._num_nodes_flag + if hasattr(self.strategy, "set_world_ranks"): + self.strategy.set_world_ranks() + self.strategy._configure_launcher() + + from lightning_lite.utilities import _IS_INTERACTIVE + + if _IS_INTERACTIVE and self.strategy.launcher and not self.strategy.launcher.is_interactive_compatible: + raise RuntimeError( + f"`Lite(strategy={self._strategy_flag!r})` is not compatible with an interactive" + " environment. Run your code as a script, or choose one of the compatible strategies:" + f" Lite(strategy=None|{'|'.join(_StrategyType.interactive_compatible_types())})." + " In case you are spawning processes yourself, make sure to include the Lite" + " creation inside the worker function." + ) + + # TODO: should be moved to _check_strategy_and_fallback(). + # Current test check precision first, so keep this check here to meet error order + if isinstance(self.accelerator, TPUAccelerator) and not isinstance( + self.strategy, (SingleTPUStrategy, XLAStrategy) + ): + raise ValueError( + "The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `XLAStrategy`," + f" found {self.strategy.__class__.__name__}." + ) + + @property + def is_distributed(self) -> bool: + # TODO: deprecate this property + # Used for custom plugins. + # Custom plugins should implement is_distributed property. + if hasattr(self.strategy, "is_distributed") and not isinstance(self.accelerator, TPUAccelerator): + return self.strategy.is_distributed + distributed_strategy = ( + DDPStrategy, + DDPSpawnShardedStrategy, + DDPShardedStrategy, + DDPSpawnStrategy, + DeepSpeedStrategy, + XLAStrategy, + ) + is_distributed = isinstance(self.strategy, distributed_strategy) + if isinstance(self.accelerator, TPUAccelerator): + is_distributed |= self.strategy.is_distributed + return is_distributed diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index 331495e04ce06..2d4a72b174e68 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -25,7 +25,12 @@ from torch.optim import Optimizer from torch.utils.data import BatchSampler, DataLoader, DistributedSampler -from lightning_lite.utilities import _AcceleratorType, _StrategyType, move_data_to_device +from lightning_lite.accelerators.accelerator import Accelerator +from lightning_lite.connector import _Connector, _PLUGIN_INPUT +from lightning_lite.plugins import Precision +from lightning_lite.strategies import DeepSpeedStrategy, Strategy, XLAStrategy +from lightning_lite.strategies.strategy import TBroadcast +from lightning_lite.utilities import move_data_to_device from lightning_lite.utilities.apply_func import convert_to_tensors from lightning_lite.utilities.data import ( _auto_add_worker_init_fn, @@ -33,15 +38,10 @@ _update_dataloader, has_iterable_dataset, ) +from lightning_lite.utilities.exceptions import MisconfigurationException from lightning_lite.utilities.seed import seed_everything -from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from pytorch_lightning.overrides.distributed import DistributedSamplerWrapper -from pytorch_lightning.plugins import PLUGIN_INPUT -from pytorch_lightning.strategies import DeepSpeedStrategy, Strategy, TPUSpawnStrategy -from pytorch_lightning.strategies.strategy import TBroadcast -from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector -from pytorch_lightning.utilities.exceptions import MisconfigurationException class LightningLite(ABC): @@ -76,34 +76,23 @@ def __init__( devices: Optional[Union[List[int], str, int]] = None, num_nodes: int = 1, precision: Union[int, str] = 32, - plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None, + plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, gpus: Optional[Union[List[int], str, int]] = None, tpu_cores: Optional[Union[List[int], str, int]] = None, ) -> None: - self._check_accelerator_support(accelerator) - self._check_strategy_support(strategy) - self._accelerator_connector = AcceleratorConnector( - num_processes=None, - devices=devices, - tpu_cores=tpu_cores, - ipus=None, + self._connector = _Connector( accelerator=accelerator, strategy=strategy, - gpus=gpus, + devices=devices, num_nodes=num_nodes, - sync_batchnorm=False, # TODO: add support? - benchmark=False, - replace_sampler_ddp=True, - deterministic=False, precision=precision, - amp_type="native", - amp_level=None, plugins=plugins, - auto_select_gpus=False, + tpu_cores=tpu_cores, + gpus=gpus, ) - self._strategy = self._accelerator_connector.strategy - self._accelerator = self._strategy.accelerator - self._precision_plugin = self._strategy.precision_plugin + self._strategy: Strategy = self._connector.strategy + self._accelerator: Accelerator = self._strategy.accelerator + self._precision_plugin: Precision = self._strategy.precision_plugin self._models_setup: int = 0 # wrap the run method so we can inject setup logic or spawn processes for the user @@ -173,7 +162,7 @@ def setup( model = self._move_model_to_device(model=model, optimizers=list(optimizers)) # Let accelerator/plugin wrap and connect the models and optimizers - model, optimizers = self._strategy._setup_model_and_optimizers(model, list(optimizers)) + model, optimizers = self._strategy.setup_module_and_optimizers(model, list(optimizers)) model = _LiteModule(model, self._precision_plugin, original_module=original_model) optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] self._models_setup += 1 @@ -234,7 +223,7 @@ def _setup_dataloader( _auto_add_worker_init_fn(dataloader, self.global_rank) dataloader = self._strategy.process_dataloader(dataloader) - device = self.device if move_to_device and not isinstance(self._strategy, TPUSpawnStrategy) else None + device = self.device if move_to_device and not isinstance(self._strategy, XLAStrategy) else None lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=device) lite_dataloader = cast(DataLoader, lite_dataloader) return lite_dataloader @@ -269,7 +258,7 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = No # requires to attach the current `DeepSpeedEngine` for the `_LiteOptimizer.step` call. self._strategy.model = module - self._precision_plugin._run_backward(tensor, module, *args, **kwargs) + self._precision_plugin.backward(tensor, module, *args, **kwargs) @contextmanager def autocast(self) -> Generator[None, None, None]: @@ -305,10 +294,7 @@ def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tens A reference to the object that was moved to the new device. """ if isinstance(obj, nn.Module): - if self.device.type == "cuda": - # need to call this manually here again in case we spawned with DDPSpawnStrategy - # TODO: refactor to let accelerator handle this cleanly (see Accelerator.setup_device) - torch.cuda.set_device(self.device) + self._accelerator.setup_device(self.device) return obj.to(self.device) return move_data_to_device(obj, device=self.device) @@ -404,13 +390,13 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: self._strategy.setup_environment() - with self._strategy.model_sharded_context(), _replace_dunder_methods( + with self._strategy.module_sharded_context(), _replace_dunder_methods( DataLoader, "dataset" ), _replace_dunder_methods(BatchSampler): return run_method(*args, **kwargs) def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: - if isinstance(self._strategy, TPUSpawnStrategy): + if isinstance(self._strategy, XLAStrategy): # When the user creates the optimizer, they reference the parameters on the CPU. # However, when running with TPU the parameters get copied and the reference in the optimizer # remains invalid. We need to update the references to point to the parameter tensors on the device. @@ -429,7 +415,7 @@ def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) - def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool: return ( - self._accelerator_connector.is_distributed + self._connector.is_distributed and not isinstance(dataloader.sampler, DistributedSampler) and not has_iterable_dataset(dataloader) ) @@ -437,47 +423,9 @@ def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool: @staticmethod def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> DistributedSampler: kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0))) + # TODO(lite): Bring the DistributedSamplerWrapper to Lite package return DistributedSamplerWrapper(dataloader.sampler, **kwargs) - def _check_accelerator_support(self, accelerator: Optional[Union[str, Accelerator]]) -> None: - supported = [t.value.lower() for t in self._supported_device_types()] + ["gpu", "auto"] - valid = accelerator is None or isinstance(accelerator, Accelerator) or accelerator in supported - if not valid: - raise MisconfigurationException( - f"`accelerator={repr(accelerator)}` is not a valid choice." - f" Choose one of {supported} or pass in a `Accelerator` instance." - ) - - def _check_strategy_support(self, strategy: Optional[Union[str, Strategy]]) -> None: - supported = [t.lower() for t in self._supported_strategy_types()] - valid = strategy is None or isinstance(strategy, Strategy) or strategy in supported - if not valid: - raise MisconfigurationException( - f"`strategy={repr(strategy)}` is not a valid choice." - f" Choose one of {supported} or pass in a `Strategy` instance." - ) - - @staticmethod - def _supported_device_types() -> Sequence[_AcceleratorType]: - return ( - _AcceleratorType.CPU, - _AcceleratorType.CUDA, - _AcceleratorType.TPU, - _AcceleratorType.MPS, - ) - - @staticmethod - def _supported_strategy_types() -> Sequence[_StrategyType]: - return ( - _StrategyType.DP, - _StrategyType.DDP, - _StrategyType.DDP_SPAWN, - _StrategyType.DDP_FORK, - _StrategyType.DEEPSPEED, - _StrategyType.DDP_SHARDED, - _StrategyType.DDP_SHARDED_SPAWN, - ) - @staticmethod def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None: if isinstance(model, _LiteModule): diff --git a/src/pytorch_lightning/lite/wrappers.py b/src/pytorch_lightning/lite/wrappers.py index 29a0c17341666..3d9c3f4f7368b 100644 --- a/src/pytorch_lightning/lite/wrappers.py +++ b/src/pytorch_lightning/lite/wrappers.py @@ -21,10 +21,10 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader +from lightning_lite.plugins import Precision +from lightning_lite.strategies import Strategy from lightning_lite.utilities.apply_func import move_data_to_device from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin -from pytorch_lightning.plugins import PrecisionPlugin -from pytorch_lightning.strategies import Strategy T_destination = TypeVar("T_destination", bound=Dict[str, Any]) @@ -56,21 +56,20 @@ def optimizer(self) -> Optimizer: return self._optimizer def state_dict(self) -> Dict[str, Tensor]: - return self._strategy.optimizer_state(self.optimizer) + return self._strategy.get_optimizer_state(self.optimizer) - def step(self, closure: Optional[Callable] = None) -> Any: + def step(self, closure: Optional[Callable] = None, module: Optional["_LiteModule"] = None) -> Any: closure = closure or _do_nothing_closure return self._strategy.optimizer_step( self.optimizer, - opt_idx=0, + model=(module if module is not None else getattr(self._strategy, "model", None)), closure=closure, - model=self._strategy.model, ) class _LiteModule(_DeviceDtypeModuleMixin): def __init__( - self, forward_module: nn.Module, precision_plugin: PrecisionPlugin, original_module: Optional[nn.Module] = None + self, forward_module: nn.Module, precision_plugin: Precision, original_module: Optional[nn.Module] = None ) -> None: """The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast automatically for the forward pass. diff --git a/tests/tests_lite/test_connector.py b/tests/tests_lite/test_connector.py new file mode 100644 index 0000000000000..8f9f9984ef53b --- /dev/null +++ b/tests/tests_lite/test_connector.py @@ -0,0 +1,720 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import os +from typing import Any, Dict +from unittest import mock + +import pytest +import torch +import torch.distributed +from tests_lite.helpers.runif import RunIf + +import lightning_lite +from lightning_lite.accelerators.accelerator import Accelerator +from lightning_lite.accelerators.cpu import CPUAccelerator +from lightning_lite.accelerators.cuda import CUDAAccelerator +from lightning_lite.accelerators.mps import MPSAccelerator +from lightning_lite.connector import _Connector +from lightning_lite.plugins import DoublePrecision, Precision +from lightning_lite.plugins.environments import ( + KubeflowEnvironment, + LightningEnvironment, + SLURMEnvironment, + TorchElasticEnvironment, +) +from lightning_lite.plugins.io import TorchCheckpointIO +from lightning_lite.strategies import ( + DataParallelStrategy, + DDPShardedStrategy, + DDPSpawnShardedStrategy, + DDPSpawnStrategy, + DDPStrategy, + DeepSpeedStrategy, + SingleDeviceStrategy, +) +from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES +from lightning_lite.utilities.exceptions import MisconfigurationException + + +def test_accelerator_choice_cpu(tmpdir): + connector = _Connector() + assert isinstance(connector.accelerator, CPUAccelerator) + assert isinstance(connector.strategy, SingleDeviceStrategy) + + +@RunIf(skip_windows=True, standalone=True) +def test_strategy_choice_ddp_on_cpu(tmpdir): + """Test that selecting DDPStrategy on CPU works.""" + _test_strategy_choice_ddp_and_cpu(ddp_strategy_class=DDPStrategy) + + +@RunIf(skip_windows=True) +def test_strategy_choice_ddp_spawn_on_cpu(tmpdir): + """Test that selecting DDPSpawnStrategy on CPU works.""" + _test_strategy_choice_ddp_and_cpu(ddp_strategy_class=DDPSpawnStrategy) + + +def _test_strategy_choice_ddp_and_cpu(ddp_strategy_class): + connector = _Connector( + strategy=ddp_strategy_class(find_unused_parameters=True), + accelerator="cpu", + devices=2, + ) + assert isinstance(connector.strategy, ddp_strategy_class) + assert isinstance(connector.accelerator, CPUAccelerator) + assert connector.strategy.num_processes == 2 + assert connector.strategy.parallel_devices == [torch.device("cpu")] * 2 + + +@mock.patch.dict( + os.environ, + { + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_PROCID": "0", + "SLURM_LOCALID": "0", + }, +) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=0) +def test_custom_cluster_environment_in_slurm_environment(_): + """Test that we choose the custom cluster even when SLURM or TE flags are around.""" + + class CustomCluster(LightningEnvironment): + @property + def main_address(self): + return "asdf" + + @property + def creates_processes_externally(self) -> bool: + return True + + connector = _Connector( + plugins=[CustomCluster()], + accelerator="cpu", + strategy="ddp", + devices=2, + ) + assert isinstance(connector.accelerator, CPUAccelerator) + assert isinstance(connector.strategy, DDPStrategy) + assert isinstance(connector.strategy.cluster_environment, CustomCluster) + + +@mock.patch.dict( + os.environ, + { + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_PROCID": "0", + "SLURM_LOCALID": "0", + }, +) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=0) +def test_custom_accelerator(*_): + class Accel(Accelerator): + def setup_device(self, device: torch.device) -> None: + pass + + def get_device_stats(self, device: torch.device) -> Dict[str, Any]: + pass + + def teardown(self) -> None: + pass + + @staticmethod + def parse_devices(devices): + return devices + + @staticmethod + def get_parallel_devices(devices): + return [torch.device("cpu")] * devices + + @staticmethod + def auto_device_count() -> int: + return 1 + + @staticmethod + def is_available() -> bool: + return True + + @staticmethod + def name() -> str: + return "custom_acc_name" + + class Prec(Precision): + pass + + class Strat(SingleDeviceStrategy): + pass + + strategy = Strat(device=torch.device("cpu"), accelerator=Accel(), precision_plugin=Prec()) + connector = _Connector(strategy=strategy, devices=2) + assert isinstance(connector.accelerator, Accel) + assert isinstance(connector.strategy, Strat) + assert isinstance(connector.precision_plugin, Prec) + assert connector.strategy is strategy + + class Strat(DDPStrategy): + pass + + strategy = Strat(accelerator=Accel(), precision_plugin=Prec()) + connector = _Connector(strategy=strategy, devices=2) + assert isinstance(connector.accelerator, Accel) + assert isinstance(connector.strategy, Strat) + assert isinstance(connector.precision_plugin, Prec) + assert connector.strategy is strategy + + +@mock.patch.dict( + os.environ, + { + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_PROCID": "0", + "SLURM_LOCALID": "0", + }, +) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=0) +def test_dist_backend_accelerator_mapping(*_): + connector = _Connector(strategy="ddp_spawn", accelerator="cpu", devices=2) + assert isinstance(connector.accelerator, CPUAccelerator) + assert isinstance(connector.strategy, DDPStrategy) + assert connector.strategy.local_rank == 0 + + +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser._get_all_available_mps_gpus", return_value=[0, 1]) +def test_ipython_incompatible_backend_error(_, __, monkeypatch): + monkeypatch.setattr(lightning_lite.utilities, "_IS_INTERACTIVE", True) + with pytest.raises(RuntimeError, match=r"strategy='ddp'\)`.*is not compatible"): + _Connector(strategy="ddp", accelerator="gpu", devices=2) + + with pytest.raises(RuntimeError, match=r"strategy='ddp_spawn'\)`.*is not compatible"): + _Connector(strategy="ddp_spawn", accelerator="gpu", devices=2) + + with pytest.raises(RuntimeError, match=r"strategy='ddp_sharded_spawn'\)`.*is not compatible"): + _Connector(strategy="ddp_sharded_spawn", accelerator="gpu", devices=2) + + with pytest.raises(RuntimeError, match=r"strategy='ddp'\)`.*is not compatible"): + # Edge case: _Connector maps dp to ddp if accelerator != gpu + _Connector(strategy="dp") + + +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) +def test_ipython_compatible_dp_strategy_gpu(_, monkeypatch): + monkeypatch.setattr(lightning_lite.utilities, "_IS_INTERACTIVE", True) + connector = _Connector(strategy="dp", accelerator="gpu") + assert connector.strategy.launcher is None + + +@RunIf(skip_windows=True) +@mock.patch("lightning_lite.accelerators.tpu.TPUAccelerator.is_available", return_value=True) +def test_ipython_compatible_strategy_tpu(_, monkeypatch): + monkeypatch.setattr(lightning_lite.utilities, "_IS_INTERACTIVE", True) + connector = _Connector(accelerator="tpu") + assert connector.strategy.launcher.is_interactive_compatible + + +@RunIf(skip_windows=True) +def test_ipython_compatible_strategy_ddp_fork(monkeypatch): + monkeypatch.setattr(lightning_lite.utilities, "_IS_INTERACTIVE", True) + connector = _Connector(strategy="ddp_fork", accelerator="cpu") + assert connector.strategy.launcher.is_interactive_compatible + + +@pytest.mark.parametrize( + ["strategy", "strategy_class"], + [ + ("ddp", DDPStrategy), + ("ddp_spawn", DDPSpawnStrategy), + ("ddp_sharded", DDPShardedStrategy), + ("ddp_sharded_spawn", DDPSpawnShardedStrategy), + pytest.param("deepspeed", DeepSpeedStrategy, marks=RunIf(deepspeed=True)), + ], +) +@pytest.mark.parametrize("devices", [1, 2]) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser._get_all_available_mps_gpus", return_value=[0, 1]) +def test_accelerator_choice_multi_node_gpu(_, __, ___, strategy, strategy_class, devices): + connector = _Connector(num_nodes=2, accelerator="gpu", strategy=strategy, devices=devices) + assert isinstance(connector.strategy, strategy_class) + + +@mock.patch("lightning_lite.accelerators.cuda.device_parser.num_cuda_devices", return_value=0) +def test_accelerator_cpu(*_): + connector = _Connector(accelerator="cpu") + assert isinstance(connector.accelerator, CPUAccelerator) + + with pytest.raises( + RuntimeError, + match="CUDAAccelerator can not run on your system since the accelerator is not available", + ): + with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed"): + _Connector(gpus=1) + + with pytest.raises( + RuntimeError, + match="CUDAAccelerator can not run on your system since the accelerator is not available.", + ): + _Connector(accelerator="cuda") + + with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed"): + _Connector(accelerator="cpu", gpus=1) + + +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) +@pytest.mark.parametrize("device_count", (["0"], [0, "1"], ["GPU"], [["0", "1"], [0, 1]], [False])) +def test_accelererator_invalid_type_devices(_, __, device_count): + with pytest.raises( + MisconfigurationException, match=r"must be an int, a string, a sequence of ints or None, but you" + ): + _ = _Connector(accelerator="gpu", devices=device_count) + + +@RunIf(min_cuda_gpus=1) +def test_accelerator_gpu(): + connector = _Connector(accelerator="gpu", devices=1) + assert isinstance(connector.accelerator, CUDAAccelerator) + + connector = _Connector(accelerator="gpu") + assert isinstance(connector.accelerator, CUDAAccelerator) + + connector = _Connector(accelerator="auto", devices=1) + assert isinstance(connector.accelerator, CUDAAccelerator) + + +@pytest.mark.parametrize(["devices", "strategy_class"], [(1, SingleDeviceStrategy), (5, DDPSpawnStrategy)]) +def test_accelerator_cpu_with_devices(devices, strategy_class): + connector = _Connector(accelerator="cpu", devices=devices) + assert connector._parallel_devices == [torch.device("cpu")] * devices + assert isinstance(connector.strategy, strategy_class) + assert isinstance(connector.accelerator, CPUAccelerator) + + +@RunIf(min_cuda_gpus=2) +@pytest.mark.parametrize( + ["devices", "strategy_class"], [(1, SingleDeviceStrategy), ([1], SingleDeviceStrategy), (2, DDPSpawnStrategy)] +) +def test_accelerator_gpu_with_devices(devices, strategy_class): + connector = _Connector(accelerator="gpu", devices=devices) + assert len(connector._parallel_devices) == len(devices) if isinstance(devices, list) else devices + assert isinstance(connector.strategy, strategy_class) + assert isinstance(connector.accelerator, CUDAAccelerator) + + +@RunIf(min_cuda_gpus=1) +def test_accelerator_auto_with_devices_gpu(): + connector = _Connector(accelerator="auto", devices=1) + assert isinstance(connector.accelerator, CUDAAccelerator) + assert connector._parallel_devices == [torch.device("cuda", 0)] + + +def test_set_devices_if_none_cpu(): + connector = _Connector(accelerator="cpu", devices=3) + assert connector._parallel_devices == [torch.device("cpu")] * 3 + + +def test_unsupported_strategy_types_on_cpu_and_fallback(): + with pytest.warns(UserWarning, match="is not supported on CPUs, hence setting `strategy='ddp"): + connector = _Connector(strategy="dp", devices=2) + assert isinstance(connector.strategy, DDPStrategy) + + +def test_invalid_accelerator_choice(): + with pytest.raises(ValueError, match="You selected an invalid accelerator name: `accelerator='cocofruit'`"): + _Connector(accelerator="cocofruit") + + +def test_invalid_strategy_choice(): + with pytest.raises(ValueError, match="You selected an invalid strategy name: `strategy='cocofruit'`"): + _Connector(strategy="cocofruit") + + +@pytest.mark.parametrize( + ["strategy", "strategy_class"], + [ + ("ddp_spawn", DDPSpawnStrategy), + ("ddp_spawn_find_unused_parameters_false", DDPSpawnStrategy), + ("ddp", DDPStrategy), + ("ddp_find_unused_parameters_false", DDPStrategy), + ], +) +def test_strategy_choice_cpu_str(strategy, strategy_class): + connector = _Connector(strategy=strategy, accelerator="cpu", devices=2) + assert isinstance(connector.strategy, strategy_class) + + +@pytest.mark.parametrize("strategy_class", [DDPSpawnStrategy, DDPStrategy]) +def test_strategy_choice_cpu_instance(strategy_class): + connector = _Connector(strategy=strategy_class(), accelerator="cpu", devices=2) + assert isinstance(connector.strategy, strategy_class) + + +@RunIf(min_cuda_gpus=2) +@pytest.mark.parametrize( + ["strategy", "strategy_class"], + [ + ("ddp_spawn", DDPSpawnStrategy), + ("ddp_spawn_find_unused_parameters_false", DDPSpawnStrategy), + ("ddp", DDPStrategy), + ("ddp_find_unused_parameters_false", DDPStrategy), + ("dp", DataParallelStrategy), + ("ddp_sharded", DDPShardedStrategy), + ("ddp_sharded_spawn", DDPSpawnShardedStrategy), + pytest.param("deepspeed", DeepSpeedStrategy, marks=RunIf(deepspeed=True)), + ], +) +def test_strategy_choice_gpu_str(strategy, strategy_class): + connector = _Connector(strategy=strategy, accelerator="gpu", devices=2) + assert isinstance(connector.strategy, strategy_class) + + +@RunIf(min_cuda_gpus=2) +@pytest.mark.parametrize("strategy_class", [DDPSpawnStrategy, DDPStrategy]) +def test_strategy_choice_gpu_instance(strategy_class): + connector = _Connector(strategy=strategy_class(), accelerator="gpu", devices=2) + assert isinstance(connector.strategy, strategy_class) + + +@RunIf(min_cuda_gpus=2) +@pytest.mark.parametrize("strategy_class", [DDPSpawnStrategy, DDPStrategy]) +def test_device_type_when_strategy_instance_gpu_passed(strategy_class): + connector = _Connector(strategy=strategy_class(), accelerator="gpu", devices=2) + assert isinstance(connector.strategy, strategy_class) + assert isinstance(connector.accelerator, CUDAAccelerator) + + +@pytest.mark.parametrize("precision", [1, 12, "invalid"]) +def test_validate_precision_type(precision): + with pytest.raises(ValueError, match=f"Precision {repr(precision)} is invalid"): + _Connector(precision=precision) + + +def test_strategy_choice_ddp_spawn_cpu(): + connector = _Connector(strategy="ddp_spawn", accelerator="cpu", devices=2) + assert isinstance(connector.accelerator, CPUAccelerator) + assert isinstance(connector.strategy, DDPSpawnStrategy) + assert isinstance(connector.strategy.cluster_environment, LightningEnvironment) + assert connector.strategy.launcher._start_method == "spawn" + + +@RunIf(skip_windows=True) +@mock.patch("lightning_lite.connector._IS_INTERACTIVE", True) +def test_strategy_choice_ddp_fork_in_interactive(): + """Test that when accelerator and strategy are unspecified, the connector chooses DDP Fork in interactive + environments by default.""" + connector = _Connector(devices=2) + assert isinstance(connector.accelerator, CPUAccelerator) + assert isinstance(connector.strategy, DDPSpawnStrategy) + assert isinstance(connector.strategy.cluster_environment, LightningEnvironment) + assert connector.strategy.launcher._start_method == "fork" + + +@RunIf(skip_windows=True) +def test_strategy_choice_ddp_fork_cpu(): + connector = _Connector(strategy="ddp_fork", accelerator="cpu", devices=2) + assert isinstance(connector.accelerator, CPUAccelerator) + assert isinstance(connector.strategy, DDPSpawnStrategy) + assert isinstance(connector.strategy.cluster_environment, LightningEnvironment) + assert connector.strategy.launcher._start_method == "fork" + + +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.accelerators.mps.MPSAccelerator.is_available", return_value=False) +def test_strategy_choice_ddp(*_): + connector = _Connector(strategy="ddp", accelerator="gpu", devices=1) + assert isinstance(connector.accelerator, CUDAAccelerator) + assert isinstance(connector.strategy, DDPStrategy) + assert isinstance(connector.strategy.cluster_environment, LightningEnvironment) + + +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.accelerators.mps.MPSAccelerator.is_available", return_value=False) +def test_strategy_choice_ddp_spawn(*_): + connector = _Connector(strategy="ddp_spawn", accelerator="gpu", devices=1) + assert isinstance(connector.accelerator, CUDAAccelerator) + assert isinstance(connector.strategy, DDPSpawnStrategy) + assert isinstance(connector.strategy.cluster_environment, LightningEnvironment) + + +@RunIf(min_cuda_gpus=2) +@mock.patch.dict( + os.environ, + { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "SLURM_PROCID": "1", + "SLURM_LOCALID": "1", + }, +) +@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True) +@pytest.mark.parametrize("strategy", ["ddp", DDPStrategy()]) +def test_strategy_choice_ddp_slurm(_, strategy): + connector = _Connector(strategy=strategy, accelerator="gpu", devices=2) + assert connector._is_slurm_managing_tasks() + assert isinstance(connector.accelerator, CUDAAccelerator) + assert isinstance(connector.strategy, DDPStrategy) + assert isinstance(connector.strategy.cluster_environment, SLURMEnvironment) + assert connector.strategy.cluster_environment.local_rank() == 1 + assert connector.strategy.local_rank == 1 + + +@mock.patch.dict( + os.environ, + { + "CUDA_VISIBLE_DEVICES": "0,1", + "WORLD_SIZE": "2", + "LOCAL_WORLD_SIZE": "2", + "RANK": "1", + "LOCAL_RANK": "1", + "GROUP_RANK": "0", + "TORCHELASTIC_RUN_ID": "1", + }, +) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.accelerators.mps.MPSAccelerator.is_available", return_value=False) +def test_strategy_choice_ddp_te(*_): + connector = _Connector(strategy="ddp", accelerator="gpu", devices=2) + assert isinstance(connector.accelerator, CUDAAccelerator) + assert isinstance(connector.strategy, DDPStrategy) + assert isinstance(connector.strategy.cluster_environment, TorchElasticEnvironment) + assert connector.strategy.cluster_environment.local_rank() == 1 + assert connector.strategy.local_rank == 1 + + +@mock.patch.dict( + os.environ, + { + "WORLD_SIZE": "2", + "LOCAL_WORLD_SIZE": "2", + "RANK": "1", + "LOCAL_RANK": "1", + "GROUP_RANK": "0", + "TORCHELASTIC_RUN_ID": "1", + }, +) +def test_strategy_choice_ddp_cpu_te(): + connector = _Connector(strategy="ddp_spawn", accelerator="cpu", devices=2) + assert isinstance(connector.accelerator, CPUAccelerator) + assert isinstance(connector.strategy, DDPStrategy) + assert isinstance(connector.strategy.cluster_environment, TorchElasticEnvironment) + assert connector.strategy.cluster_environment.local_rank() == 1 + assert connector.strategy.local_rank == 1 + + +@mock.patch.dict( + os.environ, + { + "CUDA_VISIBLE_DEVICES": "0", + "KUBERNETES_PORT": "tcp://127.0.0.1:443", + "MASTER_ADDR": "1.2.3.4", + "MASTER_PORT": "500", + "WORLD_SIZE": "20", + "RANK": "1", + }, +) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=1) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.accelerators.mps.MPSAccelerator.is_available", return_value=False) +def test_strategy_choice_ddp_kubeflow(*_): + connector = _Connector(strategy="ddp", accelerator="gpu", devices=1) + assert isinstance(connector.accelerator, CUDAAccelerator) + assert isinstance(connector.strategy, DDPStrategy) + assert isinstance(connector.strategy.cluster_environment, KubeflowEnvironment) + assert connector.strategy.cluster_environment.local_rank() == 0 + assert connector.strategy.local_rank == 0 + + +@mock.patch.dict( + os.environ, + { + "KUBERNETES_PORT": "tcp://127.0.0.1:443", + "MASTER_ADDR": "1.2.3.4", + "MASTER_PORT": "500", + "WORLD_SIZE": "20", + "RANK": "1", + }, +) +def test_strategy_choice_ddp_cpu_kubeflow(): + connector = _Connector(strategy="ddp_spawn", accelerator="cpu", devices=2) + assert isinstance(connector.accelerator, CPUAccelerator) + assert isinstance(connector.strategy, DDPStrategy) + assert isinstance(connector.strategy.cluster_environment, KubeflowEnvironment) + assert connector.strategy.cluster_environment.local_rank() == 0 + assert connector.strategy.local_rank == 0 + + +@mock.patch.dict( + os.environ, + { + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_PROCID": "0", + "SLURM_LOCALID": "0", + }, +) +@pytest.mark.parametrize("strategy", ["ddp", DDPStrategy()]) +def test_strategy_choice_ddp_cpu_slurm(strategy): + connector = _Connector(strategy=strategy, accelerator="cpu", devices=2) + assert isinstance(connector.accelerator, CPUAccelerator) + assert isinstance(connector.strategy, DDPStrategy) + assert isinstance(connector.strategy.cluster_environment, SLURMEnvironment) + assert connector.strategy.local_rank == 0 + + +@mock.patch("lightning_lite.accelerators.tpu.TPUAccelerator.is_available", return_value=True) +@mock.patch.dict(os.environ, {}, clear=True) +def test_unsupported_tpu_choice(*_): + + with pytest.raises(NotImplementedError, match=r"accelerator='tpu', precision=64\)` is not implemented"): + _Connector(accelerator="tpu", precision=64) + + # if user didn't set strategy, _Connector will choose the TPUSingleStrategy or TPUSpawnStrategy + with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"): + with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but native AMP is not supported"): + _Connector(accelerator="tpu", precision=16, strategy="ddp") + + +@mock.patch("lightning_lite.accelerators.cuda.CUDAAccelerator.is_available", return_value=False) +@mock.patch("lightning_lite.accelerators.tpu.TPUAccelerator.is_available", return_value=False) +@mock.patch("lightning_lite.accelerators.mps.MPSAccelerator.is_available", return_value=False) +def test_devices_auto_choice_cpu(*_): + connector = _Connector(accelerator="auto", devices="auto") + assert isinstance(connector.accelerator, CPUAccelerator) + assert isinstance(connector.strategy, SingleDeviceStrategy) + assert connector.strategy.root_device == torch.device("cpu") + + +@RunIf(mps=False) +@mock.patch("lightning_lite.utilities.device_parser.is_cuda_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=2) +def test_devices_auto_choice_gpu(*_): + connector = _Connector(accelerator="auto", devices="auto") + assert isinstance(connector.accelerator, CUDAAccelerator) + assert isinstance(connector.strategy, DDPSpawnStrategy) + assert len(connector._parallel_devices) == 2 + + +@RunIf(mps=True) +def test_devices_auto_choice_mps(): + connector = _Connector(accelerator="auto", devices="auto") + assert isinstance(connector.accelerator, MPSAccelerator) + assert isinstance(connector.strategy, SingleDeviceStrategy) + assert connector.strategy.root_device == torch.device("mps", 0) + assert connector._parallel_devices == [torch.device("mps", 0)] + + +@pytest.mark.parametrize( + ["parallel_devices", "accelerator"], + [([torch.device("cpu")], "cuda"), ([torch.device("cuda", i) for i in range(8)], "tpu")], +) +def test_parallel_devices_in_strategy_confilict_with_accelerator(parallel_devices, accelerator): + with pytest.raises(ValueError, match=r"parallel_devices set through"): + _Connector(strategy=DDPStrategy(parallel_devices=parallel_devices), accelerator=accelerator) + + +@pytest.mark.parametrize( + ["plugins", "expected"], + [ + ([LightningEnvironment(), SLURMEnvironment()], "ClusterEnvironment"), + ([TorchCheckpointIO(), TorchCheckpointIO()], "CheckpointIO"), + ( + [Precision(), DoublePrecision(), LightningEnvironment(), SLURMEnvironment()], + "Precision, ClusterEnvironment", + ), + ], +) +def test_plugin_only_one_instance_for_one_type(plugins, expected): + with pytest.raises(ValueError, match=f"Received multiple values for {expected}"): + _Connector(plugins=plugins) + + +@pytest.mark.parametrize("accelerator", ("cpu", "cuda", "mps", "tpu")) +@pytest.mark.parametrize("devices", ("0", 0, [])) +def test_passing_zero_and_empty_list_to_devices_flag(accelerator, devices): + with pytest.raises(ValueError, match="value is not a valid input using"): + _Connector(accelerator=accelerator, devices=devices) + + +@pytest.mark.parametrize( + "expected_accelerator_flag,expected_accelerator_class", + [ + pytest.param("cuda", CUDAAccelerator, marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps", MPSAccelerator, marks=RunIf(mps=True)), + ], +) +def test_gpu_accelerator_backend_choice(expected_accelerator_flag, expected_accelerator_class): + connector = _Connector(accelerator="gpu") + assert connector._accelerator_flag == expected_accelerator_flag + assert isinstance(connector.accelerator, expected_accelerator_class) + + +@mock.patch("lightning_lite.accelerators.mps.MPSAccelerator.is_available", return_value=False) +@mock.patch("lightning_lite.utilities.device_parser.num_cuda_devices", return_value=1) +def test_gpu_accelerator_backend_choice_cuda(*_): + connector = _Connector(accelerator="gpu") + assert connector._accelerator_flag == "cuda" + assert isinstance(connector.accelerator, CUDAAccelerator) + + +@mock.patch("lightning_lite.accelerators.mps.MPSAccelerator.is_available", return_value=True) +@mock.patch("lightning_lite.utilities.device_parser._get_all_available_mps_gpus", return_value=[0]) +@mock.patch("torch.device", return_value="mps") # necessary because torch doesn't allow creation of mps devices +def test_gpu_accelerator_backend_choice_mps(*_): + connector = _Connector(accelerator="gpu") + assert connector._accelerator_flag == "mps" + assert isinstance(connector.accelerator, MPSAccelerator) + + +@mock.patch("lightning_lite.accelerators.mps.MPSAccelerator.is_available", return_value=False) +@mock.patch("lightning_lite.accelerators.cuda.CUDAAccelerator.is_available", return_value=False) +def test_gpu_accelerator_no_gpu_backend_found_error(*_): + with pytest.raises(RuntimeError, match="No supported gpu backend found!"): + _Connector(accelerator="gpu") + + +@pytest.mark.parametrize("strategy", _DDP_FORK_ALIASES) +@mock.patch( + "lightning_lite.connector.torch.multiprocessing.get_all_start_methods", + return_value=[], +) +def test_ddp_fork_on_unsupported_platform(_, strategy): + with pytest.raises(ValueError, match="process forking is not supported on this platform"): + _Connector(strategy=strategy) + + +@RunIf(skip_windows=True) +@pytest.mark.parametrize("strategy", _DDP_FORK_ALIASES) +@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True) +def test_strategy_choice_ddp_spawn_in_interactive_when_fork_disabled(strategy): + """Test there is an error when forking is disabled via the environment variable and the user requests fork.""" + with pytest.raises(ValueError, match="Forking is disabled in this environment"): + _Connector(devices=2, strategy=strategy) diff --git a/tests/tests_pytorch/lite/test_lite.py b/tests/tests_pytorch/lite/test_lite.py index e7b5c61a67727..c3c27d55aa9b3 100644 --- a/tests/tests_pytorch/lite/test_lite.py +++ b/tests/tests_pytorch/lite/test_lite.py @@ -23,13 +23,13 @@ from torch import nn from torch.utils.data import DataLoader, DistributedSampler, Sampler +from lightning_lite.plugins import Precision +from lightning_lite.strategies import DeepSpeedStrategy, Strategy from lightning_lite.utilities import _StrategyType +from lightning_lite.utilities.exceptions import MisconfigurationException from lightning_lite.utilities.seed import pl_worker_init_function from pytorch_lightning.lite import LightningLite from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer -from pytorch_lightning.plugins import PrecisionPlugin -from pytorch_lightning.strategies import DeepSpeedStrategy, Strategy -from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.runif import RunIf @@ -48,18 +48,6 @@ def forward(self, x): return torch.nn.functional.mse_loss(x, torch.ones_like(x)) -def test_unsupported_accelerator(): - accelerator = "coconut" - with pytest.raises(MisconfigurationException, match=f"`accelerator={repr(accelerator)}` is not a valid choice"): - EmptyLite(accelerator=accelerator) - - -def test_unsupported_strategy(): - strategy = "coconut" - with pytest.raises(MisconfigurationException, match=f"`strategy={repr(strategy)}` is not a valid choice"): - EmptyLite(strategy=strategy) - - def test_run_input_output(): """Test that the dynamically patched run() method receives the input arguments and returns the result.""" @@ -80,7 +68,7 @@ def run(self, *args, **kwargs): assert lite.run_kwargs == {"three": 3} -@mock.patch("pytorch_lightning.strategies.ddp.DistributedDataParallel") +@mock.patch("lightning_lite.strategies.ddp.DistributedDataParallel") def test_setup_model(ddp_mock): """Test that the setup method lets the strategy wrap the model, but keeps a reference to the original model.""" lite = EmptyLite(accelerator="cpu", strategy="ddp", devices=2) @@ -282,7 +270,7 @@ def test_setup_dataloaders_replace_custom_sampler(strategy): # explicitly asking to replace when a custom sampler is already configured raises an exception lite = EmptyLite(accelerator="cpu", strategy=strategy, devices=2) - if lite._accelerator_connector.is_distributed: + if lite._connector.is_distributed: with pytest.raises(MisconfigurationException, match="You seem to have configured a sampler in your DataLoader"): lite.setup_dataloaders(dataloader, replace_sampler=True) @@ -307,7 +295,7 @@ def test_setup_dataloaders_replace_custom_sampler(strategy): def test_setup_dataloaders_replace_standard_sampler(shuffle, strategy): """Test that Lite replaces the default samplers with DistributedSampler automatically.""" lite = EmptyLite(accelerator="cpu", strategy=strategy, devices=2) - is_distributed = lite._accelerator_connector.is_distributed + is_distributed = lite._connector.is_distributed lite_dataloader = lite.setup_dataloaders(DataLoader(range(3), shuffle=shuffle)) assert not is_distributed or isinstance(lite_dataloader.sampler, DistributedSampler) @@ -366,10 +354,10 @@ def test_rank_properties(): def test_backward(): """Test that backward() calls into the precision plugin.""" lite = EmptyLite() - lite._precision_plugin = Mock(spec=PrecisionPlugin) + lite._precision_plugin = Mock(spec=Precision) loss = Mock() lite.backward(loss, "arg", keyword="kwarg") - lite._precision_plugin._run_backward.assert_called_with(loss, None, "arg", keyword="kwarg") + lite._precision_plugin.backward.assert_called_with(loss, None, "arg", keyword="kwarg") @RunIf(deepspeed=True) @@ -383,7 +371,7 @@ def test_backward_model_input_required(): optimizer0 = torch.optim.Adam(model0.parameters()) optimizer1 = torch.optim.Adam(model1.parameters()) - lite._strategy._setup_model_and_optimizer = lambda *args: args + lite._strategy.setup_module_and_optimizers = lambda *args: args lite.setup(model0, optimizer0) lite.setup(model1, optimizer1) diff --git a/tests/tests_pytorch/lite/test_wrappers.py b/tests/tests_pytorch/lite/test_wrappers.py index acc05cfdcda8f..de2892e1dd01a 100644 --- a/tests/tests_pytorch/lite/test_wrappers.py +++ b/tests/tests_pytorch/lite/test_wrappers.py @@ -222,10 +222,12 @@ def test_lite_dataloader_device_placement(src_device_str, dest_device_str): iterator = iter(lite_dataloader) batch0 = next(iterator) - assert torch.equal(batch0, torch.tensor([0, 1], device=dest_device)) + # TODO: This should be torch.equal, but not supported on MPS at this time (torch 1.12) + assert torch.allclose(batch0, torch.tensor([0, 1], device=dest_device)) batch1 = next(iterator) - assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device)) + # TODO: This should be torch.equal, but not supported on MPS at this time (torch 1.12) + assert torch.allclose(batch1["data"], torch.tensor([2, 3], device=dest_device)) def test_lite_optimizer_wraps(): @@ -243,7 +245,7 @@ def test_lite_optimizer_state_dict(): strategy = Mock() lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy) lite_optimizer.state_dict() - strategy.optimizer_state.assert_called_with(optimizer) + strategy.get_optimizer_state.assert_called_with(optimizer) def test_lite_optimizer_steps(): From 346fc697fd619490cece456ec6f35c49da01183f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 15 Sep 2022 16:28:15 +0200 Subject: [PATCH 02/20] error handling --- src/pytorch_lightning/lite/lite.py | 17 ++++++++--------- tests/tests_pytorch/lite/test_lite.py | 10 +++++----- tests/tests_pytorch/lite/test_parity.py | 6 ++++-- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index 2d4a72b174e68..d79f1f223419f 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -38,7 +38,6 @@ _update_dataloader, has_iterable_dataset, ) -from lightning_lite.utilities.exceptions import MisconfigurationException from lightning_lite.utilities.seed import seed_everything from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from pytorch_lightning.overrides.distributed import DistributedSamplerWrapper @@ -245,18 +244,18 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = No if isinstance(self._strategy, DeepSpeedStrategy): if model is None: if self._models_setup == 0: - raise MisconfigurationException( - "No models were setup for backward. Did you forget to call `self.setup()`?" + raise RuntimeError( + "No models were set up for backward. Did you forget to call `self.setup()`?" ) if self._models_setup > 1: - raise MisconfigurationException( + raise ValueError( "When using multiple models + deepspeed, please provide the model used to perform" " the optimization: `self.backward(loss, model=model)`" ) module = self._strategy.model else: # requires to attach the current `DeepSpeedEngine` for the `_LiteOptimizer.step` call. - self._strategy.model = module + self._strategy._deepspeed_engine = module self._precision_plugin.backward(tensor, module, *args, **kwargs) @@ -429,15 +428,15 @@ def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> Distribut @staticmethod def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None: if isinstance(model, _LiteModule): - raise MisconfigurationException("A model should be passed only once to the `setup` method.") + raise ValueError("A model should be passed only once to the `setup` method.") if any(isinstance(opt, _LiteOptimizer) for opt in optimizers): - raise MisconfigurationException("An optimizer should be passed only once to the `setup` method.") + raise ValueError("An optimizer should be passed only once to the `setup` method.") @staticmethod def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None: if any(isinstance(dl, _LiteDataLoader) for dl in dataloaders): - raise MisconfigurationException("A dataloader should be passed only once to the `setup_dataloaders` method") + raise ValueError("A dataloader should be passed only once to the `setup_dataloaders` method") if any(not isinstance(dl, DataLoader) for dl in dataloaders): - raise MisconfigurationException("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.") + raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.") diff --git a/tests/tests_pytorch/lite/test_lite.py b/tests/tests_pytorch/lite/test_lite.py index c3c27d55aa9b3..cb187a030f579 100644 --- a/tests/tests_pytorch/lite/test_lite.py +++ b/tests/tests_pytorch/lite/test_lite.py @@ -116,11 +116,11 @@ def test_setup_twice_fails(): optimizer = torch.optim.Adam(model.parameters()) lite_model, lite_optimizer = lite.setup(model, optimizer) - with pytest.raises(MisconfigurationException, match="A model should be passed only once to the"): + with pytest.raises(ValueError, match="A model should be passed only once to the"): lite.setup(lite_model, optimizer) lite_model, lite_optimizer = lite.setup(model, optimizer) - with pytest.raises(MisconfigurationException, match="An optimizer should be passed only once to the"): + with pytest.raises(ValueError, match="An optimizer should be passed only once to the"): lite.setup(model, lite_optimizer) @@ -141,7 +141,7 @@ def test_setup_tracks_num_models(): def test_setup_dataloaders_unsupported_type(): """Test that the setup_dataloaders method fails when provided with non-DataLoader objects.""" lite = EmptyLite() - with pytest.raises(MisconfigurationException, match="Only PyTorch DataLoader are currently supported"): + with pytest.raises(TypeError, match="Only PyTorch DataLoader are currently supported"): lite.setup_dataloaders(range(2)) # type: ignore @@ -205,7 +205,7 @@ def test_setup_dataloaders_twice_fails(): dataloader = DataLoader(range(2)) lite_dataloader = lite.setup_dataloaders(dataloader) - with pytest.raises(MisconfigurationException, match="A dataloader should be passed only once to the"): + with pytest.raises(ValueError, match="A dataloader should be passed only once to the"): lite.setup_dataloaders(lite_dataloader) @@ -378,7 +378,7 @@ def test_backward_model_input_required(): loss = model0(torch.randn(1, 1)).sum() - with pytest.raises(MisconfigurationException, match="please provide the model used to perform"): + with pytest.raises(ValueError, match="please provide the model used to perform"): lite.backward(loss) diff --git a/tests/tests_pytorch/lite/test_parity.py b/tests/tests_pytorch/lite/test_parity.py index eaada992da497..ffb95855154cb 100644 --- a/tests/tests_pytorch/lite/test_parity.py +++ b/tests/tests_pytorch/lite/test_parity.py @@ -133,10 +133,12 @@ def test_boring_lite_model_single_device(precision, strategy, devices, accelerat state_dict = apply_to_collection(state_dict, torch.Tensor, lite.to_device) for w_pure, w_lite in zip(state_dict.values(), lite_state_dict.values()): - assert not torch.equal(w_pure, w_lite) + # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) + assert not torch.allclose(w_pure, w_lite) for w_pure, w_lite in zip(pure_state_dict.values(), lite_state_dict.values()): - assert torch.equal(w_pure, w_lite) + # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) + assert torch.allclose(w_pure, w_lite) def run(rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdir): From 9d9d77ae9107e7eee80b602271cbd4267a34e89e Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 15 Sep 2022 16:31:34 +0200 Subject: [PATCH 03/20] fixes --- src/pytorch_lightning/lite/lite.py | 4 +--- src/pytorch_lightning/lite/wrappers.py | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index d79f1f223419f..d7beaf30a87f5 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -244,9 +244,7 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = No if isinstance(self._strategy, DeepSpeedStrategy): if model is None: if self._models_setup == 0: - raise RuntimeError( - "No models were set up for backward. Did you forget to call `self.setup()`?" - ) + raise RuntimeError("No models were set up for backward. Did you forget to call `self.setup()`?") if self._models_setup > 1: raise ValueError( "When using multiple models + deepspeed, please provide the model used to perform" diff --git a/src/pytorch_lightning/lite/wrappers.py b/src/pytorch_lightning/lite/wrappers.py index 939de6f458bfd..05fc5b94f0d0d 100644 --- a/src/pytorch_lightning/lite/wrappers.py +++ b/src/pytorch_lightning/lite/wrappers.py @@ -59,11 +59,11 @@ def state_dict(self) -> Dict[str, Tensor]: return self._strategy.get_optimizer_state(self.optimizer) def step(self, closure: Optional[Callable] = None, module: Optional["_LiteModule"] = None) -> Any: - closure = closure or _do_nothing_closure + kwargs = dict(closure=closure) if closure is not None else {} return self._strategy.optimizer_step( self.optimizer, model=(module if module is not None else getattr(self._strategy, "model", None)), - closure=closure, + **kwargs, ) From 9a3f20fbaf95879adbd4a99b4d6e59d133652c25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Sep 2022 16:33:45 +0200 Subject: [PATCH 04/20] notebook --- _notebooks | 1 - 1 file changed, 1 deletion(-) delete mode 160000 _notebooks diff --git a/_notebooks b/_notebooks deleted file mode 160000 index 6d5634b794218..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6d5634b7942180e6ba4a30bfbd74926d1c22f1eb From 392d45601fa8d9e2058cfddcc10e35b7cfc892f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Sep 2022 16:34:09 +0200 Subject: [PATCH 05/20] notebook --- _notebooks | 1 + 1 file changed, 1 insertion(+) create mode 160000 _notebooks diff --git a/_notebooks b/_notebooks new file mode 160000 index 0000000000000..8a36a41548f34 --- /dev/null +++ b/_notebooks @@ -0,0 +1 @@ +Subproject commit 8a36a41548f34c44ac455d515a72994487e85813 From 81f473978f81adf857c050d9c2cf413db7d8076d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Sep 2022 10:34:37 -0400 Subject: [PATCH 06/20] Update src/lightning_lite/connector.py --- src/lightning_lite/connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_lite/connector.py b/src/lightning_lite/connector.py index c1243d2b41926..80a932eb529cf 100644 --- a/src/lightning_lite/connector.py +++ b/src/lightning_lite/connector.py @@ -69,7 +69,7 @@ class _Connector: 2. accelerator str 3. accelerator auto - B. strategy flag could be : + B. strategy flag could be: 1. strategy class 2. strategy str registered with STRATEGY_REGISTRY 3. strategy str in _strategy_type enum which listed in each strategy as From 7cd0b1d90eb0988bf61fa13899099ea7969a8288 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 15 Sep 2022 23:45:02 +0200 Subject: [PATCH 07/20] fix local rank bug --- src/lightning_lite/strategies/ddp_spawn.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lightning_lite/strategies/ddp_spawn.py b/src/lightning_lite/strategies/ddp_spawn.py index 3e8b48b2a6b43..9f37d6423581e 100644 --- a/src/lightning_lite/strategies/ddp_spawn.py +++ b/src/lightning_lite/strategies/ddp_spawn.py @@ -200,8 +200,7 @@ def _setup_distributed(self) -> None: def _get_process_group_backend(self) -> str: return self._process_group_backend or get_default_process_group_backend_for_device(self.root_device) - def _set_world_ranks(self, process_idx: int = 0) -> None: - self._local_rank = process_idx + def _set_world_ranks(self) -> None: if self.cluster_environment is None: return self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) From 34595ec5444a9d4d6aaca62e7b6d6cb667db3569 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 15 Sep 2022 23:52:37 +0200 Subject: [PATCH 08/20] fix mypy issue --- src/pytorch_lightning/lite/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index d7beaf30a87f5..105fc8cac7874 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -90,7 +90,7 @@ def __init__( gpus=gpus, ) self._strategy: Strategy = self._connector.strategy - self._accelerator: Accelerator = self._strategy.accelerator + self._accelerator: Accelerator = self._connector.accelerator self._precision_plugin: Precision = self._strategy.precision_plugin self._models_setup: int = 0 From f3d69bf89d007722bd79ce847a0cc24ccd439612 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 15 Sep 2022 23:56:59 +0200 Subject: [PATCH 09/20] fix test --- tests/tests_pytorch/lite/test_wrappers.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/lite/test_wrappers.py b/tests/tests_pytorch/lite/test_wrappers.py index de2892e1dd01a..b82a53cf2b5f2 100644 --- a/tests/tests_pytorch/lite/test_wrappers.py +++ b/tests/tests_pytorch/lite/test_wrappers.py @@ -257,4 +257,12 @@ def test_lite_optimizer_steps(): step_output = lite_optimizer.step() assert step_output == 123 strategy.optimizer_step.assert_called_once() - strategy.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=strategy.model) + strategy.optimizer_step.assert_called_with(optimizer, model=strategy.model) + + strategy.optimizer_step.reset_mock() + + # with closure as input + closure = Mock() + lite_optimizer.step(closure=closure) + strategy.optimizer_step.assert_called_once() + strategy.optimizer_step.assert_called_with(optimizer, model=strategy.model, closure=closure) From 2bb6d8abfc88dab2063d539040ef7623d221316f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 16 Sep 2022 00:01:46 +0200 Subject: [PATCH 10/20] remove placeholder files --- src/lightning_lite/lite.py | 3 --- tests/tests_lite/test_lite.py | 12 ------------ 2 files changed, 15 deletions(-) delete mode 100644 src/lightning_lite/lite.py delete mode 100644 tests/tests_lite/test_lite.py diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py deleted file mode 100644 index 65fee1bf09834..0000000000000 --- a/src/lightning_lite/lite.py +++ /dev/null @@ -1,3 +0,0 @@ -class LightningLite: - # Placeholder for real implementation - pass diff --git a/tests/tests_lite/test_lite.py b/tests/tests_lite/test_lite.py deleted file mode 100644 index a7df3089cb5ac..0000000000000 --- a/tests/tests_lite/test_lite.py +++ /dev/null @@ -1,12 +0,0 @@ -from tests_lite.helpers.runif import RunIf - -from lightning_lite.lite import LightningLite # noqa: F401 - - -def test_placeholder(tmpdir): - assert True - - -@RunIf(min_cuda_gpus=2, standalone=True) -def test_placeholder_standalone(tmpdir): - assert True From 428807b99f781a86864e520c921cc271eaf9cafc Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 16 Sep 2022 00:19:59 +0200 Subject: [PATCH 11/20] fixes --- src/lightning_lite/__init__.py | 9 ++- src/lightning_lite/strategies/ddp_spawn.py | 4 -- src/lightning_lite/utilities/distributed.py | 59 ++++++++++++++++++- src/pytorch_lightning/lite/lite.py | 4 +- .../overrides/distributed.py | 58 +----------------- .../trainer/connectors/data_connector.py | 3 +- tests/tests_pytorch/lite/test_lite.py | 2 +- .../trainer/connectors/test_data_connector.py | 2 +- 8 files changed, 73 insertions(+), 68 deletions(-) diff --git a/src/lightning_lite/__init__.py b/src/lightning_lite/__init__.py index 6c16dcbf6c393..04d2ee778fe4b 100644 --- a/src/lightning_lite/__init__.py +++ b/src/lightning_lite/__init__.py @@ -12,10 +12,15 @@ _logger.addHandler(logging.StreamHandler()) _logger.propagate = False -from lightning_lite.lite import LightningLite # noqa: E402 +# TODO(lite): Re-enable this import +# from lightning_lite.lite import LightningLite # noqa: E402 from lightning_lite.utilities.seed import seed_everything # noqa: E402 -__all__ = ["LightningLite", "seed_everything"] +__all__ = [ + # TODO(lite): Re-enable this import + # "LightningLite", + "seed_everything", +] # for compatibility with namespace packages __import__("pkg_resources").declare_namespace(__name__) diff --git a/src/lightning_lite/strategies/ddp_spawn.py b/src/lightning_lite/strategies/ddp_spawn.py index 9f37d6423581e..def19d4ac0f24 100644 --- a/src/lightning_lite/strategies/ddp_spawn.py +++ b/src/lightning_lite/strategies/ddp_spawn.py @@ -114,10 +114,6 @@ def setup_module(self, module: Module) -> Module: return DistributedDataParallel(module=module, device_ids=self._determine_ddp_device_ids(), **self._ddp_kwargs) def module_to_device(self, module: Module) -> None: - if self.root_device.type == "cuda": - # TODO(lite): This should be handled outside module_to_device, by a call to accelerator.setup_device() - # set the device on the spawned subprocesses - torch.cuda.set_device(self.root_device) module.to(self.root_device) def reduce( diff --git a/src/lightning_lite/utilities/distributed.py b/src/lightning_lite/utilities/distributed.py index 26fa3e1e230d0..6614eeb1e8510 100644 --- a/src/lightning_lite/utilities/distributed.py +++ b/src/lightning_lite/utilities/distributed.py @@ -1,13 +1,15 @@ import logging import os -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union, Iterable, Sized, Iterator import torch from lightning_utilities.core.rank_zero import rank_zero_deprecation from torch import Tensor from torch.nn import functional as F +from torch.utils.data import Dataset, Sampler, DistributedSampler from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment +from lightning_lite.utilities.exceptions import MisconfigurationException from lightning_lite.utilities.imports import _HPU_AVAILABLE, _TPU_AVAILABLE from lightning_lite.utilities.rank_zero import rank_zero_info as new_rank_zero_info @@ -262,3 +264,58 @@ def _get_process_group_backend_from_env() -> Optional[str]: " Specify `process_group_backend` directly on the strategy constructor." ) return torch_backend + + +# TODO(lite): The error messsages refer to 'replace_sampler_ddp' in PL but Lite has it named 'replace_sampler' +class _DatasetSamplerWrapper(Dataset): + """Dataset to create indexes from `Sampler` or `Iterable`""" + + def __init__(self, sampler: Union[Sampler, Iterable]) -> None: + if not isinstance(sampler, Sized): + raise TypeError( + "You seem to have configured a sampler in your DataLoader which" + " does not provide `__len__` method. The sampler was about to be" + " replaced by `DistributedSamplerWrapper` since `replace_sampler_ddp`" + " is True and you are using distributed training. Either provide `__len__`" + " method in your sampler, remove it from DataLoader or set `replace_sampler_ddp=False`" + " if you want to handle distributed sampling yourself." + ) + if len(sampler) == float("inf"): + raise TypeError( + "You seem to have configured a sampler in your DataLoader which" + " does not provide finite `__len__` method. The sampler was about to be" + " replaced by `DistributedSamplerWrapper` since `replace_sampler_ddp`" + " is True and you are using distributed training. Either provide `__len__`" + " method in your sampler which returns a finite number, remove it from DataLoader" + " or set `replace_sampler_ddp=False` if you want to handle distributed sampling yourself." + ) + self._sampler = sampler + # defer materializing an iterator until it is necessary + self._sampler_list: Optional[List[Any]] = None + + def __getitem__(self, index: int) -> Any: + if self._sampler_list is None: + self._sampler_list = list(self._sampler) + return self._sampler_list[index] + + def __len__(self) -> int: + return len(self._sampler) + + def reset(self) -> None: + """Reset the sampler list in order to get new sampling.""" + self._sampler_list = list(self._sampler) + + +class DistributedSamplerWrapper(DistributedSampler): + """Wrapper over ``Sampler`` for distributed training. + + Allows you to use any sampler in distributed mode. It will be automatically used by Lightning in distributed + mode if sampler replacement is enabled. + """ + + def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any) -> None: + super().__init__(_DatasetSamplerWrapper(sampler), *args, **kwargs) + + def __iter__(self) -> Iterator: + self.dataset.reset() + return (self.dataset[index] for index in super().__iter__()) diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index 105fc8cac7874..27bded2068a02 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -40,7 +40,7 @@ ) from lightning_lite.utilities.seed import seed_everything from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer -from pytorch_lightning.overrides.distributed import DistributedSamplerWrapper +from lightning_lite.utilities.distributed import DistributedSamplerWrapper class LightningLite(ABC): @@ -292,7 +292,7 @@ def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tens """ if isinstance(obj, nn.Module): self._accelerator.setup_device(self.device) - return obj.to(self.device) + return self._strategy.module_to_device(obj) return move_data_to_device(obj, device=self.device) def print(self, *args: Any, **kwargs: Any) -> None: diff --git a/src/pytorch_lightning/overrides/distributed.py b/src/pytorch_lightning/overrides/distributed.py index 3ecac8c1eea04..5a38742972925 100644 --- a/src/pytorch_lightning/overrides/distributed.py +++ b/src/pytorch_lightning/overrides/distributed.py @@ -17,11 +17,11 @@ import torch from torch import Tensor from torch.nn.parallel import DistributedDataParallel -from torch.utils.data import BatchSampler, Dataset, DistributedSampler, Sampler +from torch.utils.data import BatchSampler, DistributedSampler, Sampler import pytorch_lightning as pl +from lightning_lite.utilities.distributed import _DatasetSamplerWrapper from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase -from pytorch_lightning.utilities.exceptions import MisconfigurationException class LightningDistributedModule(_LightningModuleWrapperBase): @@ -109,60 +109,6 @@ def __iter__(self) -> Iterator[List[int]]: return iter(indices) -class _DatasetSamplerWrapper(Dataset): - """Dataset to create indexes from `Sampler` or `Iterable`""" - - def __init__(self, sampler: Union[Sampler, Iterable]) -> None: - if not isinstance(sampler, Sized): - raise MisconfigurationException( - "You seem to have configured a sampler in your DataLoader which" - " does not provide `__len__` method. The sampler was about to be" - " replaced by `DistributedSamplerWrapper` since `replace_sampler_ddp`" - " is True and you are using distributed training. Either provide `__len__`" - " method in your sampler, remove it from DataLoader or set `replace_sampler_ddp=False`" - " if you want to handle distributed sampling yourself." - ) - if len(sampler) == float("inf"): - raise MisconfigurationException( - "You seem to have configured a sampler in your DataLoader which" - " does not provide finite `__len__` method. The sampler was about to be" - " replaced by `DistributedSamplerWrapper` since `replace_sampler_ddp`" - " is True and you are using distributed training. Either provide `__len__`" - " method in your sampler which returns a finite number, remove it from DataLoader" - " or set `replace_sampler_ddp=False` if you want to handle distributed sampling yourself." - ) - self._sampler = sampler - # defer materializing an iterator until it is necessary - self._sampler_list: Optional[List[Any]] = None - - def __getitem__(self, index: int) -> Any: - if self._sampler_list is None: - self._sampler_list = list(self._sampler) - return self._sampler_list[index] - - def __len__(self) -> int: - return len(self._sampler) - - def reset(self) -> None: - """Reset the sampler list in order to get new sampling.""" - self._sampler_list = list(self._sampler) - - -class DistributedSamplerWrapper(DistributedSampler): - """Wrapper over ``Sampler`` for distributed training. - - Allows you to use any sampler in distributed mode. It will be automatically used by PyTorch Lightning in distributed - mode if `replace_sampler_ddp=True` - """ - - def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any) -> None: - super().__init__(_DatasetSamplerWrapper(sampler), *args, **kwargs) - - def __iter__(self) -> Iterator: - self.dataset.reset() - return (self.dataset[index] for index in super().__iter__()) - - class UnrepeatedDistributedSamplerWrapper(UnrepeatedDistributedSampler): """Equivalent class to ``DistributedSamplerWrapper`` but for the ``UnrepeatedDistributedSampler``.""" diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index b2a6dbe0c8a5a..b98f005000bbe 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -25,7 +25,8 @@ import pytorch_lightning as pl from lightning_lite.utilities.data import _auto_add_worker_init_fn, _replace_dunder_methods, has_iterable_dataset from pytorch_lightning.accelerators.ipu import IPUAccelerator -from pytorch_lightning.overrides.distributed import DistributedSamplerWrapper, UnrepeatedDistributedSamplerWrapper +from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSamplerWrapper +from lightning_lite.utilities.distributed import DistributedSamplerWrapper from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator diff --git a/tests/tests_pytorch/lite/test_lite.py b/tests/tests_pytorch/lite/test_lite.py index cb187a030f579..8b8c999580e25 100644 --- a/tests/tests_pytorch/lite/test_lite.py +++ b/tests/tests_pytorch/lite/test_lite.py @@ -271,7 +271,7 @@ def test_setup_dataloaders_replace_custom_sampler(strategy): # explicitly asking to replace when a custom sampler is already configured raises an exception lite = EmptyLite(accelerator="cpu", strategy=strategy, devices=2) if lite._connector.is_distributed: - with pytest.raises(MisconfigurationException, match="You seem to have configured a sampler in your DataLoader"): + with pytest.raises(TypeError, match="You seem to have configured a sampler in your DataLoader"): lite.setup_dataloaders(dataloader, replace_sampler=True) # setting `replace_sampler=False` leaves the sampler untouched diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 847922c05294a..2118d8e131ded 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -24,7 +24,7 @@ from lightning_lite.utilities.warnings import PossibleUserWarning from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset -from pytorch_lightning.overrides.distributed import DistributedSamplerWrapper +from lightning_lite.utilities.distributed import DistributedSamplerWrapper from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector, _DataLoaderSource, warning_cache from pytorch_lightning.trainer.states import RunningStage, TrainerFn From 39d41af34f156542b54e6eba4fd048ac91a0e02d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 16 Sep 2022 00:20:34 +0200 Subject: [PATCH 12/20] isort --- src/lightning_lite/utilities/distributed.py | 4 ++-- src/pytorch_lightning/lite/lite.py | 2 +- src/pytorch_lightning/trainer/connectors/data_connector.py | 2 +- tests/tests_pytorch/trainer/connectors/test_data_connector.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/lightning_lite/utilities/distributed.py b/src/lightning_lite/utilities/distributed.py index 6614eeb1e8510..aefbd2313c520 100644 --- a/src/lightning_lite/utilities/distributed.py +++ b/src/lightning_lite/utilities/distributed.py @@ -1,12 +1,12 @@ import logging import os -from typing import Any, List, Optional, Tuple, Union, Iterable, Sized, Iterator +from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple, Union import torch from lightning_utilities.core.rank_zero import rank_zero_deprecation from torch import Tensor from torch.nn import functional as F -from torch.utils.data import Dataset, Sampler, DistributedSampler +from torch.utils.data import Dataset, DistributedSampler, Sampler from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment from lightning_lite.utilities.exceptions import MisconfigurationException diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index 27bded2068a02..1901454d6a65d 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -38,9 +38,9 @@ _update_dataloader, has_iterable_dataset, ) +from lightning_lite.utilities.distributed import DistributedSamplerWrapper from lightning_lite.utilities.seed import seed_everything from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer -from lightning_lite.utilities.distributed import DistributedSamplerWrapper class LightningLite(ABC): diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index b98f005000bbe..7543172de9450 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -24,9 +24,9 @@ import pytorch_lightning as pl from lightning_lite.utilities.data import _auto_add_worker_init_fn, _replace_dunder_methods, has_iterable_dataset +from lightning_lite.utilities.distributed import DistributedSamplerWrapper from pytorch_lightning.accelerators.ipu import IPUAccelerator from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSamplerWrapper -from lightning_lite.utilities.distributed import DistributedSamplerWrapper from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 2118d8e131ded..ea5b825283680 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -21,10 +21,10 @@ from torch import Tensor from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler +from lightning_lite.utilities.distributed import DistributedSamplerWrapper from lightning_lite.utilities.warnings import PossibleUserWarning from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset -from lightning_lite.utilities.distributed import DistributedSamplerWrapper from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector, _DataLoaderSource, warning_cache from pytorch_lightning.trainer.states import RunningStage, TrainerFn From 605d53ab1beb9a6b376e1c17ce59a495287de007 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 16 Sep 2022 00:21:29 +0200 Subject: [PATCH 13/20] unused import --- tests/tests_pytorch/lite/test_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/lite/test_wrappers.py b/tests/tests_pytorch/lite/test_wrappers.py index b82a53cf2b5f2..c4fc83bf99145 100644 --- a/tests/tests_pytorch/lite/test_wrappers.py +++ b/tests/tests_pytorch/lite/test_wrappers.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import ANY, Mock +from unittest.mock import Mock import pytest import torch From 6bc2c9f9d91e36f6fa9c64660c36683d4f4cba76 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Sep 2022 22:23:27 +0000 Subject: [PATCH 14/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning_lite/__init__.py | 2 +- src/lightning_lite/utilities/distributed.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning_lite/__init__.py b/src/lightning_lite/__init__.py index 04d2ee778fe4b..dccaeae932c70 100644 --- a/src/lightning_lite/__init__.py +++ b/src/lightning_lite/__init__.py @@ -13,7 +13,7 @@ _logger.propagate = False # TODO(lite): Re-enable this import -# from lightning_lite.lite import LightningLite # noqa: E402 +# from lightning_lite.lite import LightningLite from lightning_lite.utilities.seed import seed_everything # noqa: E402 __all__ = [ diff --git a/src/lightning_lite/utilities/distributed.py b/src/lightning_lite/utilities/distributed.py index aefbd2313c520..dd165fd95cb31 100644 --- a/src/lightning_lite/utilities/distributed.py +++ b/src/lightning_lite/utilities/distributed.py @@ -309,8 +309,8 @@ def reset(self) -> None: class DistributedSamplerWrapper(DistributedSampler): """Wrapper over ``Sampler`` for distributed training. - Allows you to use any sampler in distributed mode. It will be automatically used by Lightning in distributed - mode if sampler replacement is enabled. + Allows you to use any sampler in distributed mode. It will be automatically used by Lightning in distributed mode if + sampler replacement is enabled. """ def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any) -> None: From 906e49b982823f98639c7344b6c763f5d8d6b693 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 16 Sep 2022 00:27:51 +0200 Subject: [PATCH 15/20] move dead code --- src/pytorch_lightning/lite/wrappers.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/pytorch_lightning/lite/wrappers.py b/src/pytorch_lightning/lite/wrappers.py index 05fc5b94f0d0d..91f9e984764c9 100644 --- a/src/pytorch_lightning/lite/wrappers.py +++ b/src/pytorch_lightning/lite/wrappers.py @@ -29,10 +29,6 @@ T_destination = TypeVar("T_destination", bound=Dict[str, Any]) -def _do_nothing_closure() -> None: - return None - - class _LiteOptimizer: def __init__(self, optimizer: Optimizer, strategy: Strategy) -> None: """LiteOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer From 794335e6f971eec536a8db1d1330d4d71ba55123 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 16 Sep 2022 00:32:32 +0200 Subject: [PATCH 16/20] unused import --- src/lightning_lite/utilities/distributed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lightning_lite/utilities/distributed.py b/src/lightning_lite/utilities/distributed.py index dd165fd95cb31..5483c8c242a76 100644 --- a/src/lightning_lite/utilities/distributed.py +++ b/src/lightning_lite/utilities/distributed.py @@ -9,7 +9,6 @@ from torch.utils.data import Dataset, DistributedSampler, Sampler from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment -from lightning_lite.utilities.exceptions import MisconfigurationException from lightning_lite.utilities.imports import _HPU_AVAILABLE, _TPU_AVAILABLE from lightning_lite.utilities.rank_zero import rank_zero_info as new_rank_zero_info From 28c8ab50c56e50e42f606e3bfddcc5c68c6184f6 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 16 Sep 2022 01:38:08 +0200 Subject: [PATCH 17/20] fix to device --- src/pytorch_lightning/lite/lite.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index 1901454d6a65d..cd18139375d6e 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -292,7 +292,8 @@ def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tens """ if isinstance(obj, nn.Module): self._accelerator.setup_device(self.device) - return self._strategy.module_to_device(obj) + self._strategy.module_to_device(obj) + return obj return move_data_to_device(obj, device=self.device) def print(self, *args: Any, **kwargs: Any) -> None: From 912590e5e88b469386eec9d64f3cc8c1eed0ab2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Sep 2022 19:38:28 -0400 Subject: [PATCH 18/20] Update src/pytorch_lightning/lite/lite.py --- src/pytorch_lightning/lite/lite.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index cd18139375d6e..7a361231352f6 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -421,7 +421,6 @@ def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool: @staticmethod def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> DistributedSampler: kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0))) - # TODO(lite): Bring the DistributedSamplerWrapper to Lite package return DistributedSamplerWrapper(dataloader.sampler, **kwargs) @staticmethod From f89d45e3532e7cf689849b958576241bb316fd1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Sep 2022 12:35:38 -0400 Subject: [PATCH 19/20] Update src/lightning_lite/utilities/distributed.py --- src/lightning_lite/utilities/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_lite/utilities/distributed.py b/src/lightning_lite/utilities/distributed.py index 5483c8c242a76..ebcac26083922 100644 --- a/src/lightning_lite/utilities/distributed.py +++ b/src/lightning_lite/utilities/distributed.py @@ -265,7 +265,7 @@ def _get_process_group_backend_from_env() -> Optional[str]: return torch_backend -# TODO(lite): The error messsages refer to 'replace_sampler_ddp' in PL but Lite has it named 'replace_sampler' +# TODO(lite): The error messages refer to 'replace_sampler_ddp' in PL but Lite has it named 'replace_sampler' class _DatasetSamplerWrapper(Dataset): """Dataset to create indexes from `Sampler` or `Iterable`""" From 28a33393bda29de5e28bb6935ea9ad504396bfa4 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 16 Sep 2022 18:39:41 +0200 Subject: [PATCH 20/20] remove optional module arg from optimizer wrapper --- src/pytorch_lightning/lite/wrappers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/lite/wrappers.py b/src/pytorch_lightning/lite/wrappers.py index 91f9e984764c9..aed21b3aa5192 100644 --- a/src/pytorch_lightning/lite/wrappers.py +++ b/src/pytorch_lightning/lite/wrappers.py @@ -54,11 +54,11 @@ def optimizer(self) -> Optimizer: def state_dict(self) -> Dict[str, Tensor]: return self._strategy.get_optimizer_state(self.optimizer) - def step(self, closure: Optional[Callable] = None, module: Optional["_LiteModule"] = None) -> Any: + def step(self, closure: Optional[Callable] = None) -> Any: kwargs = dict(closure=closure) if closure is not None else {} return self._strategy.optimizer_step( self.optimizer, - model=(module if module is not None else getattr(self._strategy, "model", None)), + model=getattr(self._strategy, "model", None), **kwargs, )