From b4628a727d22c4c234c449c5fd6949b3c96fc3a9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 3 Aug 2021 11:31:20 -0500 Subject: [PATCH] [Testing] Enable `Target` object as argument to _target_to_requirement Previously, tvm.testing._target_to_requirement required the argument to be a string. This commit allows it to be either a string or a `tvm.target.Target`. --- python/tvm/testing.py | 42 +++++++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 3a885aa5357db..7c2110fee26b4 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -375,11 +375,12 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): def _get_targets(target_str=None): if target_str is None: target_str = os.environ.get("TVM_TEST_TARGETS", "") + # Use dict instead of set for de-duplication so that the + # targets stay in the order specified. + target_names = list({t.strip(): None for t in target_str.split(";") if t.strip()}) - if len(target_str) == 0: - target_str = DEFAULT_TEST_TARGETS - - target_names = set(t.strip() for t in target_str.split(";") if t.strip()) + if not target_names: + target_names = DEFAULT_TEST_TARGETS targets = [] for target in target_names: @@ -413,10 +414,18 @@ def _get_targets(target_str=None): return targets -DEFAULT_TEST_TARGETS = ( - "llvm;cuda;opencl;metal;rocm;vulkan -from_device=0;nvptx;" - "llvm -device=arm_cpu;opencl -device=mali,aocl_sw_emu" -) +DEFAULT_TEST_TARGETS = [ + "llvm", + "llvm -device=arm_cpu", + "cuda", + "nvptx", + "vulkan -from_device=0", + "opencl", + "opencl -device=mali,aocl_sw_emu", + "opencl -device=intel_graphics", + "metal", + "rocm", +] def device_enabled(target): @@ -730,20 +739,23 @@ def requires_rpc(*args): def _target_to_requirement(target): + if isinstance(target, str): + target = tvm.target.Target(target) + # mapping from target to decorator - if target.startswith("cuda"): + if target.kind.name == "cuda": return requires_cuda() - if target.startswith("rocm"): + if target.kind.name == "rocm": return requires_rocm() - if target.startswith("vulkan"): + if target.kind.name == "vulkan": return requires_vulkan() - if target.startswith("nvptx"): + if target.kind.name == "nvptx": return requires_nvptx() - if target.startswith("metal"): + if target.kind.name == "metal": return requires_metal() - if target.startswith("opencl"): + if target.kind.name == "opencl": return requires_opencl() - if target.startswith("llvm"): + if target.kind.name == "llvm": return requires_llvm() return []