Skip to content

Commit

Permalink
[cuDNN] Add cuDNN grouped convolutions support
Browse files Browse the repository at this point in the history
Signed-off-by: Wei Pan <[email protected]>
  • Loading branch information
wpan11nv committed Apr 13, 2020
1 parent fc75de9 commit 25b2ebd
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 82 deletions.
36 changes: 25 additions & 11 deletions python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def conv_output_shape(tensor_format,
x_shape,
w_shape,
data_dtype,
conv_dtype):
conv_dtype,
groups=1):
"""Get output shape of 2D or 3D convolution
Paramters
Expand All @@ -205,6 +206,8 @@ def conv_output_shape(tensor_format,
data type
conv_dtype: str
convolution type
groups: int
number of groups
Returns
-------
Expand All @@ -228,7 +231,8 @@ def conv_output_shape(tensor_format,
_get_np_int32_array_handle(wshape),
_get_np_int32_array_handle(oshape),
data_dtype,
conv_dtype)
conv_dtype,
groups)
return list(oshape)


Expand All @@ -240,7 +244,8 @@ def conv_find_algo(tensor_format,
w_shape,
y_shape,
data_dtype,
conv_dtype):
conv_dtype,
groups=1):
"""Choose the best algo for the given input.
Paramters
Expand All @@ -265,6 +270,8 @@ def conv_find_algo(tensor_format,
data type
conv_dtype: str
convolution type
groups: int
number of groups
Returns
-------
Expand All @@ -287,7 +294,8 @@ def conv_find_algo(tensor_format,
_get_np_int32_array_handle(wshape),
_get_np_int32_array_handle(yshape),
data_dtype,
conv_dtype)
conv_dtype,
groups)


