Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UnitTests] Added cuDNN to default test targets #8383

Merged
merged 5 commits into from
Aug 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def test_something():
import tvm.tir
import tvm.te
import tvm._ffi
from tvm.contrib import nvcc

from tvm.contrib import nvcc, cudnn
from tvm.error import TVMError


Expand Down Expand Up @@ -418,6 +419,7 @@ def _get_targets(target_str=None):
"llvm",
"llvm -device=arm_cpu",
"cuda",
"cuda -model=unknown -libs=cudnn",
"nvptx",
"vulkan -from_device=0",
"opencl",
Expand Down Expand Up @@ -557,6 +559,26 @@ def requires_cuda(*args):
return _compose(args, _requires_cuda)


def requires_cudnn(*args):
"""Mark a test as requiring the cuDNN library.

This also marks the test as requiring a cuda gpu.

Parameters
----------
f : function
Function to mark
"""

requirements = [
pytest.mark.skipif(
not cudnn.exists(), reason="cuDNN library not enabled, or not installed"
),
*requires_cuda(),
]
return _compose(args, requirements)


def requires_nvptx(*args):
"""Mark a test as requiring the NVPTX compilation on the CUDA runtime

Expand Down Expand Up @@ -740,24 +762,24 @@ def requires_rpc(*args):

def _target_to_requirement(target):
if isinstance(target, str):
target_kind = target.split()[0]
else:
target_kind = target.kind.name
target = tvm.target.Target(target)

# mapping from target to decorator
if target_kind == "cuda":
if target.kind.name == "cuda" and "cudnn" in target.attrs.get("libs", []):
return requires_cudnn()
if target.kind.name == "cuda":
return requires_cuda()
if target_kind == "rocm":
if target.kind.name == "rocm":
return requires_rocm()
if target_kind == "vulkan":
if target.kind.name == "vulkan":
return requires_vulkan()
if target_kind == "nvptx":
if target.kind.name == "nvptx":
return requires_nvptx()
if target_kind == "metal":
if target.kind.name == "metal":
return requires_metal()
if target_kind == "opencl":
if target.kind.name == "opencl":
return requires_opencl()
if target_kind == "llvm":
if target.kind.name == "llvm":
return requires_llvm()
return []

Expand Down
11 changes: 5 additions & 6 deletions python/tvm/topi/cuda/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,13 @@ def conv2d_cudnn(
# handle dilation
stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation
KH_dilated = (KH - 1) * dilation_h + 1
KW_dilated = (KW - 1) * dilation_h + 1

if (
isinstance(padding, (list, tuple))
and len(padding) == 4
and (padding[0] != padding[2] or padding[1] != padding[3])
):
pt, pl, pb, pr = get_pad_tuple(padding, (KH_dilated, KW_dilated))
if (pt != pb) or (pl != pr):
raise ValueError("Cudnn doesn't support asymmetric padding.")
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))

OH = (H + pt + pb - KH) // stride_h + 1
OW = (W + pl + pr - KW) // stride_w + 1

Expand Down
125 changes: 72 additions & 53 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class TargetInternal {
n->host = target_host;
return (Target)n;
}

private:
static std::unordered_map<String, ObjectRef> QueryDevice(int device_id, const TargetNode* target);
};

/********** Helper functions **********/
Expand Down Expand Up @@ -731,63 +734,16 @@ ObjectPtr<Object> TargetInternal::FromConfig(std::unordered_map<String, ObjectRe
}
}

// if requested, query attributes from the device
// If requested, query attributes from the device. User-specified
// parameters take precedence over queried parameters.
if (attrs.count("from_device")) {
int device_id = Downcast<Integer>(attrs.at("from_device"));
attrs.erase("from_device");
auto device_params = QueryDevice(device_id, target.get());

Device device{static_cast<DLDeviceType>(target->kind->device_type), device_id};

auto api = runtime::DeviceAPI::Get(device, true);
ICHECK(api) << "Requested reading the parameters for " << target->kind->name
<< " from device_id " << device_id
<< ", but support for this runtime wasn't enabled at compile-time.";

TVMRetValue ret;
api->GetAttr(device, runtime::kExist, &ret);
ICHECK(ret) << "Requested reading the parameters for " << target->kind->name
<< " from device_id " << device_id << ", but device_id " << device_id
<< " doesn't exist.";

for (const auto& kv : target->kind->key2vtype_) {
const String& key = kv.first;
const TargetKindNode::ValueTypeInfo& type_info = kv.second;

// Don't overwrite explicitly-specified values
if (attrs.count(key)) {
continue;
}

TVMRetValue ret;
api->GetTargetProperty(device, key, &ret);

switch (ret.type_code()) {
case kTVMNullptr:
// Nothing returned for this parameter, move on to the next one.
continue;

case kTVMArgInt:
if (type_info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
attrs[key] = Integer(static_cast<int64_t>(ret));
} else if (type_info.type_index == Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
attrs[key] = Bool(static_cast<bool>(ret));
} else {
LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received integer from device api";
}
break;

case kTVMStr:
ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex())
<< "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received string from device api";
attrs[key] = String(ret.operator std::string());
break;

default:
LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received TVMArgTypeCode(" << ret.type_code() << ") from device api";
break;
for (const auto& kv : device_params) {
if (attrs.count(kv.first) == 0) {
attrs[kv.first] = kv.second;
}
}
}
Expand All @@ -807,6 +763,69 @@ ObjectPtr<Object> TargetInternal::FromConfig(std::unordered_map<String, ObjectRe
return target;
}

std::unordered_map<String, ObjectRef> TargetInternal::QueryDevice(int device_id,
const TargetNode* target) {
std::unordered_map<String, ObjectRef> output;

Device device{static_cast<DLDeviceType>(target->kind->device_type), device_id};

auto api = runtime::DeviceAPI::Get(device, true);
if (!api) {
LOG(INFO) << "Requested reading the parameters for " << target->kind->name << " from device_id "
<< device_id << ", but support for this runtime wasn't enabled at compile-time. "
<< "Using default target parameters.";
return output;
}

TVMRetValue ret;
api->GetAttr(device, runtime::kExist, &ret);
if (!ret) {
ICHECK(ret) << "Requested reading the parameters for " << target->kind->name
<< " from device_id " << device_id << ", but device_id " << device_id
<< " doesn't exist. Using default target parameters.";
return output;
}

for (const auto& kv : target->kind->key2vtype_) {
const String& key = kv.first;
const TargetKindNode::ValueTypeInfo& type_info = kv.second;

TVMRetValue ret;
api->GetTargetProperty(device, key, &ret);

switch (ret.type_code()) {
case kTVMNullptr:
// Nothing returned for this parameter, move on to the next one.
continue;

case kTVMArgInt:
if (type_info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
output[key] = Integer(static_cast<int64_t>(ret));
} else if (type_info.type_index == Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
output[key] = Bool(static_cast<bool>(ret));
} else {
LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received integer from device api";
}
break;

case kTVMStr:
ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex())
<< "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received string from device api";
output[key] = String(ret.operator std::string());
break;

default:
LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key
<< "', but received TVMArgTypeCode(" << ret.type_code() << ") from device api";
break;
}
}

return output;
}

/********** Registry **********/

TVM_REGISTER_GLOBAL("target.Target").set_body(TargetInternal::ConstructorDispatcher);
Expand Down
3 changes: 2 additions & 1 deletion tests/python/topi/python/test_topi_batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def check_device(target, dev):
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)

for target, dev in tvm.testing.enabled_targets():
if dynamic and (target == "cuda" or target == "nvptx"):
target_kind = tvm.target.Target(target).kind.name
if dynamic and target_kind in ["cuda", "nvptx", "vulkan", "opencl"]:
print("Dynamic batch matmul test is skippped on %s" % target)
continue

Expand Down
4 changes: 4 additions & 0 deletions tests/python/topi/python/test_topi_conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def test_conv2d_nchw(
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
padding_sum = pad_top + pad_left + pad_bottom + pad_right

has_asymmetric_padding = (pad_top != pad_bottom) or (pad_left != pad_right)
if is_cudnn_target and has_asymmetric_padding:
pytest.xfail("CuDNN does not support asymmetric padding")

a_np, w_np, b_np, c_np = ref_data

A = te.placeholder(a_np.shape, name="A", dtype=dtype)
Expand Down