Skip to content

Commit

Permalink
[UnitTest] Disable ptx mma tests on unsupported nvcc versions. (apach…
Browse files Browse the repository at this point in the history
…e#10229)

* [UnitTest] Disable ptx mma tests on unsupported nvcc versions.

- Modified `tvm.contrib.nvcc.get_cuda_version` to return a
  `(major,minor,release)` tuple rather than a float.

- Implemented `tvm.testing.requries_nvcc_version` decorator to specify
  the minimum `(major,minor,release)` version needed to run a unit
  test.

- Applied decorated to unit tests in `test_tir_ptx_mma.py` that fail
  on earlier nvcc versions.

* Fix lint errors.

* Updated a few of the cuda version checks.

* More lint fixes.

* Only compare major/minor in find_libdevice, not release version.
  • Loading branch information
Lunderberg authored Feb 18, 2022
1 parent 81e4eaf commit ffd517a
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 16 deletions.
7 changes: 3 additions & 4 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import multiprocessing
import tvm
from tvm import runtime, relay
from tvm.contrib.nvcc import find_cuda_path, get_cuda_version
from tvm.contrib.nvcc import get_cuda_version
from .gen_gemm import CutlassGemmProfiler
from .gen_conv2d import CutlassConv2DProfiler
from .library import ConvKind
Expand Down Expand Up @@ -61,9 +61,8 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False):
]
if use_fast_math:
kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID")
cuda_path = find_cuda_path()
cuda_ver = get_cuda_version(cuda_path)
if cuda_ver >= 11.2:
cuda_ver = get_cuda_version()
if cuda_ver >= (11, 2):
ncpu = multiprocessing.cpu_count() if threads < 0 else threads
kwargs["options"].append("-t %d" % ncpu)
return kwargs
Expand Down
39 changes: 27 additions & 12 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,27 +140,33 @@ def find_cuda_path():
raise RuntimeError("Cannot find cuda path")


def get_cuda_version(cuda_path):
def get_cuda_version(cuda_path=None):
"""Utility function to get cuda version
Parameters
----------
cuda_path : str
Path to cuda root.
cuda_path : Optional[str]
Path to cuda root. If None is passed, will use
`find_cuda_path()` as default.
Returns
-------
version : float
The cuda version
"""
if cuda_path is None:
cuda_path = find_cuda_path()

version_file_path = os.path.join(cuda_path, "version.txt")
if not os.path.exists(version_file_path):
# Debian/Ubuntu repackaged CUDA path
version_file_path = os.path.join(cuda_path, "lib", "cuda", "version.txt")
try:
with open(version_file_path) as f:
version_str = f.readline().replace("\n", "").replace("\r", "")
return float(version_str.split(" ")[2][:2])
version_str = f.read().strip().split()[-1]
return tuple(int(field) for field in version_str.split("."))
except FileNotFoundError:
pass

Expand All @@ -171,9 +177,8 @@ def get_cuda_version(cuda_path):
if proc.returncode == 0:
release_line = [l for l in out.split("\n") if "release" in l][0]
release_fields = [s.strip() for s in release_line.split(",")]
release_version = [f[1:] for f in release_fields if f.startswith("V")][0]
major_minor = ".".join(release_version.split(".")[:2])
return float(major_minor)
version_str = [f[1:] for f in release_fields if f.startswith("V")][0]
return tuple(int(field) for field in version_str.split("."))
raise RuntimeError("Cannot read cuda version file")


Expand Down Expand Up @@ -206,7 +211,18 @@ def find_libdevice_path(arch):
selected_ver = 0
selected_path = None
cuda_ver = get_cuda_version(cuda_path)
if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0, 11.1, 11.2, 11.3):
major_minor = (cuda_ver[0], cuda_ver[1])
if major_minor in (
(9, 0),
(9, 1),
(10, 0),
(10, 1),
(10, 2),
(11, 0),
(11, 1),
(11, 2),
(11, 3),
):
path = os.path.join(lib_path, "libdevice.10.bc")
else:
for fn in os.listdir(lib_path):
Expand Down Expand Up @@ -358,9 +374,8 @@ def have_tensorcore(compute_version=None, target=None):
def have_cudagraph():
"""Either CUDA Graph support is provided"""
try:
cuda_path = find_cuda_path()
cuda_ver = get_cuda_version(cuda_path)
if cuda_ver < 10.0:
cuda_ver = get_cuda_version()
if cuda_ver < (10, 0):
return False
return True
except RuntimeError:
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,49 @@ def requires_nvptx(*args):
return _compose(args, _requires_nvptx)


def requires_nvcc_version(major_version, minor_version=0, release_version=0):
"""Mark a test as requiring at least a specific version of nvcc.
Unit test marked with this decorator will run only if the
installed version of NVCC is at least `(major_version,
minor_version, release_version)`.
This also marks the test as requiring a cuda support.
Parameters
----------
major_version: int
The major version of the (major,minor,release) version tuple.
minor_version: int
The minor version of the (major,minor,release) version tuple.
release_version: int
The release version of the (major,minor,release) version tuple.
"""

try:
nvcc_version = nvcc.get_cuda_version()
except RuntimeError:
nvcc_version = (0, 0, 0)

min_version = (major_version, minor_version, release_version)
version_str = ".".join(str(v) for v in min_version)
requires = [
pytest.mark.skipif(nvcc_version < min_version, reason=f"Requires NVCC >= {version_str}"),
*requires_cuda(),
]

def inner(func):
return _compose([func], requires)

return inner


def requires_cudagraph(*args):
"""Mark a test as requiring the CUDA Graph Feature
Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_tir_ptx_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,10 @@ def gemm_mma_m8n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle):
)


# This test uses mma instructions that are not available on NVCC 10.1.
# Failure occurs during the external call to nvcc, when attempting to
# generate the .fatbin file.
@tvm.testing.requires_nvcc_version(11)
@tvm.testing.requires_cuda
def test_gemm_mma_m8n8k16_row_col_s8s8s32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k16_row_col_s8s8s32)
Expand Down Expand Up @@ -384,6 +388,10 @@ def gemm_mma_m8n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle):
)


# This test uses mma instructions that are not available on NVCC 10.1.
# Failure occurs during the external call to nvcc, when attempting to
# generate the .fatbin file.
@tvm.testing.requires_nvcc_version(11)
@tvm.testing.requires_cuda
def test_gemm_mma_m8n8k16_row_col_s8u8s32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k16_row_col_s8u8s32)
Expand Down Expand Up @@ -458,6 +466,10 @@ def gemm_mma_m8n8k32_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle):
)


# This test uses mma instructions that are not available on NVCC 10.1.
# Failure occurs during the external call to nvcc, when attempting to
# generate the .fatbin file.
@tvm.testing.requires_nvcc_version(11)
@tvm.testing.requires_cuda
def test_gemm_mma_m8n8k32_row_col_s4s4s32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k32_row_col_s4s4s32)
Expand Down Expand Up @@ -524,6 +536,10 @@ def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle):
)


# This test uses mma instructions that are not available on NVCC 10.1.
# Failure occurs during the external call to nvcc, when attempting to
# generate the .fatbin file.
@tvm.testing.requires_nvcc_version(11)
@tvm.testing.requires_cuda
def test_gemm_mma_m8n8k32_row_col_s4u4s32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k32_row_col_s4u4s32)
Expand Down

0 comments on commit ffd517a

Please sign in to comment.