Skip to content

Commit

Permalink
[AlterLayout] NCHWc upsampling, fix depthwise conv (apache#2806)
Browse files Browse the repository at this point in the history
* [AlterLayout] NCHW upsampling

* [Relay][Pass] Fix Depthwise AlterLayout
  • Loading branch information
antinucleon authored and wweic committed Mar 24, 2019
1 parent 47f3be0 commit 6cdf90c
Show file tree
Hide file tree
Showing 13 changed files with 299 additions and 26 deletions.
25 changes: 25 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,28 @@ def schedule_contrib_conv2d_NCHWc(attrs, outs, target):

reg.register_pattern("nn.contrib_conv2d_NCHWc",
OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc")
def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target):
"""Compute definition of depthwise conv2d NCHWc"""
# pylint: disable=assignment-from-no-return
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
data_layout = attrs.get_str("data_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.get_str("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype

out = topi.nn.depthwise_conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation,
data_layout, out_layout, out_dtype)
return [out]

@reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc")
def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of contrib_conv2d_NCHWc"""
with target:
return topi.generic.schedule_depthwise_conv2d_NCHWc(outs)

reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc",
OpPattern.OUT_ELEMWISE_FUSABLE)
64 changes: 64 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,70 @@ def contrib_conv2d_nchwc(data,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)

def contrib_depthwise_conv2d_nchwc(data,
kernel,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW8c",
kernel_layout="OIHW",
out_layout="",
out_dtype=""):
r"""Variant of 2D depthwise convolution.
This operator takes the weight as the depthwise convolution kernel
and depthwise convolves it with data to produce an output, following a specialized
NCHWc data layout.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
kernel : tvm.relay.Expr
The kernel expressions.
strides : tuple of int, optional
The strides of convoltution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_depthwise_conv2d_NCHWc(data, kernel, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)

def contrib_conv2d_winograd_weight_transform(weight,
tile_size):
Expand Down
52 changes: 52 additions & 0 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,5 +582,57 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
Conv2DInferCorrectLayout<Conv2DAttrs>);


// Positional relay function to create depthwise conv2d NCHWc operator
// used by frontend FFI.
Expr MakeDepthwiseConv2DNCHWc(Expr data,
Expr kernel,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_node<Conv2DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_depthwise_conv2d_NCHWc");
return CallNode::make(op, {data, kernel}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 12>(MakeDepthwiseConv2DNCHWc, args, rv);
});


RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
.describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout.
- **data**: Input is 5D packed tensor.
- **weight**: 6D packed tensor.
- **out**: Output is 5D packed tensor
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.DepthwiseConv2D")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2D", Conv2DRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DAttrs>);


} // namespace relay
} // namespace tvm
35 changes: 32 additions & 3 deletions src/relay/op/nn/upsampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,31 @@ namespace relay {

TVM_REGISTER_NODE_TYPE(UpSamplingAttrs);

template <typename T>
Array<Array<Layout> > UpsamplingInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
// NOTE: Discard "const" qualifier here.
T *params = const_cast<T*>(attrs.as<T>());

if (new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 1);

Layout raw_layout(params->layout);
Layout input = new_in_layouts[0];
if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) {
params->layout = input.name(); // modify self to follow the input layout
}
}

Layout inferred_layout(params->layout);
return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
}

bool UpSamplingRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
Expand Down Expand Up @@ -91,6 +116,8 @@ RELAY_REGISTER_OP("nn.upsampling")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("UpSampling", UpSamplingRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
UpsamplingInferCorrectLayout<UpSamplingAttrs>)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs,
Expand All @@ -101,14 +128,16 @@ RELAY_REGISTER_OP("nn.upsampling")
CHECK(uattrs != nullptr);
auto out_tt = out_type.as<TensorTypeNode>();
CHECK(out_tt) << "expected a tensor type: " << out_type;
CHECK(uattrs->layout == "NCHW" || uattrs->layout == "NHWC")
const auto layout = uattrs->layout;
const auto base_layout = layout.substr(0, 4);
CHECK(base_layout == "NCHW" || layout == "NHWC")
<< "unknown layout: " << uattrs->layout;

Array<HalideIR::Expr> oshape;
if (uattrs->layout == "NCHW") {
if (base_layout == "NCHW") {
oshape.push_back(out_tt->shape[2]);
oshape.push_back(out_tt->shape[3]);
} else if (uattrs->layout == "NHWC") {
} else if (layout == "NHWC") {
oshape.push_back(out_tt->shape[1]);
oshape.push_back(out_tt->shape[2]);
}
Expand Down
2 changes: 1 addition & 1 deletion tests/lint/pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ single-line-if-stmt=no
no-space-check=trailing-comma,dict-separator

# Maximum number of lines in a module
max-module-lines=1000
max-module-lines=1500

# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
Expand Down
46 changes: 46 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,51 @@ def expected():

assert(alpha_equal(a, b))


def test_alter_layout_nchw_upsamping_op():
"""Test upsamping operators """
def before():
x = relay.var("x", shape=(1, 32, 28, 28))
weight = relay.var('weight', shape=(32, 32, 3, 3))
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
y = relay.nn.upsampling(y, scale=2)
y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2))
y = relay.Function(free_vars(y), y)
return y

