Skip to content

Commit

Permalink
remove dynamic behavior from standard reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart committed Jun 26, 2020
1 parent aaf6959 commit a9f25aa
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 65 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {

/*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
Optional<Array<Integer>> newshape;
Array<Integer> newshape;
bool reverse;
TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
TVM_ATTR_FIELD(newshape).describe(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def convert(self, v):
def __call__(self, args, attrs, type_args):
if attrs is None:
attrs = {}
if self.operator in (op.reshape, op.strided_slice):
if self.operator in (op.strided_slice):
x = self.operator(*args)
elif self.operator in (op.zeros, op.ones, op.full, op.broadcast_to):
x = self.operator(*args, dtype=attrs["dtype"])
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,12 @@ def _impl(inputs, attr, params, mod):
return _op.reshape_like(inputs[0], pop_node.args[0])
shape_arg = pop_node

if isinstance(shape_arg, _expr.Expr):
return AttrCvt(
op_name="dyn.reshape",
extras={'newshape': shape_arg},
ignores=['Tshape'])(inputs, attr)

return AttrCvt(
op_name="reshape",
extras={'newshape': shape_arg},
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def dense_grad(orig, grad):
@register_gradient("reshape")
def reshape_grad(orig, grad):
"""Gradient of reshape"""
return [reshape_like(grad, orig.args[0]), orig.args[1]]
return [reshape_like(grad, orig.args[0])]


@register_gradient("cast")
Expand Down
17 changes: 3 additions & 14 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def reshape(data, newshape):
data : relay.Expr
The input data to the operator.
newshape : Union[int, Tuple[int], List[int]] or relay.Expr
newshape : Union[int, Tuple[int], List[int]]
The new shape. Should be compatible with the original shape.
Returns
Expand All @@ -211,19 +211,8 @@ def reshape(data, newshape):
The reshaped result.
"""
if isinstance(newshape, int):
newshape = const([newshape])
if isinstance(newshape, (tuple, list)):
tempshape = []
for shape in newshape:
if isinstance(shape, _expr.IntImm):
tempshape.append(shape.value)
else:
try:
tempshape.append(int(shape))
except ValueError as err:
raise RuntimeError('Unrecognized shape type: %s' % err)
newshape = const(tempshape)
return _make.reshape(data, newshape)
newshape = [newshape]
return _make.reshape(data, list(newshape))

def argwhere(condition):
"""Find the indices of elements of a tensor that are
Expand Down
9 changes: 2 additions & 7 deletions src/relay/analysis/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,13 +448,8 @@ bool IsDataDependant(const CallNode* call) {
return false;
}

if (op->name == "reshape") {
if (const auto* attrs = call->attrs.as<ReshapeAttrs>()) {
if (attrs->newshape) {
// If newshape attribute exists, it isn't data dependant.
return false;
}
}
if (op->name == "dyn.reshape") {
return true;
} else if (op->name == "topk") {
if (const auto* attrs = call->attrs.as<TopKAttrs>()) {
if (attrs->k) {
Expand Down
46 changes: 12 additions & 34 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -449,13 +449,8 @@ TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
const auto* param = attrs.as<ReshapeAttrs>();
if (param->reverse) {
// types: [data, result]
CHECK_EQ(types.size(), 2);
} else {
// types: [data, newshape, result]
CHECK_EQ(types.size(), 3);
}
// types: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
Expand All @@ -467,25 +462,12 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Array<IndexExpr> data_shape;
Array<Integer> newshape;

if (param->newshape) {
auto temp = param->newshape.value();
if (param->reverse) {
data_shape.Assign(data->shape.rbegin(), data->shape.rend());
newshape.Assign(temp.rbegin(), temp.rend());
} else {
data_shape = data->shape;
newshape = temp;
}
if (param->reverse) {
data_shape.Assign(data->shape.rbegin(), data->shape.rend());
newshape.Assign(param->newshape.rbegin(), param->newshape.rend());
} else {
const auto* newshape = types[1].as<TensorTypeNode>();

// Doesn't support dynamic output rank
for (int i = 0; i < newshape->shape[0].as<IntImmNode>()->value; i++) {
oshape.push_back(Any());
}

reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
data_shape = data->shape;
newshape = param->newshape;
}

std::unordered_set<size_t> used_input_dims;
Expand Down Expand Up @@ -600,7 +582,7 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
reporter->Assign(types[1],
TensorType(Array<IndexExpr>(oshape.rbegin(), oshape.rend()), data->dtype));
} else {
reporter->Assign(types[2], TensorType(oshape, data->dtype));
reporter->Assign(types[1], TensorType(oshape, data->dtype));
}
return true;
}
Expand All @@ -620,15 +602,12 @@ Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const Array<te::Tensor>& in
return {topi::reshape(inputs[0], newshape)};
}

Expr MakeReshape(Expr data, Expr newshape) {
Expr MakeReshape(Expr data, Array<Integer> newshape) {
auto attrs = make_object<ReshapeAttrs>();
if (const ConstantNode* c = newshape.as<ConstantNode>()) {
CHECK_EQ(c->data->ndim, 1);
attrs->newshape = ToVector(c->data);
}
attrs->newshape = std::move(newshape);
attrs->reverse = false;
static const Op& op = Op::Get("reshape");
return Call(op, {data, newshape}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.reshape").set_body_typed(MakeReshape);
Expand Down Expand Up @@ -684,10 +663,9 @@ Example::
- data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4)
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.set_num_inputs(1)
.set_attrs_type<ReshapeAttrs>()
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("newshape", "Tensor", "The shape of output tensor.")
.set_support_level(3)
.add_type_rel("Reshape", ReshapeRel)
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/tensor/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
namespace tvm {
namespace relay {

extern Expr MakeReshape(Expr data, Expr newshape);
extern Expr MakeReshape(Expr data, Array<Integer> newshape);

template <typename AttrType>
bool ConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,7 @@ static Expr ReshapeToMatchAxis(Expr scale, const Array<PrimExpr>& shape,
arr.push_back(1);
}
}
return MakeReshape(
scale, MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(arr.size())}, arr));
return MakeReshape(scale, std::move(arr));
}

// if only one axis, use expand dim. Else, use reshape
Expand Down
6 changes: 2 additions & 4 deletions src/relay/transforms/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -624,12 +624,10 @@ static inline Expr Sum(Expr data, Array<Integer> axis, bool keepdims, bool exclu
return Call(op, {data}, Attrs(attrs), {});
}

Expr MakeReshape(Expr data, Expr newshape);
Expr MakeReshape(Expr data, Array<Integer> newshape);

static inline Expr Reshape(Expr data, Array<Integer> newshape) {
auto newshape_tensor =
MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(newshape.size())}, newshape);
return MakeReshape(data, newshape_tensor);
return MakeReshape(data, newshape);
}

static inline Expr AvgPool2D(Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
Expand Down

0 comments on commit a9f25aa

Please sign in to comment.