Skip to content

Commit

Permalink
add context manager that safely patches :func:`torch.cuda.is_availabl…
Browse files Browse the repository at this point in the history
…e` with its NVML-based version if possible
  • Loading branch information
speediedan committed Oct 4, 2022
1 parent 5e7f225 commit 51f7669
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
22 changes: 21 additions & 1 deletion src/lightning_lite/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.
import os
import warnings
from contextlib import contextmanager
from functools import lru_cache
from typing import Dict, List, Optional, Set, Union
from typing import Dict, Generator, List, Optional, Set, Union

import torch

Expand Down Expand Up @@ -77,6 +78,25 @@ def _get_all_available_cuda_gpus() -> List[int]:
return list(range(num_cuda_devices()))


@contextmanager
def patch_cuda_is_available() -> Generator:
"""Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if
possible."""
orig_check = None
new_check = torch.cuda.device_count if _TORCH_GREATER_EQUAL_1_13 else _device_count_nvml

if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0:
# we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding
# otherwise, patching is_available could lead to attribute errors or infinite recursion
orig_check = torch.cuda.is_available
torch.cuda.is_available = new_check # type: ignore[assignment]
try:
yield
finally:
if orig_check:
torch.cuda.is_available = orig_check


@lru_cache(1)
def num_cuda_devices() -> int:
"""Returns the number of available CUDA devices.
Expand Down
8 changes: 4 additions & 4 deletions src/lightning_lite/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.optim import LBFGS
from typing_extensions import Literal

from lightning_lite.accelerators.cuda import is_cuda_available
from lightning_lite.accelerators.cuda import patch_cuda_is_available
from lightning_lite.plugins.precision.precision import Precision
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10
Expand Down Expand Up @@ -48,9 +48,9 @@ def __init__(
if precision == "bf16" and not _TORCH_GREATER_EQUAL_1_10:
raise ImportError("To use bfloat16 with native amp you must install torch greater or equal to 1.10.")
if scaler is None and precision == 16:
# if possible, we defer CUDA initialization to support strategies that will attempt forks
torch.cuda.is_available = is_cuda_available
scaler = torch.cuda.amp.GradScaler()
with patch_cuda_is_available():
# if possible, we defer CUDA initialization to support strategies that will attempt forks
scaler = torch.cuda.amp.GradScaler()
if scaler is not None and precision == "bf16":
raise ValueError(f"`precision='bf16'` does not use a scaler, found {scaler}.")
self.precision = precision
Expand Down
8 changes: 4 additions & 4 deletions src/pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.optim import LBFGS

import pytorch_lightning as pl
from lightning_lite.accelerators.cuda import is_cuda_available
from lightning_lite.accelerators.cuda import patch_cuda_is_available
from lightning_lite.utilities.types import Steppable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType
Expand Down Expand Up @@ -51,9 +51,9 @@ def __init__(
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
)
if scaler is None and precision == 16:
# if possible, we defer CUDA initialization to support strategies that will attempt forks
torch.cuda.is_available = is_cuda_available
scaler = torch.cuda.amp.GradScaler()
with patch_cuda_is_available():
# if possible, we defer CUDA initialization to support strategies that will attempt forks
scaler = torch.cuda.amp.GradScaler()
if scaler is not None and precision == "bf16":
raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.")
self.precision = precision
Expand Down

0 comments on commit 51f7669

Please sign in to comment.