Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cuDNN] Add cuDNN grouped convolution support #5319

Merged
merged 1 commit into from
Apr 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def conv_output_shape(tensor_format,
x_shape,
w_shape,
data_dtype,
conv_dtype):
conv_dtype,
groups=1):
"""Get output shape of 2D or 3D convolution

Paramters
Expand All @@ -205,6 +206,8 @@ def conv_output_shape(tensor_format,
data type
conv_dtype: str
convolution type
groups: int
number of groups

Returns
-------
Expand All @@ -228,7 +231,8 @@ def conv_output_shape(tensor_format,
_get_np_int32_array_handle(wshape),
_get_np_int32_array_handle(oshape),
data_dtype,
conv_dtype)
conv_dtype,
groups)
return list(oshape)


Expand All @@ -240,7 +244,8 @@ def conv_find_algo(tensor_format,
w_shape,
y_shape,
data_dtype,
conv_dtype):
conv_dtype,
groups=1):
"""Choose the best algo for the given input.

Paramters
Expand All @@ -265,6 +270,8 @@ def conv_find_algo(tensor_format,
data type
conv_dtype: str
convolution type
groups: int
number of groups

Returns
-------
Expand All @@ -287,7 +294,8 @@ def conv_find_algo(tensor_format,
_get_np_int32_array_handle(wshape),
_get_np_int32_array_handle(yshape),
data_dtype,
conv_dtype)
conv_dtype,
groups)


def conv_forward(x,
Expand All @@ -298,7 +306,8 @@ def conv_forward(x,
conv_mode,
tensor_format,
algo,
conv_dtype):
conv_dtype,
groups=1):
"""Create an extern op that compute 2D or 3D convolution with CuDNN

Parameters
Expand All @@ -325,6 +334,8 @@ def conv_forward(x,
if algo == -1, the best algo will be chosen by CUDNN
conv_dtype: str
convolution type
groups: int
the number of groups

Returns
-------
Expand All @@ -335,8 +346,7 @@ def conv_forward(x,
assert dims in (4, 5)

conv_dtype = x.dtype if conv_dtype is None else conv_dtype
pad, stride, dilation, _, _ = \
_prepare_global_func_params(dims - 2, pad, stride, dilation)
pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation)

oshape = conv_output_shape(tensor_format,
pad,
Expand All @@ -345,7 +355,8 @@ def conv_forward(x,
list(x.shape),
list(w.shape),
x.dtype,
conv_dtype)
conv_dtype,
groups)
if algo == -1:
# For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when
# using INT8 data type, CuDNN will crash down.
Expand All @@ -361,7 +372,8 @@ def conv_forward(x,
list(w.shape),
oshape,
x.dtype,
conv_dtype)
conv_dtype,
groups)

if dims == 4:
return te.extern(
Expand All @@ -380,7 +392,8 @@ def conv_forward(x,
ins[0],
ins[1],
outs[0],
conv_dtype), name="y")
conv_dtype,
groups), name="y")

return te.extern(
oshape, [x, w],
Expand All @@ -401,7 +414,8 @@ def conv_forward(x,
ins[0],
ins[1],
outs[0],
conv_dtype), name="y")
conv_dtype,
groups), name="y")

def softmax(x, axis=-1):
"""Compute softmax using CuDNN
Expand Down
20 changes: 18 additions & 2 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \
padding[1] == padding[3]:
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_cudnn, True),
wrap_compute_conv2d(topi.cuda.conv2d_cudnn,
need_data_layout=True,
has_groups=True),
wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn),
name="conv2d_cudnn.cuda",
plevel=15)
Expand All @@ -181,6 +183,20 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
else: # group_conv2d
# add cudnn implementation, if any
cudnn_impl = False
if target.target_name == "cuda" and "cudnn" in target.libs:
if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \
padding[1] == padding[3]:
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_cudnn,
need_data_layout=True,
has_groups=True),
wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn),
name="conv2d_cudnn.cuda",
plevel=15)
cudnn_impl = True

if layout == 'NCHW':
# TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
Expand All @@ -194,7 +210,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8),
name="group_conv2d_NCHWc_int8.cuda")
else:
elif not cudnn_impl:
raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
return strategy