def conv_forward(x,
Expand All @@ -298,7 +306,8 @@ def conv_forward(x,
conv_mode,
tensor_format,
algo,
conv_dtype):
conv_dtype,
groups=1):
"""Create an extern op that compute 2D or 3D convolution with CuDNN
Parameters
Expand All @@ -325,6 +334,8 @@ def conv_forward(x,
if algo == -1, the best algo will be chosen by CUDNN
conv_dtype: str
convolution type
groups: int
the number of groups
Returns
-------
Expand All @@ -335,8 +346,7 @@ def conv_forward(x,
assert dims in (4, 5)

conv_dtype = x.dtype if conv_dtype is None else conv_dtype
pad, stride, dilation, _, _ = \
_prepare_global_func_params(dims - 2, pad, stride, dilation)
pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation)

oshape = conv_output_shape(tensor_format,
pad,
Expand All @@ -345,7 +355,8 @@ def conv_forward(x,
list(x.shape),
list(w.shape),
x.dtype,
conv_dtype)
conv_dtype,
groups)
if algo == -1:
# For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when
# using INT8 data type, CuDNN will crash down.
Expand All @@ -361,7 +372,8 @@ def conv_forward(x,
list(w.shape),
oshape,
x.dtype,
conv_dtype)
conv_dtype,
groups)

if dims == 4:
return te.extern(
Expand All @@ -380,7 +392,8 @@ def conv_forward(x,
ins[0],
ins[1],
outs[0],
conv_dtype), name="y")
conv_dtype,
groups), name="y")

return te.extern(
oshape, [x, w],
Expand All @@ -401,7 +414,8 @@ def conv_forward(x,
ins[0],
ins[1],
outs[0],
conv_dtype), name="y")
conv_dtype,
groups), name="y")

def softmax(x, axis=-1):
"""Compute softmax using CuDNN
Expand Down
37 changes: 25 additions & 12 deletions src/runtime/contrib/cudnn/conv_forward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ void ConvolutionForward(
int format,
int algo,
int dims,
int groups,
const int pad[],
const int stride[],
const int dilation[],
Expand Down Expand Up @@ -63,7 +64,8 @@ void ConvolutionForward(
// Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error
// in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int
if (dims == 2) {
// Set Desc
// Set Desc
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc,
pad[0],
pad[1],
Expand Down Expand Up @@ -111,6 +113,7 @@ void ConvolutionForward(
static_cast<int>(y->shape[hi]),
static_cast<int>(y->shape[wi])));
} else {
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
dims,
pad,
Expand Down Expand Up @@ -183,6 +186,7 @@ void ConvolutionForward(
void OutputShape(
int format,
int dims,
int groups,
const int pad[],
const int stride[],
const int dilation[],
Expand All @@ -202,6 +206,7 @@ void OutputShape(
int full_dims = dims + 2;

// conv desc
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
dims,
pad,
Expand Down Expand Up @@ -240,6 +245,7 @@ void OutputShape(
// Set Input
std::vector<int> tensor_stride(full_dims);
GetCudnnStride(full_dims, x_dim, tensor_stride.data());

CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc,
data_type,
full_dims,
Expand All @@ -264,6 +270,7 @@ void OutputShape(
void FindAlgo(
int format,
int dims,
int groups,
const int pad[],
const int stride[],
const int dilation[],
Expand All @@ -284,6 +291,7 @@ void FindAlgo(
int full_dims = dims + 2;

// conv desc
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
dims,
pad,
Expand Down Expand Up @@ -360,16 +368,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
int algo = args[2];
int pad_v[2], stride_v[2], dilation_v[2];
for (int i = 0; i < 2; i++) {
pad_v[i] = args[3 + i];
stride_v[i] = args[5 + i];
dilation_v[i] = args[7 + i];
pad_v[i] = args[3 + i];
stride_v[i] = args[5 + i];
dilation_v[i] = args[7 + i];
}
DLTensor* x = args[9];
DLTensor* w = args[10];
DLTensor* y = args[11];
std::string conv_dtype = args[12];
int groups = args[13];

ConvolutionForward(mode, format, algo, 2, pad_v, stride_v, dilation_v, x, w, y, conv_dtype);
ConvolutionForward(mode, format, algo, 2, groups, pad_v, stride_v,
dilation_v, x, w, y, conv_dtype);
});


Expand All @@ -380,17 +390,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward")
int algo = args[2];
int pad_v[3], stride_v[3], dilation_v[3];
for (int i = 0; i < 3; i++) {
pad_v[i] = args[3 + i];
stride_v[i] = args[6 + i];
dilation_v[i] = args[9 + i];
pad_v[i] = args[3 + i];
stride_v[i] = args[6 + i];
dilation_v[i] = args[9 + i];
}
DLTensor *x = args[12];
DLTensor *w = args[13];
DLTensor *y = args[14];
std::string conv_dtype = args[15];
int groups = args[16];

ConvolutionForward(mode, format, algo, 3, pad_v, stride_v, dilation_v, x, w, y,
conv_dtype);
ConvolutionForward(mode, format, algo, 3, groups, pad_v, stride_v,
dilation_v, x, w, y, conv_dtype);
});


Expand All @@ -406,8 +417,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape")
void* out_shape = args[7];
std::string data_dtype = args[8];
std::string conv_dtype = args[9];
int groups = args[10];

OutputShape(format, dims, pad, stride, dilation, x_dim,
OutputShape(format, dims, groups, pad, stride, dilation, x_dim,
w_dim, out_shape, data_dtype, conv_dtype);
});

Expand All @@ -424,8 +436,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo")
int* y_dim = static_cast<int*>(static_cast<void*>(args[7]));
std::string data_dtype = args[8];
std::string conv_dtype = args[9];
int groups = args[10];

FindAlgo(format, dims, pad, stride, dilation, x_dim,
FindAlgo(format, dims, groups, pad, stride, dilation, x_dim,
w_dim, y_dim, data_dtype, conv_dtype, ret);
});

Expand Down
1 change: 0 additions & 1 deletion src/runtime/contrib/cudnn/cudnn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ struct ConvEntry {
runtime::DeviceAPI *cuda_api;
void *workspace{nullptr};
size_t workspace_size{0};
int group_count {0};
ConvEntry();
~ConvEntry();
void UpdateWorkspace(const size_t wsize);
Expand Down
Loading

0 comments on commit 25b2ebd

Please sign in to comment.