From 692b4c0a2bb1ce48a38f76534bd44b649a2d8f0c Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Sun, 17 Mar 2019 10:59:43 -0700 Subject: [PATCH] [Relay][Pass] Fix Depthwise AlterLayout --- python/tvm/relay/op/nn/_nn.py | 25 ++++++++ python/tvm/relay/op/nn/nn.py | 64 +++++++++++++++++++ src/relay/op/nn/convolution.cc | 52 +++++++++++++++ tests/lint/pylintrc | 2 +- topi/python/topi/nn/depthwise_conv2d.py | 4 +- topi/python/topi/x86/conv2d.py | 30 +++++---- topi/python/topi/x86/depthwise_conv2d.py | 4 +- .../python/test_topi_depthwise_conv2d.py | 6 +- 8 files changed, 168 insertions(+), 19 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 0c2733ecae924..58de44c2e0b51 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -345,3 +345,28 @@ def schedule_contrib_conv2d_NCHWc(attrs, outs, target): reg.register_pattern("nn.contrib_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE) + +@reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc") +def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target): + """Compute definition of depthwise conv2d NCHWc""" + # pylint: disable=assignment-from-no-return + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") + data_layout = attrs.get_str("data_layout") + out_layout = attrs.get_str("out_layout") + out_dtype = attrs.get_str("out_dtype") + out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype + + out = topi.nn.depthwise_conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation, + data_layout, out_layout, out_dtype) + return [out] + +@reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc") +def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target): + """Schedule definition of contrib_conv2d_NCHWc""" + with target: + return topi.generic.schedule_depthwise_conv2d_NCHWc(outs) + +reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc", + OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 41b2148ec3904..ad8b287bb3973 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -927,6 +927,70 @@ def contrib_conv2d_nchwc(data, groups, channels, kernel_size, data_layout, kernel_layout, out_layout, out_dtype) +def contrib_depthwise_conv2d_nchwc(data, + kernel, + strides=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + channels=None, + kernel_size=None, + data_layout="NCHW8c", + kernel_layout="OIHW", + out_layout="", + out_dtype=""): + r"""Variant of 2D depthwise convolution. + + This operator takes the weight as the depthwise convolution kernel + and depthwise convolves it with data to produce an output, following a specialized + NCHWc data layout. + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + kernel : tvm.relay.Expr + The kernel expressions. + + strides : tuple of int, optional + The strides of convoltution. + + padding : tuple of int, optional + The padding of convolution on both sides of inputs before convolution. + + dilation : tuple of int, optional + Specifies the dilation rate to be used for dilated convolution. + + groups : int, optional + Number of groups for grouped convolution. + + channels : int, optional + Number of output channels of this convolution. + + kernel_size : tuple of int, optional + The spatial of the convolution kernel. + + data_layout : str, optional + Layout of the input. + + kernel_layout : str, optional + Layout of the weight. + + out_layout : str, optional + Layout of the output, by default, out_layout is the same as data_layout + + out_dtype : str, optional + Specifies the output data type for mixed precision conv2d. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.contrib_depthwise_conv2d_NCHWc(data, kernel, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, out_dtype) def contrib_conv2d_winograd_weight_transform(weight, tile_size): diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 963257a149611..b53f57464e813 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -582,5 +582,57 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc") Conv2DInferCorrectLayout); +// Positional relay function to create depthwise conv2d NCHWc operator +// used by frontend FFI. +Expr MakeDepthwiseConv2DNCHWc(Expr data, + Expr kernel, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype) { + auto attrs = make_node(); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->channels = channels; + attrs->kernel_size = std::move(kernel_size); + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->out_dtype = std::move(out_dtype); + static const Op& op = Op::Get("nn.contrib_depthwise_conv2d_NCHWc"); + return CallNode::make(op, {data, kernel}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeDepthwiseConv2DNCHWc, args, rv); + }); + + +RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc") +.describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout. +- **data**: Input is 5D packed tensor. +- **weight**: 6D packed tensor. + +- **out**: Output is 5D packed tensor +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.DepthwiseConv2D") +.set_num_inputs(2) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("weight", "Tensor", "The weight tensor.") +.set_support_level(10) +.add_type_rel("Conv2D", Conv2DRel) +.set_attr("FInferCorrectLayout", + Conv2DInferCorrectLayout); + + } // namespace relay } // namespace tvm diff --git a/tests/lint/pylintrc b/tests/lint/pylintrc index 355e2ad5acd1f..f3b60492bc5d6 100644 --- a/tests/lint/pylintrc +++ b/tests/lint/pylintrc @@ -114,7 +114,7 @@ single-line-if-stmt=no no-space-check=trailing-comma,dict-separator # Maximum number of lines in a module -max-module-lines=1000 +max-module-lines=1500 # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 # tab). diff --git a/topi/python/topi/nn/depthwise_conv2d.py b/topi/python/topi/nn/depthwise_conv2d.py index ca24b08dd0bb5..abb638039f48e 100644 --- a/topi/python/topi/nn/depthwise_conv2d.py +++ b/topi/python/topi/nn/depthwise_conv2d.py @@ -292,7 +292,7 @@ def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation, 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] Filter : tvm.Tensor - 4-D with shape [out_channel_chunk, filter_height, filter_width, out_channel_block] + 6-D with shape [out_channel_chunk, 1, filter_height, filter_width, 1, out_channel_block] In NCHWc depthwise convolution, we group kernel's in_channel and channel_multiplier together then do the tiling. @@ -317,6 +317,6 @@ def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation, Returns ------- Output : tvm.Tensor - 4-D with shape [batch, out_channel, out_height, out_width] + 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] """ raise ValueError("missing register for topi.nn.depthwise_conv2d_NCHWc") diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index f0ef226c11172..32ea15a61844d 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -1,6 +1,8 @@ # pylint: disable=invalid-name,unused-variable,unused-argument,no-member """Conv2D schedule on x86""" +import logging + import tvm from tvm import autotvm from tvm.autotvm.task.topi_integration import deserialize_args @@ -16,6 +18,8 @@ from . import conv2d_avx_1x1, conv2d_avx_common +logger = logging.getLogger('topi') + def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False): """ Get default schedule config for the workload @@ -290,7 +294,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): batch_size, in_channel, height, width = get_const_tuple(data.shape) groups = attrs.get_int("groups") - out_channel = attrs.get_int("channels") + out_channel = attrs.get_int("channels") if F == sym else attrs.get_int("channels").value padding = attrs.get_int_tuple("padding") strides = attrs.get_int_tuple("strides") dilation = attrs.get_int_tuple("dilation") @@ -330,16 +334,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), dtype=data.dtype) - if is_depthwise: - # channel, channel_multiplier, kh, kw -> out_channel_chunk, kh, kw, out_channel_block - # in which out_channel = merge(channel, channel_multiplier) - kernel_sym = copy_inputs[1] - kernel_sym = sym.reshape(kernel_sym, shape=(out_channel//oc_bn, oc_bn, kh, kw)) - kernel_sym = sym.transpose(kernel_sym, axes=(0, 2, 3, 1)) - copy_inputs[1] = kernel_sym + if is_depthwise: + new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn # Store altered operator's config - new_kernel = tvm.placeholder((out_channel//oc_bn, kh, kw, oc_bn), dtype=kernel.dtype) + new_kernel = tvm.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel.dtype) new_workload = autotvm.task.args_to_workload( [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc) @@ -356,9 +355,16 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): new_attrs['out_layout'], out_dtype], conv2d_NCHWc) dispatch_ctx.update(target, new_workload, cfg) - if F == sym: - return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) - return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) + + if is_depthwise: + if F == sym: + logging.warning("NNVM is not supported. Falling back to NCHW op.") + return None + return F.nn.contrib_depthwise_conv2d_nchwc(*copy_inputs, **new_attrs) + else: + if F == sym: + return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) + return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) @autotvm.register_topi_compute(conv2d_NCHWc, 'cpu', 'direct') diff --git a/topi/python/topi/x86/depthwise_conv2d.py b/topi/python/topi/x86/depthwise_conv2d.py index 64858df91cdc3..3c0673a29a961 100644 --- a/topi/python/topi/x86/depthwise_conv2d.py +++ b/topi/python/topi/x86/depthwise_conv2d.py @@ -58,7 +58,7 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_layout, out_dtype=None): out_dtype = data.dtype if out_dtype is None else out_dtype batch, in_channel_chunk, in_height, in_width, in_channel_block = get_const_tuple(data.shape) - out_channel_chunk, filter_height, filter_width, out_channel_block \ + out_channel_chunk, _, filter_height, filter_width, __, out_channel_block \ = get_const_tuple(kernel.shape) strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) @@ -102,7 +102,7 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation, oh*HSTR+kh, ow*WSTR+kw, ((oco * out_channel_block + oci) // channel_multiplier) % in_channel_block] .astype(out_dtype) * - kernel[oco, kh, kw, oci].astype(out_dtype)), + kernel[oco, 0, kh, kw, 0, oci].astype(out_dtype)), axis=[kh, kw]), name='DepthwiseConv2d', tag="depthwise_conv2d_NCHWc") return Output diff --git a/topi/tests/python/test_topi_depthwise_conv2d.py b/topi/tests/python/test_topi_depthwise_conv2d.py index 98c93dff99935..264b80421cbdc 100644 --- a/topi/tests/python/test_topi_depthwise_conv2d.py +++ b/topi/tests/python/test_topi_depthwise_conv2d.py @@ -216,7 +216,8 @@ def _transform_kernel(kernel, bn): out_channel = channel * channel_multiplier kernel = np.reshape(kernel, (out_channel//bn, bn, kh, kw)) kernel = np.transpose(kernel, (0, 2, 3, 1)) - return kernel + out_channel_chunk, kh, kw, out_channel_block = kernel.shape + return kernel.reshape(out_channel_chunk, 1, kh, kw, 1, out_channel_block) def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1): in_width = in_height @@ -246,7 +247,7 @@ def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_m # placeholder Input = tvm.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='Input') - Filter = tvm.placeholder((out_channel//oc_block, filter_height, filter_width, oc_block), name='Filter') + Filter = tvm.placeholder((out_channel//oc_block, 1, filter_height, filter_width, 1, oc_block), name='Filter') in_layout = "NCHW%dc" % ic_block out_layout = "NCHW%dc" % oc_block dtype = 'float32' @@ -297,6 +298,7 @@ def get_ref_data(): input_tvm = tvm.nd.array(input_np, ctx) filter_tvm = tvm.nd.array(filter_np, ctx) + depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx) relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)