Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tf32 warnings #6816

Merged
merged 27 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
5f23835
rename precision doc
qingpeng9802 Aug 3, 2023
6014216
add `version_geq`
qingpeng9802 Aug 3, 2023
e271ae2
detect default tf32 settings
qingpeng9802 Aug 3, 2023
5d286aa
refactor `is_tf32_env()`
qingpeng9802 Aug 3, 2023
3b2345a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2023
35438b8
fix style E402
qingpeng9802 Aug 3, 2023
51f3270
Merge branch 'tf32-warnings' of https://github.com/qingpeng9802/MONAI…
qingpeng9802 Aug 3, 2023
5697a31
fix style E722
qingpeng9802 Aug 3, 2023
948ccc3
[MONAI] code formatting
monai-bot Aug 3, 2023
96ca146
refactor the usage of `detect_default_tf32()`
qingpeng9802 Aug 4, 2023
d8a65bb
improve `is_tf32_env()`
qingpeng9802 Aug 4, 2023
93ff777
[MONAI] code formatting
monai-bot Aug 4, 2023
0ad3a73
Merge branch 'dev' into tf32-warnings
wyli Aug 4, 2023
348f089
resolve `torch.cuda` initialization order issue
qingpeng9802 Aug 5, 2023
04b71c3
Merge branch 'tf32-warnings' of https://github.com/qingpeng9802/MONAI…
qingpeng9802 Aug 5, 2023
9a4310f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2023
a783158
[MONAI] code formatting
monai-bot Aug 5, 2023
13cd277
use `pynvml` to avoid `torch.cuda` call
qingpeng9802 Aug 7, 2023
fddc87f
minor fix
qingpeng9802 Aug 7, 2023
2a6ff88
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2023
e1b2075
[MONAI] code formatting
monai-bot Aug 7, 2023
73cb104
Merge branch 'dev' into tf32-warnings
wyli Aug 7, 2023
72c8a72
fix import `pynvml`
qingpeng9802 Aug 7, 2023
d1ed95b
Merge branch 'tf32-warnings' of https://github.com/qingpeng9802/MONAI…
qingpeng9802 Aug 7, 2023
055ed46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2023
f4571f1
[MONAI] code formatting
monai-bot Aug 7, 2023
147224c
Merge branch 'dev' into tf32-warnings
wyli Aug 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ Technical documentation is available at `docs.monai.io <https://docs.monai.io>`_

.. toctree::
:maxdepth: 1
:caption: Precision and Performance
:caption: Precision and Accelerating

precision_performance
precision_accelerating