Expand Down
37 changes: 25 additions & 12 deletions src/runtime/contrib/cudnn/conv_forward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ void ConvolutionForward(
int format,
int algo,
int dims,
int groups,
const int pad[],
const int stride[],
const int dilation[],
Expand Down Expand Up @@ -62,8 +63,10 @@ void ConvolutionForward(

// Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error
// in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int

CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
if (dims == 2) {
// Set Desc
// Set Desc
CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc,
pad[0],
pad[1],
Expand Down Expand Up @@ -183,6 +186,7 @@ void ConvolutionForward(
void OutputShape(
int format,
int dims,
int groups,
const int pad[],
const int stride[],
const int dilation[],
Expand All @@ -202,6 +206,7 @@ void OutputShape(
int full_dims = dims + 2;

// conv desc
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
dims,
pad,
Expand Down Expand Up @@ -240,6 +245,7 @@ void OutputShape(
// Set Input
std::vector<int> tensor_stride(full_dims);
GetCudnnStride(full_dims, x_dim, tensor_stride.data());

CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc,
data_type,
full_dims,
Expand All @@ -264,6 +270,7 @@ void OutputShape(
void FindAlgo(
int format,
int dims,
int groups,
const int pad[],
const int stride[],
const int dilation[],
Expand All @@ -284,6 +291,7 @@ void FindAlgo(
int full_dims = dims + 2;

// conv desc
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
dims,
pad,
Expand Down Expand Up @@ -360,16 +368,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
int algo = args[2];
int pad_v[2], stride_v[2], dilation_v[2];
for (int i = 0; i < 2; i++) {
pad_v[i] = args[3 + i];
stride_v[i] = args[5 + i];
dilation_v[i] = args[7 + i];
pad_v[i] = args[3 + i];
stride_v[i] = args[5 + i];
dilation_v[i] = args[7 + i];
}
DLTensor* x = args[9];
DLTensor* w = args[10];
DLTensor* y = args[11];
std::string conv_dtype = args[12];
int groups = args[13];

ConvolutionForward(mode, format, algo, 2, pad_v, stride_v, dilation_v, x, w, y, conv_dtype);
ConvolutionForward(mode, format, algo, 2, groups, pad_v, stride_v,
dilation_v, x, w, y, conv_dtype);
});


Expand All @@ -380,17 +390,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward")
int algo = args[2];
int pad_v[3], stride_v[3], dilation_v[3];
for (int i = 0; i < 3; i++) {
pad_v[i] = args[3 + i];
stride_v[i] = args[6 + i];
dilation_v[i] = args[9 + i];
pad_v[i] = args[3 + i];
stride_v[i] = args[6 + i];
dilation_v[i] = args[9 + i];
}
DLTensor *x = args[12];
DLTensor *w = args[13];
DLTensor *y = args[14];
std::string conv_dtype = args[15];
int groups = args[16];

ConvolutionForward(mode, format, algo, 3, pad_v, stride_v, dilation_v, x, w, y,
conv_dtype);
ConvolutionForward(mode, format, algo, 3, groups, pad_v, stride_v,
dilation_v, x, w, y, conv_dtype);
});


Expand All @@ -406,8 +417,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape")
void* out_shape = args[7];
std::string data_dtype = args[8];
std::string conv_dtype = args[9];
int groups = args[10];

OutputShape(format, dims, pad, stride, dilation, x_dim,
OutputShape(format, dims, groups, pad, stride, dilation, x_dim,
w_dim, out_shape, data_dtype, conv_dtype);
});

Expand All @@ -424,8 +436,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo")
int* y_dim = static_cast<int*>(static_cast<void*>(args[7]));
std::string data_dtype = args[8];
std::string conv_dtype = args[9];
int groups = args[10];

FindAlgo(format, dims, pad, stride, dilation, x_dim,
FindAlgo(format, dims, groups, pad, stride, dilation, x_dim,
w_dim, y_dim, data_dtype, conv_dtype, ret);
});

Expand Down
1 change: 0 additions & 1 deletion src/runtime/contrib/cudnn/cudnn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ struct ConvEntry {
runtime::DeviceAPI *cuda_api;
void *workspace{nullptr};
size_t workspace_size{0};
int group_count {0};
ConvEntry();
~ConvEntry();
void UpdateWorkspace(const size_t wsize);
Expand Down
Loading