Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][OP] Add reverse_reshape #2503

Merged
merged 4 commits into from
Feb 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.layout_transform
tvm.relay.device_copy
tvm.relay.annotation.on_device
tvm.relay.reverse_reshape


Level 1 Definitions
Expand Down Expand Up @@ -257,4 +258,5 @@ Level 10 Definitions
.. autofunction:: tvm.relay.slice_like
.. autofunction:: tvm.relay.layout_transform
.. autofunction:: tvm.relay.device_copy
.. autofunction:: tvm.relay.annotation.on_device
.. autofunction:: tvm.relay.annotation.on_device
.. autofunction:: tvm.relay.reverse_reshape
4 changes: 4 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,13 @@ struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {
/*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
Array<Integer> newshape;
bool reverse;
TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
TVM_ATTR_FIELD(newshape)
.describe("The new shape. Should be compatible with the original shape.");
TVM_ATTR_FIELD(reverse)
.describe("Infer the special values from right to left if true")
.set_default(false);
}
}; // struct ReshapeAttrs

Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/frontend/nnvm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ def impl(inputs, _, _dtype='float32'):


def _reshape(inputs, attrs):
if attrs.get_bool("reverse", False):
raise RuntimeError("reshape do not support option reverse")
shape = attrs.get_int_tuple("shape")
reverse = attrs.get_bool("reverse", False)
if reverse:
return _op.reverse_reshape(inputs[0], newshape=shape)
return _op.reshape(inputs[0], newshape=shape)


Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_reg.register_schedule("take", schedule_injective)
_reg.register_schedule("transpose", schedule_injective)
_reg.register_schedule("where", schedule_broadcast)
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective)

# layout_transform
_reg.register_schedule("layout_transform", schedule_injective)
Expand Down
33 changes: 32 additions & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def reshape(data, newshape):
Example::
- data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape =(1,2,3,4)
- data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape = (1,2,3,4)
- data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4)
Parameters
Expand Down Expand Up @@ -449,3 +449,34 @@ def layout_transform(data, src_layout, dst_layout):
The transformed tensor.
"""
return _make.layout_transform(data, src_layout, dst_layout)


def reverse_reshape(data, newshape):
"""Reshapes the input array where the special values are inferred from
right to left.
Example::
The special values have the same semantics as :py:class:`tvm.relay.reshape`.
The difference is that special values are inferred from right to left. It
can be explained in the example below::
- data.shape = (10,5,4), newshape = (-1,0), reshape results in (40,5)
- data.shape = (10,5,4), newshape = (-1,0), reverse_reshape results in (40,5)
Parameters
----------
data : relay.Expr
The input data to the operator.
newshape : Union[int, Tuple[int], List[int]]
The new shape. Should be compatible with the original shape.
Returns
-------
result : relay.Expr
The reshaped result.
"""
if isinstance(newshape, int):
newshape = [newshape]
return _make._contrib_reverse_reshape(data, list(newshape))
95 changes: 76 additions & 19 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,20 +382,29 @@ bool ReshapeRel(const Array<Type>& types,
}

const auto* param = attrs.as<ReshapeAttrs>();
Array<IndexExpr> data_shape;
Array<Integer> newshape;
if (param->reverse) {
data_shape.assign(data->shape.rbegin(), data->shape.rend());
newshape.assign(param->newshape.rbegin(), param->newshape.rend());
} else {
data_shape = data->shape;
newshape = param->newshape;
}
Array<IndexExpr> oshape;
size_t src_idx = 0;
int infer_idx = -1;

for (size_t i = 0; i < param->newshape.size(); ++i) {
int svalue = param->newshape[i]->value;
for (size_t i = 0; i < newshape.size(); ++i) {
int svalue = newshape[i]->value;
// special flag handling for shape inference.
if (svalue > 0) {
oshape.push_back(param->newshape[i]);
oshape.push_back(newshape[i]);
++src_idx;
} else if (svalue == 0) {
// keep same
CHECK_LT(src_idx, data->shape.size());
oshape.push_back(data->shape[src_idx++]);
CHECK_LT(src_idx, data_shape.size());
oshape.push_back(data_shape[src_idx++]);
} else if (svalue == -1) {
// inference based on rest
CHECK_LT(infer_idx, 0)
Expand All @@ -405,42 +414,51 @@ bool ReshapeRel(const Array<Type>& types,
++src_idx;
} else if (svalue == -2) {
// copy all remaining dims from source
while (src_idx < data->shape.size()) {
oshape.push_back(data->shape[src_idx++]);
while (src_idx < data_shape.size()) {
oshape.push_back(data_shape[src_idx++]);
}
} else if (svalue == -3) {
// merge two dims from source
CHECK_LT(src_idx + 1, data->shape.size());
IndexExpr d1 = data->shape[src_idx++];
IndexExpr d2 = data->shape[src_idx++];
CHECK_LT(src_idx + 1, data_shape.size());
IndexExpr d1 = data_shape[src_idx++];
IndexExpr d2 = data_shape[src_idx++];
oshape.push_back(d1 * d2);
} else if (svalue == -4) {
// split the source dim s into two dims
// read the left dim and then the right dim (either can be -1)
CHECK_LT(i + 2, param->newshape.size());
CHECK_LT(src_idx, data->shape.size());
IndexExpr d0 = data->shape[src_idx++];
Integer d1 = param->newshape[++i];
Integer d2 = param->newshape[++i];
CHECK_LT(i + 2, newshape.size());
CHECK_LT(src_idx, data_shape.size());
IndexExpr d0 = data_shape[src_idx++];
Integer d1 = newshape[++i];
Integer d2 = newshape[++i];
if (d1->value == -1) {
CHECK(d2->value != -1)
<< "Split dims cannot both be -1.";
oshape.push_back(d0 / d2);
oshape.push_back(d2);
} else {
CHECK_EQ(d2->value, -1);
oshape.push_back(d1);
oshape.push_back(d0 / d1);
if (d2->value == -1) {
oshape.push_back(d0 / d1);
} else {
oshape.push_back(d2);
}
}
}
}

if (infer_idx >= 0) {
IndexExpr new_size = arith::ComputeReduce<tvm::ir::Mul>(oshape, 1);
IndexExpr old_size = arith::ComputeReduce<tvm::ir::Mul>(data->shape, 1);
IndexExpr old_size = arith::ComputeReduce<tvm::ir::Mul>(data_shape, 1);
oshape.Set(infer_idx, old_size / new_size);
}
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));

if (param->reverse) {
reporter->Assign(types[1], TensorTypeNode::make(
Array<IndexExpr>(oshape.rbegin(), oshape.rend()), data->dtype));
} else {
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
}
return true;
}

Expand All @@ -457,6 +475,7 @@ Expr MakeReshape(Expr data,
Array<Integer> newshape) {
auto attrs = make_node<ReshapeAttrs>();
attrs->newshape = std::move(newshape);
attrs->reverse = false;
static const Op& op = Op::Get("reshape");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
Expand Down Expand Up @@ -1699,5 +1718,43 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
.set_support_level(5)
.set_attr<FTVMCompute>("FTVMCompute", LayoutTransformCompute);


/* relay._contrib_reverse_reshape */
Expr MakeReverseReshape(Expr data,
Array<Integer> newshape) {
auto attrs = make_node<ReshapeAttrs>();
attrs->newshape = std::move(newshape);
attrs->reverse = true;
static const Op& op = Op::Get("_contrib_reverse_reshape");
return CallNode::make(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op._make._contrib_reverse_reshape")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeReverseReshape, args, rv);
});

RELAY_REGISTER_OP("_contrib_reverse_reshape")
.describe(R"code(Reshapes the input array where the special values are inferred from
right to left.
Example::
The special values have the same semantics as reshape. The difference is that
special values are inferred from right to left. It can be explained in the
example below::
- data.shape = (10,5,4), newshape = (-1,0), reshape results in (40,5)
- data.shape = (10,5,4), newshape = (-1,0), reverse_reshape results in (40,5)
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.ReshapeAttrs")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(10)
.add_type_rel("Reshape", ReshapeRel)
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

} // namespace relay
} // namespace tvm
23 changes: 23 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,31 @@ def test_slice_like():
axes=(2, 3),
output=(1, 3, 112, 112))