@register_alter_op_layout("nn.conv2d", level=108)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)

def expected():
x = relay.var("x", shape=(1, 32, 28, 28))
weight = relay.var("weight")
x = relay.layout_transform(x, "NCHW", "NCHW16c")
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1),
data_layout="NCHW16c")
y = relay.nn.upsampling(y, scale=2, layout="NCHW16c")
y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2), layout='NCHW16c')
y = relay.layout_transform(y, "NCHW16c", "NCHW")
y = relay.Function(free_vars(y), y)
return y

a = before()
a = infer_type(a)
a = canonicalize_ops(a)
a = infer_type(a)

a = alter_op_layout(a)
a = infer_type(a)

b = expected()
b = infer_type(b)

assert(alpha_equal(a, b))


if __name__ == "__main__":
test_alter_op()
test_alter_return_none()
Expand All @@ -420,3 +465,4 @@ def expected():
test_alter_layout_broadcast_op()
test_alter_layout_scalar()
test_alter_layout_concatenate()
test_alter_layout_nchw_upsamping_op()
3 changes: 3 additions & 0 deletions tests/scripts/task_python_topi.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@ export PYTHONPATH=python:topi/python
make cython || exit -1
make cython3 || exit -1

rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
rm -rf topi/python/topi/*.pyc topi/python/topi/*/*.pyc topi/python/topi/*/*/*.pyc topi/python/topi/*/*/*/*.pyc

python -m nose -v topi/tests/python || exit -1
python3 -m nose -v topi/tests/python || exit -1
49 changes: 47 additions & 2 deletions topi/include/topi/image/resize.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,45 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input,
}, name, tag);
}

/*!
* \brief Resize given tensor to given shape using nearest neighbour for NCHWc
*
* \param input The input tensor.
* \param shape Output shape to resize to.
* \param align_corners To preserve centers of 4 corner pixels
* \param name Name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor resized to given shape
*/
inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input,
const Array<Expr>& shape,
bool align_corners = false,
std::string name = "tensor",
std::string tag = kInjective) {
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
out_shape.push_back(input->shape[4]);

Expr h_ratio = shape[0] / input->shape[2];
Expr w_ratio = shape[1] / input->shape[3];

return compute(
out_shape, [&](const Array<Var>& indices) {
Array<Expr> idx;
idx.push_back(indices[0]);
idx.push_back(indices[1]);
idx.push_back(indices[2] / h_ratio);
idx.push_back(indices[3] / w_ratio);
idx.push_back(indices[4]);

return input(idx);
}, name, tag);
}

/*!
* \brief Resize given tensor to given shape using nearest neighbour
*
Expand All @@ -153,11 +192,17 @@ inline Tensor resize_nearest_neighbor(const Tensor& input,
std::string name = "tensor",
std::string tag = kInjective) {
CHECK_EQ(align_corners, false) << "Align corners not supported for nearest neighbour";

auto base_layout = layout.substr(0, 4);
if (layout == "NHWC") {
return resize_nearest_neighbor_nhwc(input, shape, align_corners);
} else {
} else if (layout == "NCHW") {
return resize_nearest_neighbor_nchw(input, shape, align_corners);
} else if (base_layout == "NCHW") {
// NCHWc
return resize_nearest_neighbor_nchwc(input, shape, align_corners);
} else {
LOG(FATAL) << "Unknown layout: " << layout;
return Tensor();
}
}

Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/nn/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation,
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
Filter : tvm.Tensor
4-D with shape [out_channel_chunk, filter_height, filter_width, out_channel_block]
6-D with shape [out_channel_chunk, 1, filter_height, filter_width, 1, out_channel_block]
In NCHWc depthwise convolution,
we group kernel's in_channel and channel_multiplier together then do the tiling.
Expand All @@ -317,6 +317,6 @@ def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation,
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
raise ValueError("missing register for topi.nn.depthwise_conv2d_NCHWc")
4 changes: 2 additions & 2 deletions topi/python/topi/nn/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'):
4-D with shape [batch, channel, in_height*scale, in_width*scale]
or [batch, in_height*scale, in_width*scale, channel]
"""

if layout == "NCHW":
base_layout = layout[0:4]
if base_layout == "NCHW":
out_shape = (simplify(data.shape[2] * scale), simplify(data.shape[3] * scale))
elif layout == "NHWC":
out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale))
Expand Down
Loading

0 comments on commit 6cdf90c

Please sign in to comment.