Skip to content

Commit

Permalink
[Relay][Pass] Fix Depthwise AlterLayout
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Xu committed Mar 19, 2019
1 parent a6844de commit 692b4c0
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 19 deletions.
25 changes: 25 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,28 @@ def schedule_contrib_conv2d_NCHWc(attrs, outs, target):

reg.register_pattern("nn.contrib_conv2d_NCHWc",
OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc")
def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target):
"""Compute definition of depthwise conv2d NCHWc"""
# pylint: disable=assignment-from-no-return
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
data_layout = attrs.get_str("data_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.get_str("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype

out = topi.nn.depthwise_conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation,
data_layout, out_layout, out_dtype)
return [out]

@reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc")
def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of contrib_conv2d_NCHWc"""
with target:
return topi.generic.schedule_depthwise_conv2d_NCHWc(outs)

reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc",
OpPattern.OUT_ELEMWISE_FUSABLE)
64 changes: 64 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,70 @@ def contrib_conv2d_nchwc(data,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)

def contrib_depthwise_conv2d_nchwc(data,
kernel,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW8c",
kernel_layout="OIHW",
out_layout="",
out_dtype=""):
r"""Variant of 2D depthwise convolution.
This operator takes the weight as the depthwise convolution kernel
and depthwise convolves it with data to produce an output, following a specialized
NCHWc data layout.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
kernel : tvm.relay.Expr
The kernel expressions.
strides : tuple of int, optional
The strides of convoltution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_depthwise_conv2d_NCHWc(data, kernel, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)

def contrib_conv2d_winograd_weight_transform(weight,
tile_size):
Expand Down
52 changes: 52 additions & 0 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,5 +582,57 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
Conv2DInferCorrectLayout<Conv2DAttrs>);


// Positional relay function to create depthwise conv2d NCHWc operator
// used by frontend FFI.
Expr MakeDepthwiseConv2DNCHWc(Expr data,
Expr kernel,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_node<Conv2DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_depthwise_conv2d_NCHWc");
return CallNode::make(op, {data, kernel}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 12>(MakeDepthwiseConv2DNCHWc, args, rv);
});


RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
.describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout.
- **data**: Input is 5D packed tensor.
- **weight**: 6D packed tensor.
- **out**: Output is 5D packed tensor
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.DepthwiseConv2D")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2D", Conv2DRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DAttrs>);


} // namespace relay
} // namespace tvm
2 changes: 1 addition & 1 deletion tests/lint/pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ single-line-if-stmt=no
no-space-check=trailing-comma,dict-separator

# Maximum number of lines in a module
max-module-lines=1000
max-module-lines=1500

# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/nn/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation,
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
Filter : tvm.Tensor
4-D with shape [out_channel_chunk, filter_height, filter_width, out_channel_block]
6-D with shape [out_channel_chunk, 1, filter_height, filter_width, 1, out_channel_block]
In NCHWc depthwise convolution,
we group kernel's in_channel and channel_multiplier together then do the tiling.
Expand All @@ -317,6 +317,6 @@ def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation,
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
raise ValueError("missing register for topi.nn.depthwise_conv2d_NCHWc")
30 changes: 18 additions & 12 deletions topi/python/topi/x86/conv2d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Conv2D schedule on x86"""

import logging

import tvm
from tvm import autotvm
from tvm.autotvm.task.topi_integration import deserialize_args
Expand All @@ -16,6 +18,8 @@

from . import conv2d_avx_1x1, conv2d_avx_common

logger = logging.getLogger('topi')

def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False):
"""
Get default schedule config for the workload
Expand Down Expand Up @@ -290,7 +294,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
batch_size, in_channel, height, width = get_const_tuple(data.shape)

groups = attrs.get_int("groups")
out_channel = attrs.get_int("channels")
out_channel = attrs.get_int("channels") if F == sym else attrs.get_int("channels").value
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
Expand Down Expand Up @@ -330,16 +334,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):

new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
dtype=data.dtype)
if is_depthwise:
# channel, channel_multiplier, kh, kw -> out_channel_chunk, kh, kw, out_channel_block
# in which out_channel = merge(channel, channel_multiplier)
kernel_sym = copy_inputs[1]
kernel_sym = sym.reshape(kernel_sym, shape=(out_channel//oc_bn, oc_bn, kh, kw))
kernel_sym = sym.transpose(kernel_sym, axes=(0, 2, 3, 1))
copy_inputs[1] = kernel_sym

if is_depthwise:
new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn
# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, kh, kw, oc_bn), dtype=kernel.dtype)
new_kernel = tvm.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc)
Expand All @@ -356,9 +355,16 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)

dispatch_ctx.update(target, new_workload, cfg)
if F == sym:
return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)

if is_depthwise:
if F == sym:
logging.warning("NNVM is not supported. Falling back to NCHW op.")
return None
return F.nn.contrib_depthwise_conv2d_nchwc(*copy_inputs, **new_attrs)
else:
if F == sym:
return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)


@autotvm.register_topi_compute(conv2d_NCHWc, 'cpu', 'direct')
Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/x86/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation,
layout, out_layout, out_dtype=None):
out_dtype = data.dtype if out_dtype is None else out_dtype
batch, in_channel_chunk, in_height, in_width, in_channel_block = get_const_tuple(data.shape)
out_channel_chunk, filter_height, filter_width, out_channel_block \
out_channel_chunk, _, filter_height, filter_width, __, out_channel_block \
= get_const_tuple(kernel.shape)

strides = strides if isinstance(strides, (tuple, list)) else (strides, strides)
Expand Down Expand Up @@ -102,7 +102,7 @@ def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation,
oh*HSTR+kh, ow*WSTR+kw,
((oco * out_channel_block + oci) // channel_multiplier) % in_channel_block]
.astype(out_dtype) *
kernel[oco, kh, kw, oci].astype(out_dtype)),
kernel[oco, 0, kh, kw, 0, oci].astype(out_dtype)),
axis=[kh, kw]),
name='DepthwiseConv2d', tag="depthwise_conv2d_NCHWc")
return Output
Expand Down
6 changes: 4 additions & 2 deletions topi/tests/python/test_topi_depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ def _transform_kernel(kernel, bn):
out_channel = channel * channel_multiplier
kernel = np.reshape(kernel, (out_channel//bn, bn, kh, kw))
kernel = np.transpose(kernel, (0, 2, 3, 1))
return kernel
out_channel_chunk, kh, kw, out_channel_block = kernel.shape
return kernel.reshape(out_channel_chunk, 1, kh, kw, 1, out_channel_block)

def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1):
in_width = in_height
Expand Down Expand Up @@ -246,7 +247,7 @@ def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_m

# placeholder
Input = tvm.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='Input')
Filter = tvm.placeholder((out_channel//oc_block, filter_height, filter_width, oc_block), name='Filter')
Filter = tvm.placeholder((out_channel//oc_block, 1, filter_height, filter_width, 1, oc_block), name='Filter')
in_layout = "NCHW%dc" % ic_block
out_layout = "NCHW%dc" % oc_block
dtype = 'float32'
Expand Down Expand Up @@ -297,6 +298,7 @@ def get_ref_data():

input_tvm = tvm.nd.array(input_np, ctx)
filter_tvm = tvm.nd.array(filter_np, ctx)

depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),
dtype=DepthwiseConv2d.dtype), ctx)
relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
Expand Down

0 comments on commit 692b4c0

Please sign in to comment.