def test_reverse_reshape():
def verify_reverse_reshape(shape, newshape, oshape):
x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.reverse_reshape(x, newshape=newshape)
zz = relay.ir_pass.infer_type(z)
print(zz.checked_type)
assert "newshape=" in z.astext()
assert zz.checked_type == relay.ty.TensorType(oshape, "float32")

func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
ref_res = np.reshape(x_data, oshape)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_reverse_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2))
verify_reverse_reshape((2, 3, 4), (2, 0, 0), (2, 3, 4))
verify_reverse_reshape((2, 3, 4), (0, -1), (3, 8))
verify_reverse_reshape((2, 3, 4), (-1, 0), (6, 4))
verify_reverse_reshape((2, 3, 4), (0, -3), (2, 12))

if __name__ == "__main__":
test_collapse_sum_like()
test_broadcast_to_like()
test_slice_like()
test_reverse_reshape()
27 changes: 19 additions & 8 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,25 +152,36 @@ def test_reshape_infer_type():
(n, t, 2000), "float32")

def test_reshape():
def verify_reshape(shape, oshape):
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
ref_res = np.reshape(x_data, oshape)

def verify_reshape(shape, newshape, oshape):
x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.reshape(x, newshape=ref_res.shape)
z = relay.reshape(x, newshape=newshape)
zz = relay.ir_pass.infer_type(z)
assert "newshape=" in z.astext()
assert zz.checked_type == relay.ty.TensorType(oshape, "float32")

func = relay.Function([x], z)

x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
ref_res = np.reshape(x_data, oshape)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_reshape((2, 3, 4), (8, 3))
verify_reshape((4, 7), (2, 7, 2))
verify_reshape((2, 3, 4), (8, 3), (8, 3))
verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
verify_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2))
verify_reshape((2, 3, 4), (2, 0, 0), (2, 3, 4))
verify_reshape((2, 3, 4), (0, -1), (2, 12))
verify_reshape((2, 3, 4), (-1, 0), (8, 3))
verify_reshape((2, 3, 4), (2, -2), (2, 3, 4))
verify_reshape((2, 3, 4), (-2, 1, 1), (2, 3, 4, 1, 1))
verify_reshape((2, 3, 4), (-3, 4), (6, 4))
verify_reshape((2, 3, 4, 5), (-3, -3), (6, 20))
verify_reshape((2, 3, 4), (0, -3), (2, 12))
verify_reshape((2, 3, 4), (-3, -2), (6, 4))
verify_reshape((2, 3, 4), (-4, 1, 2, -2), (1, 2, 3, 4))
verify_reshape((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4))


def test_reshape_like_infer_type():
# concrete shape
Expand Down