From fe938c522df060a26f946aa712d4cd06bfb3c87d Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Sun, 2 Oct 2022 21:06:41 -0700 Subject: [PATCH 01/14] support DDP_FORK strategy with native AMP by attempting NVML-based CUDA availability assessment --- src/lightning_lite/lite.py | 2 +- .../plugins/precision/native_amp.py | 3 ++ .../plugins/precision/native_amp.py | 3 ++ tests/tests_pytorch/helpers/runif.py | 6 ++- tests/tests_pytorch/helpers/test_models.py | 42 ++++++++++++++++++ tests/tests_pytorch/models/test_amp.py | 43 +------------------ .../tests_pytorch/models/test_ddp_fork_amp.py | 41 ++++++++++++++++++ 7 files changed, 96 insertions(+), 44 deletions(-) create mode 100644 tests/tests_pytorch/models/test_ddp_fork_amp.py diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 04b964e41c5a0..d1b94e1b46fbf 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -26,9 +26,9 @@ from torch.optim import Optimizer from torch.utils.data import BatchSampler, DataLoader, DistributedSampler +from lightning_lite.plugins import Precision # avoid circular imports: # isort: split from lightning_lite.accelerators.accelerator import Accelerator from lightning_lite.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT -from lightning_lite.plugins import Precision from lightning_lite.strategies import DeepSpeedStrategy, Strategy, XLAStrategy from lightning_lite.strategies.strategy import TBroadcast from lightning_lite.utilities import move_data_to_device diff --git a/src/lightning_lite/plugins/precision/native_amp.py b/src/lightning_lite/plugins/precision/native_amp.py index 34b4fb5591724..2fb618ba3c0cd 100644 --- a/src/lightning_lite/plugins/precision/native_amp.py +++ b/src/lightning_lite/plugins/precision/native_amp.py @@ -20,6 +20,7 @@ from torch.optim import LBFGS from typing_extensions import Literal +from lightning_lite.accelerators.cuda import is_cuda_available from lightning_lite.plugins.precision.precision import Precision from lightning_lite.plugins.precision.utils import _convert_fp_tensor from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10 @@ -47,6 +48,8 @@ def __init__( if precision == "bf16" and not _TORCH_GREATER_EQUAL_1_10: raise ImportError("To use bfloat16 with native amp you must install torch greater or equal to 1.10.") if scaler is None and precision == 16: + # if possible, we defer CUDA initialization to support strategies that will attempt forks + torch.cuda.is_available = is_cuda_available scaler = torch.cuda.amp.GradScaler() if scaler is not None and precision == "bf16": raise ValueError(f"`precision='bf16'` does not use a scaler, found {scaler}.") diff --git a/src/pytorch_lightning/plugins/precision/native_amp.py b/src/pytorch_lightning/plugins/precision/native_amp.py index 6127aaed9c7db..0f2ac302a0c82 100644 --- a/src/pytorch_lightning/plugins/precision/native_amp.py +++ b/src/pytorch_lightning/plugins/precision/native_amp.py @@ -19,6 +19,7 @@ from torch.optim import LBFGS import pytorch_lightning as pl +from lightning_lite.accelerators.cuda import is_cuda_available from lightning_lite.utilities.types import Steppable from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType @@ -50,6 +51,8 @@ def __init__( "To use bfloat16 with native amp you must install torch greater or equal to 1.10." ) if scaler is None and precision == 16: + # if possible, we defer CUDA initialization to support strategies that will attempt forks + torch.cuda.is_available = is_cuda_available scaler = torch.cuda.amp.GradScaler() if scaler is not None and precision == "bf16": raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.") diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index 98b1530500bc8..a424194a84639 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -20,6 +20,7 @@ from packaging.version import Version from pkg_resources import get_distribution +from lightning_lite.accelerators.cuda import num_cuda_devices from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.accelerators.mps import MPSAccelerator from pytorch_lightning.accelerators.tpu import TPUAccelerator @@ -61,6 +62,9 @@ def test_wrapper(arg1): assert arg1 > 0.0 """ + # if possible, we defer CUDA initialization to support tests that will attempt forks + torch.cuda.is_available = num_cuda_devices + def __new__( self, *args, @@ -124,7 +128,7 @@ def __new__( reasons = [] if min_cuda_gpus: - conditions.append(torch.cuda.device_count() < min_cuda_gpus) + conditions.append(torch.cuda.is_available() < min_cuda_gpus) reasons.append(f"GPUs>={min_cuda_gpus}") # used in conftest.py::pytest_collection_modifyitems kwargs["min_cuda_gpus"] = True diff --git a/tests/tests_pytorch/helpers/test_models.py b/tests/tests_pytorch/helpers/test_models.py index 0b38e31e0a219..398b641b1bf41 100644 --- a/tests/tests_pytorch/helpers/test_models.py +++ b/tests/tests_pytorch/helpers/test_models.py @@ -14,6 +14,7 @@ import os import pytest +import torch from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel @@ -22,6 +23,47 @@ from tests_pytorch.helpers.simple_models import ClassificationModel, RegressionModel +class AMPTestModel(BoringModel): + def _step(self, batch): + self._assert_autocast_enabled() + output = self(batch) + is_bfloat16 = self.trainer.precision_plugin.precision == "bf16" + assert output.dtype == torch.float16 if not is_bfloat16 else torch.bfloat16 + loss = self.loss(batch, output) + return loss + + def loss(self, batch, prediction): + # todo (sean): convert bfloat16 to float32 as mse loss for cpu amp is currently not supported + if self.trainer.precision_plugin.device == "cpu": + prediction = prediction.float() + return super().loss(batch, prediction) + + def training_step(self, batch, batch_idx): + output = self._step(batch) + return {"loss": output} + + def validation_step(self, batch, batch_idx): + output = self._step(batch) + return {"x": output} + + def test_step(self, batch, batch_idx): + output = self._step(batch) + return {"y": output} + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + self._assert_autocast_enabled() + output = self(batch) + is_bfloat16 = self.trainer.precision_plugin.precision == "bf16" + assert output.dtype == torch.float16 if not is_bfloat16 else torch.bfloat16 + return output + + def _assert_autocast_enabled(self): + if self.trainer.precision_plugin.device == "cpu": + assert torch.is_autocast_cpu_enabled() + else: + assert torch.is_autocast_enabled() + + @pytest.mark.parametrize( "data_class,model_class", [ diff --git a/tests/tests_pytorch/models/test_amp.py b/tests/tests_pytorch/models/test_amp.py index 74bd4c20abeaf..3ba30b045fb81 100644 --- a/tests/tests_pytorch/models/test_amp.py +++ b/tests/tests_pytorch/models/test_amp.py @@ -15,7 +15,6 @@ from unittest import mock import pytest -import torch from torch import optim from torch.utils.data import DataLoader @@ -24,47 +23,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset from tests_pytorch.helpers.runif import RunIf - - -class AMPTestModel(BoringModel): - def _step(self, batch): - self._assert_autocast_enabled() - output = self(batch) - is_bfloat16 = self.trainer.precision_plugin.precision == "bf16" - assert output.dtype == torch.float16 if not is_bfloat16 else torch.bfloat16 - loss = self.loss(batch, output) - return loss - - def loss(self, batch, prediction): - # todo (sean): convert bfloat16 to float32 as mse loss for cpu amp is currently not supported - if self.trainer.precision_plugin.device == "cpu": - prediction = prediction.float() - return super().loss(batch, prediction) - - def training_step(self, batch, batch_idx): - output = self._step(batch) - return {"loss": output} - - def validation_step(self, batch, batch_idx): - output = self._step(batch) - return {"x": output} - - def test_step(self, batch, batch_idx): - output = self._step(batch) - return {"y": output} - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - self._assert_autocast_enabled() - output = self(batch) - is_bfloat16 = self.trainer.precision_plugin.precision == "bf16" - assert output.dtype == torch.float16 if not is_bfloat16 else torch.bfloat16 - return output - - def _assert_autocast_enabled(self): - if self.trainer.precision_plugin.device == "cpu": - assert torch.is_autocast_cpu_enabled() - else: - assert torch.is_autocast_enabled() +from tests_pytorch.helpers.test_models import AMPTestModel @RunIf(min_torch="1.10") diff --git a/tests/tests_pytorch/models/test_ddp_fork_amp.py b/tests/tests_pytorch/models/test_ddp_fork_amp.py new file mode 100644 index 0000000000000..9bf9eb485cbc2 --- /dev/null +++ b/tests/tests_pytorch/models/test_ddp_fork_amp.py @@ -0,0 +1,41 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torch.utils.data import DataLoader + +import tests_pytorch.helpers.utils as tutils +from pytorch_lightning import Trainer +from pytorch_lightning.demos.boring_classes import RandomDataset +from tests_pytorch.helpers.runif import RunIf +from tests_pytorch.helpers.test_models import AMPTestModel + + +# needs to be standalone to avoid other processes initializing CUDA +@RunIf(min_cuda_gpus=2, min_torch="1.12", standalone=True) +def test_amp_gpus_ddp_fork(tmpdir): + """Make sure combinations of AMP and strategies work if supported.""" + tutils.reset_seed() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + accelerator="gpu", + devices=2, + strategy="ddp_fork", + precision=16, + ) + + model = AMPTestModel() + trainer.fit(model) + trainer.test(model) + trainer.predict(model, DataLoader(RandomDataset(32, 64))) From f91d46ca7842864d0c054d84bf3bfdafc1d50f43 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Mon, 3 Oct 2022 15:38:48 -0700 Subject: [PATCH 02/14] update CHANGELOG --- src/pytorch_lightning/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index ace0dc6c4a0ce..fee00b8c299d4 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added native AMP support for `ddp_fork` (and associated alias strategies) with CUDA GPUs ([#14983](https://github.com/Lightning-AI/lightning/pull/14983)) + + - Added `BatchSizeFinder` callback ([#11089](https://github.com/PyTorchLightning/pytorch-lightning/pull/11089)) From 5760519508bf93c8d54b967bc82324bc1fdbd22d Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Mon, 3 Oct 2022 16:24:37 -0700 Subject: [PATCH 03/14] narrow scope of RunIf NVML-based CUDA availability check to a new RunIf condition --- tests/tests_pytorch/helpers/runif.py | 17 +++++++++++++---- tests/tests_pytorch/models/test_ddp_fork_amp.py | 2 +- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index a424194a84639..ae07aaac35baf 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -20,7 +20,6 @@ from packaging.version import Version from pkg_resources import get_distribution -from lightning_lite.accelerators.cuda import num_cuda_devices from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.accelerators.mps import MPSAccelerator from pytorch_lightning.accelerators.tpu import TPUAccelerator @@ -62,9 +61,6 @@ def test_wrapper(arg1): assert arg1 > 0.0 """ - # if possible, we defer CUDA initialization to support tests that will attempt forks - torch.cuda.is_available = num_cuda_devices - def __new__( self, *args, @@ -91,6 +87,7 @@ def __new__( bagua: bool = False, psutil: bool = False, hivemind: bool = False, + min_cuda_gpus_no_init: bool = False, **kwargs, ): """ @@ -122,6 +119,7 @@ def __new__( bagua: Require that BaguaSys/bagua is installed. psutil: Require that psutil is installed. hivemind: Require that Hivemind is installed. + min_cuda_gpus_no_init: Require this number of gpus but use an NVML-based CUDA availabilty check. **kwargs: Any :class:`pytest.mark.skipif` keyword arguments. """ conditions = [] @@ -253,6 +251,17 @@ def __new__( conditions.append(not _HIVEMIND_AVAILABLE or sys.platform in ("win32", "darwin")) reasons.append("Hivemind") + if min_cuda_gpus_no_init: + # special condition to defer CUDA initialization if possible, supporting tests that will attempt forks after + # NVML-based CUDA availbility checks + # local import to avoid potential issues this import could cause other tests + from lightning_lite.accelerators.cuda import num_cuda_devices + + conditions.append(num_cuda_devices() < min_cuda_gpus) + reasons.append(f"GPUs>={min_cuda_gpus}") + # used in conftest.py::pytest_collection_modifyitems + kwargs["min_cuda_gpus_no_init"] = True + reasons = [rs for cond, rs in zip(conditions, reasons) if cond] return pytest.mark.skipif( *args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs diff --git a/tests/tests_pytorch/models/test_ddp_fork_amp.py b/tests/tests_pytorch/models/test_ddp_fork_amp.py index 9bf9eb485cbc2..f731969f61345 100644 --- a/tests/tests_pytorch/models/test_ddp_fork_amp.py +++ b/tests/tests_pytorch/models/test_ddp_fork_amp.py @@ -21,7 +21,7 @@ # needs to be standalone to avoid other processes initializing CUDA -@RunIf(min_cuda_gpus=2, min_torch="1.12", standalone=True) +@RunIf(min_cuda_gpus_no_init=2, min_torch="1.12", standalone=True) def test_amp_gpus_ddp_fork(tmpdir): """Make sure combinations of AMP and strategies work if supported.""" tutils.reset_seed() From 18a99b52ebe38599592d3d0aede3a71b719bcbd1 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Mon, 3 Oct 2022 16:58:23 -0700 Subject: [PATCH 04/14] clarify new runif condition, add skip_windows condition to new test --- tests/tests_pytorch/helpers/runif.py | 2 +- tests/tests_pytorch/models/test_ddp_fork_amp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index ae07aaac35baf..143cdaa196c42 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -254,7 +254,7 @@ def __new__( if min_cuda_gpus_no_init: # special condition to defer CUDA initialization if possible, supporting tests that will attempt forks after # NVML-based CUDA availbility checks - # local import to avoid potential issues this import could cause other tests + # local import to avoid potential issues this import could cause other tests (e.g. infinite recursion) from lightning_lite.accelerators.cuda import num_cuda_devices conditions.append(num_cuda_devices() < min_cuda_gpus) diff --git a/tests/tests_pytorch/models/test_ddp_fork_amp.py b/tests/tests_pytorch/models/test_ddp_fork_amp.py index f731969f61345..6c2e25aec7a4f 100644 --- a/tests/tests_pytorch/models/test_ddp_fork_amp.py +++ b/tests/tests_pytorch/models/test_ddp_fork_amp.py @@ -21,7 +21,7 @@ # needs to be standalone to avoid other processes initializing CUDA -@RunIf(min_cuda_gpus_no_init=2, min_torch="1.12", standalone=True) +@RunIf(min_cuda_gpus_no_init=2, skip_windows=True, min_torch="1.12", standalone=True) def test_amp_gpus_ddp_fork(tmpdir): """Make sure combinations of AMP and strategies work if supported.""" tutils.reset_seed() From 9908f3f83b5f723a808d2a9e25cf6693b91f52cd Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Mon, 3 Oct 2022 19:51:00 -0700 Subject: [PATCH 05/14] add context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if possible --- src/lightning_lite/accelerators/cuda.py | 22 ++++++++++++++++++- .../plugins/precision/native_amp.py | 8 +++---- .../plugins/precision/native_amp.py | 8 +++---- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/src/lightning_lite/accelerators/cuda.py b/src/lightning_lite/accelerators/cuda.py index 9179a0015548c..93237b4c9c318 100644 --- a/src/lightning_lite/accelerators/cuda.py +++ b/src/lightning_lite/accelerators/cuda.py @@ -13,8 +13,9 @@ # limitations under the License. import os import warnings +from contextlib import contextmanager from functools import lru_cache -from typing import Dict, List, Optional, Set, Union +from typing import Dict, Generator, List, Optional, Set, Union import torch @@ -77,6 +78,25 @@ def _get_all_available_cuda_gpus() -> List[int]: return list(range(num_cuda_devices())) +@contextmanager +def _patch_cuda_is_available() -> Generator: + """Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if + possible.""" + orig_check = None + new_check = torch.cuda.device_count if _TORCH_GREATER_EQUAL_1_13 else _device_count_nvml + + if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0: + # we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding + # otherwise, patching is_available could lead to attribute errors or infinite recursion + orig_check = torch.cuda.is_available + torch.cuda.is_available = new_check # type: ignore[assignment] + try: + yield + finally: + if orig_check: + torch.cuda.is_available = orig_check + + @lru_cache(1) def num_cuda_devices() -> int: """Returns the number of available CUDA devices. diff --git a/src/lightning_lite/plugins/precision/native_amp.py b/src/lightning_lite/plugins/precision/native_amp.py index 2fb618ba3c0cd..f5a6d0459bb96 100644 --- a/src/lightning_lite/plugins/precision/native_amp.py +++ b/src/lightning_lite/plugins/precision/native_amp.py @@ -20,7 +20,7 @@ from torch.optim import LBFGS from typing_extensions import Literal -from lightning_lite.accelerators.cuda import is_cuda_available +from lightning_lite.accelerators.cuda import _patch_cuda_is_available from lightning_lite.plugins.precision.precision import Precision from lightning_lite.plugins.precision.utils import _convert_fp_tensor from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10 @@ -48,9 +48,9 @@ def __init__( if precision == "bf16" and not _TORCH_GREATER_EQUAL_1_10: raise ImportError("To use bfloat16 with native amp you must install torch greater or equal to 1.10.") if scaler is None and precision == 16: - # if possible, we defer CUDA initialization to support strategies that will attempt forks - torch.cuda.is_available = is_cuda_available - scaler = torch.cuda.amp.GradScaler() + with _patch_cuda_is_available(): + # if possible, we defer CUDA initialization to support strategies that will attempt forks + scaler = torch.cuda.amp.GradScaler() if scaler is not None and precision == "bf16": raise ValueError(f"`precision='bf16'` does not use a scaler, found {scaler}.") self.precision = precision diff --git a/src/pytorch_lightning/plugins/precision/native_amp.py b/src/pytorch_lightning/plugins/precision/native_amp.py index 0f2ac302a0c82..a2e6dcca83841 100644 --- a/src/pytorch_lightning/plugins/precision/native_amp.py +++ b/src/pytorch_lightning/plugins/precision/native_amp.py @@ -19,7 +19,7 @@ from torch.optim import LBFGS import pytorch_lightning as pl -from lightning_lite.accelerators.cuda import is_cuda_available +from lightning_lite.accelerators.cuda import _patch_cuda_is_available from lightning_lite.utilities.types import Steppable from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType @@ -51,9 +51,9 @@ def __init__( "To use bfloat16 with native amp you must install torch greater or equal to 1.10." ) if scaler is None and precision == 16: - # if possible, we defer CUDA initialization to support strategies that will attempt forks - torch.cuda.is_available = is_cuda_available - scaler = torch.cuda.amp.GradScaler() + with _patch_cuda_is_available(): + # if possible, we defer CUDA initialization to support strategies that will attempt forks + scaler = torch.cuda.amp.GradScaler() if scaler is not None and precision == "bf16": raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.") self.precision = precision From 3bc011736f69538009d7fae9d5944c5cc1cfe9a6 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Tue, 4 Oct 2022 13:34:37 -0700 Subject: [PATCH 06/14] update min_cuda_gpus RunIf condition for both PL and Lite, narrow scope of new ddp_fork test --- tests/tests_lite/helpers/runif.py | 3 +- tests/tests_pytorch/helpers/runif.py | 18 +++------- .../tests_pytorch/models/test_ddp_fork_amp.py | 35 +++++++------------ 3 files changed, 18 insertions(+), 38 deletions(-) diff --git a/tests/tests_lite/helpers/runif.py b/tests/tests_lite/helpers/runif.py index 6a40a47b9a770..146761a59fe1a 100644 --- a/tests/tests_lite/helpers/runif.py +++ b/tests/tests_lite/helpers/runif.py @@ -20,6 +20,7 @@ from packaging.version import Version from pkg_resources import get_distribution +from lightning_lite.accelerators.cuda import num_cuda_devices from lightning_lite.accelerators import TPUAccelerator from lightning_lite.accelerators.mps import MPSAccelerator from lightning_lite.strategies.deepspeed import _DEEPSPEED_AVAILABLE @@ -74,7 +75,7 @@ def __new__( reasons = [] if min_cuda_gpus: - conditions.append(torch.cuda.device_count() < min_cuda_gpus) + conditions.append(num_cuda_devices() < min_cuda_gpus) reasons.append(f"GPUs>={min_cuda_gpus}") # used in conftest.py::pytest_collection_modifyitems kwargs["min_cuda_gpus"] = True diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index 143cdaa196c42..987ec91067a11 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -20,6 +20,7 @@ from packaging.version import Version from pkg_resources import get_distribution +from lightning_lite.accelerators.cuda import num_cuda_devices from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.accelerators.mps import MPSAccelerator from pytorch_lightning.accelerators.tpu import TPUAccelerator @@ -87,7 +88,6 @@ def __new__( bagua: bool = False, psutil: bool = False, hivemind: bool = False, - min_cuda_gpus_no_init: bool = False, **kwargs, ): """ @@ -119,14 +119,15 @@ def __new__( bagua: Require that BaguaSys/bagua is installed. psutil: Require that psutil is installed. hivemind: Require that Hivemind is installed. - min_cuda_gpus_no_init: Require this number of gpus but use an NVML-based CUDA availabilty check. **kwargs: Any :class:`pytest.mark.skipif` keyword arguments. """ conditions = [] reasons = [] if min_cuda_gpus: - conditions.append(torch.cuda.is_available() < min_cuda_gpus) + # defer CUDA initialization if possible, supporting tests that will attempt forks after NVML-based CUDA + # availability checks + conditions.append(num_cuda_devices() < min_cuda_gpus) reasons.append(f"GPUs>={min_cuda_gpus}") # used in conftest.py::pytest_collection_modifyitems kwargs["min_cuda_gpus"] = True @@ -251,17 +252,6 @@ def __new__( conditions.append(not _HIVEMIND_AVAILABLE or sys.platform in ("win32", "darwin")) reasons.append("Hivemind") - if min_cuda_gpus_no_init: - # special condition to defer CUDA initialization if possible, supporting tests that will attempt forks after - # NVML-based CUDA availbility checks - # local import to avoid potential issues this import could cause other tests (e.g. infinite recursion) - from lightning_lite.accelerators.cuda import num_cuda_devices - - conditions.append(num_cuda_devices() < min_cuda_gpus) - reasons.append(f"GPUs>={min_cuda_gpus}") - # used in conftest.py::pytest_collection_modifyitems - kwargs["min_cuda_gpus_no_init"] = True - reasons = [rs for cond, rs in zip(conditions, reasons) if cond] return pytest.mark.skipif( *args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs diff --git a/tests/tests_pytorch/models/test_ddp_fork_amp.py b/tests/tests_pytorch/models/test_ddp_fork_amp.py index 6c2e25aec7a4f..f892a13f0070a 100644 --- a/tests/tests_pytorch/models/test_ddp_fork_amp.py +++ b/tests/tests_pytorch/models/test_ddp_fork_amp.py @@ -11,31 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torch.utils.data import DataLoader +import multiprocessing -import tests_pytorch.helpers.utils as tutils -from pytorch_lightning import Trainer -from pytorch_lightning.demos.boring_classes import RandomDataset +import torch + +from pytorch_lightning.plugins import NativeMixedPrecisionPlugin from tests_pytorch.helpers.runif import RunIf -from tests_pytorch.helpers.test_models import AMPTestModel # needs to be standalone to avoid other processes initializing CUDA -@RunIf(min_cuda_gpus_no_init=2, skip_windows=True, min_torch="1.12", standalone=True) -def test_amp_gpus_ddp_fork(tmpdir): - """Make sure combinations of AMP and strategies work if supported.""" - tutils.reset_seed() - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - accelerator="gpu", - devices=2, - strategy="ddp_fork", - precision=16, - ) - - model = AMPTestModel() - trainer.fit(model) - trainer.test(model) - trainer.predict(model, DataLoader(RandomDataset(32, 64))) +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) +def test_amp_gpus_ddp_fork(): + """Ensure the use of native AMP with `ddp_fork` (or associated alias strategies) does not generate CUDA + initialization errors.""" + _ = NativeMixedPrecisionPlugin(precision=16, device="cuda") + with multiprocessing.get_context("fork").Pool(1) as pool: + in_bad_fork = pool.apply(torch.cuda._is_in_bad_fork) + assert not in_bad_fork From dac8f682f3e0814b582e520f044f26cf60d2ba60 Mon Sep 17 00:00:00 2001 From: Dan Dale Date: Tue, 4 Oct 2022 17:25:26 -0700 Subject: [PATCH 07/14] restructure to avoid unnecessary None check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- src/lightning_lite/accelerators/cuda.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/lightning_lite/accelerators/cuda.py b/src/lightning_lite/accelerators/cuda.py index 93237b4c9c318..32d4d251dfd5b 100644 --- a/src/lightning_lite/accelerators/cuda.py +++ b/src/lightning_lite/accelerators/cuda.py @@ -90,10 +90,9 @@ def _patch_cuda_is_available() -> Generator: # otherwise, patching is_available could lead to attribute errors or infinite recursion orig_check = torch.cuda.is_available torch.cuda.is_available = new_check # type: ignore[assignment] - try: - yield - finally: - if orig_check: + try: + yield + finally: torch.cuda.is_available = orig_check From 4016b5d20c38684445673ea05fc1584659b86400 Mon Sep 17 00:00:00 2001 From: Dan Dale Date: Tue, 4 Oct 2022 17:25:56 -0700 Subject: [PATCH 08/14] remove unnecessary None assignment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- src/lightning_lite/accelerators/cuda.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lightning_lite/accelerators/cuda.py b/src/lightning_lite/accelerators/cuda.py index 32d4d251dfd5b..6b2c5e1476a1e 100644 --- a/src/lightning_lite/accelerators/cuda.py +++ b/src/lightning_lite/accelerators/cuda.py @@ -82,7 +82,6 @@ def _get_all_available_cuda_gpus() -> List[int]: def _patch_cuda_is_available() -> Generator: """Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if possible.""" - orig_check = None new_check = torch.cuda.device_count if _TORCH_GREATER_EQUAL_1_13 else _device_count_nvml if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0: From 844172697328c65acfd5ac751996f8cc0954c83f Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Tue, 4 Oct 2022 17:41:26 -0700 Subject: [PATCH 09/14] cleanup verbose documentation, revert `AMPTestModel` location, refine new `_patch_cuda_is_available` context manager --- src/lightning_lite/accelerators/cuda.py | 4 +- tests/tests_pytorch/helpers/runif.py | 2 - tests/tests_pytorch/helpers/test_models.py | 42 --------------------- tests/tests_pytorch/models/test_amp.py | 43 +++++++++++++++++++++- 4 files changed, 44 insertions(+), 47 deletions(-) diff --git a/src/lightning_lite/accelerators/cuda.py b/src/lightning_lite/accelerators/cuda.py index 6b2c5e1476a1e..4bf2d9a8d0ec6 100644 --- a/src/lightning_lite/accelerators/cuda.py +++ b/src/lightning_lite/accelerators/cuda.py @@ -82,13 +82,13 @@ def _get_all_available_cuda_gpus() -> List[int]: def _patch_cuda_is_available() -> Generator: """Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if possible.""" - new_check = torch.cuda.device_count if _TORCH_GREATER_EQUAL_1_13 else _device_count_nvml if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0: # we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding # otherwise, patching is_available could lead to attribute errors or infinite recursion + new_check = is_cuda_available orig_check = torch.cuda.is_available - torch.cuda.is_available = new_check # type: ignore[assignment] + torch.cuda.is_available = new_check try: yield finally: diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index 987ec91067a11..ae09f5b65eba7 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -125,8 +125,6 @@ def __new__( reasons = [] if min_cuda_gpus: - # defer CUDA initialization if possible, supporting tests that will attempt forks after NVML-based CUDA - # availability checks conditions.append(num_cuda_devices() < min_cuda_gpus) reasons.append(f"GPUs>={min_cuda_gpus}") # used in conftest.py::pytest_collection_modifyitems diff --git a/tests/tests_pytorch/helpers/test_models.py b/tests/tests_pytorch/helpers/test_models.py index 398b641b1bf41..0b38e31e0a219 100644 --- a/tests/tests_pytorch/helpers/test_models.py +++ b/tests/tests_pytorch/helpers/test_models.py @@ -14,7 +14,6 @@ import os import pytest -import torch from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel @@ -23,47 +22,6 @@ from tests_pytorch.helpers.simple_models import ClassificationModel, RegressionModel -class AMPTestModel(BoringModel): - def _step(self, batch): - self._assert_autocast_enabled() - output = self(batch) - is_bfloat16 = self.trainer.precision_plugin.precision == "bf16" - assert output.dtype == torch.float16 if not is_bfloat16 else torch.bfloat16 - loss = self.loss(batch, output) - return loss - - def loss(self, batch, prediction): - # todo (sean): convert bfloat16 to float32 as mse loss for cpu amp is currently not supported - if self.trainer.precision_plugin.device == "cpu": - prediction = prediction.float() - return super().loss(batch, prediction) - - def training_step(self, batch, batch_idx): - output = self._step(batch) - return {"loss": output} - - def validation_step(self, batch, batch_idx): - output = self._step(batch) - return {"x": output} - - def test_step(self, batch, batch_idx): - output = self._step(batch) - return {"y": output} - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - self._assert_autocast_enabled() - output = self(batch) - is_bfloat16 = self.trainer.precision_plugin.precision == "bf16" - assert output.dtype == torch.float16 if not is_bfloat16 else torch.bfloat16 - return output - - def _assert_autocast_enabled(self): - if self.trainer.precision_plugin.device == "cpu": - assert torch.is_autocast_cpu_enabled() - else: - assert torch.is_autocast_enabled() - - @pytest.mark.parametrize( "data_class,model_class", [ diff --git a/tests/tests_pytorch/models/test_amp.py b/tests/tests_pytorch/models/test_amp.py index 3ba30b045fb81..74bd4c20abeaf 100644 --- a/tests/tests_pytorch/models/test_amp.py +++ b/tests/tests_pytorch/models/test_amp.py @@ -15,6 +15,7 @@ from unittest import mock import pytest +import torch from torch import optim from torch.utils.data import DataLoader @@ -23,7 +24,47 @@ from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset from tests_pytorch.helpers.runif import RunIf -from tests_pytorch.helpers.test_models import AMPTestModel + + +class AMPTestModel(BoringModel): + def _step(self, batch): + self._assert_autocast_enabled() + output = self(batch) + is_bfloat16 = self.trainer.precision_plugin.precision == "bf16" + assert output.dtype == torch.float16 if not is_bfloat16 else torch.bfloat16 + loss = self.loss(batch, output) + return loss + + def loss(self, batch, prediction): + # todo (sean): convert bfloat16 to float32 as mse loss for cpu amp is currently not supported + if self.trainer.precision_plugin.device == "cpu": + prediction = prediction.float() + return super().loss(batch, prediction) + + def training_step(self, batch, batch_idx): + output = self._step(batch) + return {"loss": output} + + def validation_step(self, batch, batch_idx): + output = self._step(batch) + return {"x": output} + + def test_step(self, batch, batch_idx): + output = self._step(batch) + return {"y": output} + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + self._assert_autocast_enabled() + output = self(batch) + is_bfloat16 = self.trainer.precision_plugin.precision == "bf16" + assert output.dtype == torch.float16 if not is_bfloat16 else torch.bfloat16 + return output + + def _assert_autocast_enabled(self): + if self.trainer.precision_plugin.device == "cpu": + assert torch.is_autocast_cpu_enabled() + else: + assert torch.is_autocast_enabled() @RunIf(min_torch="1.10") From 6cc8e44b1c9c02618ed2cebcd17c49b4e2622cc0 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Tue, 4 Oct 2022 18:14:35 -0700 Subject: [PATCH 10/14] add missing yield --- src/lightning_lite/accelerators/cuda.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning_lite/accelerators/cuda.py b/src/lightning_lite/accelerators/cuda.py index 4bf2d9a8d0ec6..0360508c70d50 100644 --- a/src/lightning_lite/accelerators/cuda.py +++ b/src/lightning_lite/accelerators/cuda.py @@ -93,6 +93,8 @@ def _patch_cuda_is_available() -> Generator: yield finally: torch.cuda.is_available = orig_check + else: + yield @lru_cache(1) From 3a4b0b949478ea9600da0629b991e178de49adda Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Oct 2022 01:17:24 +0000 Subject: [PATCH 11/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_lite/helpers/runif.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_lite/helpers/runif.py b/tests/tests_lite/helpers/runif.py index 146761a59fe1a..3572fb107979d 100644 --- a/tests/tests_lite/helpers/runif.py +++ b/tests/tests_lite/helpers/runif.py @@ -20,8 +20,8 @@ from packaging.version import Version from pkg_resources import get_distribution -from lightning_lite.accelerators.cuda import num_cuda_devices from lightning_lite.accelerators import TPUAccelerator +from lightning_lite.accelerators.cuda import num_cuda_devices from lightning_lite.accelerators.mps import MPSAccelerator from lightning_lite.strategies.deepspeed import _DEEPSPEED_AVAILABLE from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE From eb04abe2e493d79b69d919280a1f820436f14aee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 5 Oct 2022 11:15:41 +0200 Subject: [PATCH 12/14] nit --- src/lightning_lite/accelerators/cuda.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lightning_lite/accelerators/cuda.py b/src/lightning_lite/accelerators/cuda.py index 0360508c70d50..1e8f82a057b2b 100644 --- a/src/lightning_lite/accelerators/cuda.py +++ b/src/lightning_lite/accelerators/cuda.py @@ -82,13 +82,11 @@ def _get_all_available_cuda_gpus() -> List[int]: def _patch_cuda_is_available() -> Generator: """Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if possible.""" - if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0: # we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding # otherwise, patching is_available could lead to attribute errors or infinite recursion - new_check = is_cuda_available orig_check = torch.cuda.is_available - torch.cuda.is_available = new_check + torch.cuda.is_available = is_cuda_available try: yield finally: From 635d9bd534eff7561b0e60bbb6f6cd5691bfd63e Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Wed, 5 Oct 2022 14:29:25 -0700 Subject: [PATCH 13/14] since we're not using my original ddp_fork full strategy test anymore, we only need 1 GPU --- tests/tests_pytorch/models/test_ddp_fork_amp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/models/test_ddp_fork_amp.py b/tests/tests_pytorch/models/test_ddp_fork_amp.py index f892a13f0070a..7cbc5ea84b524 100644 --- a/tests/tests_pytorch/models/test_ddp_fork_amp.py +++ b/tests/tests_pytorch/models/test_ddp_fork_amp.py @@ -20,7 +20,7 @@ # needs to be standalone to avoid other processes initializing CUDA -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) +@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True) def test_amp_gpus_ddp_fork(): """Ensure the use of native AMP with `ddp_fork` (or associated alias strategies) does not generate CUDA initialization errors.""" From bba3e82833a248fc187cc96a991b8c89b18362db Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Wed, 5 Oct 2022 14:40:54 -0700 Subject: [PATCH 14/14] looks like an unrequired import got added to native_amp sometime in the last few hours --- src/pytorch_lightning/plugins/precision/native_amp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/plugins/precision/native_amp.py b/src/pytorch_lightning/plugins/precision/native_amp.py index 350e2b0056a80..b486a2d9b8fcf 100644 --- a/src/pytorch_lightning/plugins/precision/native_amp.py +++ b/src/pytorch_lightning/plugins/precision/native_amp.py @@ -20,7 +20,7 @@ import pytorch_lightning as pl from lightning_lite.accelerators.cuda import _patch_cuda_is_available -from lightning_lite.utilities.types import Optimizable, Steppable +from lightning_lite.utilities.types import Optimizable from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException