Skip to content

Commit

Permalink
Support ddp_fork strategy with native AMP by attempting NVML-based CU…
Browse files Browse the repository at this point in the history
…DA availability assessment (#14984)

Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
4 people authored Oct 5, 2022
1 parent 7fed7a1 commit 3b75c52
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 6 deletions.
20 changes: 19 additions & 1 deletion src/lightning_lite/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.
import os
import warnings
from contextlib import contextmanager
from functools import lru_cache
from typing import Dict, List, Optional, Set, Union
from typing import Dict, Generator, List, Optional, Set, Union

import torch

Expand Down Expand Up @@ -77,6 +78,23 @@ def _get_all_available_cuda_gpus() -> List[int]:
return list(range(num_cuda_devices()))


@contextmanager
def _patch_cuda_is_available() -> Generator:
"""Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if
possible."""
if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0:
# we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding
# otherwise, patching is_available could lead to attribute errors or infinite recursion
orig_check = torch.cuda.is_available
torch.cuda.is_available = is_cuda_available
try:
yield
finally:
torch.cuda.is_available = orig_check
else:
yield


@lru_cache(1)
def num_cuda_devices() -> int:
"""Returns the number of available CUDA devices.
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from torch.optim import Optimizer
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler

from lightning_lite.plugins import Precision # avoid circular imports: # isort: split
from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT
from lightning_lite.plugins import Precision
from lightning_lite.strategies import DeepSpeedStrategy, Strategy, XLAStrategy
from lightning_lite.strategies.strategy import TBroadcast
from lightning_lite.utilities import move_data_to_device
Expand Down
5 changes: 4 additions & 1 deletion src/lightning_lite/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.optim import LBFGS
from typing_extensions import Literal

from lightning_lite.accelerators.cuda import _patch_cuda_is_available
from lightning_lite.plugins.precision.precision import Precision
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10
Expand Down Expand Up @@ -47,7 +48,9 @@ def __init__(
if precision == "bf16" and not _TORCH_GREATER_EQUAL_1_10:
raise ImportError("To use bfloat16 with native amp you must install torch greater or equal to 1.10.")
if scaler is None and precision == 16:
scaler = torch.cuda.amp.GradScaler()
with _patch_cuda_is_available():
# if possible, we defer CUDA initialization to support strategies that will attempt forks
scaler = torch.cuda.amp.GradScaler()
if scaler is not None and precision == "bf16":
raise ValueError(f"`precision='bf16'` does not use a scaler, found {scaler}.")
self.precision = precision
Expand Down
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added native AMP support for `ddp_fork` (and associated alias strategies) with CUDA GPUs ([#14983](https://github.com/Lightning-AI/lightning/pull/14983))


- Added `BatchSizeFinder` callback ([#11089](https://github.com/PyTorchLightning/pytorch-lightning/pull/11089))


Expand Down
5 changes: 4 additions & 1 deletion src/pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.optim import LBFGS

import pytorch_lightning as pl
from lightning_lite.accelerators.cuda import _patch_cuda_is_available
from lightning_lite.utilities.types import Optimizable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType
Expand Down Expand Up @@ -50,7 +51,9 @@ def __init__(
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
)
if scaler is None and precision == 16:
scaler = torch.cuda.amp.GradScaler()
with _patch_cuda_is_available():
# if possible, we defer CUDA initialization to support strategies that will attempt forks
scaler = torch.cuda.amp.GradScaler()
if scaler is not None and precision == "bf16":
raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.")
self.precision = precision
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_lite/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pkg_resources import get_distribution

from lightning_lite.accelerators import TPUAccelerator
from lightning_lite.accelerators.cuda import num_cuda_devices
from lightning_lite.accelerators.mps import MPSAccelerator
from lightning_lite.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
Expand Down Expand Up @@ -74,7 +75,7 @@ def __new__(
reasons = []

if min_cuda_gpus:
conditions.append(torch.cuda.device_count() < min_cuda_gpus)
conditions.append(num_cuda_devices() < min_cuda_gpus)
reasons.append(f"GPUs>={min_cuda_gpus}")
# used in conftest.py::pytest_collection_modifyitems
kwargs["min_cuda_gpus"] = True
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_pytorch/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from packaging.version import Version
from pkg_resources import get_distribution

from lightning_lite.accelerators.cuda import num_cuda_devices
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.accelerators.mps import MPSAccelerator
from pytorch_lightning.accelerators.tpu import TPUAccelerator
Expand Down Expand Up @@ -124,7 +125,7 @@ def __new__(
reasons = []

if min_cuda_gpus:
conditions.append(torch.cuda.device_count() < min_cuda_gpus)
conditions.append(num_cuda_devices() < min_cuda_gpus)
reasons.append(f"GPUs>={min_cuda_gpus}")
# used in conftest.py::pytest_collection_modifyitems
kwargs["min_cuda_gpus"] = True
Expand Down
30 changes: 30 additions & 0 deletions tests/tests_pytorch/models/test_ddp_fork_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing

import torch

from pytorch_lightning.plugins import NativeMixedPrecisionPlugin
from tests_pytorch.helpers.runif import RunIf


# needs to be standalone to avoid other processes initializing CUDA
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True)
def test_amp_gpus_ddp_fork():
"""Ensure the use of native AMP with `ddp_fork` (or associated alias strategies) does not generate CUDA
initialization errors."""
_ = NativeMixedPrecisionPlugin(precision=16, device="cuda")
with multiprocessing.get_context("fork").Pool(1) as pool:
in_bad_fork = pool.apply(torch.cuda._is_in_bad_fork)
assert not in_bad_fork

0 comments on commit 3b75c52

Please sign in to comment.