From ffd517ad6ea2d9a558a8228ab6650de73ed87839 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 18 Feb 2022 07:05:43 -0600 Subject: [PATCH] [UnitTest] Disable ptx mma tests on unsupported nvcc versions. (#10229) * [UnitTest] Disable ptx mma tests on unsupported nvcc versions. - Modified `tvm.contrib.nvcc.get_cuda_version` to return a `(major,minor,release)` tuple rather than a float. - Implemented `tvm.testing.requries_nvcc_version` decorator to specify the minimum `(major,minor,release)` version needed to run a unit test. - Applied decorated to unit tests in `test_tir_ptx_mma.py` that fail on earlier nvcc versions. * Fix lint errors. * Updated a few of the cuda version checks. * More lint fixes. * Only compare major/minor in find_libdevice, not release version. --- python/tvm/contrib/cutlass/build.py | 7 ++-- python/tvm/contrib/nvcc.py | 39 +++++++++++++------- python/tvm/testing/utils.py | 43 +++++++++++++++++++++++ tests/python/unittest/test_tir_ptx_mma.py | 16 +++++++++ 4 files changed, 89 insertions(+), 16 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 06c33f2f7ae0..bd372572c403 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -21,7 +21,7 @@ import multiprocessing import tvm from tvm import runtime, relay -from tvm.contrib.nvcc import find_cuda_path, get_cuda_version +from tvm.contrib.nvcc import get_cuda_version from .gen_gemm import CutlassGemmProfiler from .gen_conv2d import CutlassConv2DProfiler from .library import ConvKind @@ -61,9 +61,8 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False): ] if use_fast_math: kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID") - cuda_path = find_cuda_path() - cuda_ver = get_cuda_version(cuda_path) - if cuda_ver >= 11.2: + cuda_ver = get_cuda_version() + if cuda_ver >= (11, 2): ncpu = multiprocessing.cpu_count() if threads < 0 else threads kwargs["options"].append("-t %d" % ncpu) return kwargs diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 19196dc3eefb..52b88d355602 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -140,27 +140,33 @@ def find_cuda_path(): raise RuntimeError("Cannot find cuda path") -def get_cuda_version(cuda_path): +def get_cuda_version(cuda_path=None): """Utility function to get cuda version Parameters ---------- - cuda_path : str - Path to cuda root. + cuda_path : Optional[str] + + Path to cuda root. If None is passed, will use + `find_cuda_path()` as default. Returns ------- version : float The cuda version + """ + if cuda_path is None: + cuda_path = find_cuda_path() + version_file_path = os.path.join(cuda_path, "version.txt") if not os.path.exists(version_file_path): # Debian/Ubuntu repackaged CUDA path version_file_path = os.path.join(cuda_path, "lib", "cuda", "version.txt") try: with open(version_file_path) as f: - version_str = f.readline().replace("\n", "").replace("\r", "") - return float(version_str.split(" ")[2][:2]) + version_str = f.read().strip().split()[-1] + return tuple(int(field) for field in version_str.split(".")) except FileNotFoundError: pass @@ -171,9 +177,8 @@ def get_cuda_version(cuda_path): if proc.returncode == 0: release_line = [l for l in out.split("\n") if "release" in l][0] release_fields = [s.strip() for s in release_line.split(",")] - release_version = [f[1:] for f in release_fields if f.startswith("V")][0] - major_minor = ".".join(release_version.split(".")[:2]) - return float(major_minor) + version_str = [f[1:] for f in release_fields if f.startswith("V")][0] + return tuple(int(field) for field in version_str.split(".")) raise RuntimeError("Cannot read cuda version file") @@ -206,7 +211,18 @@ def find_libdevice_path(arch): selected_ver = 0 selected_path = None cuda_ver = get_cuda_version(cuda_path) - if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0, 11.1, 11.2, 11.3): + major_minor = (cuda_ver[0], cuda_ver[1]) + if major_minor in ( + (9, 0), + (9, 1), + (10, 0), + (10, 1), + (10, 2), + (11, 0), + (11, 1), + (11, 2), + (11, 3), + ): path = os.path.join(lib_path, "libdevice.10.bc") else: for fn in os.listdir(lib_path): @@ -358,9 +374,8 @@ def have_tensorcore(compute_version=None, target=None): def have_cudagraph(): """Either CUDA Graph support is provided""" try: - cuda_path = find_cuda_path() - cuda_ver = get_cuda_version(cuda_path) - if cuda_ver < 10.0: + cuda_ver = get_cuda_version() + if cuda_ver < (10, 0): return False return True except RuntimeError: diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index dbaba46fdc9c..9ac1e9035235 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -635,6 +635,49 @@ def requires_nvptx(*args): return _compose(args, _requires_nvptx) +def requires_nvcc_version(major_version, minor_version=0, release_version=0): + """Mark a test as requiring at least a specific version of nvcc. + + Unit test marked with this decorator will run only if the + installed version of NVCC is at least `(major_version, + minor_version, release_version)`. + + This also marks the test as requiring a cuda support. + + Parameters + ---------- + major_version: int + + The major version of the (major,minor,release) version tuple. + + minor_version: int + + The minor version of the (major,minor,release) version tuple. + + release_version: int + + The release version of the (major,minor,release) version tuple. + + """ + + try: + nvcc_version = nvcc.get_cuda_version() + except RuntimeError: + nvcc_version = (0, 0, 0) + + min_version = (major_version, minor_version, release_version) + version_str = ".".join(str(v) for v in min_version) + requires = [ + pytest.mark.skipif(nvcc_version < min_version, reason=f"Requires NVCC >= {version_str}"), + *requires_cuda(), + ] + + def inner(func): + return _compose([func], requires) + + return inner + + def requires_cudagraph(*args): """Mark a test as requiring the CUDA Graph Feature diff --git a/tests/python/unittest/test_tir_ptx_mma.py b/tests/python/unittest/test_tir_ptx_mma.py index 4b8e3fcaffef..c304e818ef05 100644 --- a/tests/python/unittest/test_tir_ptx_mma.py +++ b/tests/python/unittest/test_tir_ptx_mma.py @@ -310,6 +310,10 @@ def gemm_mma_m8n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): ) +# This test uses mma instructions that are not available on NVCC 10.1. +# Failure occurs during the external call to nvcc, when attempting to +# generate the .fatbin file. +@tvm.testing.requires_nvcc_version(11) @tvm.testing.requires_cuda def test_gemm_mma_m8n8k16_row_col_s8s8s32(): sch = tvm.tir.Schedule(gemm_mma_m8n8k16_row_col_s8s8s32) @@ -384,6 +388,10 @@ def gemm_mma_m8n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): ) +# This test uses mma instructions that are not available on NVCC 10.1. +# Failure occurs during the external call to nvcc, when attempting to +# generate the .fatbin file. +@tvm.testing.requires_nvcc_version(11) @tvm.testing.requires_cuda def test_gemm_mma_m8n8k16_row_col_s8u8s32(): sch = tvm.tir.Schedule(gemm_mma_m8n8k16_row_col_s8u8s32) @@ -458,6 +466,10 @@ def gemm_mma_m8n8k32_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): ) +# This test uses mma instructions that are not available on NVCC 10.1. +# Failure occurs during the external call to nvcc, when attempting to +# generate the .fatbin file. +@tvm.testing.requires_nvcc_version(11) @tvm.testing.requires_cuda def test_gemm_mma_m8n8k32_row_col_s4s4s32(): sch = tvm.tir.Schedule(gemm_mma_m8n8k32_row_col_s4s4s32) @@ -524,6 +536,10 @@ def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): ) +# This test uses mma instructions that are not available on NVCC 10.1. +# Failure occurs during the external call to nvcc, when attempting to +# generate the .fatbin file. +@tvm.testing.requires_nvcc_version(11) @tvm.testing.requires_cuda def test_gemm_mma_m8n8k32_row_col_s4u4s32(): sch = tvm.tir.Schedule(gemm_mma_m8n8k32_row_col_s4u4s32)