From c32137d7724f7a84ede071e32b4422a2d3065f45 Mon Sep 17 00:00:00 2001 From: optima2005 <56945758+optima2005@users.noreply.github.com> Date: Wed, 1 Jan 2020 17:42:54 +0800 Subject: [PATCH] [FRONTEND][TF] Add conv3d (#4604) * [FRONTEND][TF] Add conv3d * fix high rtol --- include/tvm/relay/attrs/nn.h | 6 +- python/tvm/relay/frontend/tensorflow.py | 133 +++++++++++++++++- python/tvm/relay/op/nn/_nn.py | 5 +- src/relay/op/nn/convolution.cc | 19 +-- src/relay/op/nn/convolution.h | 13 +- src/relay/op/op_common.h | 39 +++++ .../frontend/tensorflow/test_forward.py | 62 +++++++- tests/python/relay/test_op_level2.py | 96 +++++++++++++ topi/python/topi/cuda/conv3d.py | 57 ++++++-- topi/python/topi/generic/nn.py | 16 +++ topi/python/topi/nn/conv3d.py | 72 +++++++++- topi/python/topi/nn/util.py | 12 +- topi/python/topi/testing/__init__.py | 1 + .../topi/testing/conv3d_ncdhw_python.py | 37 ++--- .../topi/testing/conv3d_ndhwc_python.py | 82 +++++++++++ topi/tests/python/test_topi_conv3d_ncdhw.py | 20 ++- topi/tests/python/test_topi_conv3d_ndhwc.py | 79 +++++++++++ 17 files changed, 683 insertions(+), 66 deletions(-) create mode 100644 topi/python/topi/testing/conv3d_ndhwc_python.py create mode 100644 topi/tests/python/test_topi_conv3d_ndhwc.py diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index d724f8173832..a2cad94320d7 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -212,7 +212,11 @@ struct Conv3DAttrs : public tvm::AttrsNode { .describe("Specifies the strides of the convolution."); TVM_ATTR_FIELD(padding).set_default(Array({0, 0, 0})) .describe("If padding is non-zero, then the input is implicitly zero-padded" - "on both sides for padding number of points"); + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom," + "right)"); TVM_ATTR_FIELD(dilation).set_default(Array({1, 1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); TVM_ATTR_FIELD(groups).set_default(1) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index f748fe828bfd..db037e49bded 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -66,16 +66,18 @@ def _impl(attr): kernel = attr['kernel_shape'] if len(kernel) == 2: return prefix + '2d' + surfix + if len(kernel) == 3: + return prefix + '3d' + surfix raise tvm.error.OpAttributeInvalid( - 'Only 2D kernels are supported for operator {}'.format(prefix + '2d')) + 'Only 2D or 3D kernels are supported for operator {}'.format(prefix + '2d or 3d')) return _impl def _dimension_constraint(): def _dim_check(attrs): - if len(attrs['kernel_shape']) == 2: + if len(attrs['kernel_shape']) in (2, 3): return True return False - return _dim_check, "Only 2d kernel supported." + return _dim_check, "Only 2d or 3d kernel supported." def _get_param(params, input_node): if isinstance(input_node, _expr.Constant): @@ -425,6 +427,130 @@ def _impl(inputs, attr, params): return out return _impl +def _conv3d(opname): + def _impl(inputs, attr, params): + attr['data_format'] = attr['data_format'].decode("utf-8") + flip_layout = False + + inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2] + + # NCDHW Layout require weights transpose + if attr['data_format'] == 'NCDHW': + tmp_shape = attr['_input_shapes'][inputs[1]] + tmp_shape = [tmp_shape[ii] for ii in (4, 3, 0, 1, 2)] + inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2)) + attr['_input_shapes'][inputs[1]] = tmp_shape + + input_shape = attr['_input_shapes'][inputs_data] + weights_shape = attr['_input_shapes'][inputs[1]] + + if attr['_target_layout'] == "NCDHW" and attr['data_format'] == "NDHWC": + input_shape = [input_shape[ii] for ii in (0, 4, 1, 2, 3)] + inputs_data = _op.transpose(inputs_data, axes=(0, 4, 1, 2, 3)) + weights_shape = [weights_shape[ii] for ii in (4, 3, 0, 1, 2)] + inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2)) + + attr['data_format'] = "NCDHW" + attr['strides'] = [attr['strides'][ii] for ii in (0, 4, 1, 2, 3)] + flip_layout = True + + if attr['data_format'] == 'NDHWC': + kernel_d, kernel_h, kernel_w, _, _ = weights_shape + attr['kernel_shape'] = (kernel_d, kernel_h, kernel_w) + if opname == 'conv': + attr['channels'] = weights_shape[4] + elif opname == 'conv_transpose': + attr['channels'] = weights_shape[3] + + if 'dilations' in attr: + attr['dilations'] =\ + (attr['dilations'][1], attr['dilations'][2], attr['dilations'][3]) + attr['strides'] = (attr['strides'][1], attr['strides'][2], attr['strides'][3]) + elif attr['data_format'] == 'NCDHW': + _, _, kernel_d, kernel_h, kernel_w = weights_shape + attr['kernel_shape'] = (kernel_d, kernel_h, kernel_w) + if opname == 'conv': + attr['channels'] = weights_shape[0] + elif opname == 'conv_transpose': + attr['channels'] = weights_shape[1] + + if 'dilations' in attr: + attr['dilations'] =\ + (attr['dilations'][2], attr['dilations'][3], attr['dilations'][4]) + attr['strides'] = (attr['strides'][2], attr['strides'][3], attr['strides'][4]) + else: + msg = 'Value {} in attribute "data_format" of operator Conv is ' \ + 'not valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) + + # Fix padding + attr['padding'] = attr['padding'].decode("utf-8") + + if attr['padding'] == 'VALID': + attr['padding'] = [0, 0, 0] + elif attr['padding'] == 'SAME': + stride_d, stride_h, stride_w = attr['strides'] + kernel_d, kernel_h, kernel_w = attr['kernel_shape'] + + pdata_shape = input_shape + if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0: + pdata_shape = attr['_output_shapes'][0] + + if attr['data_format'] == 'NDHWC': + in_d = pdata_shape[1] + in_h = pdata_shape[2] + in_w = pdata_shape[3] + else: + in_d = pdata_shape[2] + in_h = pdata_shape[3] + in_w = pdata_shape[4] + + dilation_d = attr['dilations'][0] + dilation_h = attr['dilations'][1] + dilation_w = attr['dilations'][2] + dilated_kernel_d = (kernel_d - 1) * dilation_d + 1 + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_d = _get_pad_pair(in_d, dilated_kernel_d, stride_d) + pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) + pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) + + attr['padding'] = [pad_d[0], pad_v[0], pad_h[0], pad_v[0], pad_v[1], pad_h[1]] + + else: + msg = 'Value {} in attribute "padding" of operator Conv is not ' \ + 'valid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) + + if 'kernel_layout' not in attr: + attr['kernel_layout'] = 'DHWIO' if attr['data_format'] == 'NDHWC' else 'OIDHW' + + use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4) + channel_axis = 1 if attr['data_format'] == "NCDHW" else 3 + + # Ignore the new attributes from TF2.0, for now. + out = AttrCvt( + op_name=_dimension_picker('conv', \ + surfix="_transpose" if opname == 'conv_transpose' else ""), + ignores=['explicit_paddings'], + transforms={ + 'kernel_shape': 'kernel_size', + 'data_format': 'data_layout', + 'dilations': ('dilation', (0, 0)), + 'group': ('groups', 1)}, + custom_check=_dimension_constraint())([inputs_data, inputs[1]], attr) + + if use_bias: + out = _op.nn.bias_add(out, + inputs[2] if opname != 'conv_transpose' else inputs[3], + axis=channel_axis) + + if flip_layout: + out = _op.transpose(out, axes=(0, 2, 3, 4, 1)) + + return out + return _impl + def _decode_image(): def _impl(inputs, attr, params): # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer. @@ -1442,6 +1568,7 @@ def _impl(inputs, attr, params): 'Concat' : _concat(), 'ConcatV2' : _concatV2(), 'Conv2D' : _conv('conv'), + 'Conv3D' : _conv3d('conv'), 'Conv2DBackpropInput' : _conv('conv_transpose'), 'CropAndResize' : _crop_and_resize(), 'DecodeJpeg' : _decode_image(), diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 322325819fba..452eb27f2b07 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -173,6 +173,7 @@ def _get_out_depth(): assert len(weight_shape) == 5 C, M, _, _, VC = weight_shape return C * VC * M + if groups == 1: out = topi.nn.conv2d( inputs[0], inputs[1], strides, padding, @@ -330,7 +331,7 @@ def compute_conv3d(attrs, inputs, out_type, target): out_dtype = (inputs[0].dtype if out_dtype in ("same", "") else out_dtype) - assert layout in ["NCDHW"] + assert layout in ["NCDHW", "NDHWC"] (dilation_d, dilation_h, dilation_w) = dilation if dilation_d < 1 or dilation_h < 1 or dilation_w < 1: raise ValueError("dilation should be positive value") @@ -353,6 +354,8 @@ def schedule_conv3d(attrs, outs, target): with target: if groups == 1 and layout == "NCDHW": return topi.generic.schedule_conv3d_ncdhw(outs) + elif groups == 1 and layout == "NDHWC": + return topi.generic.schedule_conv3d_ndhwc(outs) raise ValueError("No compatible schedule") diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 40c24462c8f7..e49c9d65c905 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -38,7 +38,7 @@ namespace relay { TVM_REGISTER_NODE_TYPE(Conv2DAttrs); template -Array > Conv2DInferCorrectLayout( +Array > ConvInferCorrectLayout( const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, @@ -105,7 +105,7 @@ with the layer input to produce a tensor of outputs. .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(2) .add_type_rel("Conv2D", Conv2DRel) -.set_attr("FInferCorrectLayout", Conv2DInferCorrectLayout); +.set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // relay.nn.conv3d TVM_REGISTER_NODE_TYPE(Conv3DAttrs); @@ -163,7 +163,8 @@ with the layer input to produce a tensor of outputs. .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(2) -.add_type_rel("Conv3D", Conv3DRel); +.add_type_rel("Conv3D", Conv3DRel) +.set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // relay.nn.conv2d_transpose @@ -337,7 +338,7 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW` .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(2) .set_attr("FInferCorrectLayout", - Conv2DInferCorrectLayout) + ConvInferCorrectLayout) .add_type_rel("Conv2DTranspose", Conv2DTransposeRel); @@ -635,7 +636,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform") .set_support_level(10) .add_type_rel("Conv2DWinograd", Conv2DWinogradRel) .set_attr("FInferCorrectLayout", - Conv2DInferCorrectLayout); + ConvInferCorrectLayout); // relay.nn.contrib_conv2d_winograd_weight_transform TVM_REGISTER_NODE_TYPE(Conv2DWinogradWeightTransformAttrs); @@ -744,7 +745,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(10) .add_type_rel("Conv2DWinogradNNPACKRel", Conv2DWinogradRel) -.set_attr("FInferCorrectLayout", Conv2DInferCorrectLayout); +.set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // relay.nn.contrib_conv2d_winograd_nnpack_weight_transform TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs); @@ -854,7 +855,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc_int8") .set_support_level(10) .add_type_rel("Conv2DNCHWcInt8", Conv2DWinogradRel) .set_attr("FInferCorrectLayout", - Conv2DInferCorrectLayout); + ConvInferCorrectLayout); // Positional relay function to create conv2d NCHWc operator // used by frontend FFI. @@ -903,7 +904,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc") .set_support_level(10) .add_type_rel("Conv2DNCHWc", Conv2DWinogradRel) .set_attr("FInferCorrectLayout", - Conv2DInferCorrectLayout); + ConvInferCorrectLayout); // Positional relay function to create depthwise conv2d NCHWc operator @@ -953,7 +954,7 @@ RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc") .set_support_level(10) .add_type_rel("Conv2D", Conv2DRel) .set_attr("FInferCorrectLayout", - Conv2DInferCorrectLayout); + ConvInferCorrectLayout); bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index efcf7dfe6906..0f4bb05883a0 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -28,6 +28,8 @@ #include #include +#include "../op_common.h" + namespace tvm { namespace relay { @@ -187,7 +189,7 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, param->kernel_size[1], param->kernel_size[2]}}; } - /*wshape = trans_kernel_layout.BackwardShape(wshape); */ + wshape = trans_kernel_layout.BackwardShape(wshape); channels = param->channels; dilated_ksize_z = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; @@ -196,6 +198,7 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, if (weight != nullptr) { weight_dtype = weight->dtype; } + // assign result to reporter reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype)); } else { @@ -225,22 +228,24 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, // dilation Array oshape({dshape_ncdhw[0], channels, 0, 0, 0}); + IndexExpr pad_d, pad_h, pad_w; + GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); if (!dshape_ncdhw[2].as()) { - oshape.Set(2, indexdiv(dshape_ncdhw[2] + param->padding[0] * 2 - dilated_ksize_z, + oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z, param->strides[0]) + 1); } else { oshape.Set(2, dshape_ncdhw[2]); } if (!dshape_ncdhw[3].as()) { - oshape.Set(3, indexdiv(dshape_ncdhw[3] + param->padding[1] * 2 - dilated_ksize_y, + oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y, param->strides[1]) + 1); } else { oshape.Set(3, dshape_ncdhw[3]); } if (!dshape_ncdhw[4].as()) { - oshape.Set(4, indexdiv(dshape_ncdhw[4] + param->padding[2] * 2 - dilated_ksize_x, + oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x, param->strides[2]) + 1); } else { oshape.Set(4, dshape_ncdhw[4]); diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h index 53495ccff15d..04f26b93670e 100644 --- a/src/relay/op/op_common.h +++ b/src/relay/op/op_common.h @@ -162,6 +162,45 @@ inline void GetPaddingWidth(const Array& padding, IndexExpr* pad_w) { } } +/*! \brief A utility function to get padding height and width from a 1, 2, 4 ints tuple. */ +inline void GetPaddingHeightWidth(const Array& padding, IndexExpr* pad_h, + IndexExpr* pad_w) { + if (padding.size() == 1) { + *pad_h = padding[0] * 2; + *pad_w = padding[0] * 2; + } else if (padding.size() == 2) { + *pad_h = padding[0] * 2; + *pad_w = padding[1] * 2; + } else if (padding.size() == 4) { + *pad_h = padding[0] + padding[2]; + *pad_w = padding[1] + padding[3]; + } else { + CHECK_EQ(padding.size(), 4) << " Padding size should be 1, 2 or 4, but got " + << padding.size(); + } +} + +/*! \brief A utility function to get padding depth, height and width from a 1, 3, 6 ints tuple. */ +inline void GetPaddingDepthHeightWidth(const Array& padding, IndexExpr* pad_d, + IndexExpr* pad_h, IndexExpr* pad_w) { + if (padding.size() == 1) { + *pad_d = padding[0] * 2; + *pad_h = padding[0] * 2; + *pad_w = padding[0] * 2; + } else if (padding.size() == 3) { + *pad_d = padding[0] * 2; + *pad_h = padding[1] * 2; + *pad_w = padding[2] * 2; + } else if (padding.size() == 6) { + *pad_d = padding[0] + padding[3]; + *pad_h = padding[1] + padding[4]; + *pad_w = padding[2] + padding[5]; + } else { + CHECK_EQ(padding.size(), 6) << " Padding size should be 1, 3 or 6, but got " + << padding.size(); + } +} + } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 9b7fe62306fd..97557d3c2a04 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -94,13 +94,14 @@ def vmobj_to_list(o): def run_tvm_graph(graph_def, input_data, input_node, num_output=1, - target='llvm', out_names=None, opt_level=3, mode='graph_runtime'): + target='llvm', out_names=None, opt_level=3, mode='graph_runtime', + cuda_layout="NCHW"): """ Generic function to compile on relay and execute on tvm """ input_data = convert_to_list(input_data) input_node = convert_to_list(input_node) layout = None if target == "cuda": - layout = "NCHW" + layout = cuda_layout target_host = None shape_dict = {e: i.shape for e, i in zip(input_node, input_data)} mod, params = relay.frontend.from_tensorflow(graph_def, @@ -160,7 +161,8 @@ def run_tf_graph(sess, input_data, input_node, output_node): def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, - no_gpu=False, opt_level=3, mode='graph_runtime'): + no_gpu=False, opt_level=3, mode='graph_runtime', + cuda_layout="NCHW"): """Generic function to generate and compare tensorflow and TVM output""" def name_without_num(name): return name.split(':')[0] if ":" in name else name @@ -191,7 +193,8 @@ def name_without_num(name): tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device, out_names=out_name, - num_output=len(out_name), opt_level=opt_level, mode=mode) + num_output=len(out_name), opt_level=opt_level, mode=mode, + cuda_layout=cuda_layout) # since the names from tensorflow and relay runs are not exactly same, # first len(tf_output) will be compared for i in range(len(tf_output)): @@ -469,6 +472,57 @@ def test_forward_convolution(): 'NHWC', [1, 8, 8, 1]) +####################################################################### +# Convolution3D +# ----------- + + +def _test_convolution3d(opname, tensor_in_sizes, filter_in_sizes, + dilations, strides, padding, data_format, + deconv_output_shape=[]): + """ One iteration of 3D convolution with given shapes and attributes """ + + total_size_1 = np.prod(tensor_in_sizes) + total_size_2 = np.prod(filter_in_sizes) + # Initializes the input tensor with array containing incrementing + # numbers from 1. + data_array = [f * 1.0 for f in range(1, total_size_1 + 1)] + filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)] + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32') + in_filter = constant_op.constant( + filter_array, shape=filter_in_sizes, dtype='float32') + if data_format == 'NDHWC': + strides = [1] + strides + [1] + dilations = [1] + dilations + [1] + else: + strides = [1, 1] + strides + dilations = [1, 1] + dilations + + if opname == 'conv': + nn_ops.conv3d(in_data, + in_filter, + strides=strides, + dilations=dilations, + padding=padding, + data_format=data_format) + + compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'), + 'Placeholder:0', 'Conv3D:0', cuda_layout="NCDHW") + +def test_forward_convolution3d(): + if is_gpu_available(): + _test_convolution3d('conv', [4, 176, 8, 8, 8], [1, 1, 1, 176, 32], [1, 1, 1], [1, 1, 1], 'SAME', 'NCDHW') + _test_convolution3d('conv', [4, 19, 17, 17, 17], [3, 3, 3, 19, 19], [1, 1, 1], [2, 2, 2], 'VALID', 'NCDHW') + _test_convolution3d('conv', [4, 124, 17, 17, 17], [1, 1, 1, 124, 19], [1, 1, 1], [1, 1, 1], 'SAME', 'NCDHW') + _test_convolution3d('conv', [4, 12, 17, 17, 17], [3, 3, 3, 12, 32], [1, 1, 1], [2, 2, 2], 'VALID', 'NCDHW') + _test_convolution3d('conv', [4, 8, 8, 8, 176], [1, 1, 1, 176, 32], [1, 1, 1], [1, 1, 1], 'SAME', 'NDHWC') + _test_convolution3d('conv', [4, 17, 17, 17, 19], [3, 3, 3, 19, 19], [1, 1, 1], [2, 2, 2], 'VALID', 'NDHWC') + _test_convolution3d('conv', [4, 17, 17, 17, 124], [1, 1, 1, 124, 19], [1, 1, 1], [1, 1, 1], 'SAME', 'NDHWC') + _test_convolution3d('conv', [4, 17, 17, 17, 12], [3, 3, 3, 12, 32], [1, 1, 1], [2, 2, 2], 'VALID', 'NDHWC') + + ####################################################################### # BiasAdd # ----------- diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 2f19f7a1f7d6..ceb5d093533e 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -294,6 +294,56 @@ def run_test_conv2d_cuda(dtype, out_dtype, scale, dshape, kshape, padding=(2, 2), channels=192, kernel_size=(7, 7)) +def test_conv3d_infer_type(): + # symbolic in batch dimension + n, c, d, h, w = tvm.var("n"), 10, 224, 224, 224 + x = relay.var("x", relay.ty.TensorType((n, c, d, h, w), "float32")) + w = relay.var("w") + y = relay.nn.conv3d(x, w, + kernel_size=(3, 3, 3), + padding=(1, 1, 1), + channels=2) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, 2, 224, 224, 224), "float32") + assert yy.args[1].checked_type == relay.TensorType( + (2, 10, 3, 3, 3), "float32") + + # infer by shape of w, mixed precision + n, c, d, h, w = tvm.var("n"), 10, 224, 224, 224 + x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8")) + w = relay.var("w", relay.TensorType((2, 10, 3, 3, 3), "int8")) + y = relay.nn.conv3d(x, w, out_dtype="int32") + assert "out_dtype=\"int32\"" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, 2, 222, 222, 222), "int32") + + # infer shape in case of different dtypes for input and weight. + n, c, d, h, w = tvm.var("n"), 10, 224, 224, 224 + x = relay.var("x", relay.TensorType((n, c, d, h, w), "uint8")) + w = relay.var("w", relay.TensorType((2, 10, 3, 3, 3), "int8")) + y = relay.nn.conv3d(x, w, out_dtype="int32") + assert "out_dtype=\"int32\"" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, 2, 222, 222, 222), "int32") + + # Infer with NDHWC + n, c, d, h, w = 4, 32, 224, 224, 224 + x = relay.var("x", relay.TensorType((n, d, h, w, c), "int8")) + wt = relay.var("w") + y = relay.nn.conv3d(x, wt, + kernel_size=(3, 3, 3), + padding=(1, 1, 1), + channels=16, + data_layout="NDHWC", + out_dtype="int32") + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, d, h, w, 16), "int32") + + def test_conv3d_run(): def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape, padding=(1, 1, 1), @@ -338,6 +388,50 @@ def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape, run_test_conv3d("float32", "float32", 1, dshape, kshape, padding=(1, 1, 1), channels=10, kernel_size=(3, 3 ,3)) +def test_conv3d_ndhwc_run(): + def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape, + padding=(1, 1, 1), + fref=None, + groups=1, + dilation=(1, 1, 1), + except_targets=None, + **attrs): + if except_targets is None: + except_targets = [] + + x = relay.var("x", shape=dshape, dtype=dtype) + w = relay.var("w", dtype=dtype) + y = relay.nn.conv3d(x, w, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NDHWC", kernel_layout="DHWIO", + **attrs) + func = relay.Function([x, w], y) + data = np.random.uniform(-scale, scale, size=dshape).astype(dtype) + kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype) + dkernel = topi.testing.dilate_python(kernel, (1, 1) + dilation) + if fref is None: + ref_res = topi.testing.conv3d_ndhwc_python( + data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding) + else: + ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype)) + + + for target, ctx in ctx_list(): + if target in except_targets: + continue + + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data, kernel) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + # normal conv3d + dshape = (1, 5, 224, 224, 6) + kshape = (3, 3, 3, 6, 10) + run_test_conv3d("float32", "float32", 1, dshape, kshape, + padding=(1, 1, 1), channels=10, kernel_size=(3, 3 ,3), except_targets=["cuda"]) + def test_conv2d_transpose_infer_type(): # symbolic in batch dimension @@ -993,6 +1087,7 @@ def test_bitpack_infer_type(): test_lrn() test_l2_normalize() test_conv2d_infer_type() + test_conv3d_infer_type() test_bitpack_infer_type() test_upsampling_infer_type() test_upsampling3d_infer_type() @@ -1006,6 +1101,7 @@ def test_bitpack_infer_type(): test_conv2d_run() test_conv2d_winograd() test_conv3d_run() + test_conv3d_ndhwc_run() test_bitserial_conv2d_infer_type() test_batch_flatten() test_upsampling() diff --git a/topi/python/topi/cuda/conv3d.py b/topi/python/topi/cuda/conv3d.py index 8d3c720b6a89..7d3c0b4afc1b 100644 --- a/topi/python/topi/cuda/conv3d.py +++ b/topi/python/topi/cuda/conv3d.py @@ -21,6 +21,7 @@ from tvm.contrib import cudnn from .. import nn, generic +from ..nn.util import get_pad_tuple3d from ..util import get_const_tuple, traverse_inline from .conv3d_direct import schedule_direct_3d_cuda @@ -44,8 +45,10 @@ def conv3d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCDHW', o strides : int or a list/tuple of three ints stride size, or [stride_depth, stride_height, stride_width] - padding : int or a list/tuple of three ints - padding size, or [pad_depth, pad_height, pad_width] + padding : int or a list/tuple of 3 or 6 ints + padding size, or + [pad_depth, pad_height, pad_width] for 3 ints, or + [pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right] for 6 ints dilation: int or a list/tuple of three ints dilation size, or [dilation_depth, dilation_height, dilation_width] @@ -77,25 +80,27 @@ def conv3d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCDHW', o # handle dilation stride_d, stride_h, stride_w = (strides, strides, strides) if isinstance(strides, int) \ else strides - pad_d, pad_h, pad_w = (padding, padding, padding) if isinstance(padding, int) else padding + if isinstance(padding, (list, tuple)) and len(padding) > 3: + raise ValueError("Cudnn doesn't support asymmetric padding.") + pf, pt, pl, pk, pb, pr = get_pad_tuple3d(padding, (KD, KH, KW)) dilation_d, dilation_h, dilation_w = (dilation, dilation, dilation) if \ isinstance(dilation, int) else dilation - OD = (D + 2 * pad_d - KD) // stride_d + 1 - OH = (H + 2 * pad_h - KH) // stride_h + 1 - OW = (W + 2 * pad_w - KW) // stride_w + 1 - cfg.add_flop(2 * N * OD * OH * OW * CO * CI * ((DH - 1) * dilation_d + 1) *\ + OD = (D + pf + pk - KD) // stride_d + 1 + OH = (H + pt + pb - KH) // stride_h + 1 + OW = (W + pl + pr - KW) // stride_w + 1 + cfg.add_flop(2 * N * OD * OH * OW * CO * CI * ((KD - 1) * dilation_d + 1) *\ ((KH - 1) * dilation_h + 1) * ((KW - 1) * dilation_w + 1)) return cudnn.conv_forward(data, kernel, - [pad_d, pad_h, pad_w], + [pf, pt, pl], # cudnn padding pt, pl on both sides of input [stride_d, stride_h, stride_w], [dilation_d, dilation_h, dilation_w], conv_mode=1, tensor_format=tensor_format, algo=-1, # let CUDNN choose the best algo - conv_dtype=dtype) + conv_dtype=data.dtype) if layout == 'NCDHW': return nn.conv3d_ncdhw(data, kernel, strides, padding, dilation, out_dtype) @@ -134,3 +139,37 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + + +@autotvm.register_topi_schedule(generic.schedule_conv3d_ndhwc, ["cuda", "gpu"], + ["direct"]) +def schedule_conv3d_ndhwc_cuda(cfg, outs): + """TOPI schedule callback of conv3d for cuda gpu + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + outs: Array of Tensor + The computation graph description of conv2d + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for conv2d. + """ + target = tvm.target.current_target() + if 'cudnn' in target.libs: + return generic.schedule_extern(outs) + + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == 'conv3d_ndhwc': + schedule_direct_3d_cuda(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 77f8cadb991e..980db65d9b8d 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -242,6 +242,22 @@ def schedule_conv3d_ncdhw(outs): """ return _default_schedule(outs, False) +@tvm.target.generic_func +def schedule_conv3d_ndhwc(outs): + """Schedule for conv3d_ndhwc + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of conv3d_ndhwc + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) @tvm.target.generic_func def schedule_conv2d_transpose_nchw(outs): diff --git a/topi/python/topi/nn/conv3d.py b/topi/python/topi/nn/conv3d.py index 928f32f51d75..21d893fd5656 100644 --- a/topi/python/topi/nn/conv3d.py +++ b/topi/python/topi/nn/conv3d.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, unused-variable, too-many-locals -# pylint: disable=unused-argument, redefined-builtin +# pylint: disable=unused-argument, redefined-builtin, no-else-return """Conv3D operators""" from __future__ import absolute_import as _abs import tvm @@ -58,6 +58,8 @@ def conv3d(input, filter, strides, padding, dilation, layout='NCDHW', out_dtype= # default declaration if layout == 'NCDHW': return conv3d_ncdhw(input, filter, strides, padding, dilation, out_dtype) + elif layout == 'NDHWC': + return conv3d_ndhwc(input, filter, strides, padding, dilation, out_dtype) raise ValueError("not support this layout {} yet".format(layout)) @@ -128,3 +130,71 @@ def conv3d_ncdhw(Input, Filter, stride, padding, dilation, out_dtype=None): xx * stride_w + rx * dilation_w].astype(out_dtype) * Filter[ff, rc, rz, ry, rx].astype(out_dtype), axis=[rc, rz, ry, rx]), tag="conv3d_ncdhw") + + +def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): + """Convolution operator in NDHWC layout. + + Parameters + ---------- + Input : tvm.Tensor + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] + + Filter : tvm.Tensor + 5-D with shape [num_filter, in_channel, filter_depth, filter_height, filter_width] + + stride : int or a list/tuple of three ints + Stride size, or [strid_depth, stride_height, stride_width] + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + dilation: int or a list/tuple of three ints + dilation size, or [dilation_depth, dilation_height, dilation_width] + + Returns + ------- + Output : tvm.Tensor + 5-D with shape [batch, out_channel, out_depth, out_height, out_width] + """ + assert isinstance(stride, int) or len(stride) == 3 + assert isinstance(dilation, int) or len(dilation) == 3 + + if isinstance(stride, int): + stride_d = stride_h = stride_w = stride + else: + stride_d, stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_d = dilation_h = dilation_w = dilation + else: + dilation_d, dilation_h, dilation_w = dilation + + batch, in_depth, in_height, in_width, in_channel = Input.shape + kernel_d, kernel_h, kernel_w, channel, num_filter = Filter.shape + # compute the output shape + dilated_kernel_d = (kernel_d - 1) * dilation_d + 1 + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + + pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d( + padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w)) + out_channel = num_filter + out_depth = simplify((in_depth - dilated_kernel_d + pad_front + pad_back) // stride_d + 1) + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + pad_before = [0, pad_front, pad_top, pad_left, 0] + pad_after = [0, pad_back, pad_down, pad_right, 0] + PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") + rc = tvm.reduce_axis((0, in_channel), name='rc') + rz = tvm.reduce_axis((0, kernel_d), name='rz') + ry = tvm.reduce_axis((0, kernel_h), name='ry') + rx = tvm.reduce_axis((0, kernel_w), name='rx') + Output = tvm.compute( + (batch, out_depth, out_height, out_width, out_channel), + lambda nn, zz, yy, xx, ff: tvm.sum( + PaddedInput[nn, zz * stride_d + rz * dilation_d, yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * + Filter[rz, ry, rx, rc, ff].astype(out_dtype), axis=[rz, ry, rx, rc]), + name="Conv3dOutput", tag="conv3d_ndhwc") + return Output diff --git a/topi/python/topi/nn/util.py b/topi/python/topi/nn/util.py index 847a5c84daaa..c2c5c2bf6505 100644 --- a/topi/python/topi/nn/util.py +++ b/topi/python/topi/nn/util.py @@ -158,9 +158,15 @@ def get_pad_tuple3d(padding, kernel): """ # compute the padding size if isinstance(padding, (tuple, list)): - pad_h = padding[0] * 2 - pad_w = padding[1] * 2 - pad_d = padding[2] * 2 + if len(padding) == 3: + pad_d = padding[0] * 2 + pad_h = padding[1] * 2 + pad_w = padding[2] * 2 + elif len(padding) == 6: + return padding[0], padding[1], padding[2], padding[3], \ + padding[4], padding[5] + else: + raise ValueError("Size of padding can only be 3 or 6") elif isinstance(padding, int): pad_d = pad_w = pad_h = padding * 2 elif padding == "VALID": diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 2826a2b5fdc6..5e2f3fe848f0 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -25,6 +25,7 @@ from .conv2d_nchw_python import conv2d_nchw_python from .conv2d_nhwc_python import conv2d_nhwc_python from .conv3d_ncdhw_python import conv3d_ncdhw_python +from .conv3d_ndhwc_python import conv3d_ndhwc_python from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python from .conv1d_transpose_ncw_python import conv1d_transpose_ncw_python from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python diff --git a/topi/python/topi/testing/conv3d_ncdhw_python.py b/topi/python/topi/testing/conv3d_ncdhw_python.py index 3a4db25da897..825ec622a1ec 100644 --- a/topi/python/topi/testing/conv3d_ncdhw_python.py +++ b/topi/python/topi/testing/conv3d_ncdhw_python.py @@ -18,6 +18,7 @@ """Convolution 3D in python""" import numpy as np import scipy.signal +from topi.nn.util import get_pad_tuple3d def _conv3d_ncdhw_python(a_np, w_np, stride, padding): @@ -27,20 +28,13 @@ def _conv3d_ncdhw_python(a_np, w_np, stride, padding): stride_d = stride_h = stride_w = stride else: stride_d, stride_h, stride_w = stride - if isinstance(padding, int): - pad_d = pad_h = pad_w = padding * 2 - elif isinstance(padding, (list, tuple)): - pad_d, pad_h, pad_w = padding[0] * 2, padding[1] * 2, padding[2] * 2 - else: - pad_d = 0 if padding == 'VALID' else kernel_d - 1 - pad_h = 0 if padding == 'VALID' else kernel_h - 1 - pad_w = 0 if padding == 'VALID' else kernel_w - 1 - pad_front = int(np.ceil(float(pad_d) / 2)) - pad_back = pad_d - pad_front - pad_top = int(np.ceil(float(pad_h) / 2)) - pad_bottom = pad_h - pad_top - pad_left = int(np.ceil(float(pad_w) / 2)) - pad_right = pad_w - pad_left + + pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = \ + get_pad_tuple3d(padding, (kernel_d, kernel_h, kernel_w)) + pad_d = pad_front + pad_back + pad_h = pad_top + pad_bottom + pad_w = pad_left + pad_right + # compute the output shape out_channel = num_filter out_depth = (in_depth - kernel_d + pad_d) // stride_d + 1 @@ -53,19 +47,8 @@ def _conv3d_ncdhw_python(a_np, w_np, stride, padding): for c in range(in_channel): if pad_d > 0 or pad_h > 0 or pad_w > 0: apad = np.zeros((in_depth + pad_d, in_height + pad_h, in_width + pad_w)) - if pad_d == 0 and pad_h == 0: - apad[:, :, pad_left:-pad_right] = a_np[n, c] - elif pad_d == 0 and pad_w == 0: - apad[:, pad_top:-pad_bottom, :] = a_np[n, c] - elif pad_d == 0 and pad_h != 0 and pad_w != 0: - apad[:, pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, c] - elif pad_d != 0 and pad_h == 0: - apad[pad_front:-pad_back, :, pad_left:-pad_right] = a_np[n, c] - elif pad_d != 0 and pad_w == 0: - apad[pad_front:-pad_back, pad_top:-pad_bottom, :] = a_np[n, c] - elif pad_d != 0 and pad_h != 0 and pad_w != 0: - apad[pad_front:-pad_back, pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, c] - + apad[pad_front:pad_front + in_depth, pad_top:pad_top + in_height,\ + pad_left:pad_left + in_width] = a_np[n, c] else: apad = a_np[n, c] out = scipy.signal.convolve( diff --git a/topi/python/topi/testing/conv3d_ndhwc_python.py b/topi/python/topi/testing/conv3d_ndhwc_python.py new file mode 100644 index 000000000000..2810f72b094f --- /dev/null +++ b/topi/python/topi/testing/conv3d_ndhwc_python.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Convolution 3D in python""" +import numpy as np +import scipy.signal +from topi.nn.util import get_pad_tuple3d + + +def conv3d_ndhwc_python(a_np, w_np, stride, padding): + """Convolution 3D operator in NDHWC layout. + + Parameters + ---------- + a_np : numpy.ndarray + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] + + w_np : numpy.ndarray + 5-D with shape [num_filter, in_channel, filter_depth, filter_height, filter_width] + + stride : int or a list/tuple of three ints + Stride size, or [stride_depth, stride_height, stride_width] + + padding : int or str or a list/tuple of three ints + Padding size, or ['VALID', 'SAME'], or [pad_depth, pad_height, pad_width] + groups : int + Number of groups + + Returns + ------- + b_np : np.ndarray + 5-D with shape [batch, out_channel, out_depth, out_height, out_width] + """ + batch, in_depth, in_height, in_width, in_channel = a_np.shape + kernel_d, kernel_h, kernel_w, _, num_filter = w_np.shape + if isinstance(stride, int): + stride_d = stride_h = stride_w = stride + else: + stride_d, stride_h, stride_w = stride + + pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = \ + get_pad_tuple3d(padding, (kernel_d, kernel_h, kernel_w)) + pad_d = pad_front + pad_back + pad_h = pad_top + pad_bottom + pad_w = pad_left + pad_right + # compute the output shape + out_channel = num_filter + out_depth = (in_depth - kernel_d + pad_d) // stride_d + 1 + out_height = (in_height - kernel_h + pad_h) // stride_h + 1 + out_width = (in_width - kernel_w + pad_w) // stride_w + 1 + # change the layout from NHWC to NCHW + at = a_np.transpose((0, 4, 1, 2, 3)) + wt = w_np.transpose((4, 3, 0, 1, 2)) + bt = np.zeros((batch, out_channel, out_depth, out_height, out_width)) + # computation + for n in range(batch): + for f in range(out_channel): + for c in range(in_channel): + if pad_d > 0 or pad_h > 0 or pad_w > 0: + apad = np.zeros((in_depth + pad_d, in_height + pad_h, in_width + pad_w)) + apad[pad_front:pad_front + in_depth, pad_top:pad_top + in_height,\ + pad_left:pad_left + in_width] = at[n, c] + else: + apad = at[n, c] + out = scipy.signal.convolve( + apad, np.flip(wt[f, c]), mode='valid') + bt[n, f] += out[::stride_d, ::stride_h, ::stride_w] + return bt.transpose((0, 2, 3, 4, 1)) diff --git a/topi/tests/python/test_topi_conv3d_ncdhw.py b/topi/tests/python/test_topi_conv3d_ncdhw.py index 78827e4ca9d1..681190633d67 100644 --- a/topi/tests/python/test_topi_conv3d_ncdhw.py +++ b/topi/tests/python/test_topi_conv3d_ncdhw.py @@ -22,12 +22,16 @@ import topi import topi.testing from tvm.contrib.pickle_memoize import memoize +from topi.nn.util import get_pad_tuple3d from topi.util import get_const_tuple from common import get_all_backend def verify_conv3d_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False): - print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = get_pad_tuple3d(padding, (kernel, kernel, kernel)) + padding_sum = pad_front + pad_back + pad_top + pad_left + pad_bottom + pad_right + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, + padding_sum, dilation)) in_depth = in_height = in_width = in_size @@ -62,7 +66,7 @@ def check_device(device): return print("Running on target: %s" % device) with tvm.target.create(device): - C = topi.nn.conv3d(A, W, (stride, stride, stride), (padding, padding, padding), + C = topi.nn.conv3d(A, W, (stride, stride, stride), padding, (dilation, dilation, dilation), layout='NCDHW', out_dtype=dtype) if add_bias: C = topi.add(C, bias) @@ -75,10 +79,10 @@ def check_device(device): b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) if add_bias: - func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, b, c) else: - func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, c) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4) @@ -109,6 +113,14 @@ def test_conv3d_ncdhw(): verify_conv3d_ncdhw(2, 2, 2, 2, 2, 2, 2) verify_conv3d_ncdhw(3, 3, 3, 3, 3, 3, 3) + # Asymmetric padding + verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, (0, 0, 0, 1, 1, 1)) + verify_conv3d_ncdhw(1, 32, 32, 1, 1, 1, (2, 1, 2, 1, 2, 1)) + verify_conv3d_ncdhw(1, 64, 56, 3, 3, 1, (2, 2, 2, 1, 1, 1), dilation=2) + verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, (0, 1, 1)) + verify_conv3d_ncdhw(1, 32, 32, 1, 1, 1, (2, 1, 0)) + verify_conv3d_ncdhw(1, 32, 32, 1, 3, 1, "VALID") + verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, "VALID") if __name__ == "__main__": diff --git a/topi/tests/python/test_topi_conv3d_ndhwc.py b/topi/tests/python/test_topi_conv3d_ndhwc.py new file mode 100644 index 000000000000..66ccf086275c --- /dev/null +++ b/topi/tests/python/test_topi_conv3d_ndhwc.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example code to do convolution.""" +import os +import numpy as np +import tvm +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + + +def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): + in_depth = in_height = in_width = in_size + + A = tvm.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A') + W = tvm.placeholder((kernel, kernel, kernel, in_channel, num_filter), name='W') + B = topi.nn.conv3d_ndhwc(A, W, stride, padding, dilation) + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_conv3d_ndhwc.verify_ndhwc.v2") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, dilation, 1, 1)) + b_np = topi.testing.conv3d_ndhwc_python(a_np, dw_np, stride, padding) + return a_np, w_np, b_np + a_np, w_np, b_np = get_ref_data() + + def check_device(device): + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_conv3d_ndhwc([B]) + ctx = tvm.context(device, 0) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + func = tvm.build(s, [A, W, B], device) + func(a, w, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in ['llvm']: + check_device(device) + + +def test_conv3d_ndhwc(): + verify_conv3d_ndhwc(1, 16, 32, 16, 3, 1, "SAME") + verify_conv3d_ndhwc(4, 32, 16, 32, 5, 2, "SAME") + verify_conv3d_ndhwc(4, 32, 16, 64, 5, 2, "SAME") + verify_conv3d_ndhwc(1, 64, 32, 64, 3, 1, "VALID") + verify_conv3d_ndhwc(1, 64, 32, 64, 3, 1, "VALID") + verify_conv3d_ndhwc(4, 32, 16, 32, 5, 2, "VALID") + verify_conv3d_ndhwc(4, 32, 16, 64, 5, 2, "VALID") + # dilation = 2 + verify_conv3d_ndhwc(1, 64, 32, 64, 3, 1, "SAME", dilation=2) + + +if __name__ == "__main__": + test_conv3d_ndhwc()