diff --git a/src/lightning_lite/accelerators/cuda.py b/src/lightning_lite/accelerators/cuda.py index f852462b084f6..ca11ca1cfa0a2 100644 --- a/src/lightning_lite/accelerators/cuda.py +++ b/src/lightning_lite/accelerators/cuda.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import warnings from contextlib import contextmanager @@ -20,7 +21,13 @@ import torch from lightning_lite.accelerators.accelerator import Accelerator -from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCH_GREATER_EQUAL_1_14 +from lightning_lite.utilities.imports import ( + _TORCH_GREATER_EQUAL_1_12, + _TORCH_GREATER_EQUAL_1_13, + _TORCH_GREATER_EQUAL_1_14, +) + +_log = logging.getLogger(__name__) class CUDAAccelerator(Accelerator): @@ -34,6 +41,7 @@ def setup_device(self, device: torch.device) -> None: """ if device.type != "cuda": raise ValueError(f"Device should be CUDA, got {device} instead.") + _check_cuda_matmul_precision(device) torch.cuda.set_device(device) def teardown(self) -> None: @@ -179,3 +187,24 @@ def _device_count_nvml() -> int: return -1 except AttributeError: return -1 + + +def _check_cuda_matmul_precision(device: torch.device) -> None: + if not _TORCH_GREATER_EQUAL_1_12: + # before 1.12, tf32 was used by default + return + major, _ = torch.cuda.get_device_capability(device) + ampere_or_later = major >= 8 # Ampere and later leverage tensor cores, where this setting becomes useful + if not ampere_or_later: + return + # check that the user hasn't changed the precision already, this works for both `allow_tf32 = True` and + # `set_float32_matmul_precision` + if torch.get_float32_matmul_precision() == "highest": # default + _log.info( + f"You are using a CUDA device ({torch.cuda.get_device_name(device)!r}) that has Tensor Cores. To properly" + " utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off" + " precision for performance. For more details, read https://pytorch.org/docs/stable/generated/" + "torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision" + ) + # note: no need change `torch.backends.cudnn.allow_tf32` as it's enabled by default: + # https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index e184893436a42..24be138ef22cd 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -35,9 +35,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826)) + - Added the option to set `DDPFullyShardedNativeStrategy(cpu_offload=True|False)` via bool instead of needing to pass a configufation object ([#15832](https://github.com/Lightning-AI/lightning/pull/15832)) +- Added info message for Ampere CUDA GPU users to enable tf32 matmul precision ([#16037](https://github.com/Lightning-AI/lightning/pull/16037)) + ### Changed - Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347)) diff --git a/src/pytorch_lightning/accelerators/cuda.py b/src/pytorch_lightning/accelerators/cuda.py index 58dcee0ef2d76..8472b62e8aaab 100644 --- a/src/pytorch_lightning/accelerators/cuda.py +++ b/src/pytorch_lightning/accelerators/cuda.py @@ -20,7 +20,7 @@ import torch import pytorch_lightning as pl -from lightning_lite.accelerators.cuda import num_cuda_devices +from lightning_lite.accelerators.cuda import _check_cuda_matmul_precision, num_cuda_devices from lightning_lite.utilities.device_parser import _parse_gpu_ids from lightning_lite.utilities.types import _DEVICE from pytorch_lightning.accelerators.accelerator import Accelerator @@ -40,6 +40,7 @@ def setup_device(self, device: torch.device) -> None: """ if device.type != "cuda": raise MisconfigurationException(f"Device should be GPU, got {device} instead") + _check_cuda_matmul_precision(device) torch.cuda.set_device(device) def setup(self, trainer: "pl.Trainer") -> None: diff --git a/tests/tests_lite/accelerators/test_cuda.py b/tests/tests_lite/accelerators/test_cuda.py index 8cc7a7eee3961..ab0c54b59c7fe 100644 --- a/tests/tests_lite/accelerators/test_cuda.py +++ b/tests/tests_lite/accelerators/test_cuda.py @@ -12,15 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +import logging import os from unittest import mock +from unittest.mock import Mock import pytest import torch from tests_lite.helpers.runif import RunIf import lightning_lite -from lightning_lite.accelerators.cuda import CUDAAccelerator, is_cuda_available, num_cuda_devices +from lightning_lite.accelerators.cuda import ( + _check_cuda_matmul_precision, + CUDAAccelerator, + is_cuda_available, + num_cuda_devices, +) @mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2) @@ -51,9 +58,11 @@ def test_get_parallel_devices(devices, expected): @mock.patch("torch.cuda.set_device") -def test_set_cuda_device(set_device_mock): - CUDAAccelerator().setup_device(torch.device("cuda", 1)) - set_device_mock.assert_called_once_with(torch.device("cuda", 1)) +@mock.patch("torch.cuda.get_device_capability", return_value=(7, 0)) +def test_set_cuda_device(_, set_device_mock): + device = torch.device("cuda", 1) + CUDAAccelerator().setup_device(device) + set_device_mock.assert_called_once_with(device) @mock.patch("lightning_lite.accelerators.cuda._device_count_nvml", return_value=-1) @@ -73,3 +82,35 @@ def test_force_nvml_based_cuda_check(): importlib.reload(lightning_lite) # reevaluate top-level code, without becoming a different object assert os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] == "1" + + +@RunIf(min_torch="1.12") +@mock.patch("torch.cuda.get_device_capability", return_value=(10, 1)) +@mock.patch("torch.cuda.get_device_name", return_value="Z100") +def test_tf32_message(_, __, caplog): + device = Mock() + expected = "Z100') that has Tensor Cores" + assert torch.get_float32_matmul_precision() == "highest" # default in torch + with caplog.at_level(logging.INFO): + _check_cuda_matmul_precision(device) + assert expected in caplog.text + + caplog.clear() + torch.backends.cuda.matmul.allow_tf32 = True # changing this changes the string + assert torch.get_float32_matmul_precision() == "high" + with caplog.at_level(logging.INFO): + _check_cuda_matmul_precision(device) + assert not caplog.text + + caplog.clear() + torch.backends.cuda.matmul.allow_tf32 = False + torch.set_float32_matmul_precision("medium") # also the other way around + assert torch.backends.cuda.matmul.allow_tf32 + with caplog.at_level(logging.INFO): + _check_cuda_matmul_precision(device) + assert not caplog.text + + torch.set_float32_matmul_precision("highest") # can be reverted + with caplog.at_level(logging.INFO): + _check_cuda_matmul_precision(device) + assert expected in caplog.text diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py index d95c76e20d4a5..529b9ff21c4d4 100644 --- a/tests/tests_pytorch/strategies/test_ddp.py +++ b/tests/tests_pytorch/strategies/test_ddp.py @@ -74,7 +74,8 @@ def test_torch_distributed_backend_invalid(cuda_count_2, tmpdir): @RunIf(skip_windows=True) @mock.patch("torch.cuda.set_device") -def test_ddp_torch_dist_is_available_in_setup(mock_set_device, cuda_count_1, tmpdir): +@mock.patch("pytorch_lightning.accelerators.cuda._check_cuda_matmul_precision") +def test_ddp_torch_dist_is_available_in_setup(_, __, cuda_count_1, tmpdir): """Test to ensure torch distributed is available within the setup hook using ddp.""" class TestModel(BoringModel):