Skip to content

Commit

Permalink
[UnitTests] Added cuDNN target to default test targets
Browse files Browse the repository at this point in the history
Some unit tests explicitly test cudnn in addition to
tvm.testing.enabled_targets().  This moved the cudnn checks into the
same framework as all other targets, and adds it to the default list
of targets to be run.  Also, added `@tvm.testing.requires_cudnn` for
tests specific to cudnn.
  • Loading branch information
Lunderberg committed Jul 22, 2021
1 parent 07243a8 commit 373d8d1
Showing 1 changed file with 52 additions and 16 deletions.
68 changes: 52 additions & 16 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def test_something():
import tvm.tir
import tvm.te
import tvm._ffi
from tvm.contrib import nvcc

from tvm.contrib import nvcc, cudnn
from tvm.error import TVMError


Expand Down Expand Up @@ -375,11 +376,12 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap):
def _get_targets(target_str=None):
if target_str is None:
target_str = os.environ.get("TVM_TEST_TARGETS", "")
# Use dict instead of set for de-duplication so that the
# targets stay in the order specified.
target_names = list({t.strip(): None for t in target_str.split(";") if t.strip()})

if len(target_str) == 0:
target_str = DEFAULT_TEST_TARGETS

target_names = set(t.strip() for t in target_str.split(";") if t.strip())
if len(target_names) == 0:
target_names = DEFAULT_TEST_TARGETS

targets = []
for target in target_names:
Expand Down Expand Up @@ -413,10 +415,19 @@ def _get_targets(target_str=None):
return targets


DEFAULT_TEST_TARGETS = (
"llvm;cuda;opencl;metal;rocm;vulkan -from_device=0;nvptx;"
"llvm -device=arm_cpu;opencl -device=mali,aocl_sw_emu"
)
DEFAULT_TEST_TARGETS = [
"llvm",
"llvm -device=arm_cpu",
"cuda",
"cuda -model=unknown -libs=cudnn",
"nvptx",
"vulkan -from_device=0",
"opencl",
"opencl -device=mali,aocl_sw_emu",
"opencl -device=intel_graphics",
"metal",
"rocm",
]


def device_enabled(target):
Expand Down Expand Up @@ -548,6 +559,26 @@ def requires_cuda(*args):
return _compose(args, _requires_cuda)


def requires_cudnn(*args):
"""Mark a test as requiring the cuDNN library.
This also marks the test as requiring a cuda gpu.
Parameters
----------
f : function
Function to mark
"""

requirements = [
pytest.mark.skipif(
not cudnn.exists(), reason="cuDNN library not enabled, or not installed"
),
*requires_cuda(),
]
return _compose(args, requirements)


def requires_nvptx(*args):
"""Mark a test as requiring the NVPTX compilation on the CUDA runtime
Expand Down Expand Up @@ -730,20 +761,25 @@ def requires_rpc(*args):


def _target_to_requirement(target):
if isinstance(target, str):
target = tvm.target.Target(target)

# mapping from target to decorator
if target.startswith("cuda"):
if target.kind.name == "cuda" and "cudnn" in target.attrs.get("libs", []):
return requires_cudnn()
if target.kind.name == "cuda":
return requires_cuda()
if target.startswith("rocm"):
if target.kind.name == "rocm":
return requires_rocm()
if target.startswith("vulkan"):
if target.kind.name == "vulkan":
return requires_vulkan()
if target.startswith("nvptx"):
if target.kind.name == "nvptx":
return requires_nvptx()
if target.startswith("metal"):
if target.kind.name == "metal":
return requires_metal()
if target.startswith("opencl"):
if target.kind.name == "opencl":
return requires_opencl()
if target.startswith("llvm"):
if target.kind.name == "llvm":
return requires_llvm()
return []

Expand Down

0 comments on commit 373d8d1

Please sign in to comment.