From 04a36d4967e0dd8af41b85d8bc43a94f5910558d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 1 Jul 2021 10:41:06 -0700 Subject: [PATCH] [UnitTests] Added cuDNN target to default test targets Some unit tests explicitly test cudnn in addition to tvm.testing.enabled_targets(). This moved the cudnn checks into the same framework as all other targets, and adds it to the default list of targets to be run. Also, added `@tvm.testing.requires_cudnn` for tests specific to cudnn. --- python/tvm/testing.py | 70 ++++++++++++++++++++++++++++++++----------- 1 file changed, 53 insertions(+), 17 deletions(-) diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 4721c0050656c..e143353a2b241 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -70,7 +70,8 @@ def test_something(): import tvm.tir import tvm.te import tvm._ffi -from tvm.contrib import nvcc + +from tvm.contrib import nvcc, cudnn from tvm.error import TVMError @@ -375,11 +376,12 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): def _get_targets(target_str=None): if target_str is None: target_str = os.environ.get("TVM_TEST_TARGETS", "") + # Use dict instead of set for de-duplication so that the + # targets stay in the order specified. + target_names = list({t.strip(): None for t in target_str.split(";") if t.strip()}) - if len(target_str) == 0: - target_str = DEFAULT_TEST_TARGETS - - target_names = set(t.strip() for t in target_str.split(";") if t.strip()) + if len(target_names) == 0: + target_names = DEFAULT_TEST_TARGETS targets = [] for target in target_names: @@ -413,10 +415,19 @@ def _get_targets(target_str=None): return targets -DEFAULT_TEST_TARGETS = ( - "llvm;cuda;opencl;metal;rocm;vulkan -from_device=0;nvptx;" - "llvm -device=arm_cpu;opencl -device=mali,aocl_sw_emu" -) +DEFAULT_TEST_TARGETS = [ + "llvm", + "llvm -device=arm_cpu", + "cuda", + "cuda -model=unknown -libs=cudnn", + "nvptx", + "vulkan -from_device=0", + "opencl", + "opencl -device=mali,aocl_sw_emu", + "opencl -device=intel_graphics", + "metal", + "rocm", +] def device_enabled(target): @@ -548,6 +559,24 @@ def requires_cuda(*args): return _compose(args, _requires_cuda) +def requires_cudnn(*args): + """Mark a test as requiring the cuDNN library. + + This also marks the test as requiring a cuda gpu. + + Parameters + ---------- + f : function + Function to mark + """ + + requirements = [ + pytest.mark.skipif(not cudnn.exists(), reason="cuDNN library not enabled, or not installed") + * requires_cuda(), + ] + return _compose(args, requirements) + + def requires_nvptx(*args): """Mark a test as requiring the NVPTX compilation on the CUDA runtime @@ -730,20 +759,27 @@ def requires_rpc(*args): def _target_to_requirement(target): + if isinstance(target, str): + target = tvm.target.Target(target) + # mapping from target to decorator - if target.startswith("cuda"): - return requires_cuda() - if target.startswith("rocm"): + if target.kind.name == "cuda": + if "cudnn" in target.attrs.get("libs", []): + return requires_cudnn() + else: + return requires_cuda() + + if target.kind.name == "rocm": return requires_rocm() - if target.startswith("vulkan"): + if target.kind.name == "vulkan": return requires_vulkan() - if target.startswith("nvptx"): + if target.kind.name == "nvptx": return requires_nvptx() - if target.startswith("metal"): + if target.kind.name == "metal": return requires_metal() - if target.startswith("opencl"): + if target.kind.name == "opencl": return requires_opencl() - if target.startswith("llvm"): + if target.kind.name == "llvm": return requires_llvm() return []