diff --git a/CHANGELOG.md b/CHANGELOG.md index 906603f55e3db..f60ed8f5658a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -78,9 +78,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `XLAEnvironment` cluster environment plugin ([#11330](https://github.com/PyTorchLightning/pytorch-lightning/pull/11330)) +- Added support for calling unknown methods with `DummyLogger` ([#13224](https://github.com/PyTorchLightning/pytorch-lightning/pull/13224) -- Added support for calling unknown methods with `DummyLogger` ([#13224](https://github.com/PyTorchLightning/pytorch-lightning/pull/13224)) +- Added Apple Silicon Support via `MPSAccelerator` ([#13123](https://github.com/PyTorchLightning/pytorch-lightning/pull/13123)) + ### Changed diff --git a/docs/source-pytorch/accelerators/mps.rst b/docs/source-pytorch/accelerators/mps.rst new file mode 100644 index 0000000000000..53e8609a0dc1d --- /dev/null +++ b/docs/source-pytorch/accelerators/mps.rst @@ -0,0 +1,32 @@ +.. _mps: + +Accelerator: Apple Silicon training +=================================== + +.. raw:: html + +
+
+ +.. Add callout items below this line + +.. displayitem:: + :header: Prepare your code (Optional) + :description: Prepare your code to run on any hardware + :col_css: col-md-4 + :button_link: accelerator_prepare.html + :height: 150 + :tag: basic + +.. displayitem:: + :header: Basic + :description: Learn the basics of Apple silicon gpu training. + :col_css: col-md-4 + :button_link: mps_basic.html + :height: 150 + :tag: basic + +.. raw:: html + +
+
diff --git a/docs/source-pytorch/accelerators/mps_basic.rst b/docs/source-pytorch/accelerators/mps_basic.rst new file mode 100644 index 0000000000000..15e6ab929ba14 --- /dev/null +++ b/docs/source-pytorch/accelerators/mps_basic.rst @@ -0,0 +1,48 @@ +:orphan: + +.. _mps_basic: + +MPS training (basic) +==================== +**Audience:** Users looking to train on their Apple silicon GPUs. + +.. warning:: + + Both the MPS accelerator and the PyTorch backend are still experimental. + As such, not all operations are currently supported. However, with ongoing development from the PyTorch team, an increasingly large number of operations are becoming available. + You can use ``PYTORCH_ENABLE_MPS_FALLBACK=1 python your_script.py`` to fall back to cpu for unsupported operations. + + +---- + +What is Apple silicon? +---------------------- +Apple silicon chips are a unified system on a chip (SoC) developed by Apple based on the ARM design. +Among other things, they feature CPU-cores, GPU-cores, a neural engine and shared memory between all of these features. + +---- + +So it's a CPU? +-------------- +Apple silicon includes CPU-cores among several other features. However, the full potential for the hardware acceleration of which the M-Socs are capable is unavailable when running on the ``CPUAccelerator``. This is because they also feature a GPU and a neural engine. + +To use them, Lightning supports the ``MPSAccelerator``. + +---- + +Run on Apple silicon gpus +------------------------- +Enable the following Trainer arguments to run on Apple silicon gpus (MPS devices). + +.. code:: + + trainer = Trainer(accelerator="mps", devices=1) + +.. note:: + The ``MPSAccelerator`` only supports 1 device at a time. Currently there are no machines with multiple MPS-capable GPUs. + +---- + +What does MPS stand for? +------------------------ +MPS is short for `Metal Performance Shaders `_ which is the technology used in the back for gpu communication and computing. diff --git a/docs/source-pytorch/extensions/accelerator.rst b/docs/source-pytorch/extensions/accelerator.rst index dd5a0672485c7..fdfe9660b90aa 100644 --- a/docs/source-pytorch/extensions/accelerator.rst +++ b/docs/source-pytorch/extensions/accelerator.rst @@ -4,7 +4,7 @@ Accelerator ########### -The Accelerator connects a Lightning Trainer to arbitrary hardware (CPUs, GPUs, TPUs, IPUs, ...). +The Accelerator connects a Lightning Trainer to arbitrary hardware (CPUs, GPUs, TPUs, IPUs, MPS, ...). Currently there are accelerators for: - CPU @@ -12,6 +12,7 @@ Currently there are accelerators for: - :doc:`TPU <../accelerators/tpu>` - :doc:`IPU <../accelerators/ipu>` - :doc:`HPU <../accelerators/hpu>` +- :doc:`MPS <../accelerators/mps>` The Accelerator is part of the Strategy which manages communication across multiple devices (distributed communication). Whenever the Trainer, the loops or any other component in Lightning needs to talk to hardware, it calls into the Strategy and the Strategy calls into the Accelerator. @@ -127,4 +128,5 @@ Accelerator API GPUAccelerator HPUAccelerator IPUAccelerator + MPSAccelerator TPUAccelerator diff --git a/docs/source-pytorch/index.rst b/docs/source-pytorch/index.rst index fad7cb006079d..40a3d63d787e0 100644 --- a/docs/source-pytorch/index.rst +++ b/docs/source-pytorch/index.rst @@ -210,6 +210,7 @@ Current Lightning Users Train on single or multiple HPUs Train on single or multiple IPUs Train on single or multiple TPUs + Train on MPS Use a pretrained model model/own_your_loop diff --git a/src/pytorch_lightning/accelerators/__init__.py b/src/pytorch_lightning/accelerators/__init__.py index 1ab90e025b087..e7d757cd73149 100644 --- a/src/pytorch_lightning/accelerators/__init__.py +++ b/src/pytorch_lightning/accelerators/__init__.py @@ -15,6 +15,7 @@ from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.hpu import HPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401 +from pytorch_lightning.accelerators.mps import MPSAccelerator # noqa: F401 from pytorch_lightning.accelerators.registry import AcceleratorRegistry, call_register_accelerators # noqa: F401 from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa: F401 diff --git a/src/pytorch_lightning/accelerators/gpu.py b/src/pytorch_lightning/accelerators/gpu.py index 15495a3bae095..898ce09b91431 100644 --- a/src/pytorch_lightning/accelerators/gpu.py +++ b/src/pytorch_lightning/accelerators/gpu.py @@ -74,7 +74,7 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: @staticmethod def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: """Accelerator device parsing logic.""" - return device_parser.parse_gpu_ids(devices) + return device_parser.parse_gpu_ids(devices, include_cuda=True) @staticmethod def get_parallel_devices(devices: List[int]) -> List[torch.device]: diff --git a/src/pytorch_lightning/accelerators/mps.py b/src/pytorch_lightning/accelerators/mps.py new file mode 100644 index 0000000000000..5c35b618b55fc --- /dev/null +++ b/src/pytorch_lightning/accelerators/mps.py @@ -0,0 +1,95 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Optional, Union + +import torch + +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.utilities import device_parser +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _PSUTIL_AVAILABLE, _TORCH_GREATER_EQUAL_1_12 +from pytorch_lightning.utilities.types import _DEVICE + +_MPS_AVAILABLE = _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() + + +class MPSAccelerator(Accelerator): + """Accelerator for Metal Apple Silicon GPU devices.""" + + def setup_environment(self, root_device: torch.device) -> None: + """ + Raises: + MisconfigurationException: + If the selected device is not MPS. + """ + super().setup_environment(root_device) + if root_device.type != "mps": + raise MisconfigurationException(f"Device should be MPS, got {root_device} instead.") + + def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + """Get M1 (cpu + gpu) stats from ``psutil`` package.""" + return get_device_stats() + + @staticmethod + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: + """Accelerator device parsing logic.""" + parsed_devices = device_parser.parse_gpu_ids(devices, include_mps=True) + return parsed_devices + + @staticmethod + def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: + """Gets parallel devices for the Accelerator.""" + parsed_devices = MPSAccelerator.parse_devices(devices) + assert parsed_devices is not None + + return [torch.device("mps", i) for i in range(len(parsed_devices))] + + @staticmethod + def auto_device_count() -> int: + """Get the devices when set to auto.""" + return 1 + + @staticmethod + def is_available() -> bool: + """MPS is only available for certain torch builds starting at torch>=1.12.""" + return _MPS_AVAILABLE + + @classmethod + def register_accelerators(cls, accelerator_registry: Dict) -> None: + accelerator_registry.register( + "mps", + cls, + description=cls.__class__.__name__, + ) + + +# device metrics +_VM_PERCENT = "M1_vm_percent" +_PERCENT = "M1_percent" +_SWAP_PERCENT = "M1_swap_percent" + + +def get_device_stats() -> Dict[str, float]: + if not _PSUTIL_AVAILABLE: + raise ModuleNotFoundError( + "Fetching M1 device stats requires `psutil` to be installed." + " Install it by running `pip install -U psutil`." + ) + import psutil + + return { + _VM_PERCENT: psutil.virtual_memory().percent, + _PERCENT: psutil.cpu_percent(), + _SWAP_PERCENT: psutil.swap_memory().percent, + } diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index 4f3f28fcab931..4dfcde177f953 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -466,6 +466,7 @@ def _supported_device_types() -> Sequence[_AcceleratorType]: _AcceleratorType.CPU, _AcceleratorType.GPU, _AcceleratorType.TPU, + _AcceleratorType.MPS, ) @staticmethod diff --git a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py index e4ef59aca3cf4..265cfdaf13f08 100644 --- a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -25,6 +25,7 @@ from pytorch_lightning.accelerators.gpu import GPUAccelerator from pytorch_lightning.accelerators.hpu import HPUAccelerator from pytorch_lightning.accelerators.ipu import IPUAccelerator +from pytorch_lightning.accelerators.mps import MPSAccelerator from pytorch_lightning.accelerators.registry import AcceleratorRegistry from pytorch_lightning.accelerators.tpu import TPUAccelerator from pytorch_lightning.plugins import ( @@ -178,7 +179,7 @@ def __init__( self._precision_flag: Optional[Union[int, str]] = None self._precision_plugin_flag: Optional[PrecisionPlugin] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None - self._parallel_devices: List[Union[int, torch.device]] = [] + self._parallel_devices: List[Union[int, torch.device, str]] = [] self._layer_sync: Optional[LayerSync] = NativeSyncBatchNorm() if sync_batchnorm else None self.checkpoint_io: Optional[CheckpointIO] = None self._amp_type_flag: Optional[LightningEnum] = None @@ -407,7 +408,7 @@ def _check_device_config_and_set_final_flags( if self._devices_flag == "auto" and self._accelerator_flag is None: raise MisconfigurationException( f"You passed `devices={devices}` but haven't specified" - " `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu'|'hpu)` for the devices mapping." + " `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu'|'hpu'|'mps')` for the devices mapping." ) def _map_deprecated_devices_specific_info_to_accelerator_and_device_flag( @@ -484,6 +485,8 @@ def _choose_accelerator(self) -> str: return "ipu" if _HPU_AVAILABLE: return "hpu" + if MPSAccelerator.is_available(): + return "mps" if torch.cuda.is_available() and torch.cuda.device_count() > 0: return "gpu" return "cpu" @@ -571,11 +574,13 @@ def _choose_strategy(self) -> Union[Strategy, str]: if self._num_nodes_flag > 1: return DDPStrategy.strategy_name if len(self._parallel_devices) <= 1: - device = ( - device_parser.determine_root_gpu_device(self._parallel_devices) # type: ignore - if self._accelerator_flag == "gpu" - else "cpu" - ) + # TODO: Change this once gpu accelerator was renamed to cuda accelerator + if isinstance(self._accelerator_flag, (GPUAccelerator, MPSAccelerator)) or ( + isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("gpu", "mps") + ): + device = device_parser.determine_root_gpu_device(self._parallel_devices) + else: + device = "cpu" # TODO: lazy initialized device, then here could be self._strategy_flag = "single_device" return SingleDeviceStrategy(device=device) # type: ignore if len(self._parallel_devices) > 1: diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 5dc3266319d0a..46774395fd5e2 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -36,7 +36,14 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.accelerators import Accelerator, GPUAccelerator, HPUAccelerator, IPUAccelerator, TPUAccelerator +from pytorch_lightning.accelerators import ( + Accelerator, + GPUAccelerator, + HPUAccelerator, + IPUAccelerator, + MPSAccelerator, + TPUAccelerator, +) from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.datamodule import LightningDataModule @@ -188,7 +195,7 @@ def __init__( Args: - accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "auto") + accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps, "auto") as well as custom accelerator instances. .. deprecated:: v1.5 @@ -1731,9 +1738,19 @@ def __setup_profiler(self) -> None: self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir) def _log_device_info(self) -> None: - rank_zero_info( - f"GPU available: {torch.cuda.is_available()}, used: {isinstance(self.accelerator, GPUAccelerator)}" - ) + + if GPUAccelerator.is_available(): + gpu_available = True + gpu_type = " (cuda)" + elif MPSAccelerator.is_available(): + gpu_available = True + gpu_type = " (mps)" + else: + gpu_available = False + gpu_type = "" + + gpu_used = isinstance(self.accelerator, (GPUAccelerator, MPSAccelerator)) + rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}") num_tpu_cores = self.num_devices if isinstance(self.accelerator, TPUAccelerator) else 0 rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores") @@ -1744,6 +1761,7 @@ def _log_device_info(self) -> None: num_hpus = self.num_devices if isinstance(self.accelerator, HPUAccelerator) else 0 rank_zero_info(f"HPU available: {_HPU_AVAILABLE}, using: {num_hpus} HPUs") + # TODO: Integrate MPS Accelerator here, once gpu maps to both if torch.cuda.is_available() and not isinstance(self.accelerator, GPUAccelerator): rank_zero_warn( "GPU available but not used. Set `accelerator` and `devices` using" @@ -1769,6 +1787,12 @@ def _log_device_info(self) -> None: f" `Trainer(accelerator='hpu', devices={HPUAccelerator.auto_device_count()})`." ) + if MPSAccelerator.is_available() and not isinstance(self.accelerator, MPSAccelerator): + rank_zero_warn( + "MPS available but not used. Set `accelerator` and `devices` using" + f" `Trainer(accelerator='mps', devices={MPSAccelerator.auto_device_count()})`." + ) + """ Data loading methods """ diff --git a/src/pytorch_lightning/utilities/device_parser.py b/src/pytorch_lightning/utilities/device_parser.py index 1aa1ee7662e59..881a02a809ec2 100644 --- a/src/pytorch_lightning/utilities/device_parser.py +++ b/src/pytorch_lightning/utilities/device_parser.py @@ -53,17 +53,23 @@ def _parse_devices( gpus: Optional[Union[List[int], str, int]], auto_select_gpus: bool, tpu_cores: Optional[Union[List[int], str, int]], + include_cuda: bool = False, + include_mps: bool = False, ) -> Tuple[Optional[List[int]], Optional[Union[List[int], int]]]: if auto_select_gpus and isinstance(gpus, int): gpus = pick_multiple_gpus(gpus) # TODO (@seannaren, @kaushikb11): Include IPU parsing logic here - gpu_ids = parse_gpu_ids(gpus) + gpu_ids = parse_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps) tpu_cores = parse_tpu_cores(tpu_cores) return gpu_ids, tpu_cores -def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[int]]: +def parse_gpu_ids( + gpus: Optional[Union[int, str, List[int]]], + include_cuda: bool = False, + include_mps: bool = False, +) -> Optional[List[int]]: """ Parses the GPU ids given in the format as accepted by the :class:`~pytorch_lightning.trainer.Trainer`. @@ -74,6 +80,8 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i indicates specific GPUs to use. An int 0 means that no GPUs should be used. Any int N > 0 indicates that GPUs [0..N) should be used. + include_cuda: A boolean indicating whether to include cuda devices for gpu parsing. + include_mps: A boolean indicating whether to include mps devices for gpu parsing. Returns: a list of gpus to be used or ``None`` if no GPUs were requested @@ -81,6 +89,10 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i Raises: MisconfigurationException: If no GPUs are available but the value of gpus variable indicates request for GPUs + + .. note:: + ``include_cuda`` and ``include_mps`` default to ``False`` so that you only + have to specify which device type to use and not disabling all the others. """ # Check that gpus param is None, Int, String or Sequence of Ints _check_data_type(gpus) @@ -92,17 +104,21 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i # We know user requested GPUs therefore if some of the # requested GPUs are not available an exception is thrown. gpus = _normalize_parse_gpu_string_input(gpus) - gpus = _normalize_parse_gpu_input_to_list(gpus) + gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps) if not gpus: raise MisconfigurationException("GPUs requested but none are available.") - if TorchElasticEnvironment.detect() and len(gpus) != 1 and len(_get_all_available_gpus()) == 1: + if ( + TorchElasticEnvironment.detect() + and len(gpus) != 1 + and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)) == 1 + ): # omit sanity check on torchelastic as by default shows one visible GPU per process return gpus # Check that gpus are unique. Duplicate gpus are not supported by the backend. _check_unique(gpus) - return _sanitize_gpu_ids(gpus) + return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps) def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional[Union[int, List[int]]]: @@ -167,7 +183,7 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in return int(s.strip()) -def _sanitize_gpu_ids(gpus: List[int]) -> List[int]: +def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]: """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of the GPUs is not available. @@ -181,7 +197,9 @@ def _sanitize_gpu_ids(gpus: List[int]) -> List[int]: MisconfigurationException: If machine has fewer available GPUs than requested. """ - all_available_gpus = _get_all_available_gpus() + if sum((include_cuda, include_mps)) == 0: + raise ValueError("At least one gpu type should be specified!") + all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) for gpu in gpus: if gpu not in all_available_gpus: raise MisconfigurationException( @@ -190,7 +208,9 @@ def _sanitize_gpu_ids(gpus: List[int]) -> List[int]: return gpus -def _normalize_parse_gpu_input_to_list(gpus: Union[int, List[int], Tuple[int, ...]]) -> Optional[List[int]]: +def _normalize_parse_gpu_input_to_list( + gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool +) -> Optional[List[int]]: assert gpus is not None if isinstance(gpus, (MutableSequence, tuple)): return list(gpus) @@ -199,15 +219,36 @@ def _normalize_parse_gpu_input_to_list(gpus: Union[int, List[int], Tuple[int, .. if not gpus: # gpus==0 return None if gpus == -1: - return _get_all_available_gpus() + return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) return list(range(gpus)) -def _get_all_available_gpus() -> List[int]: +def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> List[int]: + """ + Returns: + a list of all available gpus + """ + cuda_gpus = _get_all_available_cuda_gpus() if include_cuda else [] + mps_gpus = _get_all_available_mps_gpus() if include_mps else [] + return cuda_gpus + mps_gpus + + +def _get_all_available_mps_gpus() -> List[int]: + """ + Returns: + a list of all available MPS gpus + """ + # lazy import to avoid circular dependencies + from pytorch_lightning.accelerators.mps import _MPS_AVAILABLE + + return [0] if _MPS_AVAILABLE else [] + + +def _get_all_available_cuda_gpus() -> List[int]: """ Returns: - a list of all available gpus + a list of all available CUDA gpus """ return list(range(torch.cuda.device_count())) diff --git a/src/pytorch_lightning/utilities/enums.py b/src/pytorch_lightning/utilities/enums.py index f4b0f29d8be41..b7f714d230971 100644 --- a/src/pytorch_lightning/utilities/enums.py +++ b/src/pytorch_lightning/utilities/enums.py @@ -254,6 +254,7 @@ class _AcceleratorType(LightningEnum): IPU = "IPU" TPU = "TPU" HPU = "HPU" + MPS = "MPS" class _FaultTolerantMode(LightningEnum): diff --git a/tests/tests_pytorch/accelerators/test_accelerator_connector.py b/tests/tests_pytorch/accelerators/test_accelerator_connector.py index c5480fad089fc..20cac155f9915 100644 --- a/tests/tests_pytorch/accelerators/test_accelerator_connector.py +++ b/tests/tests_pytorch/accelerators/test_accelerator_connector.py @@ -25,6 +25,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.gpu import GPUAccelerator +from pytorch_lightning.accelerators.mps import MPSAccelerator from pytorch_lightning.plugins import DoublePrecisionPlugin, LayerSync, NativeSyncBatchNorm, PrecisionPlugin from pytorch_lightning.plugins.environments import ( KubeflowEnvironment, @@ -714,12 +715,21 @@ def test_devices_auto_choice_cpu( @mock.patch("torch.cuda.is_available", return_value=True) @mock.patch("torch.cuda.device_count", return_value=2) +@RunIf(mps=False) def test_devices_auto_choice_gpu(is_gpu_available_mock, device_count_mock): + trainer = Trainer(accelerator="auto", devices="auto") assert isinstance(trainer.accelerator, GPUAccelerator) assert trainer.num_devices == 2 +@RunIf(mps=True) +def test_devices_auto_choice_mps(): + trainer = Trainer(accelerator="auto", devices="auto") + assert isinstance(trainer.accelerator, MPSAccelerator) + assert trainer.num_devices == 1 + + @pytest.mark.parametrize( ["parallel_devices", "accelerator"], [([torch.device("cpu")], "gpu"), ([torch.device("cuda", i) for i in range(8)], ("tpu"))], diff --git a/tests/tests_pytorch/accelerators/test_accelerator_registry.py b/tests/tests_pytorch/accelerators/test_accelerator_registry.py index 4e2b521873408..11c806a21c740 100644 --- a/tests/tests_pytorch/accelerators/test_accelerator_registry.py +++ b/tests/tests_pytorch/accelerators/test_accelerator_registry.py @@ -63,4 +63,4 @@ def is_available(): def test_available_accelerators_in_registry(): - assert AcceleratorRegistry.available_accelerators() == ["cpu", "gpu", "hpu", "ipu", "tpu"] + assert AcceleratorRegistry.available_accelerators() == ["cpu", "gpu", "hpu", "ipu", "mps", "tpu"] diff --git a/tests/tests_pytorch/accelerators/test_mps.py b/tests/tests_pytorch/accelerators/test_mps.py new file mode 100644 index 0000000000000..01e13e937b4d0 --- /dev/null +++ b/tests/tests_pytorch/accelerators/test_mps.py @@ -0,0 +1,164 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple + +import pytest +import torch + +import tests_pytorch.helpers.pipelines as tpipes +from pytorch_lightning import Trainer +from pytorch_lightning.accelerators import MPSAccelerator +from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY +from tests_pytorch.helpers.imports import Batch, Dataset, Example, Field, LabelField +from tests_pytorch.helpers.runif import RunIf + + +@RunIf(mps=True) +def test_get_mps_stats(): + current_device = torch.device("mps") + device_stats = MPSAccelerator().get_device_stats(current_device) + fields = ["M1_vm_percent", "M1_percent", "M1_swap_percent"] + + for f in fields: + assert any(f in h for h in device_stats.keys()) + + +@RunIf(mps=True) +def test_mps_availability(): + assert MPSAccelerator.is_available() + + +@RunIf(mps=True) +def test_warning_if_mps_not_used(): + with pytest.warns(UserWarning, match="MPS available but not used. Set `accelerator` and `devices`"): + Trainer() + + +@RunIf(mps=True) +@pytest.mark.parametrize("accelerator_value", ["mps", MPSAccelerator()]) +def test_trainer_mps_accelerator(accelerator_value): + trainer = Trainer(accelerator=accelerator_value) + assert isinstance(trainer.accelerator, MPSAccelerator) + assert trainer.num_devices == 1 + + +@RunIf(mps=True) +@pytest.mark.parametrize("devices", [1, [0], "-1"]) +def test_single_gpu_model(tmpdir, devices): + """Make sure single GPU works.""" + trainer_options = dict( + default_root_dir=tmpdir, + enable_progress_bar=False, + max_epochs=1, + limit_train_batches=0.1, + limit_val_batches=0.1, + accelerator="mps", + devices=devices, + ) + + model = BoringModel() + tpipes.run_model_test(trainer_options, model) + + +@RunIf(mps=True) +def test_single_gpu_batch_parse(): + trainer = Trainer(accelerator="mps", devices=1) + + # non-transferrable types + primitive_objects = [None, {}, [], 1.0, "x", [None, 2], {"x": (1, 2), "y": None}] + for batch in primitive_objects: + data = trainer.strategy.batch_to_device(batch, torch.device("mps")) + assert data == batch + + # batch is just a tensor + batch = torch.rand(2, 3) + batch = trainer.strategy.batch_to_device(batch, torch.device("mps")) + assert batch.device.index == 0 and batch.type() == "torch.mps.FloatTensor" + + # tensor list + batch = [torch.rand(2, 3), torch.rand(2, 3)] + batch = trainer.strategy.batch_to_device(batch, torch.device("mps")) + assert batch[0].device.index == 0 and batch[0].type() == "torch.mps.FloatTensor" + assert batch[1].device.index == 0 and batch[1].type() == "torch.mps.FloatTensor" + + # tensor list of lists + batch = [[torch.rand(2, 3), torch.rand(2, 3)]] + batch = trainer.strategy.batch_to_device(batch, torch.device("mps")) + assert batch[0][0].device.index == 0 and batch[0][0].type() == "torch.mps.FloatTensor" + assert batch[0][1].device.index == 0 and batch[0][1].type() == "torch.mps.FloatTensor" + + # tensor dict + batch = [{"a": torch.rand(2, 3), "b": torch.rand(2, 3)}] + batch = trainer.strategy.batch_to_device(batch, torch.device("mps")) + assert batch[0]["a"].device.index == 0 and batch[0]["a"].type() == "torch.mps.FloatTensor" + assert batch[0]["b"].device.index == 0 and batch[0]["b"].type() == "torch.mps.FloatTensor" + + # tuple of tensor list and list of tensor dict + batch = ([torch.rand(2, 3) for _ in range(2)], [{"a": torch.rand(2, 3), "b": torch.rand(2, 3)} for _ in range(2)]) + batch = trainer.strategy.batch_to_device(batch, torch.device("mps")) + assert batch[0][0].device.index == 0 and batch[0][0].type() == "torch.mps.FloatTensor" + + assert batch[1][0]["a"].device.index == 0 + assert batch[1][0]["a"].type() == "torch.mps.FloatTensor" + + assert batch[1][0]["b"].device.index == 0 + assert batch[1][0]["b"].type() == "torch.mps.FloatTensor" + + # namedtuple of tensor + BatchType = namedtuple("BatchType", ["a", "b"]) + batch = [BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)] + batch = trainer.strategy.batch_to_device(batch, torch.device("mps")) + assert batch[0].a.device.index == 0 + assert batch[0].a.type() == "torch.mps.FloatTensor" + + # non-Tensor that has `.to()` defined + class CustomBatchType: + def __init__(self): + self.a = torch.rand(2, 2) + + def to(self, *args, **kwargs): + self.a = self.a.to(*args, **kwargs) + return self + + batch = trainer.strategy.batch_to_device(CustomBatchType(), torch.device("mps")) + assert batch.a.type() == "torch.mps.FloatTensor" + + # torchtext.data.Batch + if not _TORCHTEXT_LEGACY: + return + + samples = [ + {"text": "PyTorch Lightning is awesome!", "label": 0}, + {"text": "Please make it work with torchtext", "label": 1}, + ] + + text_field = Field() + label_field = LabelField() + fields = {"text": ("text", text_field), "label": ("label", label_field)} + + examples = [Example.fromdict(sample, fields) for sample in samples] + dataset = Dataset(examples=examples, fields=fields.values()) + # Batch runs field.process() that numericalizes tokens, but it requires to build dictionary first + text_field.build_vocab(dataset) + label_field.build_vocab(dataset) + + batch = Batch(data=examples, dataset=dataset) + + with pytest.deprecated_call(match="The `torchtext.legacy.Batch` object is deprecated"): + batch = trainer.strategy.batch_to_device(batch, torch.device("mps")) + + assert batch.text.type() == "torch.mps.LongTensor" + assert batch.label.type() == "torch.mps.LongTensor" diff --git a/tests/tests_pytorch/callbacks/test_finetuning_callback.py b/tests/tests_pytorch/callbacks/test_finetuning_callback.py index 9a5813575d709..f56e35d684b2f 100644 --- a/tests/tests_pytorch/callbacks/test_finetuning_callback.py +++ b/tests/tests_pytorch/callbacks/test_finetuning_callback.py @@ -368,6 +368,7 @@ def test_callbacks_restore(tmpdir): } if _TORCH_GREATER_EQUAL_1_11: expected["maximize"] = False + assert callback._internal_optimizer_metadata[0][0] == expected # new param group diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index 8ffefb9bda3f8..632a3a7c1b41f 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -162,9 +162,11 @@ def test_swa_callback_ddp_cpu(tmpdir): train_with_swa(tmpdir, strategy="ddp_spawn", accelerator="cpu", devices=2) -@RunIf(min_cuda_gpus=1) -def test_swa_callback_1_gpu(tmpdir): - train_with_swa(tmpdir, accelerator="gpu", devices=1) +@pytest.mark.parametrize( + "accelerator", [pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), pytest.param("mps", marks=RunIf(mps=True))] +) +def test_swa_callback_1_gpu(tmpdir, accelerator): + train_with_swa(tmpdir, accelerator=accelerator, devices=1) @pytest.mark.parametrize("batchnorm", (True, False)) diff --git a/tests/tests_pytorch/callbacks/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/test_tqdm_progress_bar.py index 43e68011dbc56..55c047075f918 100644 --- a/tests/tests_pytorch/callbacks/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/test_tqdm_progress_bar.py @@ -30,8 +30,14 @@ from pytorch_lightning.core.module import LightningModule from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from tests_pytorch.helpers.runif import RunIf +if _TORCH_GREATER_EQUAL_1_12: + torch_test_assert_close = torch.testing.assert_close +else: + torch_test_assert_close = torch.testing.assert_allclose + class MockTqdm(Tqdm): def __init__(self, *args, **kwargs): @@ -415,7 +421,7 @@ def training_step(self, batch, batch_idx): ) trainer.fit(TestModel()) - torch.testing.assert_allclose(trainer.progress_bar_metrics["a"], 0.123) + torch_test_assert_close(trainer.progress_bar_metrics["a"], 0.123) assert trainer.progress_bar_metrics["b"] == {"b1": 1.0} assert trainer.progress_bar_metrics["c"] == {"c1": 2.0} pbar = trainer.progress_bar_callback.main_progress_bar diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index bfcb49a306d88..04299fa582869 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -265,13 +265,19 @@ def test_full_loop(tmpdir): assert result[0]["test_acc"] > 0.6 -@RunIf(min_cuda_gpus=1) +@pytest.mark.parametrize( + "accelerator,device", + [ + pytest.param("gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps", "mps:0", marks=RunIf(mps=True)), + ], +) @mock.patch( "pytorch_lightning.strategies.Strategy.lightning_module", new_callable=PropertyMock, ) -def test_dm_apply_batch_transfer_handler(get_module_mock): - expected_device = torch.device("cuda", 0) +def test_dm_apply_batch_transfer_handler(get_module_mock, accelerator, device): + expected_device = torch.device(device) class CustomBatch: def __init__(self, data): @@ -312,7 +318,7 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx): batch = CustomBatch((torch.zeros(5, 32), torch.ones(5, 1, dtype=torch.long))) - trainer = Trainer(accelerator="gpu", devices=1) + trainer = Trainer(accelerator=accelerator, devices=1) model.trainer = trainer # running .fit() would require us to implement custom data loaders, we mock the model reference instead get_module_mock.return_value = model diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 373376191ca11..56dbbafa47c3e 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -276,11 +276,17 @@ def configure_optimizers(self): trainer.fit(model) -@RunIf(min_cuda_gpus=1) -def test_device_placement(tmpdir): +@pytest.mark.parametrize( + "accelerator,device", + [ + pytest.param("gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps", "mps:0", marks=RunIf(mps=True)), + ], +) +def test_device_placement(tmpdir, accelerator, device): model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator="gpu", devices=1) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator, devices=1) trainer.fit(model) def assert_device(device: torch.device) -> None: @@ -289,8 +295,8 @@ def assert_device(device: torch.device) -> None: assert p.device == device assert_device(torch.device("cpu")) - model.to(torch.device("cuda:0")) - assert_device(torch.device("cuda:0")) + model.to(torch.device(device)) + assert_device(torch.device(device)) trainer.test(model) assert_device(torch.device("cpu")) trainer.predict(model, dataloaders=model.train_dataloader()) diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index 78e2d6fa2541b..e6b3f8a1c682e 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -298,8 +298,15 @@ def lightning_log(fx, *args, **kwargs): batch_idx = None -@pytest.mark.parametrize("device", ("cpu", pytest.param("cuda", marks=RunIf(min_cuda_gpus=1)))) -def test_lightning_module_logging_result_collection(tmpdir, device): +@pytest.mark.parametrize( + "accelerator,device", + ( + ("cpu", "cpu"), + pytest.param("gpu", "cuda", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps", "mps", marks=RunIf(mps=True)), + ), +) +def test_lightning_module_logging_result_collection(tmpdir, accelerator, device): class LoggingModel(BoringModel): def __init__(self): super().__init__() @@ -348,7 +355,7 @@ def on_save_checkpoint(self, checkpoint) -> None: limit_train_batches=2, limit_val_batches=2, callbacks=[ckpt], - accelerator="gpu" if device == "cuda" else "cpu", + accelerator=accelerator, devices=1, ) trainer.fit(model) @@ -474,10 +481,15 @@ def test_result_collection_reload(tmpdir): result_collection_reload(default_root_dir=tmpdir) -@RunIf(min_cuda_gpus=1) +@pytest.mark.parametrize( + "accelerator", + [ + pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), + ], +) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -def test_result_collection_reload_1_gpu_ddp(tmpdir): - result_collection_reload(default_root_dir=tmpdir, strategy="ddp", accelerator="gpu") +def test_result_collection_reload_1_gpu_ddp(tmpdir, accelerator): + result_collection_reload(default_root_dir=tmpdir, strategy="ddp", accelerator=accelerator) @RunIf(min_cuda_gpus=2, standalone=True) diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index 1fc7ca893bf98..669fbac19b431 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -20,6 +20,7 @@ from packaging.version import Version from pkg_resources import get_distribution +from pytorch_lightning.accelerators.mps import _MPS_AVAILABLE from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.imports import ( _APEX_AVAILABLE, @@ -74,6 +75,7 @@ def __new__( tpu: bool = False, ipu: bool = False, hpu: bool = False, + mps: Optional[bool] = None, horovod: bool = False, horovod_nccl: bool = False, skip_windows: bool = False, @@ -102,6 +104,8 @@ def __new__( tpu: Require that TPU is available. ipu: Require that IPU is available. hpu: Require that HPU is available. + mps: If True: Require that MPS (Apple Silicon) is available, + if False: Explicitly Require that MPS is not available horovod: Require that Horovod is installed. horovod_nccl: Require that Horovod is installed with NCCL support. skip_windows: Skip for Windows platform. @@ -180,6 +184,14 @@ def __new__( conditions.append(not _HPU_AVAILABLE) reasons.append("HPU") + if mps is not None: + if mps: + conditions.append(not _MPS_AVAILABLE) + reasons.append("MPS") + else: + conditions.append(_MPS_AVAILABLE) + reasons.append("not MPS") + if horovod: conditions.append(not _HOROVOD_AVAILABLE) reasons.append("Horovod") diff --git a/tests/tests_pytorch/lite/test_lite.py b/tests/tests_pytorch/lite/test_lite.py index f38ec9c294e02..7166be0981846 100644 --- a/tests/tests_pytorch/lite/test_lite.py +++ b/tests/tests_pytorch/lite/test_lite.py @@ -312,29 +312,32 @@ def test_setup_dataloaders_replace_standard_sampler(shuffle, strategy): @pytest.mark.parametrize( "accelerator, expected", [ - ("cpu", torch.device("cpu")), - pytest.param("gpu", torch.device("cuda", 0), marks=RunIf(min_cuda_gpus=1)), - pytest.param("tpu", torch.device("xla", 0), marks=RunIf(tpu=True)), + ("cpu", "cpu"), + pytest.param("gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param("tpu", "xla:0", marks=RunIf(tpu=True)), + pytest.param("mps", "mps:0", marks=RunIf(mps=True)), ], ) def test_to_device(accelerator, expected): """Test that the to_device method can move various objects to the device determined by the accelerator.""" lite = EmptyLite(accelerator=accelerator, devices=1) + expected_device = torch.device(expected) + # module module = torch.nn.Linear(2, 3) module = lite.to_device(module) - assert all(param.device == expected for param in module.parameters()) + assert all(param.device == expected_device for param in module.parameters()) # tensor tensor = torch.rand(2, 2) tensor = lite.to_device(tensor) - assert tensor.device == expected + assert tensor.device == expected_device # collection collection = {"data": torch.rand(2, 2), "int": 1} collection = lite.to_device(collection) - assert collection["data"].device == expected + assert collection["data"].device == expected_device def test_rank_properties(): @@ -410,7 +413,7 @@ def run(self): optimizer.step() for mw_b, mw_a in zip(state_dict.values(), model.state_dict().values()): - assert not torch.equal(mw_b, mw_a) + assert not torch.allclose(mw_b, mw_a) self.seed_everything(42) model_1 = BoringModel() @@ -421,7 +424,7 @@ def run(self): optimizer_2 = torch.optim.SGD(model_2.parameters(), lr=0.0001) for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()): - assert torch.equal(mw_1, mw_2) + assert torch.allclose(mw_1, mw_2) model_1, optimizer_1 = self.setup(model_1, optimizer_1) model_2, optimizer_2 = self.setup(model_2, optimizer_2) @@ -438,7 +441,7 @@ def run(self): optimizer_1.step() for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()): - assert not torch.equal(mw_1, mw_2) + assert not torch.allclose(mw_1, mw_2) for data in data_list: optimizer_2.zero_grad() @@ -448,11 +451,11 @@ def run(self): optimizer_2.step() for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()): - assert torch.equal(mw_1, mw_2) + assert torch.allclose(mw_1, mw_2) # Verify collectives works as expected ranks = self.all_gather(torch.tensor([self.local_rank]).to(self.device)) - assert torch.equal(ranks.cpu(), torch.tensor([[0], [1]])) + assert torch.allclose(ranks.cpu(), torch.tensor([[0], [1]])) assert self.broadcast(True) assert self.is_global_zero == (self.local_rank == 0) diff --git a/tests/tests_pytorch/lite/test_parity.py b/tests/tests_pytorch/lite/test_parity.py index 8a9f8d5abb5e5..e294094799196 100644 --- a/tests/tests_pytorch/lite/test_parity.py +++ b/tests/tests_pytorch/lite/test_parity.py @@ -112,6 +112,7 @@ def precision_context(precision, accelerator) -> Generator[None, None, None]: pytest.param(32, None, 1, "gpu", marks=RunIf(min_cuda_gpus=1)), pytest.param(16, None, 1, "gpu", marks=RunIf(min_cuda_gpus=1)), pytest.param("bf16", None, 1, "gpu", marks=RunIf(min_cuda_gpus=1, min_torch="1.10", bf16_cuda=True)), + pytest.param(32, None, 1, "mps", marks=RunIf(mps=True)), ], ) def test_boring_lite_model_single_device(precision, strategy, devices, accelerator, tmpdir): diff --git a/tests/tests_pytorch/lite/test_wrappers.py b/tests/tests_pytorch/lite/test_wrappers.py index 953d6bb9a7372..1098d2f0f8459 100644 --- a/tests/tests_pytorch/lite/test_wrappers.py +++ b/tests/tests_pytorch/lite/test_wrappers.py @@ -69,26 +69,47 @@ def __init__(self): _ = lite_module.not_exists -@RunIf(min_cuda_gpus=1) @pytest.mark.parametrize( - "precision, input_type, expected_type", + "precision, input_type, expected_type, accelerator, device_str", [ - (32, torch.float16, torch.float32), - (32, torch.float32, torch.float32), - (32, torch.float64, torch.float32), - (32, torch.int, torch.int), - (16, torch.float32, torch.float16), - (16, torch.float64, torch.float16), - (16, torch.long, torch.long), - pytest.param("bf16", torch.float32, torch.bfloat16, marks=RunIf(min_torch="1.10", bf16_cuda=True)), - pytest.param("bf16", torch.float64, torch.bfloat16, marks=RunIf(min_torch="1.10", bf16_cuda=True)), - pytest.param("bf16", torch.bool, torch.bool, marks=RunIf(min_torch="1.10", bf16_cuda=True)), + pytest.param(32, torch.float16, torch.float32, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param(32, torch.float32, torch.float32, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param(32, torch.float64, torch.float32, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param(32, torch.int, torch.int, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param(16, torch.float32, torch.float16, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param(16, torch.float64, torch.float16, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param(16, torch.long, torch.long, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param( + "bf16", + torch.float32, + torch.bfloat16, + "gpu", + "cuda:0", + marks=RunIf(min_cuda_gpus=1, min_torch="1.10", bf16_cuda=True), + ), + pytest.param( + "bf16", + torch.float64, + torch.bfloat16, + "gpu", + "cuda:0", + marks=RunIf(min_cuda_gpus=1, min_torch="1.10", bf16_cuda=True), + ), + pytest.param( + "bf16", + torch.bool, + torch.bool, + "gpu", + "cuda:0", + marks=RunIf(min_cuda_gpus=1, min_torch="1.10", bf16_cuda=True), + ), + pytest.param(32, torch.float32, torch.float32, "mps", "mps:0", marks=RunIf(mps=True)), ], ) -def test_lite_module_forward_conversion(precision, input_type, expected_type): +def test_lite_module_forward_conversion(precision, input_type, expected_type, accelerator, device_str): """Test that the LiteModule performs autocasting on the input tensors and during forward().""" - lite = EmptyLite(precision=precision, accelerator="gpu", devices=1) - device = torch.device("cuda", 0) + lite = EmptyLite(precision=precision, accelerator=accelerator, devices=1) + device = torch.device(device_str) def check_autocast(forward_input): assert precision != 16 or torch.is_autocast_enabled() @@ -102,12 +123,19 @@ def check_autocast(forward_input): @pytest.mark.parametrize( - "device", [torch.device("cpu"), pytest.param(torch.device("cuda", 0), marks=RunIf(min_cuda_gpus=1))] + "device_str", + [ + "cpu", + pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps", marks=RunIf(mps=True)), + ], ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) -def test_lite_module_device_dtype_propagation(device, dtype): +def test_lite_module_device_dtype_propagation(device_str, dtype): """Test that the LiteModule propagates device and dtype properties to its submodules (e.g. torchmetrics).""" + device = torch.device(device_str) + class DeviceModule(DeviceDtypeModuleMixin): pass @@ -144,15 +172,20 @@ def test_lite_dataloader_iterator(): @pytest.mark.parametrize( - "src_device, dest_device", + "src_device_str, dest_device_str", [ - (torch.device("cpu"), torch.device("cpu")), - pytest.param(torch.device("cpu"), torch.device("cuda", 0), marks=RunIf(min_cuda_gpus=1)), - pytest.param(torch.device("cuda", 0), torch.device("cpu"), marks=RunIf(min_cuda_gpus=1)), + ("cpu", "cpu"), + pytest.param("cpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param("cuda:0", "cpu", marks=RunIf(min_cuda_gpus=1)), + pytest.param("cpu", "mps", marks=RunIf(mps=True)), + pytest.param("mps", "cpu", marks=RunIf(mps=True)), ], ) -def test_lite_dataloader_device_placement(src_device, dest_device): +def test_lite_dataloader_device_placement(src_device_str, dest_device_str): """Test that the LiteDataLoader moves data to the device in its iterator.""" + src_device = torch.device(src_device_str) + dest_device = torch.device(dest_device_str) + sample0 = torch.tensor(0, device=src_device) sample1 = torch.tensor(1, device=src_device) sample2 = {"data": torch.tensor(2, device=src_device)} diff --git a/tests/tests_pytorch/loops/test_all.py b/tests/tests_pytorch/loops/test_all.py index dabaa81c489cc..1cab9e9fdf2a7 100644 --- a/tests/tests_pytorch/loops/test_all.py +++ b/tests/tests_pytorch/loops/test_all.py @@ -11,65 +11,83 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from pytorch_lightning import Callback, Trainer from pytorch_lightning.demos.boring_classes import BoringModel from tests_pytorch.helpers.runif import RunIf +def _device_check_helper(batch_device, module_device): + assert batch_device.type == module_device.type + if batch_device.index is not None and module_device.index is not None: + assert batch_device.index == module_device.index + else: + # devices with index None are the same as with index 0 + assert batch_device.index in (0, None) + assert module_device.index in (0, None) + + class BatchHookObserverCallback(Callback): def on_train_batch_start(self, trainer, pl_module, batch, *args): - assert batch.device == pl_module.device + _device_check_helper(batch.device, pl_module.device) def on_train_batch_end(self, trainer, pl_module, outputs, batch, *args): - assert batch.device == pl_module.device + _device_check_helper(batch.device, pl_module.device) def on_validation_batch_start(self, trainer, pl_module, batch, *args): - assert batch.device == pl_module.device + _device_check_helper(batch.device, pl_module.device) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, *args): - assert batch.device == pl_module.device + _device_check_helper(batch.device, pl_module.device) def on_test_batch_start(self, trainer, pl_module, batch, *args): - assert batch.device == pl_module.device + _device_check_helper(batch.device, pl_module.device) def on_test_batch_end(self, trainer, pl_module, outputs, batch, *args): - assert batch.device == pl_module.device + _device_check_helper(batch.device, pl_module.device) def on_predict_batch_start(self, trainer, pl_module, batch, *args): - assert batch.device == pl_module.device + _device_check_helper(batch.device, pl_module.device) def on_predict_batch_end(self, trainer, pl_module, outputs, batch, *args): - assert batch.device == pl_module.device + _device_check_helper(batch.device, pl_module.device) class BatchHookObserverModel(BoringModel): def on_train_batch_start(self, batch, *args): - assert batch.device == self.device + _device_check_helper(batch.device, self.device) def on_train_batch_end(self, outputs, batch, *args): - assert batch.device == self.device + _device_check_helper(batch.device, self.device) def on_validation_batch_start(self, batch, *args): - assert batch.device == self.device + _device_check_helper(batch.device, self.device) def on_validation_batch_end(self, outputs, batch, *args): - assert batch.device == self.device + _device_check_helper(batch.device, self.device) def on_test_batch_start(self, batch, *args): - assert batch.device == self.device + _device_check_helper(batch.device, self.device) def on_test_batch_end(self, outputs, batch, *args): - assert batch.device == self.device + _device_check_helper(batch.device, self.device) def on_predict_batch_start(self, batch, *args): - assert batch.device == self.device + _device_check_helper(batch.device, self.device) def on_predict_batch_end(self, outputs, batch, *args): - assert batch.device == self.device + _device_check_helper(batch.device, self.device) -@RunIf(min_cuda_gpus=1) -def test_callback_batch_on_device(tmpdir): +@pytest.mark.parametrize( + "accelerator", + [ + pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps", marks=RunIf(mps=True)), + ], +) +def test_callback_batch_on_device(tmpdir, accelerator): """Test that the batch object sent to the on_*_batch_start/end hooks is on the right device.""" batch_callback = BatchHookObserverCallback() @@ -82,7 +100,7 @@ def test_callback_batch_on_device(tmpdir): limit_val_batches=1, limit_test_batches=1, limit_predict_batches=1, - accelerator="gpu", + accelerator=accelerator, devices=1, callbacks=[batch_callback], ) diff --git a/tests/tests_pytorch/loops/test_evaluation_loop.py b/tests/tests_pytorch/loops/test_evaluation_loop.py index cd531aaa2f80b..4ab898699f478 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop.py @@ -93,7 +93,11 @@ def on_validation_end(self): @RunIf(min_cuda_gpus=1) def test_memory_consumption_validation(tmpdir): - """Test that the training batch is no longer in GPU memory when running validation.""" + """Test that the training batch is no longer in GPU memory when running validation. + + Cannot run with MPS, since there we can only measure shared memory and not dedicated, which device has how much + memory allocated. + """ initial_memory = torch.cuda.memory_allocated(0) diff --git a/tests/tests_pytorch/models/test_gpu.py b/tests/tests_pytorch/models/test_gpu.py index 5868ff2079228..ffd093e6ee0e3 100644 --- a/tests/tests_pytorch/models/test_gpu.py +++ b/tests/tests_pytorch/models/test_gpu.py @@ -147,30 +147,30 @@ def test_determine_root_gpu_device(devices, expected_root_gpu): ], ) def test_parse_gpu_ids(mocked_device_count, devices, expected_gpu_ids): - assert device_parser.parse_gpu_ids(devices) == expected_gpu_ids + assert device_parser.parse_gpu_ids(devices, include_cuda=True) == expected_gpu_ids @pytest.mark.parametrize("devices", [0.1, -2, False, [-1], [None], ["0"], [0, 0]]) def test_parse_gpu_fail_on_unsupported_inputs(mocked_device_count, devices): with pytest.raises(MisconfigurationException): - device_parser.parse_gpu_ids(devices) + device_parser.parse_gpu_ids(devices, include_cuda=True) @pytest.mark.parametrize("devices", [[1, 2, 19], -1, "-1"]) def test_parse_gpu_fail_on_non_existent_id(mocked_device_count_0, devices): with pytest.raises(MisconfigurationException): - device_parser.parse_gpu_ids(devices) + device_parser.parse_gpu_ids(devices, include_cuda=True) def test_parse_gpu_fail_on_non_existent_id_2(mocked_device_count): with pytest.raises(MisconfigurationException): - device_parser.parse_gpu_ids([1, 2, 19]) + device_parser.parse_gpu_ids([1, 2, 19], include_cuda=True) @pytest.mark.parametrize("devices", [-1, "-1"]) def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_count_0, devices): with pytest.raises(MisconfigurationException): - device_parser.parse_gpu_ids(devices) + device_parser.parse_gpu_ids(devices, include_cuda=True) @mock.patch.dict( @@ -195,10 +195,10 @@ def test_torchelastic_gpu_parsing(mocked_device_count, mocked_is_available, gpus trainer = Trainer(gpus=gpus) assert isinstance(trainer._accelerator_connector.cluster_environment, TorchElasticEnvironment) # when use gpu - if device_parser.parse_gpu_ids(gpus) is not None: + if device_parser.parse_gpu_ids(gpus, include_cuda=True) is not None: assert isinstance(trainer.accelerator, GPUAccelerator) assert trainer.num_devices == len(gpus) if isinstance(gpus, list) else gpus - assert trainer.device_ids == device_parser.parse_gpu_ids(gpus) + assert trainer.device_ids == device_parser.parse_gpu_ids(gpus, include_cuda=True) # fall back to cpu else: assert isinstance(trainer.accelerator, CPUAccelerator) @@ -284,7 +284,6 @@ def to(self, *args, **kwargs): examples = [Example.fromdict(sample, fields) for sample in samples] dataset = Dataset(examples=examples, fields=fields.values()) - # Batch runs field.process() that numericalizes tokens, but it requires to build dictionary first text_field.build_vocab(dataset) label_field.build_vocab(dataset) diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 39b18cf0d2ad4..a2235c592d5fb 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -111,13 +111,19 @@ def on_train_batch_end(self, outputs, batch, batch_idx): assert overridden_model.len_outputs == overridden_model.num_train_batches -@RunIf(min_cuda_gpus=1) +@pytest.mark.parametrize( + "accelerator,expected_device_str", + [ + pytest.param("gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps", "mps:0", marks=RunIf(mps=True)), + ], +) @mock.patch( "pytorch_lightning.strategies.Strategy.lightning_module", new_callable=PropertyMock, ) -def test_apply_batch_transfer_handler(model_getter_mock): - expected_device = torch.device("cuda", 0) +def test_apply_batch_transfer_handler(model_getter_mock, accelerator, expected_device_str): + expected_device = torch.device(expected_device_str) class CustomBatch: def __init__(self, data): @@ -156,7 +162,7 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx): model = CurrentTestModel() batch = CustomBatch((torch.zeros(5, 32), torch.ones(5, 1, dtype=torch.long))) - trainer = Trainer(accelerator="gpu", devices=1) + trainer = Trainer(accelerator=accelerator, devices=1) # running .fit() would require us to implement custom data loaders, we mock the model reference instead model_getter_mock.return_value = model diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index 7428b4b976ad9..9b379d6d34c99 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -39,11 +39,13 @@ def test_model_saves_with_input_sample(tmpdir): assert os.path.getsize(file_path) > 4e2 -@RunIf(min_cuda_gpus=1) -def test_model_saves_on_gpu(tmpdir): +@pytest.mark.parametrize( + "accelerator", [pytest.param("mps", marks=RunIf(mps=True)), pytest.param("gpu", marks=RunIf(min_cuda_gpus=True))] +) +def test_model_saves_on_gpu(tmpdir, accelerator): """Test that model saves on gpu.""" model = BoringModel() - trainer = Trainer(accelerator="gpu", devices=1, fast_dev_run=True) + trainer = Trainer(accelerator=accelerator, devices=1, fast_dev_run=True) trainer.fit(model) file_path = os.path.join(tmpdir, "model.onnx") diff --git a/tests/tests_pytorch/models/test_torchscript.py b/tests/tests_pytorch/models/test_torchscript.py index 127664af332ca..150ea86044be6 100644 --- a/tests/tests_pytorch/models/test_torchscript.py +++ b/tests/tests_pytorch/models/test_torchscript.py @@ -78,10 +78,17 @@ def test_torchscript_input_output_trace(): assert torch.allclose(script_output, model_output) -@RunIf(min_cuda_gpus=1) -@pytest.mark.parametrize("device", [torch.device("cpu"), torch.device("cuda", 0)]) -def test_torchscript_device(device): +@pytest.mark.parametrize( + "device_str", + [ + "cpu", + pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps:0", marks=RunIf(mps=True)), + ], +) +def test_torchscript_device(device_str): """Test that scripted module is on the correct device.""" + device = torch.device(device_str) model = BoringModel().to(device) model.example_input_array = torch.randn(5, 32) diff --git a/tests/plugins/environments/test_xla_environment.py b/tests/tests_pytorch/plugins/environments/test_xla_environment.py similarity index 98% rename from tests/plugins/environments/test_xla_environment.py rename to tests/tests_pytorch/plugins/environments/test_xla_environment.py index 21ef9bb5bf171..8c6bae204ed17 100644 --- a/tests/plugins/environments/test_xla_environment.py +++ b/tests/tests_pytorch/plugins/environments/test_xla_environment.py @@ -15,10 +15,10 @@ from unittest import mock import pytest -from tests.helpers.runif import RunIf import pytorch_lightning as pl from pytorch_lightning.plugins.environments import XLAEnvironment +from tests_pytorch.helpers.runif import RunIf @RunIf(tpu=True) diff --git a/tests/tests_pytorch/strategies/test_common.py b/tests/tests_pytorch/strategies/test_common.py index 489ecaed2968d..479b222e25a9d 100644 --- a/tests/tests_pytorch/strategies/test_common.py +++ b/tests/tests_pytorch/strategies/test_common.py @@ -18,11 +18,17 @@ from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from pytorch_lightning.utilities.seed import seed_everything from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.strategies.test_dp import CustomClassificationModelDP +if _TORCH_GREATER_EQUAL_1_12: + torch_test_assert_close = torch.testing.assert_close +else: + torch_test_assert_close = torch.testing.assert_allclose + @pytest.mark.parametrize( "trainer_kwargs", @@ -30,6 +36,7 @@ pytest.param(dict(accelerator="gpu", devices=1), marks=RunIf(min_cuda_gpus=1)), pytest.param(dict(strategy="dp", accelerator="gpu", devices=2), marks=RunIf(min_cuda_gpus=2)), pytest.param(dict(strategy="ddp_spawn", accelerator="gpu", devices=2), marks=RunIf(min_cuda_gpus=2)), + pytest.param(dict(accelerator="mps", devices=1), marks=RunIf(mps=True)), ), ) def test_evaluate(tmpdir, trainer_kwargs): @@ -51,7 +58,7 @@ def test_evaluate(tmpdir, trainer_kwargs): # make sure weights didn't change new_weights = model.layer_0.weight.clone().detach().cpu() - torch.testing.assert_allclose(old_weights, new_weights) + torch_test_assert_close(old_weights, new_weights) def test_model_parallel_setup_called(tmpdir): diff --git a/tests/tests_pytorch/strategies/test_single_device_strategy.py b/tests/tests_pytorch/strategies/test_single_device_strategy.py index 65933e30a87d7..0c41e2b17bc6a 100644 --- a/tests/tests_pytorch/strategies/test_single_device_strategy.py +++ b/tests/tests_pytorch/strategies/test_single_device_strategy.py @@ -38,7 +38,10 @@ def on_train_start(self) -> None: @RunIf(min_cuda_gpus=1, skip_windows=True) def test_single_gpu(): - """Tests if device is set correctly when training and after teardown for single GPU strategy.""" + """Tests if device is set correctly when training and after teardown for single GPU strategy. + + Cannot run this test on MPS due to shared memory not allowing dedicated measurements of GPU memory utilization. + """ trainer = Trainer(accelerator="gpu", devices=1, fast_dev_run=True) # assert training strategy attributes for device setting assert isinstance(trainer.strategy, SingleDeviceStrategy) diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py index b21abf51e8464..7d91870c067e5 100644 --- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py @@ -691,8 +691,14 @@ def val_dataloader(self): trainer.fit(model) -@RunIf(min_cuda_gpus=1) -def test_evaluation_move_metrics_to_cpu_and_outputs(tmpdir): +@pytest.mark.parametrize( + "accelerator", + [ + pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps", marks=RunIf(mps=True)), + ], +) +def test_evaluation_move_metrics_to_cpu_and_outputs(tmpdir, accelerator): class TestModel(BoringModel): def validation_step(self, *args): x = torch.tensor(2.0, requires_grad=True, device=self.device) @@ -705,13 +711,13 @@ def validation_step(self, *args): def validation_epoch_end(self, outputs): # the step outputs were not moved - assert all(o.device == self.device for o in outputs), outputs + assert all(o.device == self.device for o in outputs) # but the logging results were assert self.trainer.callback_metrics["foo"].device.type == "cpu" model = TestModel() trainer = Trainer( - default_root_dir=tmpdir, limit_val_batches=2, move_metrics_to_cpu=True, accelerator="gpu", devices=1 + default_root_dir=tmpdir, limit_val_batches=2, move_metrics_to_cpu=True, accelerator=accelerator, devices=1 ) trainer.validate(model, verbose=False) diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py index 1d5f10f571339..1b805682cd784 100644 --- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py @@ -398,8 +398,16 @@ def validation_step(self, batch, batch_idx): return super().validation_step(batch, batch_idx) -@pytest.mark.parametrize("devices", [1, pytest.param(2, marks=RunIf(min_cuda_gpus=2, skip_windows=True))]) -def test_logging_sync_dist_true(tmpdir, devices): +@pytest.mark.parametrize( + "devices, accelerator", + [ + (1, "cpu"), + (2, "cpu"), + pytest.param(2, "gpu", marks=RunIf(min_cuda_gpus=2)), + ], +) +def test_logging_sync_dist_true(tmpdir, devices, accelerator): + """Tests to ensure that the sync_dist flag works (should just return the original value)""" fake_result = 1 model = LoggingSyncDistModel(fake_result) @@ -412,7 +420,7 @@ def test_logging_sync_dist_true(tmpdir, devices): limit_val_batches=3, enable_model_summary=False, strategy="ddp_spawn" if use_multiple_devices else None, - accelerator="auto", + accelerator=accelerator, devices=devices, ) trainer.fit(model) @@ -556,8 +564,14 @@ def on_train_epoch_end(self, trainer, pl_module): assert trainer.callback_metrics == expected -@RunIf(min_cuda_gpus=1) -def test_metric_are_properly_reduced(tmpdir): +# mps not yet supported by torchmetrics, see https://github.com/PyTorchLightning/metrics/issues/1044 +@pytest.mark.parametrize( + "accelerator", + [ + pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), + ], +) +def test_metric_are_properly_reduced(tmpdir, accelerator): class TestingModel(BoringModel): def __init__(self, *args, **kwargs) -> None: super().__init__() @@ -584,7 +598,7 @@ def validation_step(self, batch, batch_idx): model = TestingModel() trainer = Trainer( default_root_dir=tmpdir, - accelerator="gpu", + accelerator=accelerator, devices=1, max_epochs=2, limit_train_batches=5, diff --git a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py index fd44184645b12..43edff94b171a 100644 --- a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py +++ b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py @@ -24,8 +24,14 @@ from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.strategies import Strategy +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from tests_pytorch.helpers.runif import RunIf +if _TORCH_GREATER_EQUAL_1_12: + torch_test_assert_close = torch.testing.assert_close +else: + torch_test_assert_close = torch.testing.assert_allclose + class ManualOptModel(BoringModel): def __init__(self): @@ -198,8 +204,9 @@ def training_epoch_end(self, outputs) -> None: assert set(trainer.logged_metrics) == {"a_step", "a_epoch"} -@RunIf(min_cuda_gpus=1) -def test_multiple_optimizers_manual_native_amp(tmpdir): +# precision = 16 not yet working properly with mps backend +@pytest.mark.parametrize("accelerator", [pytest.param("gpu", marks=RunIf(min_cuda_gpus=1))]) +def test_multiple_optimizers_manual_native_amp(tmpdir, accelerator): model = ManualOptModel() model.val_dataloader = None @@ -212,7 +219,7 @@ def test_multiple_optimizers_manual_native_amp(tmpdir): log_every_n_steps=1, enable_model_summary=False, precision=16, - accelerator="gpu", + accelerator=accelerator, devices=1, ) diff --git a/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py index 2f897171c4d5e..662f577e59975 100644 --- a/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py @@ -17,6 +17,12 @@ import pytorch_lightning as pl from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 + +if _TORCH_GREATER_EQUAL_1_12: + torch_test_assert_close = torch.testing.assert_close +else: + torch_test_assert_close = torch.testing.assert_allclose class MultiOptModel(BoringModel): @@ -52,7 +58,7 @@ def training_step(self, batch, batch_idx, optimizer_idx): for k, v in model.actual.items(): assert torch.equal(trainer.callback_metrics[f"loss_{k}_step"], v[-1]) # test loss is properly reduced - torch.testing.assert_allclose(trainer.callback_metrics[f"loss_{k}_epoch"], torch.tensor(v).mean()) + torch_test_assert_close(trainer.callback_metrics[f"loss_{k}_epoch"], torch.tensor(v).mean()) def test_multiple_optimizers(tmpdir): diff --git a/tests/tests_pytorch/trainer/properties/test_get_model.py b/tests/tests_pytorch/trainer/properties/test_get_model.py index f6aec2210adaa..4d095234253e0 100644 --- a/tests/tests_pytorch/trainer/properties/test_get_model.py +++ b/tests/tests_pytorch/trainer/properties/test_get_model.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel from tests_pytorch.helpers.runif import RunIf @@ -56,8 +58,14 @@ def test_get_model_ddp_cpu(tmpdir): trainer.fit(model) -@RunIf(min_cuda_gpus=1) -def test_get_model_gpu(tmpdir): +@pytest.mark.parametrize( + "accelerator", + [ + pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps", marks=RunIf(mps=True)), + ], +) +def test_get_model_gpu(tmpdir, accelerator): """Tests that `trainer.lightning_module` extracts the model correctly when using GPU.""" model = TrainerGetModel() @@ -68,7 +76,7 @@ def test_get_model_gpu(tmpdir): limit_train_batches=limit_train_batches, limit_val_batches=2, max_epochs=1, - accelerator="gpu", + accelerator=accelerator, devices=1, ) trainer.fit(model) diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 3069b589bb448..5966f4a41267e 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -56,7 +56,7 @@ from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException -from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE +from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_1_12 from pytorch_lightning.utilities.seed import seed_everything from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.datasets import RandomIterableDataset, RandomIterableDatasetWithLen @@ -66,6 +66,11 @@ if _OMEGACONF_AVAILABLE: from omegaconf import OmegaConf +if _TORCH_GREATER_EQUAL_1_12: + torch_test_assert_close = torch.testing.assert_close +else: + torch_test_assert_close = torch.testing.assert_allclose + @pytest.mark.parametrize("url_ckpt", [True, False]) def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): @@ -1067,7 +1072,7 @@ def configure_gradient_clipping(self, *args, **kwargs): # test that gradient is clipped correctly parameters = self.parameters() grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) - torch.testing.assert_allclose(grad_norm, torch.tensor(0.05, device=self.device)) + torch_test_assert_close(grad_norm, torch.tensor(0.05, device=self.device)) self.assertion_called = True model = TestModel() @@ -1098,7 +1103,7 @@ def configure_gradient_clipping(self, *args, **kwargs): parameters = self.parameters() grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters] grad_max = torch.max(torch.stack(grad_max_list)) - torch.testing.assert_allclose(grad_max.abs(), torch.tensor(1e-10, device=self.device)) + torch_test_assert_close(grad_max.abs(), torch.tensor(1e-10, device=self.device)) self.assertion_called = True model = TestModel() @@ -1440,9 +1445,15 @@ def test_trainer_predict_standalone(tmpdir, kwargs): predict(tmpdir, accelerator="gpu", **kwargs) -@RunIf(min_cuda_gpus=1) -def test_trainer_predict_1_gpu(tmpdir): - predict(tmpdir, accelerator="gpu", devices=1) +@pytest.mark.parametrize( + "accelerator", + [ + pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps", marks=RunIf(mps=True)), + ], +) +def test_trainer_predict_1_gpu(tmpdir, accelerator): + predict(tmpdir, accelerator=accelerator, devices=1) @RunIf(skip_windows=True) @@ -1529,8 +1540,15 @@ def configure_optimizers(self): trainer.fit(model, train_data) -@RunIf(min_cuda_gpus=1) -def test_setup_hook_move_to_device_correctly(tmpdir): +@pytest.mark.parametrize( + "accelerator", + [ + pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps", marks=RunIf(mps=True)), + ], +) +def test_setup_hook_move_to_device_correctly(tmpdir, accelerator): + """Verify that if a user defines a layer in the setup hook function, this is moved to the correct device.""" class TestModel(BoringModel): @@ -1549,7 +1567,7 @@ def training_step(self, batch, batch_idx): # model model = TestModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator="gpu", devices=1) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator, devices=1) trainer.fit(model, train_data) diff --git a/tests/tests_pytorch/utilities/test_auto_restart.py b/tests/tests_pytorch/utilities/test_auto_restart.py index 59d758f7dc6f4..47051d4efd098 100644 --- a/tests/tests_pytorch/utilities/test_auto_restart.py +++ b/tests/tests_pytorch/utilities/test_auto_restart.py @@ -59,9 +59,14 @@ from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher -from pytorch_lightning.utilities.imports import _fault_tolerant_training +from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_12 from tests_pytorch.helpers.runif import RunIf +if _TORCH_GREATER_EQUAL_1_12: + torch_test_assert_close = torch.testing.assert_close +else: + torch_test_assert_close = torch.testing.assert_allclose + def test_fast_forward_getattr(): dataset = range(15) @@ -967,9 +972,9 @@ def run(should_fail, resume): pre_fail_train_batches, pre_fail_val_batches = run(should_fail=True, resume=False) post_fail_train_batches, post_fail_val_batches = run(should_fail=False, resume=True) - torch.testing.assert_allclose(total_train_batches, pre_fail_train_batches + post_fail_train_batches) + torch_test_assert_close(total_train_batches, pre_fail_train_batches + post_fail_train_batches) for k in total_val_batches: - torch.testing.assert_allclose(total_val_batches[k], pre_fail_val_batches[k] + post_fail_val_batches[k]) + torch_test_assert_close(total_val_batches[k], pre_fail_val_batches[k] + post_fail_val_batches[k]) class TestAutoRestartModelUnderSignal(BoringModel): @@ -1512,6 +1517,6 @@ def configure_optimizers(self): trainer.train_dataloader = None restart_batches = model.batches - torch.testing.assert_allclose(total_batches, failed_batches + restart_batches) + torch_test_assert_close(total_batches, failed_batches + restart_batches) assert not torch.equal(total_weight, failed_weight) assert torch.equal(total_weight, model.layer.weight) diff --git a/tests/tests_pytorch/utilities/test_cli.py b/tests/tests_pytorch/utilities/test_cli.py index 8e801299aa23c..655d9849a64ca 100644 --- a/tests/tests_pytorch/utilities/test_cli.py +++ b/tests/tests_pytorch/utilities/test_cli.py @@ -609,7 +609,8 @@ def on_fit_start(self): raise MisconfigurationException("Error on fit start") -@RunIf(skip_windows=True) +# mps not yet supported by distributed +@RunIf(skip_windows=True, mps=False) @pytest.mark.parametrize("logger", (False, True)) @pytest.mark.parametrize("strategy", ("ddp_spawn", "ddp")) def test_cli_distributed_save_config_callback(tmpdir, logger, strategy): diff --git a/tests/tests_pytorch/utilities/test_dtype_device_mixin.py b/tests/tests_pytorch/utilities/test_dtype_device_mixin.py index ec53816d2d5fc..38f72b555d52d 100644 --- a/tests/tests_pytorch/utilities/test_dtype_device_mixin.py +++ b/tests/tests_pytorch/utilities/test_dtype_device_mixin.py @@ -46,13 +46,25 @@ def on_train_batch_start(self, trainer, model, batch, batch_idx): assert model.device == model.module.module.device -@pytest.mark.parametrize("dst_dtype", [torch.float, torch.double, torch.half]) -@pytest.mark.parametrize("dst_device", [torch.device("cpu"), torch.device("cuda", 0)]) +@pytest.mark.parametrize( + "dst_device_str,dst_dtype", + [ + ("cpu", torch.half), + ("cpu", torch.float), + ("cpu", torch.double), + pytest.param("cuda:0", torch.half, marks=RunIf(min_cuda_gpus=1)), + pytest.param("cuda:0", torch.float, marks=RunIf(min_cuda_gpus=1)), + pytest.param("cuda:0", torch.double, marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps:0", torch.float, marks=RunIf(mps=True)), # double and half are not yet supported. + ], +) @RunIf(min_cuda_gpus=1) -def test_submodules_device_and_dtype(dst_device, dst_dtype): +def test_submodules_device_and_dtype(dst_device_str, dst_dtype): """Test that the device and dtype property updates propagate through mixed nesting of regular nn.Modules and the special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule).""" + dst_device = torch.device(dst_device_str) + model = TopModule() assert model.device == torch.device("cpu") model = model.to(device=dst_device, dtype=dst_dtype) diff --git a/tests/tests_pytorch/utilities/test_fetching.py b/tests/tests_pytorch/utilities/test_fetching.py index 50c9b85970bf0..e9ab01387f7f6 100644 --- a/tests/tests_pytorch/utilities/test_fetching.py +++ b/tests/tests_pytorch/utilities/test_fetching.py @@ -194,8 +194,10 @@ def test_dataloader(self): @pytest.mark.flaky(reruns=3) -@RunIf(min_cuda_gpus=1) -def test_trainer_num_prefetch_batches(tmpdir): +@pytest.mark.parametrize( + "accelerator", [pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), pytest.param("mps", marks=RunIf(mps=True))] +) +def test_trainer_num_prefetch_batches(tmpdir, accelerator): model = RecommenderModel() @@ -211,7 +213,7 @@ def on_train_epoch_end(self, trainer, lightning_module): trainer_kwargs = dict( default_root_dir=tmpdir, max_epochs=1, - accelerator="gpu", + accelerator=accelerator, devices=1, limit_train_batches=4, limit_val_batches=0, diff --git a/tests/tests_pytorch/utilities/test_model_summary.py b/tests/tests_pytorch/utilities/test_model_summary.py index e2f903725369b..daaf929fd07ff 100644 --- a/tests/tests_pytorch/utilities/test_model_summary.py +++ b/tests/tests_pytorch/utilities/test_model_summary.py @@ -155,11 +155,18 @@ def test_empty_model_summary_shapes(max_depth): assert summary.param_nums == [] -@RunIf(min_cuda_gpus=1) @pytest.mark.parametrize("max_depth", [-1, 1]) -@pytest.mark.parametrize("device", [torch.device("cpu"), torch.device("cuda", 0)]) -def test_linear_model_summary_shapes(device, max_depth): +@pytest.mark.parametrize( + "device_str", + [ + "cpu", + pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps:0", marks=RunIf(mps=True)), + ], +) +def test_linear_model_summary_shapes(device_str, max_depth): """Test that the model summary correctly computes the input- and output shapes.""" + device = torch.device(device_str) model = UnorderedModel().to(device) model.train() summary = summarize(model, max_depth=max_depth) @@ -288,13 +295,21 @@ def test_empty_model_size(max_depth): assert 0.0 == summary.model_size -@RunIf(min_cuda_gpus=1) -def test_model_size_precision(tmpdir): +@pytest.mark.parametrize( + "accelerator", + [ + pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps", marks=RunIf(mps=True)), + ], +) +def test_model_size_precision(tmpdir, accelerator): """Test model size for half and full precision.""" model = PreCalculatedModel() # fit model - trainer = Trainer(default_root_dir=tmpdir, accelerator="gpu", devices=1, max_steps=1, max_epochs=1, precision=32) + trainer = Trainer( + default_root_dir=tmpdir, accelerator=accelerator, devices=1, max_steps=1, max_epochs=1, precision=32 + ) trainer.fit(model) summary = summarize(model) assert model.pre_calculated_model_size == summary.model_size