Skip to content

Commit

Permalink
Add info message for Ampere GPUs to enable tf32 matmuls (#16037)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Dec 13, 2022
1 parent 53bf714 commit 3e664c9
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 7 deletions.
31 changes: 30 additions & 1 deletion src/lightning_lite/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import warnings
from contextlib import contextmanager
Expand All @@ -20,7 +21,13 @@
import torch

from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCH_GREATER_EQUAL_1_14
from lightning_lite.utilities.imports import (
_TORCH_GREATER_EQUAL_1_12,
_TORCH_GREATER_EQUAL_1_13,
_TORCH_GREATER_EQUAL_1_14,
)

_log = logging.getLogger(__name__)


class CUDAAccelerator(Accelerator):
Expand All @@ -34,6 +41,7 @@ def setup_device(self, device: torch.device) -> None:
"""
if device.type != "cuda":
raise ValueError(f"Device should be CUDA, got {device} instead.")
_check_cuda_matmul_precision(device)
torch.cuda.set_device(device)

def teardown(self) -> None:
Expand Down Expand Up @@ -179,3 +187,24 @@ def _device_count_nvml() -> int:
return -1
except AttributeError:
return -1


def _check_cuda_matmul_precision(device: torch.device) -> None:
if not _TORCH_GREATER_EQUAL_1_12:
# before 1.12, tf32 was used by default
return
major, _ = torch.cuda.get_device_capability(device)
ampere_or_later = major >= 8 # Ampere and later leverage tensor cores, where this setting becomes useful
if not ampere_or_later:
return
# check that the user hasn't changed the precision already, this works for both `allow_tf32 = True` and
# `set_float32_matmul_precision`
if torch.get_float32_matmul_precision() == "highest": # default
_log.info(
f"You are using a CUDA device ({torch.cuda.get_device_name(device)!r}) that has Tensor Cores. To properly"
" utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off"
" precision for performance. For more details, read https://pytorch.org/docs/stable/generated/"
"torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision"
)
# note: no need change `torch.backends.cudnn.allow_tf32` as it's enabled by default:
# https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))


- Added the option to set `DDPFullyShardedNativeStrategy(cpu_offload=True|False)` via bool instead of needing to pass a configufation object ([#15832](https://github.com/Lightning-AI/lightning/pull/15832))


- Added info message for Ampere CUDA GPU users to enable tf32 matmul precision ([#16037](https://github.com/Lightning-AI/lightning/pull/16037))

### Changed

- Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347))
Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch

import pytorch_lightning as pl
from lightning_lite.accelerators.cuda import num_cuda_devices
from lightning_lite.accelerators.cuda import _check_cuda_matmul_precision, num_cuda_devices
from lightning_lite.utilities.device_parser import _parse_gpu_ids
from lightning_lite.utilities.types import _DEVICE
from pytorch_lightning.accelerators.accelerator import Accelerator
Expand All @@ -40,6 +40,7 @@ def setup_device(self, device: torch.device) -> None:
"""
if device.type != "cuda":
raise MisconfigurationException(f"Device should be GPU, got {device} instead")
_check_cuda_matmul_precision(device)
torch.cuda.set_device(device)

def setup(self, trainer: "pl.Trainer") -> None:
Expand Down
49 changes: 45 additions & 4 deletions tests/tests_lite/accelerators/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import logging
import os
from unittest import mock
from unittest.mock import Mock

import pytest
import torch
from tests_lite.helpers.runif import RunIf

import lightning_lite
from lightning_lite.accelerators.cuda import CUDAAccelerator, is_cuda_available, num_cuda_devices
from lightning_lite.accelerators.cuda import (
_check_cuda_matmul_precision,
CUDAAccelerator,
is_cuda_available,
num_cuda_devices,
)


@mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2)
Expand Down Expand Up @@ -51,9 +58,11 @@ def test_get_parallel_devices(devices, expected):


@mock.patch("torch.cuda.set_device")
def test_set_cuda_device(set_device_mock):
CUDAAccelerator().setup_device(torch.device("cuda", 1))
set_device_mock.assert_called_once_with(torch.device("cuda", 1))
@mock.patch("torch.cuda.get_device_capability", return_value=(7, 0))
def test_set_cuda_device(_, set_device_mock):
device = torch.device("cuda", 1)
CUDAAccelerator().setup_device(device)
set_device_mock.assert_called_once_with(device)


@mock.patch("lightning_lite.accelerators.cuda._device_count_nvml", return_value=-1)
Expand All @@ -73,3 +82,35 @@ def test_force_nvml_based_cuda_check():
importlib.reload(lightning_lite) # reevaluate top-level code, without becoming a different object

assert os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] == "1"


@RunIf(min_torch="1.12")
@mock.patch("torch.cuda.get_device_capability", return_value=(10, 1))
@mock.patch("torch.cuda.get_device_name", return_value="Z100")
def test_tf32_message(_, __, caplog):
device = Mock()
expected = "Z100') that has Tensor Cores"
assert torch.get_float32_matmul_precision() == "highest" # default in torch
with caplog.at_level(logging.INFO):
_check_cuda_matmul_precision(device)
assert expected in caplog.text

caplog.clear()
torch.backends.cuda.matmul.allow_tf32 = True # changing this changes the string
assert torch.get_float32_matmul_precision() == "high"
with caplog.at_level(logging.INFO):
_check_cuda_matmul_precision(device)
assert not caplog.text

caplog.clear()
torch.backends.cuda.matmul.allow_tf32 = False
torch.set_float32_matmul_precision("medium") # also the other way around
assert torch.backends.cuda.matmul.allow_tf32
with caplog.at_level(logging.INFO):
_check_cuda_matmul_precision(device)
assert not caplog.text

torch.set_float32_matmul_precision("highest") # can be reverted
with caplog.at_level(logging.INFO):
_check_cuda_matmul_precision(device)
assert expected in caplog.text
3 changes: 2 additions & 1 deletion tests/tests_pytorch/strategies/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def test_torch_distributed_backend_invalid(cuda_count_2, tmpdir):

@RunIf(skip_windows=True)
@mock.patch("torch.cuda.set_device")
def test_ddp_torch_dist_is_available_in_setup(mock_set_device, cuda_count_1, tmpdir):
@mock.patch("pytorch_lightning.accelerators.cuda._check_cuda_matmul_precision")
def test_ddp_torch_dist_is_available_in_setup(_, __, cuda_count_1, tmpdir):
"""Test to ensure torch distributed is available within the setup hook using ddp."""

class TestModel(BoringModel):
Expand Down

0 comments on commit 3e664c9

Please sign in to comment.