From 3b75c52869baf1118d2d546aee4861bee133571d Mon Sep 17 00:00:00 2001 From: Dan Dale Date: Wed, 5 Oct 2022 15:52:06 -0700 Subject: [PATCH] Support ddp_fork strategy with native AMP by attempting NVML-based CUDA availability assessment (#14984) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec --- src/lightning_lite/accelerators/cuda.py | 20 ++++++++++++- src/lightning_lite/lite.py | 2 +- .../plugins/precision/native_amp.py | 5 +++- src/pytorch_lightning/CHANGELOG.md | 3 ++ .../plugins/precision/native_amp.py | 5 +++- tests/tests_lite/helpers/runif.py | 3 +- tests/tests_pytorch/helpers/runif.py | 3 +- .../tests_pytorch/models/test_ddp_fork_amp.py | 30 +++++++++++++++++++ 8 files changed, 65 insertions(+), 6 deletions(-) create mode 100644 tests/tests_pytorch/models/test_ddp_fork_amp.py diff --git a/src/lightning_lite/accelerators/cuda.py b/src/lightning_lite/accelerators/cuda.py index 9179a0015548c..1e8f82a057b2b 100644 --- a/src/lightning_lite/accelerators/cuda.py +++ b/src/lightning_lite/accelerators/cuda.py @@ -13,8 +13,9 @@ # limitations under the License. import os import warnings +from contextlib import contextmanager from functools import lru_cache -from typing import Dict, List, Optional, Set, Union +from typing import Dict, Generator, List, Optional, Set, Union import torch @@ -77,6 +78,23 @@ def _get_all_available_cuda_gpus() -> List[int]: return list(range(num_cuda_devices())) +@contextmanager +def _patch_cuda_is_available() -> Generator: + """Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if + possible.""" + if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0: + # we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding + # otherwise, patching is_available could lead to attribute errors or infinite recursion + orig_check = torch.cuda.is_available + torch.cuda.is_available = is_cuda_available + try: + yield + finally: + torch.cuda.is_available = orig_check + else: + yield + + @lru_cache(1) def num_cuda_devices() -> int: """Returns the number of available CUDA devices. diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 04b964e41c5a0..d1b94e1b46fbf 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -26,9 +26,9 @@ from torch.optim import Optimizer from torch.utils.data import BatchSampler, DataLoader, DistributedSampler +from lightning_lite.plugins import Precision # avoid circular imports: # isort: split from lightning_lite.accelerators.accelerator import Accelerator from lightning_lite.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT -from lightning_lite.plugins import Precision from lightning_lite.strategies import DeepSpeedStrategy, Strategy, XLAStrategy from lightning_lite.strategies.strategy import TBroadcast from lightning_lite.utilities import move_data_to_device diff --git a/src/lightning_lite/plugins/precision/native_amp.py b/src/lightning_lite/plugins/precision/native_amp.py index 24ff14e8e08de..b09ac5647f89a 100644 --- a/src/lightning_lite/plugins/precision/native_amp.py +++ b/src/lightning_lite/plugins/precision/native_amp.py @@ -20,6 +20,7 @@ from torch.optim import LBFGS from typing_extensions import Literal +from lightning_lite.accelerators.cuda import _patch_cuda_is_available from lightning_lite.plugins.precision.precision import Precision from lightning_lite.plugins.precision.utils import _convert_fp_tensor from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10 @@ -47,7 +48,9 @@ def __init__( if precision == "bf16" and not _TORCH_GREATER_EQUAL_1_10: raise ImportError("To use bfloat16 with native amp you must install torch greater or equal to 1.10.") if scaler is None and precision == 16: - scaler = torch.cuda.amp.GradScaler() + with _patch_cuda_is_available(): + # if possible, we defer CUDA initialization to support strategies that will attempt forks + scaler = torch.cuda.amp.GradScaler() if scaler is not None and precision == "bf16": raise ValueError(f"`precision='bf16'` does not use a scaler, found {scaler}.") self.precision = precision diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index f55adac572cdc..724ccca7cbdd8 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added native AMP support for `ddp_fork` (and associated alias strategies) with CUDA GPUs ([#14983](https://github.com/Lightning-AI/lightning/pull/14983)) + + - Added `BatchSizeFinder` callback ([#11089](https://github.com/PyTorchLightning/pytorch-lightning/pull/11089)) diff --git a/src/pytorch_lightning/plugins/precision/native_amp.py b/src/pytorch_lightning/plugins/precision/native_amp.py index 0fccb387f046c..b486a2d9b8fcf 100644 --- a/src/pytorch_lightning/plugins/precision/native_amp.py +++ b/src/pytorch_lightning/plugins/precision/native_amp.py @@ -19,6 +19,7 @@ from torch.optim import LBFGS import pytorch_lightning as pl +from lightning_lite.accelerators.cuda import _patch_cuda_is_available from lightning_lite.utilities.types import Optimizable from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType @@ -50,7 +51,9 @@ def __init__( "To use bfloat16 with native amp you must install torch greater or equal to 1.10." ) if scaler is None and precision == 16: - scaler = torch.cuda.amp.GradScaler() + with _patch_cuda_is_available(): + # if possible, we defer CUDA initialization to support strategies that will attempt forks + scaler = torch.cuda.amp.GradScaler() if scaler is not None and precision == "bf16": raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.") self.precision = precision diff --git a/tests/tests_lite/helpers/runif.py b/tests/tests_lite/helpers/runif.py index 6a40a47b9a770..3572fb107979d 100644 --- a/tests/tests_lite/helpers/runif.py +++ b/tests/tests_lite/helpers/runif.py @@ -21,6 +21,7 @@ from pkg_resources import get_distribution from lightning_lite.accelerators import TPUAccelerator +from lightning_lite.accelerators.cuda import num_cuda_devices from lightning_lite.accelerators.mps import MPSAccelerator from lightning_lite.strategies.deepspeed import _DEEPSPEED_AVAILABLE from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE @@ -74,7 +75,7 @@ def __new__( reasons = [] if min_cuda_gpus: - conditions.append(torch.cuda.device_count() < min_cuda_gpus) + conditions.append(num_cuda_devices() < min_cuda_gpus) reasons.append(f"GPUs>={min_cuda_gpus}") # used in conftest.py::pytest_collection_modifyitems kwargs["min_cuda_gpus"] = True diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index 98b1530500bc8..ae09f5b65eba7 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -20,6 +20,7 @@ from packaging.version import Version from pkg_resources import get_distribution +from lightning_lite.accelerators.cuda import num_cuda_devices from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.accelerators.mps import MPSAccelerator from pytorch_lightning.accelerators.tpu import TPUAccelerator @@ -124,7 +125,7 @@ def __new__( reasons = [] if min_cuda_gpus: - conditions.append(torch.cuda.device_count() < min_cuda_gpus) + conditions.append(num_cuda_devices() < min_cuda_gpus) reasons.append(f"GPUs>={min_cuda_gpus}") # used in conftest.py::pytest_collection_modifyitems kwargs["min_cuda_gpus"] = True diff --git a/tests/tests_pytorch/models/test_ddp_fork_amp.py b/tests/tests_pytorch/models/test_ddp_fork_amp.py new file mode 100644 index 0000000000000..7cbc5ea84b524 --- /dev/null +++ b/tests/tests_pytorch/models/test_ddp_fork_amp.py @@ -0,0 +1,30 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import multiprocessing + +import torch + +from pytorch_lightning.plugins import NativeMixedPrecisionPlugin +from tests_pytorch.helpers.runif import RunIf + + +# needs to be standalone to avoid other processes initializing CUDA +@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True) +def test_amp_gpus_ddp_fork(): + """Ensure the use of native AMP with `ddp_fork` (or associated alias strategies) does not generate CUDA + initialization errors.""" + _ = NativeMixedPrecisionPlugin(precision=16, device="cuda") + with multiprocessing.get_context("fork").Pool(1) as pool: + in_bad_fork = pool.apply(torch.cuda._is_in_bad_fork) + assert not in_bad_fork