Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ddp_fork strategy with native AMP by attempting NVML-based CUDA availability assessment #14984

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/lightning_lite/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.
import os
import warnings
from contextlib import contextmanager
from functools import lru_cache
from typing import Dict, List, Optional, Set, Union
from typing import Dict, Generator, List, Optional, Set, Union

import torch

Expand Down Expand Up @@ -77,6 +78,23 @@ def _get_all_available_cuda_gpus() -> List[int]:
return list(range(num_cuda_devices()))


@contextmanager
def _patch_cuda_is_available() -> Generator:
"""Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if
possible."""
if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0:
# we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding
# otherwise, patching is_available could lead to attribute errors or infinite recursion
orig_check = torch.cuda.is_available
torch.cuda.is_available = is_cuda_available
try:
yield
finally:
torch.cuda.is_available = orig_check
else:
yield


@lru_cache(1)
def num_cuda_devices() -> int:
"""Returns the number of available CUDA devices.
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from torch.optim import Optimizer
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler

from lightning_lite.plugins import Precision # avoid circular imports: # isort: split
from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT
from lightning_lite.plugins import Precision
from lightning_lite.strategies import DeepSpeedStrategy, Strategy, XLAStrategy
from lightning_lite.strategies.strategy import TBroadcast
from lightning_lite.utilities import move_data_to_device
Expand Down
5 changes: 4 additions & 1 deletion src/lightning_lite/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.optim import LBFGS
from typing_extensions import Literal

from lightning_lite.accelerators.cuda import _patch_cuda_is_available
from lightning_lite.plugins.precision.precision import Precision
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10
Expand Down Expand Up @@ -47,7 +48,9 @@ def __init__(
if precision == "bf16" and not _TORCH_GREATER_EQUAL_1_10:
raise ImportError("To use bfloat16 with native amp you must install torch greater or equal to 1.10.")
if scaler is None and precision == 16:
scaler = torch.cuda.amp.GradScaler()
with _patch_cuda_is_available():
# if possible, we defer CUDA initialization to support strategies that will attempt forks
scaler = torch.cuda.amp.GradScaler()
if scaler is not None and precision == "bf16":
raise ValueError(f"`precision='bf16'` does not use a scaler, found {scaler}.")
self.precision = precision
Expand Down
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added native AMP support for `ddp_fork` (and associated alias strategies) with CUDA GPUs ([#14983](https://github.com/Lightning-AI/lightning/pull/14983))


- Added `BatchSizeFinder` callback ([#11089](https://github.com/PyTorchLightning/pytorch-lightning/pull/11089))


Expand Down
5 changes: 4 additions & 1 deletion src/pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.optim import LBFGS

import pytorch_lightning as pl
from lightning_lite.accelerators.cuda import _patch_cuda_is_available
from lightning_lite.utilities.types import Optimizable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType
Expand Down Expand Up @@ -50,7 +51,9 @@ def __init__(
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
)
if scaler is None and precision == 16:
scaler = torch.cuda.amp.GradScaler()
with _patch_cuda_is_available():
# if possible, we defer CUDA initialization to support strategies that will attempt forks
scaler = torch.cuda.amp.GradScaler()
if scaler is not None and precision == "bf16":
raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.")
self.precision = precision
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_lite/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pkg_resources import get_distribution

from lightning_lite.accelerators import TPUAccelerator
from lightning_lite.accelerators.cuda import num_cuda_devices
from lightning_lite.accelerators.mps import MPSAccelerator
from lightning_lite.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
Expand Down Expand Up @@ -74,7 +75,7 @@ def __new__(
reasons = []

if min_cuda_gpus:
conditions.append(torch.cuda.device_count() < min_cuda_gpus)
conditions.append(num_cuda_devices() < min_cuda_gpus)
reasons.append(f"GPUs>={min_cuda_gpus}")
# used in conftest.py::pytest_collection_modifyitems
kwargs["min_cuda_gpus"] = True
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_pytorch/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from packaging.version import Version
from pkg_resources import get_distribution

from lightning_lite.accelerators.cuda import num_cuda_devices
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.accelerators.mps import MPSAccelerator
from pytorch_lightning.accelerators.tpu import TPUAccelerator
Expand Down Expand Up @@ -124,7 +125,7 @@ def __new__(
reasons = []

if min_cuda_gpus:
conditions.append(torch.cuda.device_count() < min_cuda_gpus)
conditions.append(num_cuda_devices() < min_cuda_gpus)
reasons.append(f"GPUs>={min_cuda_gpus}")
# used in conftest.py::pytest_collection_modifyitems
kwargs["min_cuda_gpus"] = True
Expand Down
30 changes: 30 additions & 0 deletions tests/tests_pytorch/models/test_ddp_fork_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing

import torch

from pytorch_lightning.plugins import NativeMixedPrecisionPlugin
from tests_pytorch.helpers.runif import RunIf


# needs to be standalone to avoid other processes initializing CUDA
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True)
def test_amp_gpus_ddp_fork():
"""Ensure the use of native AMP with `ddp_fork` (or associated alias strategies) does not generate CUDA
initialization errors."""
_ = NativeMixedPrecisionPlugin(precision=16, device="cuda")
with multiprocessing.get_context("fork").Pool(1) as pool:
in_bad_fork = pool.apply(torch.cuda._is_in_bad_fork)
assert not in_bad_fork
awaelchli marked this conversation as resolved.
Show resolved Hide resolved