From a9f25aa345e8742edbeb4177d0ad65a34698e826 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 26 Jun 2020 10:42:21 -0700 Subject: [PATCH] remove dynamic behavior from standard reshape --- include/tvm/relay/attrs/transform.h | 2 +- python/tvm/relay/_parser.py | 2 +- python/tvm/relay/frontend/tensorflow.py | 6 ++++ python/tvm/relay/op/_tensor_grad.py | 2 +- python/tvm/relay/op/transform.py | 17 ++------- src/relay/analysis/util.cc | 9 ++--- src/relay/op/tensor/transform.cc | 46 +++++++------------------ src/relay/op/tensor/transform.h | 2 +- src/relay/transforms/fold_scale_axis.cc | 3 +- src/relay/transforms/pattern_util.h | 6 ++-- 10 files changed, 30 insertions(+), 65 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 750a8a43163c3..8af9f6349d418 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -82,7 +82,7 @@ struct TransposeAttrs : public tvm::AttrsNode { /*! \brief Attributes used in reshape operators */ struct ReshapeAttrs : public tvm::AttrsNode { - Optional> newshape; + Array newshape; bool reverse; TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") { TVM_ATTR_FIELD(newshape).describe( diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index eb567658f2a19..09217870b07ad 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -114,7 +114,7 @@ def convert(self, v): def __call__(self, args, attrs, type_args): if attrs is None: attrs = {} - if self.operator in (op.reshape, op.strided_slice): + if self.operator in (op.strided_slice): x = self.operator(*args) elif self.operator in (op.zeros, op.ones, op.full, op.broadcast_to): x = self.operator(*args, dtype=attrs["dtype"]) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 62dadce1d3f61..a77cb736a6123 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1183,6 +1183,12 @@ def _impl(inputs, attr, params, mod): return _op.reshape_like(inputs[0], pop_node.args[0]) shape_arg = pop_node + if isinstance(shape_arg, _expr.Expr): + return AttrCvt( + op_name="dyn.reshape", + extras={'newshape': shape_arg}, + ignores=['Tshape'])(inputs, attr) + return AttrCvt( op_name="reshape", extras={'newshape': shape_arg}, diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 0deb87a60e34d..6d4f8fa60f2c1 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -482,7 +482,7 @@ def dense_grad(orig, grad): @register_gradient("reshape") def reshape_grad(orig, grad): """Gradient of reshape""" - return [reshape_like(grad, orig.args[0]), orig.args[1]] + return [reshape_like(grad, orig.args[0])] @register_gradient("cast") diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index a37226ea4f586..864917cefc95b 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -202,7 +202,7 @@ def reshape(data, newshape): data : relay.Expr The input data to the operator. - newshape : Union[int, Tuple[int], List[int]] or relay.Expr + newshape : Union[int, Tuple[int], List[int]] The new shape. Should be compatible with the original shape. Returns @@ -211,19 +211,8 @@ def reshape(data, newshape): The reshaped result. """ if isinstance(newshape, int): - newshape = const([newshape]) - if isinstance(newshape, (tuple, list)): - tempshape = [] - for shape in newshape: - if isinstance(shape, _expr.IntImm): - tempshape.append(shape.value) - else: - try: - tempshape.append(int(shape)) - except ValueError as err: - raise RuntimeError('Unrecognized shape type: %s' % err) - newshape = const(tempshape) - return _make.reshape(data, newshape) + newshape = [newshape] + return _make.reshape(data, list(newshape)) def argwhere(condition): """Find the indices of elements of a tensor that are diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index f07c14a286d89..f7ef176f5ef6f 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -448,13 +448,8 @@ bool IsDataDependant(const CallNode* call) { return false; } - if (op->name == "reshape") { - if (const auto* attrs = call->attrs.as()) { - if (attrs->newshape) { - // If newshape attribute exists, it isn't data dependant. - return false; - } - } + if (op->name == "dyn.reshape") { + return true; } else if (op->name == "topk") { if (const auto* attrs = call->attrs.as()) { if (attrs->k) { diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index ee5e291e3d532..8dee6d89f1385 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -449,13 +449,8 @@ TVM_REGISTER_NODE_TYPE(ReshapeAttrs); bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { const auto* param = attrs.as(); - if (param->reverse) { - // types: [data, result] - CHECK_EQ(types.size(), 2); - } else { - // types: [data, newshape, result] - CHECK_EQ(types.size(), 3); - } + // types: [data, result] + CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) @@ -467,25 +462,12 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, Array data_shape; Array newshape; - if (param->newshape) { - auto temp = param->newshape.value(); - if (param->reverse) { - data_shape.Assign(data->shape.rbegin(), data->shape.rend()); - newshape.Assign(temp.rbegin(), temp.rend()); - } else { - data_shape = data->shape; - newshape = temp; - } + if (param->reverse) { + data_shape.Assign(data->shape.rbegin(), data->shape.rend()); + newshape.Assign(param->newshape.rbegin(), param->newshape.rend()); } else { - const auto* newshape = types[1].as(); - - // Doesn't support dynamic output rank - for (int i = 0; i < newshape->shape[0].as()->value; i++) { - oshape.push_back(Any()); - } - - reporter->Assign(types[2], TensorType(oshape, data->dtype)); - return true; + data_shape = data->shape; + newshape = param->newshape; } std::unordered_set used_input_dims; @@ -600,7 +582,7 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, reporter->Assign(types[1], TensorType(Array(oshape.rbegin(), oshape.rend()), data->dtype)); } else { - reporter->Assign(types[2], TensorType(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); } return true; } @@ -620,15 +602,12 @@ Array ReshapeCompute(const Attrs& attrs, const Array& in return {topi::reshape(inputs[0], newshape)}; } -Expr MakeReshape(Expr data, Expr newshape) { +Expr MakeReshape(Expr data, Array newshape) { auto attrs = make_object(); - if (const ConstantNode* c = newshape.as()) { - CHECK_EQ(c->data->ndim, 1); - attrs->newshape = ToVector(c->data); - } + attrs->newshape = std::move(newshape); attrs->reverse = false; static const Op& op = Op::Get("reshape"); - return Call(op, {data, newshape}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.reshape").set_body_typed(MakeReshape); @@ -684,10 +663,9 @@ Example:: - data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4) )code" TVM_ADD_FILELINE) - .set_num_inputs(2) + .set_num_inputs(1) .set_attrs_type() .add_argument("data", "Tensor", "The input tensor.") - .add_argument("newshape", "Tensor", "The shape of output tensor.") .set_support_level(3) .add_type_rel("Reshape", ReshapeRel) .set_attr("FTVMCompute", ReshapeCompute) diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 7149417aa9b55..0ff0b3ffa25b0 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -38,7 +38,7 @@ namespace tvm { namespace relay { -extern Expr MakeReshape(Expr data, Expr newshape); +extern Expr MakeReshape(Expr data, Array newshape); template bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs, diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index a3765f3c3befd..0c2abbfdd238e 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -329,8 +329,7 @@ static Expr ReshapeToMatchAxis(Expr scale, const Array& shape, arr.push_back(1); } } - return MakeReshape( - scale, MakeConstantTensor(DataType::Int(32), {static_cast(arr.size())}, arr)); + return MakeReshape(scale, std::move(arr)); } // if only one axis, use expand dim. Else, use reshape diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index 7518eb9ac81a1..27764ecc3624c 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -624,12 +624,10 @@ static inline Expr Sum(Expr data, Array axis, bool keepdims, bool exclu return Call(op, {data}, Attrs(attrs), {}); } -Expr MakeReshape(Expr data, Expr newshape); +Expr MakeReshape(Expr data, Array newshape); static inline Expr Reshape(Expr data, Array newshape) { - auto newshape_tensor = - MakeConstantTensor(DataType::Int(32), {static_cast(newshape.size())}, newshape); - return MakeReshape(data, newshape_tensor); + return MakeReshape(data, newshape); } static inline Expr AvgPool2D(Expr data, Array pool_size, Array strides,