From 25b2ebd9e56f07ba18f9872b18900eb933450bb4 Mon Sep 17 00:00:00 2001 From: Wei Pan Date: Fri, 3 Apr 2020 09:29:00 -0700 Subject: [PATCH] [cuDNN] Add cuDNN grouped convolutions support Signed-off-by: Wei Pan --- python/tvm/contrib/cudnn.py | 36 ++++-- src/runtime/contrib/cudnn/conv_forward.cc | 37 ++++-- src/runtime/contrib/cudnn/cudnn_utils.h | 1 - tests/python/contrib/test_cudnn.py | 114 +++++++++--------- topi/python/topi/cuda/conv2d.py | 7 +- .../python/topi/testing/conv2d_nhwc_python.py | 37 +++++- .../topi/testing/conv3d_ncdhw_python.py | 1 + 7 files changed, 151 insertions(+), 82 deletions(-) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 5043520ccf136..0650b934b9722 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -182,7 +182,8 @@ def conv_output_shape(tensor_format, x_shape, w_shape, data_dtype, - conv_dtype): + conv_dtype, + groups=1): """Get output shape of 2D or 3D convolution Paramters @@ -205,6 +206,8 @@ def conv_output_shape(tensor_format, data type conv_dtype: str convolution type + groups: int + number of groups Returns ------- @@ -228,7 +231,8 @@ def conv_output_shape(tensor_format, _get_np_int32_array_handle(wshape), _get_np_int32_array_handle(oshape), data_dtype, - conv_dtype) + conv_dtype, + groups) return list(oshape) @@ -240,7 +244,8 @@ def conv_find_algo(tensor_format, w_shape, y_shape, data_dtype, - conv_dtype): + conv_dtype, + groups=1): """Choose the best algo for the given input. Paramters @@ -265,6 +270,8 @@ def conv_find_algo(tensor_format, data type conv_dtype: str convolution type + groups: int + number of groups Returns ------- @@ -287,7 +294,8 @@ def conv_find_algo(tensor_format, _get_np_int32_array_handle(wshape), _get_np_int32_array_handle(yshape), data_dtype, - conv_dtype) + conv_dtype, + groups) def conv_forward(x, @@ -298,7 +306,8 @@ def conv_forward(x, conv_mode, tensor_format, algo, - conv_dtype): + conv_dtype, + groups=1): """Create an extern op that compute 2D or 3D convolution with CuDNN Parameters @@ -325,6 +334,8 @@ def conv_forward(x, if algo == -1, the best algo will be chosen by CUDNN conv_dtype: str convolution type + groups: int + the number of groups Returns ------- @@ -335,8 +346,7 @@ def conv_forward(x, assert dims in (4, 5) conv_dtype = x.dtype if conv_dtype is None else conv_dtype - pad, stride, dilation, _, _ = \ - _prepare_global_func_params(dims - 2, pad, stride, dilation) + pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation) oshape = conv_output_shape(tensor_format, pad, @@ -345,7 +355,8 @@ def conv_forward(x, list(x.shape), list(w.shape), x.dtype, - conv_dtype) + conv_dtype, + groups) if algo == -1: # For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when # using INT8 data type, CuDNN will crash down. @@ -361,7 +372,8 @@ def conv_forward(x, list(w.shape), oshape, x.dtype, - conv_dtype) + conv_dtype, + groups) if dims == 4: return te.extern( @@ -380,7 +392,8 @@ def conv_forward(x, ins[0], ins[1], outs[0], - conv_dtype), name="y") + conv_dtype, + groups), name="y") return te.extern( oshape, [x, w], @@ -401,7 +414,8 @@ def conv_forward(x, ins[0], ins[1], outs[0], - conv_dtype), name="y") + conv_dtype, + groups), name="y") def softmax(x, axis=-1): """Compute softmax using CuDNN diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index 95811332bbfa9..c921d4b602188 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -35,6 +35,7 @@ void ConvolutionForward( int format, int algo, int dims, + int groups, const int pad[], const int stride[], const int dilation[], @@ -63,7 +64,8 @@ void ConvolutionForward( // Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error // in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int if (dims == 2) { - // Set Desc + // Set Desc + CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc, pad[0], pad[1], @@ -111,6 +113,7 @@ void ConvolutionForward( static_cast(y->shape[hi]), static_cast(y->shape[wi]))); } else { + CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, @@ -183,6 +186,7 @@ void ConvolutionForward( void OutputShape( int format, int dims, + int groups, const int pad[], const int stride[], const int dilation[], @@ -202,6 +206,7 @@ void OutputShape( int full_dims = dims + 2; // conv desc + CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, @@ -240,6 +245,7 @@ void OutputShape( // Set Input std::vector tensor_stride(full_dims); GetCudnnStride(full_dims, x_dim, tensor_stride.data()); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, @@ -264,6 +270,7 @@ void OutputShape( void FindAlgo( int format, int dims, + int groups, const int pad[], const int stride[], const int dilation[], @@ -284,6 +291,7 @@ void FindAlgo( int full_dims = dims + 2; // conv desc + CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, @@ -360,16 +368,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") int algo = args[2]; int pad_v[2], stride_v[2], dilation_v[2]; for (int i = 0; i < 2; i++) { - pad_v[i] = args[3 + i]; - stride_v[i] = args[5 + i]; - dilation_v[i] = args[7 + i]; + pad_v[i] = args[3 + i]; + stride_v[i] = args[5 + i]; + dilation_v[i] = args[7 + i]; } DLTensor* x = args[9]; DLTensor* w = args[10]; DLTensor* y = args[11]; std::string conv_dtype = args[12]; + int groups = args[13]; - ConvolutionForward(mode, format, algo, 2, pad_v, stride_v, dilation_v, x, w, y, conv_dtype); + ConvolutionForward(mode, format, algo, 2, groups, pad_v, stride_v, + dilation_v, x, w, y, conv_dtype); }); @@ -380,17 +390,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") int algo = args[2]; int pad_v[3], stride_v[3], dilation_v[3]; for (int i = 0; i < 3; i++) { - pad_v[i] = args[3 + i]; - stride_v[i] = args[6 + i]; - dilation_v[i] = args[9 + i]; + pad_v[i] = args[3 + i]; + stride_v[i] = args[6 + i]; + dilation_v[i] = args[9 + i]; } DLTensor *x = args[12]; DLTensor *w = args[13]; DLTensor *y = args[14]; std::string conv_dtype = args[15]; + int groups = args[16]; - ConvolutionForward(mode, format, algo, 3, pad_v, stride_v, dilation_v, x, w, y, - conv_dtype); + ConvolutionForward(mode, format, algo, 3, groups, pad_v, stride_v, + dilation_v, x, w, y, conv_dtype); }); @@ -406,8 +417,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape") void* out_shape = args[7]; std::string data_dtype = args[8]; std::string conv_dtype = args[9]; + int groups = args[10]; - OutputShape(format, dims, pad, stride, dilation, x_dim, + OutputShape(format, dims, groups, pad, stride, dilation, x_dim, w_dim, out_shape, data_dtype, conv_dtype); }); @@ -424,8 +436,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo") int* y_dim = static_cast(static_cast(args[7])); std::string data_dtype = args[8]; std::string conv_dtype = args[9]; + int groups = args[10]; - FindAlgo(format, dims, pad, stride, dilation, x_dim, + FindAlgo(format, dims, groups, pad, stride, dilation, x_dim, w_dim, y_dim, data_dtype, conv_dtype, ret); }); diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index ee6bb5089e38d..c2000d02e0c99 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -78,7 +78,6 @@ struct ConvEntry { runtime::DeviceAPI *cuda_api; void *workspace{nullptr}; size_t workspace_size{0}; - int group_count {0}; ConvEntry(); ~ConvEntry(); void UpdateWorkspace(const size_t wsize); diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 5d1f100c1fc47..60951f9b602bc 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -17,11 +17,11 @@ import tvm from tvm import te from tvm.contrib import cudnn +from tvm.contrib.nvcc import have_fp16 import numpy as np import topi.testing - -def verify_conv2d(data_dtype, conv_dtype, tensor_format=0): +def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1): in_channel = 4 out_channel = 16 filter_h = 3 @@ -34,7 +34,7 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0): dilation_w = 1 batch = 3 height = 32 - weight = 32 + width = 32 if not tvm.runtime.enabled("cuda"): print("skip because cuda is not enabled...") @@ -42,12 +42,17 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0): if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): print("skip because cudnn is not enabled...") return + if data_dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): + print("Skip because gpu does not have fp16 support") + return + + # schedule if tensor_format == 0: - xshape = [batch, in_channel, height, weight] - wshape = [out_channel, in_channel, filter_h, filter_w] + xshape = [batch, in_channel, height, width] + wshape = [out_channel, in_channel // groups, filter_h, filter_w] else: - xshape = [batch, height, weight, in_channel] - wshape = [out_channel, filter_h, filter_w, in_channel] + xshape = [batch, height, width, in_channel] + wshape = [out_channel, filter_h, filter_w, in_channel // groups] X = te.placeholder(xshape, name='X', dtype=data_dtype) W = te.placeholder(wshape, name='W', dtype=data_dtype) @@ -59,39 +64,41 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0): conv_mode=1, tensor_format=tensor_format, conv_dtype=conv_dtype, - algo=-1) + algo=-1, + groups=groups) yshape = [x.value for x in Y.shape] s = te.create_schedule(Y.op) - def verify(): - ctx = tvm.gpu(0) - f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv2d") - x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype) - w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) - y_np = np.zeros(yshape).astype(data_dtype) - x = tvm.nd.array(x_np, ctx) - w = tvm.nd.array(w_np, ctx) - y = tvm.nd.array(y_np, ctx) - if tensor_format == 0: - c_np = topi.testing.conv2d_nchw_python(x_np, w_np, 1, 1) - elif tensor_format == 1: - wt = w_np.transpose((1, 2, 3, 0)) #OHWI => HWIO - c_np = topi.testing.conv2d_nhwc_python(x_np, wt, 1, 1) - - f(x, w, y) - tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=3e-5, rtol=1e-3) - - verify() + # validation + ctx = tvm.gpu(0) + f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv2d") + x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype) + w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) + y_np = np.zeros(yshape).astype(data_dtype) + x = tvm.nd.array(x_np, ctx) + w = tvm.nd.array(w_np, ctx) + y = tvm.nd.array(y_np, ctx) + if tensor_format == 0: + c_np = topi.testing.conv2d_nchw_python(x_np, w_np, 1, 1, groups=groups) + elif tensor_format == 1: + wt = w_np.transpose((1, 2, 3, 0)) #OHWI => HWIO + c_np = topi.testing.conv2d_nhwc_python(x_np, wt, 1, 1, groups=groups) + + f(x, w, y) + tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=1e-2, rtol=1e-2) def test_conv2d(): verify_conv2d("float32", "float32", tensor_format=0) verify_conv2d("float16", "float32", tensor_format=1) - #Not pass accuracy test, need check - #verify_conv2d("float16", "float16", tensor_format=0) + verify_conv2d("float16", "float16", tensor_format=0) verify_conv2d("int8", "int32", tensor_format=1) + verify_conv2d("float32", "float32", tensor_format=0, groups=2) + verify_conv2d("float16", "float32", tensor_format=1, groups=2) + verify_conv2d("float16", "float16", tensor_format=0, groups=2) + verify_conv2d("int8", "int32", tensor_format=1, groups=2) -def verify_conv3d(data_dtype, conv_dtype, tensor_format=0): +def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1): in_channel = 4 out_channel = 16 filter_d = 3 @@ -109,7 +116,7 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0): batch = 3 depth = 32 height = 32 - weight = 32 + width = 32 if not tvm.runtime.enabled("cuda"): print("skip because cuda is not enabled...") @@ -118,8 +125,9 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0): print("skip because cudnn is not enabled...") return - xshape = [batch, in_channel, depth, height, weight] - wshape = [out_channel, in_channel, filter_d, filter_h, filter_w] + # schedule + xshape = [batch, in_channel, depth, height, width] + wshape = [out_channel, in_channel // groups, filter_d, filter_h, filter_w] X = te.placeholder(xshape, name='X', dtype=data_dtype) W = te.placeholder(wshape, name='W', dtype=data_dtype) @@ -131,33 +139,31 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0): conv_mode=1, tensor_format=tensor_format, algo=-1, - conv_dtype=conv_dtype) + conv_dtype=conv_dtype, + groups=groups) yshape = [x.value for x in Y.shape] s = te.create_schedule(Y.op) - def verify(): - ctx = tvm.gpu(0) - f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv3d") - x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype) - w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) - y_np = np.zeros(yshape).astype(data_dtype) - x = tvm.nd.array(x_np, ctx) - w = tvm.nd.array(w_np, ctx) - y = tvm.nd.array(y_np, ctx) - if tensor_format == 0: - c_np = topi.testing.conv3d_ncdhw_python(x_np, w_np, 1, 1) - else: - raise AssertionError("For now, conv3d tensor format only support: 0(NCHW)") - - f(x, w, y) - tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=3e-5, rtol=1e-4) - - verify() + # validation + ctx = tvm.gpu(0) + f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv3d") + x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype) + w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) + y_np = np.zeros(yshape).astype(data_dtype) + x = tvm.nd.array(x_np, ctx) + w = tvm.nd.array(w_np, ctx) + y = tvm.nd.array(y_np, ctx) + if tensor_format == 0: + c_np = topi.testing.conv3d_ncdhw_python(x_np, w_np, 1, 1, groups) + else: + raise AssertionError("For now, conv3d tensor format only support: 0(NCHW)") + f(x, w, y) + tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=3e-5, rtol=1e-4) def test_conv3d(): verify_conv3d("float32", "float32", tensor_format=0) - + verify_conv3d("float32", "float32", tensor_format=0, groups=2) def verify_softmax(shape, axis, dtype="float32"): A = te.placeholder(shape, dtype=dtype, name='A') @@ -206,4 +212,4 @@ def test_softmax(): if __name__ == "__main__": test_conv2d() test_conv3d() - test_softmax() + test_softmax() \ No newline at end of file diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py index c7df3dc96a5e4..f744020b050d8 100644 --- a/topi/python/topi/cuda/conv2d.py +++ b/topi/python/topi/cuda/conv2d.py @@ -67,7 +67,7 @@ def _callback(op): @autotvm.register_topi_compute("conv2d_cudnn.cuda") def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, layout='NCHW', - out_dtype='float32'): + out_dtype='float32', groups=1): """Compute conv2d using CuDNN library""" if layout == 'NCHW': tensor_format = 0 # CUDNN_TENSOR_NCHW @@ -89,7 +89,7 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, layout='NCHW', pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) OH = (H + pt + pb - KH) // stride_h + 1 OW = (W + pl + pr - KW) // stride_w + 1 - cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) * \ + cfg.add_flop(groups * 2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) * \ ((KW - 1) * dilation_w + 1)) if data.dtype == "int8" or kernel.dtype == "int8": @@ -107,7 +107,8 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, layout='NCHW', conv_mode=1, tensor_format=tensor_format, algo=-1, # let CUDNN choose the best algo - conv_dtype=dtype) + conv_dtype=dtype, + groups=groups) @autotvm.register_topi_schedule("conv2d_cudnn.cuda") diff --git a/topi/python/topi/testing/conv2d_nhwc_python.py b/topi/python/topi/testing/conv2d_nhwc_python.py index dc5f915daa220..d8713110056a4 100644 --- a/topi/python/topi/testing/conv2d_nhwc_python.py +++ b/topi/python/topi/testing/conv2d_nhwc_python.py @@ -21,7 +21,7 @@ from topi.nn.util import get_pad_tuple -def conv2d_nhwc_python(a_np, w_np, stride, padding): +def _conv2d_nhwc_python(a_np, w_np, stride, padding): """Convolution operator in NHWC layout. Parameters @@ -77,3 +77,38 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding): apad, np.rot90(np.rot90(wt[f, c])), mode='valid') bt[n, f] += out[::stride_h, ::stride_w] return bt.transpose((0, 2, 3, 1)) + +def conv2d_nhwc_python(a_np, w_np, stride, padding, groups=1): + """Convolution operator in NHWC layout. + + Parameters + ---------- + a_np : numpy.ndarray + 4-D with shape [batch, in_height, in_width, in_channel] + + w_np : numpy.ndarray + 4-D with shape [filter_height, filter_width, in_channel // groups, num_filter] + + stride : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : int or str or a list/tuple of 2 or 4 ints + Padding size, or ['VALID', 'SAME'], or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 2 ints + + groups : int + Number of groups + + Returns + ------- + b_np : np.ndarray + 4-D with shape [batch, out_height, out_width, out_channel] + """ + + a_slices = np.array_split(a_np, groups, axis=3) + w_slices = np.array_split(w_np, groups, axis=3) + b_slices = [_conv2d_nhwc_python(a_slice, w_slice, stride, padding) + for a_slice, w_slice in zip(a_slices, w_slices)] + b_np = np.concatenate(b_slices, axis=3) + return b_np diff --git a/topi/python/topi/testing/conv3d_ncdhw_python.py b/topi/python/topi/testing/conv3d_ncdhw_python.py index 063c07d941330..0b2620fc290cc 100644 --- a/topi/python/topi/testing/conv3d_ncdhw_python.py +++ b/topi/python/topi/testing/conv3d_ncdhw_python.py @@ -73,6 +73,7 @@ def conv3d_ncdhw_python(a_np, w_np, stride, padding, groups=1): padding : int or str or a list/tuple of three ints Padding size, or ['VALID', 'SAME'], or [pad_depth, pad_height, pad_width] + groups : int Number of groups