Skip to content

Commit

Permalink
[UnitTests] Added cuDNN to default test targets (apache#8383)
Browse files Browse the repository at this point in the history
* [Target][UnitTests] Look up target requirements based on tvm.target.Target

- Read target.kind.name instead of using string manipulation.

- Target device query on a non-existent target is no longer an error.
  This occurs if expanding `vulkan -from_device=0` on a non-GPU
  machine.

* [UnitTests] Added cuDNN target to default test targets

Some unit tests explicitly test cudnn in addition to
`tvm.testing.enabled_targets()`.  This moved the cudnn checks into the
same framework as all other targets, and adds it to the default list
of targets to be run.  Also, added `@tvm.testing.requires_cudnn` for
tests specific to cudnn.

* [UnitTests] pytest.xfail for CuDNN conv2d with asymmetric padding

* [Topi][CuDNN] Added handling of dilation to conv2d_cudnn

* [Topi] Skip dynamic batch matmul on cudnn, vulkan, opencl

Previously, cuda/nvptx targets were excluded.  Changed it to look up
by target.kind.name, and to also exclude vulkan/opencl, as the dynamic
lookup currently doesn't work on those backends.

Co-authored-by: Eric Lunderberg <[email protected]>
  • Loading branch information
2 people authored and ylc committed Jan 13, 2022
1 parent 9f00893 commit cd2690d
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 71 deletions.
44 changes: 33 additions & 11 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def test_something():
import tvm.tir
import tvm.te
import tvm._ffi
from tvm.contrib import nvcc

from tvm.contrib import nvcc, cudnn
from tvm.error import TVMError


Expand Down Expand Up @@ -418,6 +419,7 @@ def _get_targets(target_str=None):
"llvm",
"llvm -device=arm_cpu",
"cuda",
"cuda -model=unknown -libs=cudnn",
"nvptx",
"vulkan -from_device=0",
"opencl",
Expand Down Expand Up @@ -557,6 +559,26 @@ def requires_cuda(*args):
return _compose(args, _requires_cuda)


def requires_cudnn(*args):
"""Mark a test as requiring the cuDNN library.
This also marks the test as requiring a cuda gpu.
Parameters
----------
f : function
Function to mark
"""

requirements = [
pytest.mark.skipif(
not cudnn.exists(), reason="cuDNN library not enabled, or not installed"
),
*requires_cuda(),
]
return _compose(args, requirements)


def requires_nvptx(*args):
"""Mark a test as requiring the NVPTX compilation on the CUDA runtime
Expand Down Expand Up @@ -740,24 +762,24 @@ def requires_rpc(*args):

def _target_to_requirement(target):
if isinstance(target, str):
target_kind = target.split()[0]
else:
target_kind = target.kind.name
target = tvm.target.Target(target)

# mapping from target to decorator
if target_kind == "cuda":
if target.kind.name == "cuda" and "cudnn" in target.attrs.get("libs", []):
return requires_cudnn()
if target.kind.name == "cuda":
return requires_cuda()
if target_kind == "rocm":
if target.kind.name == "rocm":
return requires_rocm()
if target_kind == "vulkan":
if target.kind.name == "vulkan":
return requires_vulkan()
if target_kind == "nvptx":
if target.kind.name == "nvptx":
return requires_nvptx()
if target_kind == "metal":
if target.kind.name == "metal":
return requires_metal()
if target_kind == "opencl":
if target.kind.name == "opencl":
return requires_opencl()
if target_kind == "llvm":
if target.kind.name == "llvm":
return requires_llvm()
return []

Expand Down
11 changes: 5 additions & 6 deletions python/tvm/topi/cuda/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,13 @@ def conv2d_cudnn(
# handle dilation
stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation
KH_dilated = (KH - 1) * dilation_h + 1
KW_dilated = (KW - 1) * dilation_h + 1

if (
isinstance(padding, (list, tuple))
and len(padding) == 4
and (padding[0] != padding[2] or padding[1] != padding[3])
):
pt, pl, pb, pr = get_pad_tuple(padding, (KH_dilated, KW_dilated))
if (pt != pb) or (pl != pr):
raise ValueError("Cudnn doesn't support asymmetric padding.")
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))

OH = (H + pt + pb - KH) // stride_h + 1
OW = (W + pl + pr - KW) // stride_w + 1

Expand Down
125 changes: 72 additions & 53 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class TargetInternal {
n->host = target_host;
return (Target)n;
}

private:
static std::unordered_map<String, ObjectRef> QueryDevice(int device_id, const TargetNode* target);
};

/********** Helper functions **********/
Expand Down Expand Up @@ -731,63 +734,16 @@ ObjectPtr<Object> TargetInternal::FromConfig(std::unordered_map<String, ObjectRe
}
}

// if requested, query attributes from the device
// If requested, query attributes from the device. User-specified
// parameters take precedence over queried parameters.
if (attrs.count("from_device")) {
int device_id = Downcast<Integer>(attrs.at("from_device"));
attrs.erase("from_device");
auto device_params = QueryDevice(device_id, target.get());

Device device{static_cast<DLDeviceType>(target->kind->device_type), device_id};

auto api = runtime::DeviceAPI::Get(device, true);
ICHECK(api) << "Requested reading the parameters for " << target->kind->name
<< " from device_id " << device_id
<< ", but support for this runtime wasn't enabled at compile-time.";

TVMRetValue ret;
api->GetAttr(device, runtime::kExist, &ret);
ICHECK(ret) << "Requested reading the parameters for " << target->kind->name
<< " from device_id " << device_id << ", but device_id " << device_id
<< " doesn't exist.";

for (const auto& kv : target->kind->key2vtype_) {
const String& key = kv.first;
const TargetKindNode::ValueTypeInfo& type_info = kv.second;

// Don't overwrite explicitly-specified values
if (attrs.count(key)) {
continue;
}

TVMRetValue ret;
api->GetTargetProperty(device, key, &ret);

switch (ret.type_code()) {
case kTVMNullptr:
// Nothing returned for this parameter, move on to the next one.
continue;

case kTVMArgInt:
if (type_info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
attrs[key] = Integer(static_cast<int64_t>(ret));
} else if (type_info.type_index == Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
attrs[key] = Bool(static_cast<bool>(ret));
} else {
LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received integer from device api";
}
break;

case kTVMStr:
ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex())
<< "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received string from device api";
attrs[key] = String(ret.operator std::string());
break;

default:
LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received TVMArgTypeCode(" << ret.type_code() << ") from device api";
break;
for (const auto& kv : device_params) {
if (attrs.count(kv.first) == 0) {
attrs[kv.first] = kv.second;
}
}
}
Expand All @@ -807,6 +763,69 @@ ObjectPtr<Object> TargetInternal::FromConfig(std::unordered_map<String, ObjectRe
return target;
}

std::unordered_map<String, ObjectRef> TargetInternal::QueryDevice(int device_id,
const TargetNode* target) {
std::unordered_map<String, ObjectRef> output;

Device device{static_cast<DLDeviceType>(target->kind->device_type), device_id};

auto api = runtime::DeviceAPI::Get(device, true);
if (!api) {
LOG(INFO) << "Requested reading the parameters for " << target->kind->name << " from device_id "
<< device_id << ", but support for this runtime wasn't enabled at compile-time. "
<< "Using default target parameters.";
return output;
}

TVMRetValue ret;
api->GetAttr(device, runtime::kExist, &ret);
if (!ret) {
ICHECK(ret) << "Requested reading the parameters for " << target->kind->name
<< " from device_id " << device_id << ", but device_id " << device_id
<< " doesn't exist. Using default target parameters.";
return output;
}

for (const auto& kv : target->kind->key2vtype_) {
const String& key = kv.first;
const TargetKindNode::ValueTypeInfo& type_info = kv.second;

TVMRetValue ret;
api->GetTargetProperty(device, key, &ret);

switch (ret.type_code()) {
case kTVMNullptr:
// Nothing returned for this parameter, move on to the next one.
continue;

case kTVMArgInt:
if (type_info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
output[key] = Integer(static_cast<int64_t>(ret));
} else if (type_info.type_index == Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
output[key] = Bool(static_cast<bool>(ret));
} else {
LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received integer from device api";
}
break;

case kTVMStr:
ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex())
<< "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received string from device api";
output[key] = String(ret.operator std::string());
break;

default:
LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received TVMArgTypeCode(" << ret.type_code() << ") from device api";
break;
}
}

return output;
}

/********** Registry **********/

TVM_REGISTER_GLOBAL("target.Target").set_body(TargetInternal::ConstructorDispatcher);
Expand Down
3 changes: 2 additions & 1 deletion tests/python/topi/python/test_topi_batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def check_device(target, dev):
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)

for target, dev in tvm.testing.enabled_targets():
if dynamic and (target == "cuda" or target == "nvptx"):
target_kind = tvm.target.Target(target).kind.name
if dynamic and target_kind in ["cuda", "nvptx", "vulkan", "opencl"]:
print("Dynamic batch matmul test is skippped on %s" % target)
continue

Expand Down
4 changes: 4 additions & 0 deletions tests/python/topi/python/test_topi_conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ def test_conv2d_nchw(
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
padding_sum = pad_top + pad_left + pad_bottom + pad_right

has_asymmetric_padding = (pad_top != pad_bottom) or (pad_left != pad_right)
if is_cudnn_target and has_asymmetric_padding:
pytest.xfail("CuDNN does not support asymmetric padding")

a_np, w_np, b_np, c_np = ref_data

A = te.placeholder(a_np.shape, name="A", dtype=dtype)
Expand Down

0 comments on commit cd2690d

Please sign in to comment.