Skip to content

Commit

Permalink
Add support for MXNet pad operator.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexgl-github committed Aug 13, 2019
1 parent a78adbd commit b5bee4b
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 8 deletions.
7 changes: 6 additions & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,13 +401,18 @@ struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
double pad_value;
Array<Array<IndexExpr> > pad_width;
std::string pad_mode;

TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") {
TVM_ATTR_FIELD(pad_value).set_default(0.0)
.describe("Specifies the strides of the convolution.");
.describe("The value used for padding when mode is 'constant'.");
TVM_ATTR_FIELD(pad_width)
.describe("Number of values padded to the edges of each axis, "
"in the format of ((before_1, after_1), ..., (before_N, after_N))");
TVM_ATTR_FIELD(pad_mode).set_default("constant")
.describe("Padding type to use. \"constant\" pads with constant_value, "
"\"edge\" pads using the edge values of the input array, "
"\"reflect\" pads by reflecting values with respect to the edges.");
}
};

Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,24 @@ def _mx_expand_dims(inputs, attrs):
axis = attrs.get_int("axis")
return _op.expand_dims(inputs[0], axis=axis)

def _mx_pad(inputs, attrs):
pad_mode = attrs.get_str('mode', None)
if pad_mode is None:
raise tvm.error.OpAttributeRequired(
'Attribute "mode" not found in operator pad.')
if pad_mode not in ['constant', 'edge', 'reflect']:
raise tvm.error.OpAttributeInvalid(
'Value ' + mode + ' in attribute "mode" is not valid')
pad_width = attrs.get_int_tuple('pad_width', None)
if pad_width is None:
raise tvm.error.OpAttributeRequired(
'Attribute "pad_width" not found in operator pad.')
if None in pad_width:
raise tvm.error.OpAttributeInvalid(
'Value None in attribute "pad_width" of operator Slice is not valid.')
constant_value = attrs.get_float('constant_value', 0.0)
padding = tuple(tuple((b, a)) for b, a in zip(pad_width[::2], pad_width[1::2]))
return _op.nn.pad(data=inputs[0], pad_width=padding, pad_value=constant_value, pad_mode=pad_mode)

def _mx_leaky_relu(inputs, attrs):
act_type = attrs.get_str("act_type")
Expand Down Expand Up @@ -1026,6 +1044,8 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
"_full" : _mx_full,
"repeat" : _mx_repeat,
"tile" : _mx_tile,
"pad" : _mx_pad,
"Pad" : _mx_pad,
"take" : _mx_take,
"reverse" : _mx_reverse,
"squeeze" : _mx_squeeze,
Expand Down
10 changes: 7 additions & 3 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,8 @@ def prelu(data, alpha, axis=1):