.. toctree::
:maxdepth: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ by TF32 mode so the impact is very wide.
torch.backends.cuda.matmul.allow_tf32 = False # in PyTorch 1.12 and later.
torch.backends.cudnn.allow_tf32 = True
```
Please note that there are environment variables that can override the flags above. For example, the environment variables mentioned in [Accelerating AI Training with NVIDIA TF32 Tensor Cores](https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/) and `TORCH_ALLOW_TF32_CUBLAS_OVERRIDE` used by PyTorch. Thus, in some cases, the flags may be accidentally changed or overridden.

We recommend that users print out these two flags for confirmation when unsure.
Please note that there are environment variables that can override the flags above. For example, the environment variable `NVIDIA_TF32_OVERRIDE` mentioned in [Accelerating AI Training with NVIDIA TF32 Tensor Cores](https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/) and `TORCH_ALLOW_TF32_CUBLAS_OVERRIDE` used by PyTorch. Thus, in some cases, the flags may be accidentally changed or overridden.

If you are using an [NGC PyTorch container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch), the container includes a layer `ENV TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=1`.
The default value `torch.backends.cuda.matmul.allow_tf32` will be overridden to `True`.

We recommend that users print out these two flags for confirmation when unsure.

If you can confirm through experiments that your model has no accuracy or convergence issues in TF32 mode and you have NVIDIA Ampere GPUs or above, you can set the two flags above to `True` to speed up your model.
10 changes: 10 additions & 0 deletions monai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,13 @@
"utils",
"visualize",
]

try:
from .utils.tf32 import detect_default_tf32

detect_default_tf32()
except BaseException:
from .utils.misc import MONAIEnvVars

if MONAIEnvVars.debug():
raise
2 changes: 2 additions & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
require_pkg,
run_debug,
run_eval,
version_geq,
version_leq,
)
from .nvtx import Range
Expand All @@ -128,6 +129,7 @@
torch_profiler_time_end_to_end,
)
from .state_cacher import StateCacher
from .tf32 import detect_default_tf32, has_ampere_or_later
from .type_conversion import (
convert_data_type,
convert_to_cupy,
Expand Down
68 changes: 52 additions & 16 deletions monai/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pydoc import locate
from re import match
from types import FunctionType, ModuleType
from typing import Any, cast
from typing import Any, Iterable, cast

import torch

Expand Down Expand Up @@ -55,6 +55,7 @@
"get_package_version",
"get_torch_version_tuple",
"version_leq",
"version_geq",
"pytorch_after",
]

Expand Down Expand Up @@ -518,24 +519,11 @@ def get_torch_version_tuple():
return tuple(int(x) for x in torch.__version__.split(".")[:2])


def version_leq(lhs: str, rhs: str) -> bool:
def parse_version_strs(lhs: str, rhs: str) -> tuple[Iterable[int | str], Iterable[int | str]]:
"""
Returns True if version `lhs` is earlier or equal to `rhs`.

Args:
lhs: version name to compare with `rhs`, return True if earlier or equal to `rhs`.
rhs: version name to compare with `lhs`, return True if later or equal to `lhs`.

Parse the version strings.
"""

lhs, rhs = str(lhs), str(rhs)
pkging, has_ver = optional_import("pkg_resources", name="packaging")
if has_ver:
try:
return cast(bool, pkging.version.Version(lhs) <= pkging.version.Version(rhs))
except pkging.version.InvalidVersion:
return True

def _try_cast(val: str) -> int | str:
val = val.strip()
try:
Expand All @@ -554,7 +542,28 @@ def _try_cast(val: str) -> int | str:
# parse the version strings in this basic way without `packaging` package
lhs_ = map(_try_cast, lhs.split("."))
rhs_ = map(_try_cast, rhs.split("."))
return lhs_, rhs_


def version_leq(lhs: str, rhs: str) -> bool:
"""
Returns True if version `lhs` is earlier or equal to `rhs`.

Args:
lhs: version name to compare with `rhs`, return True if earlier or equal to `rhs`.
rhs: version name to compare with `lhs`, return True if later or equal to `lhs`.

"""

lhs, rhs = str(lhs), str(rhs)
pkging, has_ver = optional_import("pkg_resources", name="packaging")
if has_ver:
try:
return cast(bool, pkging.version.Version(lhs) <= pkging.version.Version(rhs))
except pkging.version.InvalidVersion:
return True

lhs_, rhs_ = parse_version_strs(lhs, rhs)
for l, r in zip(lhs_, rhs_):
if l != r:
if isinstance(l, int) and isinstance(r, int):
Expand All @@ -564,6 +573,33 @@ def _try_cast(val: str) -> int | str:
return True


def version_geq(lhs: str, rhs: str) -> bool:
"""
Returns True if version `lhs` is later or equal to `rhs`.

Args:
lhs: version name to compare with `rhs`, return True if later or equal to `rhs`.
rhs: version name to compare with `lhs`, return True if earlier or equal to `lhs`.

"""
lhs, rhs = str(lhs), str(rhs)
pkging, has_ver = optional_import("pkg_resources", name="packaging")
if has_ver:
try:
return cast(bool, pkging.version.Version(lhs) >= pkging.version.Version(rhs))
except pkging.version.InvalidVersion:
return True

lhs_, rhs_ = parse_version_strs(lhs, rhs)
for l, r in zip(lhs_, rhs_):
if l != r:
if isinstance(l, int) and isinstance(r, int):
return l > r
return f"{l}" > f"{r}"

return True


@functools.lru_cache(None)
def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: str | None = None) -> bool:
"""
Expand Down
76 changes: 76 additions & 0 deletions monai/utils/tf32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import functools
import os
import warnings

import torch

from monai.utils.module import pytorch_after, version_geq

__all__ = ["has_ampere_or_later", "detect_default_tf32"]


@functools.lru_cache(None)
def has_ampere_or_later() -> bool:
"""
Check if there is any Ampere and later GPU.
"""
if not torch.cuda.is_available():
return False
if not version_geq(f"{torch.version.cuda}", "11.0"):
return False
for i in range(torch.cuda.device_count()):
major, _ = torch.cuda.get_device_capability(i)
if major >= 8: # Ampere and later
return True
wyli marked this conversation as resolved.
Show resolved Hide resolved
return False


@functools.lru_cache(None)
def detect_default_tf32() -> bool:
qingpeng9802 marked this conversation as resolved.
Show resolved Hide resolved
"""
Dectect if there is anything that may enable TF32 mode by default.
If any, show a warning message.
"""
may_enable_tf32 = False
try:
if not has_ampere_or_later():
return False

if pytorch_after(1, 7, 0) and not pytorch_after(1, 12, 0):
warnings.warn(
"torch.backends.cuda.matmul.allow_tf32 = True by default.\n"
" This value defaults to True when PyTorch version in [1.7, 1.11] and may affect precision.\n"
" See https://docs.monai.io/en/latest/precision_accelerating.html#precision-and-accelerating"
)
may_enable_tf32 = True

override_tf32_env_vars = {"NVIDIA_TF32_OVERRIDE": "1", "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE": "1"}
for name, override_val in override_tf32_env_vars.items():
if os.environ.get(name) == override_val:
warnings.warn(
f"Environment variable `{name} = {override_val}` is set.\n"
f" This environment variable may enable TF32 mode accidentally and affect precision.\n"
f" See https://docs.monai.io/en/latest/precision_accelerating.html#precision-and-accelerating"
)
may_enable_tf32 = True

return may_enable_tf32
except BaseException:
from monai.utils.misc import MONAIEnvVars

if MONAIEnvVars.debug():
raise
return False
9 changes: 7 additions & 2 deletions tests/test_version_leq.py → tests/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from parameterized import parameterized

from monai.utils import version_leq
from monai.utils import version_geq, version_leq


# from pkg_resources
Expand Down Expand Up @@ -76,10 +76,15 @@ def _pairwise(iterable):

class TestVersionCompare(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_compare(self, a, b, expected=True):
def test_compare_leq(self, a, b, expected=True):
"""Test version_leq with `a` and `b`"""
self.assertEqual(version_leq(a, b), expected)

@parameterized.expand(TEST_CASES)
def test_compare_geq(self, a, b, expected=True):
"""Test version_geq with `b` and `a`"""
self.assertEqual(version_geq(b, a), expected)


if __name__ == "__main__":
unittest.main()
16 changes: 6 additions & 10 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
from monai.data.meta_tensor import MetaTensor, get_track_meta
from monai.networks import convert_to_onnx, convert_to_torchscript
from monai.utils import optional_import
from monai.utils.module import pytorch_after, version_leq
from monai.utils.module import pytorch_after
from monai.utils.tf32 import detect_default_tf32
from monai.utils.type_conversion import convert_data_type

nib, _ = optional_import("nibabel")
Expand Down Expand Up @@ -172,19 +173,14 @@ def test_is_quick():

def is_tf32_env():
"""
The environment variable NVIDIA_TF32_OVERRIDE=0 will override any defaults
or programmatic configuration of NVIDIA libraries, and consequently,
cuBLAS will not accelerate FP32 computations with TF32 tensor cores.
When we may be using TF32 mode, check the precision of matrix operation.
If the checking result is greater than the threshold 0.001,
set _tf32_enabled=True (and relax _rtol for tests).
"""
global _tf32_enabled
qingpeng9802 marked this conversation as resolved.
Show resolved Hide resolved
if _tf32_enabled is None:
_tf32_enabled = False
if (
torch.cuda.is_available()
and not version_leq(f"{torch.version.cuda}", "10.100")
and os.environ.get("NVIDIA_TF32_OVERRIDE", "1") != "0"
and torch.cuda.device_count() > 0 # at least 11.0
):
if detect_default_tf32() or torch.backends.cuda.matmul.allow_tf32:
try:
# with TF32 enabled, the speed is ~8x faster, but the precision has ~2 digits less in the result
g_gpu = torch.Generator(device="cuda")
Expand Down