Skip to content

Commit

Permalink
[UnitTest] Disable ptx mma tests on unsupported nvcc versions.
Browse files Browse the repository at this point in the history
- Modified `tvm.contrib.nvcc.get_cuda_version` to return a
  `(major,minor,release)` tuple rather than a float.

- Implemented `tvm.testing.requries_nvcc_version` decorator to specify
  the minimum `(major,minor,release)` version needed to run a unit
  test.

- Applied decorated to unit tests in `test_tir_ptx_mma.py` that fail
  on earlier nvcc versions.
  • Loading branch information
Lunderberg committed Feb 14, 2022
1 parent 5e4e239 commit 3b53d8d
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 11 deletions.
5 changes: 2 additions & 3 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,8 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False):
]
if use_fast_math:
kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID")
cuda_path = find_cuda_path()
cuda_ver = get_cuda_version(cuda_path)
if cuda_ver >= 11.2:
cuda_ver = get_cuda_version()
if cuda_ver >= (11, 2):
ncpu = multiprocessing.cpu_count() if threads < 0 else threads
kwargs["options"].append("-t %d" % ncpu)
return kwargs
Expand Down
21 changes: 13 additions & 8 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,27 +140,33 @@ def find_cuda_path():
raise RuntimeError("Cannot find cuda path")


def get_cuda_version(cuda_path):
def get_cuda_version(cuda_path=None):
"""Utility function to get cuda version
Parameters
----------
cuda_path : str
Path to cuda root.
cuda_path : Optional[str]
Path to cuda root. If None is passed, will use
`find_cuda_path()` as default.
Returns
-------
version : float
The cuda version
"""
if cuda_path is None:
cuda_path = find_cuda_path()

version_file_path = os.path.join(cuda_path, "version.txt")
if not os.path.exists(version_file_path):
# Debian/Ubuntu repackaged CUDA path
version_file_path = os.path.join(cuda_path, "lib", "cuda", "version.txt")
try:
with open(version_file_path) as f:
version_str = f.readline().replace("\n", "").replace("\r", "")
return float(version_str.split(" ")[2][:2])
version_str = f.read().strip().split()[-1]
return tuple(int(field) for field in version_str.split("."))
except FileNotFoundError:
pass

Expand All @@ -171,9 +177,8 @@ def get_cuda_version(cuda_path):
if proc.returncode == 0:
release_line = [l for l in out.split("\n") if "release" in l][0]
release_fields = [s.strip() for s in release_line.split(",")]
release_version = [f[1:] for f in release_fields if f.startswith("V")][0]
major_minor = ".".join(release_version.split(".")[:2])
return float(major_minor)
version_str = [f[1:] for f in release_fields if f.startswith("V")][0]
return tuple(int(field) for field in version_str.split("."))
raise RuntimeError("Cannot read cuda version file")


Expand Down
43 changes: 43 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,49 @@ def requires_nvptx(*args):
return _compose(args, _requires_nvptx)


def requires_nvcc_version(major_version, minor_version=0, release_version=0):
"""Mark a test as requiring at least a specific version of nvcc.
Unit test marked with this decorator will run only if the
installed version of NVCC is at least `(major_version,
minor_version, release_version)`.
This also marks the test as requiring a cuda support.
Parameters
----------
major_version: int
The major version of the (major,minor,release) version tuple.
minor_version: int
The minor version of the (major,minor,release) version tuple.
release_version: int
The release version of the (major,minor,release) version tuple.
"""

try:
nvcc_version = nvcc.get_cuda_version()
except RuntimeError:
nvcc_version = (0, 0, 0)

min_version = (major_version, minor_version, release_version)
version_str = ".".join(str(v) for v in min_version)
requires = [
pytest.mark.skipif(nvcc_version < min_version, reason=f"Requires NVCC >= {version_str}"),
*requires_cuda(),
]

def inner(func):
return _compose([inner], requires)

return inner


def requires_cudagraph(*args):
"""Mark a test as requiring the CUDA Graph Feature
Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_tir_ptx_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,10 @@ def gemm_mma_m8n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle):
)


# This test uses mma instructions that are not available on NVCC 10.1.
# Failure occurs during the external call to nvcc, when attempting to
# generate the .fatbin file.
@tvm.testing.requires_nvcc_version(11)
@tvm.testing.requires_cuda
def test_gemm_mma_m8n8k16_row_col_s8s8s32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k16_row_col_s8s8s32)
Expand Down Expand Up @@ -384,6 +388,10 @@ def gemm_mma_m8n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle):
)


# This test uses mma instructions that are not available on NVCC 10.1.
# Failure occurs during the external call to nvcc, when attempting to
# generate the .fatbin file.
@tvm.testing.requires_nvcc_version(11)
@tvm.testing.requires_cuda
def test_gemm_mma_m8n8k16_row_col_s8u8s32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k16_row_col_s8u8s32)
Expand Down Expand Up @@ -458,6 +466,10 @@ def gemm_mma_m8n8k32_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle):
)


# This test uses mma instructions that are not available on NVCC 10.1.
# Failure occurs during the external call to nvcc, when attempting to
# generate the .fatbin file.
@tvm.testing.requires_nvcc_version(11)
@tvm.testing.requires_cuda
def test_gemm_mma_m8n8k32_row_col_s4s4s32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k32_row_col_s4s4s32)
Expand Down Expand Up @@ -524,6 +536,10 @@ def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle):
)


# This test uses mma instructions that are not available on NVCC 10.1.
# Failure occurs during the external call to nvcc, when attempting to
# generate the .fatbin file.
@tvm.testing.requires_nvcc_version(11)
@tvm.testing.requires_cuda
def test_gemm_mma_m8n8k32_row_col_s4u4s32():
sch = tvm.tir.Schedule(gemm_mma_m8n8k32_row_col_s4u4s32)
Expand Down

0 comments on commit 3b53d8d

Please sign in to comment.