def pad(data,
pad_width,
pad_value=0.0):
pad_value=0.0,
pad_mode='constant'):
r"""Padding
This operator takes in a tensor and pads each axis by the specified
Expand All @@ -680,13 +681,16 @@ def pad(data,
of ((before_1, after_1), ..., (before_N, after_N))
pad_value: float, optional, default=0.0
The value used for padding
pad_mode: 'constant', 'edge', 'reflect'
'constant' pads with constant_value pad_value
'edge' pads using the edge values of the input array
'reflect' pads by reflecting values with respect to the edge
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.pad(data, pad_width, pad_value)
return _make.pad(data, pad_width, pad_value, pad_mode)


def lrn(data, size=5, axis=1, bias=2, alpha=.00001, beta=0.75):
Expand Down
6 changes: 4 additions & 2 deletions src/relay/op/nn/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,16 @@ Array<Tensor> PadCompute(const Attrs& attrs,
}
const auto* out_ttype = out_type.as<TensorTypeNode>();
return Array<Tensor>{ topi::pad(inputs[0], pad_before, pad_after,
tvm::make_const(out_ttype->dtype, param->pad_value)) };
tvm::make_const(out_ttype->dtype, param->pad_value),
param->pad_mode) };
}

// Handler to create a call to the padding op used by front-end FFI
Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) {
Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value, std::string pad_mode) {
auto attrs = make_node<PadAttrs>();
attrs->pad_value = pad_value;
attrs->pad_width = std::move(pad_width);
attrs->pad_mode = std::move(pad_mode);
static const Op& op = Op::Get("nn.pad");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
Expand Down
17 changes: 17 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,22 @@ def verify(shape, axis=-1):
verify((2, 5), axis=0)
verify((2, 5, 6))


def test_forward_pad():
def verify(data_shape, out_shape, mode, pad_width, constant_value=0.0):
data = mx.sym.var('data')
mx_sym = mx.sym.pad(data, mode=mode, pad_width=pad_width, constant_value=constant_value)
verify_mxnet_frontend_impl(mx_sym, data_shape=data_shape, out_shape=out_shape)

verify(data_shape=(1,1,3,5), out_shape=(1,1,6,12), mode="constant", pad_width=(0,0,0,0,1,2,3,4))
verify(data_shape=(1,1,3,5), out_shape=(1,1,6,12), mode="constant", pad_width=(0,0,0,0,1,2,3,4), constant_value=3.0)
verify(data_shape=(1,1,3,5), out_shape=(1,1,6,12), mode="edge", pad_width=(0,0,0,0,1,2,3,4))
verify(data_shape=(1,1,3,5), out_shape=(1,1,6,12), mode="reflect", pad_width=(0,0,0,0,1,2,3,4))
verify(data_shape=(1,1,3,5,7), out_shape=(1,1,6,12,18), mode="constant", pad_width=(0,0,0,0,1,2,3,4,5,6))
verify(data_shape=(1,1,3,5,7), out_shape=(1,1,6,12,18), mode="constant", pad_width=(0,0,0,0,1,2,3,4,5,6), constant_value=3.0)
verify(data_shape=(1,1,3,5,7), out_shape=(1,1,6,12,18), mode="edge", pad_width=(0,0,0,0,1,2,3,4,5,6))
verify(data_shape=(1,1,3,5,7), out_shape=(1,1,6,12,18), mode="reflect", pad_width=(0,0,0,0,1,2,3,4,5,6))

if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand All @@ -791,6 +807,7 @@ def verify(shape, axis=-1):
test_forward_split()
test_forward_split_squeeze()
test_forward_expand_dims()
test_forward_pad()
test_forward_pooling()
test_forward_adaptive_pooling()
test_forward_lrn()
Expand Down
30 changes: 28 additions & 2 deletions topi/include/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ inline tvm::Tensor prelu(const tvm::Tensor &x,
* \param pad_after An Array of Expr describing the padding after the
* respective iterator
* \param pad_value The value to fill padding elements with
* \param pad_mode Padding type to use.
* "constant" pads with constant_value;
* "edge" pads using the edge values of the input array;
* "reflect" pads by reflecting values with respect to the edges.
* \param name The name of the operation
* \param tag The tag to mark the operation
*
Expand Down Expand Up @@ -172,6 +176,7 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
const tvm::Array<tvm::Expr>& pad_before,
tvm::Array<tvm::Expr> pad_after = tvm::Array<tvm::Expr>(),
Expr pad_value = Expr(),
std::string pad_mode = "constant",
std::string name = "T_pad",
std::string tag = kElementWise) {
if (pad_after.size() < pad_before.size()) {
Expand All @@ -198,6 +203,7 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
auto l = [&](tvm::Array<tvm::Var> ovars) {
tvm::Array<tvm::Expr> indices;
tvm::Array<tvm::Expr> sel;
tvm::Array<tvm::Expr> pad_idx;
for (size_t i = 0; i < t->shape.size(); ++i) {
if (i >= pad_before.size()) {
indices.push_back(ovars[i]);
Expand All @@ -212,10 +218,30 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
if (!topi::detail::EqualCheck(pad_after[i], 0)) {
sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before[i] + t->shape[i]));
}
if (pad_mode == "edge") {
pad_idx.push_back(tvm::if_then_else(
ovars[i] < pad_before[i],
0,
tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i],
t->shape[i] - 1,
ovars[i] - pad_before[i])));
} else if (pad_mode == "reflect") {
pad_idx.push_back(tvm::if_then_else(
ovars[i] < pad_before[i],
pad_before[i] - ovars[i],
tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i],
t->shape[i] * 2 - ovars[i] + pad_before[i] - 2,
ovars[i] - pad_before[i])));
}
}
if (sel.size() != 0) {
return tvm::if_then_else(
detail::Map(sel, tvm::ir::And::make), t(indices), pad_value);
if (pad_mode == "constant") {
return tvm::if_then_else(
detail::Map(sel, tvm::ir::And::make), t(indices), pad_value);
} else if (pad_mode == "edge" || pad_mode == "reflect") {
return tvm::if_then_else(
detail::Map(sel, tvm::ir::And::make), t(indices), t(pad_idx));
}
}
return t(indices);
};
Expand Down

0 comments on commit b5bee4b

Please sign in to comment.