Skip to content

Commit

Permalink
[RELAY] Turn reshape into nop in graph executor backend. (apache#7945)
Browse files Browse the repository at this point in the history
* [RELAY] Turn reshape into nop in graph executor backend.

Previously we are generating the function calls for reshape.
This PR updates the optimization to turn reshape into nop:

- Tag a fused function as reshape only if it only contains reshape.
- Update memory planner to force input output to share the same piece of memory
- Update the graph runtime codegen to emit nop when reshape only function is encountered.

* Address review comments.

* Additional comment and TODOs on the rationale
  • Loading branch information
tqchen authored May 2, 2021
1 parent ce8f52b commit b2c4f1c
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 40 deletions.
3 changes: 3 additions & 0 deletions include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ constexpr const char* kComposite = "Composite";
constexpr const char* kInline = "Inline";
/*! \brief Indicate the function was created by the Pattern Partitioning Pass. */
constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";

/*! \brief Mark the function as only composed of reshape operations. */
constexpr const char* kReshapeOnly = "relay.reshape_only";
} // namespace attr

} // namespace relay
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ using TOpIsStateful = bool;
*/
using TNonComputational = bool;

/*!
* \brief Mark the operator as reshape op of its first input
* and can be turned into a nop when the first input and output
* shares the same piece of memory.
*/
using TReshapeOp = bool;

/*!
* \brief Mark the operator whether output shape is data dependent.
*/
Expand Down
23 changes: 23 additions & 0 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,16 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
return AddNode(node, GetRef<Expr>(op));
}

bool ShareSameStorage(const Expr& lhs, const Expr& rhs) {
auto lit = storage_device_map_.find(lhs);
auto rit = storage_device_map_.find(rhs);
ICHECK(lit != storage_device_map_.end());
ICHECK(rit != storage_device_map_.end());
int64_t lhs_storage_id = ((*lit).second)[0][0]->value;
int64_t rhs_storage_id = ((*rit).second)[0][0]->value;
return lhs_storage_id == rhs_storage_id;
}

std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
Expr expr = GetRef<Expr>(op);
Function func;
Expand Down Expand Up @@ -380,6 +390,19 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name);
}

// In the current flat memory allocation scenario
// the flat memory allocator can always allocate input
// and output of the reshape to the same memory, we can turn reshape only
// function to a nop.
//
// NOTE that for non-flat memory this is not necessarily true.
//
// TODO(tvm-team) Update checks of flat memory enablement when we support
// opaque-nd memory planning to skip this path.
if (func->HasNonzeroAttr(attr::kReshapeOnly) && ShareSameStorage(expr, op->args[0])) {
return GraphAddCallNode(op, "reshape_nop", "__nop");
}

ICHECK_GE(storage_device_map_.count(expr), 0);
auto& device_type = storage_device_map_[expr][1];
auto call_dev_type = device_type[0]->value;
Expand Down
45 changes: 43 additions & 2 deletions src/relay/backend/graph_plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
auto it = prototype_.find(op);
ICHECK(it != prototype_.end());
std::vector<StorageToken*> tokens;

for (StorageToken* tok : it->second) {
if (can_realloc) {
tokens.push_back(Request(tok));
Expand All @@ -250,6 +251,22 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
}
token_map_[op] = tokens;
}
// Mark op to reuse the input_token
// tie the two memories together
void ReuseInputToken(const ExprNode* op, StorageToken* input_token) {
ICHECK(!token_map_.count(op));
auto it = prototype_.find(op);
ICHECK(it != prototype_.end());
ICHECK_EQ(it->second.size(), 1U);
StorageToken* prototype = it->second[0];
// add the reference counter of the output
// so the input token can only be deleted after references
// to both are expired
input_token->ref_counter += prototype->ref_counter;
// reuse the input token
token_map_[op] = {input_token};
}

// The call map
void VisitExpr_(const CallNode* op) final {
std::vector<StorageToken*> args;
Expand All @@ -259,8 +276,21 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
args.push_back(tok);
}
}
// create token for the call node.
CreateToken(op, true);
// Under the flat-memory setting.
// we can force aliasing the input and output of reshape
// to make it an nop. Note that this is not true
// for non-flat memory case. Given the current graph plan memory
// only works for flat memory case, we will go with this choice
//
// TODO(tvm-team) Update checks of flat memory enablement when we support
// opaque-nd memory planning to skip this path.
if (IsReshape(op)) {
ICHECK_EQ(args.size(), 1U);
ReuseInputToken(op, args[0]);
} else {
// create token for the call node.
CreateToken(op, true);
}
// check if there is orphaned output that can be released immediately.
for (StorageToken* tok : token_map_.at(op)) {
CheckForRelease(tok);
Expand All @@ -278,6 +308,17 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
static size_t DivRoundUp(size_t size, size_t word_size) {
return (size + word_size - 1) / word_size;
}
/*!
* \brief The call is an reshape only op
* \param call The call to be checked.
* \return the check result.
*/
static bool IsReshape(const CallNode* call) {
if (const auto* fn = call->op.as<FunctionNode>()) {
return fn->HasNonzeroAttr(attr::kReshapeOnly);
}
return false;
}
/*!
* \brief Get the memory requirement.
* \param prototype The prototype token.
Expand Down
3 changes: 2 additions & 1 deletion src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ RELAY_REGISTER_OP("dyn.reshape")
.set_support_level(3)
.add_type_rel("DynamicReshape", ReshapeRel)
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<TReshapeOp>("TReshapeOp", true);

// tile operator
// TVM_REGISTER_NODE_TYPE(TileAttrs);
Expand Down
12 changes: 8 additions & 4 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ RELAY_REGISTER_OP("expand_dims")
.set_support_level(1)
.add_type_rel("ExpandDims", ExpandDimsRel)
.set_attr<FTVMCompute>("FTVMCompute", ExpandDimsCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
.set_attr<TOpPattern>("TOpPattern", kBroadcast)
.set_attr<TReshapeOp>("TReshapeOp", true);

// relay.concatenate
TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);
Expand Down Expand Up @@ -887,7 +888,8 @@ Example::
.set_support_level(3)
.add_type_rel("Reshape", ReshapeRel)
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<TReshapeOp>("TReshapeOp", true);

/*!
* \brief ReshapeLikeRel User defined type constraint function.
Expand Down Expand Up @@ -2243,7 +2245,8 @@ RELAY_REGISTER_OP("squeeze")
.add_type_rel("Squeeze", SqueezeRel)
.set_attr<FTVMCompute>("FTVMCompute", SqueezeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", SqueezeInferCorrectLayout);
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", SqueezeInferCorrectLayout)
.set_attr<TReshapeOp>("TReshapeOp", true);

// CollapseSumLike: <A, B> -> B where BroadCast(A, B) = A
bool CollapseSumLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down Expand Up @@ -3221,7 +3224,8 @@ example below::
.set_support_level(10)
.add_type_rel("ReverseReshape", ReverseReshapeRel)
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<TReshapeOp>("TReshapeOp", true);

// gather operator
TVM_REGISTER_NODE_TYPE(GatherAttrs);
Expand Down
30 changes: 27 additions & 3 deletions src/relay/transforms/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -948,15 +948,39 @@ class FuseMutator : private MixedModeMutator {
}

Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
// If the function has no call, it is not a primitive function.
struct HasCallVisitor : ExprVisitor {
// Quickly check special properties of the fused function.
// A pass to check if the fused op contains only reshape ops.
class CheckReshapeOnly : public ExprVisitor {
public:
void VisitExpr_(const CallNode* cn) final {
this->has_call = true;
static auto freshape_op = Op::GetAttrMap<TReshapeOp>("TReshapeOp");

if (!freshape_op.get(cn->op, false)) {
this->reshape_only = false;
}

if (!this->reshape_only) return;
ExprVisitor::VisitExpr_(cn);
}

void VisitExpr_(const VarNode* vn) final {
if (!vn->type_annotation.defined() || !vn->type_annotation->IsInstance<TensorTypeNode>()) {
this->reshape_only = false;
}
}

bool reshape_only = true;
bool has_call = false;
void VisitExpr_(const CallNode* op) final { has_call = true; }
} visitor;

visitor(body);
const GroupInfo& ginfo = ginfo_[group];
auto func = Function(ginfo.params, body, ret_type, {});
func = WithAttr(std::move(func), attr::kPrimitive, tvm::Integer(visitor.has_call));
if (visitor.has_call && visitor.reshape_only) {
func = WithAttr(std::move(func), attr::kReshapeOnly, tvm::Integer(visitor.reshape_only));
}
return Call(func, ginfo.arguments, Attrs());
}

Expand Down
34 changes: 4 additions & 30 deletions src/relay/transforms/memory_alloc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,39 +64,13 @@ inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dt
return AllocTensor(storage, offset, shape, dtype, assert_shape);
}

// A pass to check if the fused op contains only reshape ops.
class CheckReshapeOnly : public ExprVisitor {
public:
CheckReshapeOnly()
: reshape_(Op::Get("reshape")),
contr_reshape_(Op::Get("contrib_reverse_reshape")),
dyn_reshape_(Op::Get("dyn.reshape")) {}

void VisitExpr_(const CallNode* cn) final {
if (!reshape_only) return;
if (cn->op != reshape_ && cn->op != contr_reshape_ && cn->op != dyn_reshape_) {
reshape_only = false;
}
for (auto arg : cn->args) ExprVisitor::VisitExpr(arg);
}

void VisitExpr_(const VarNode* vn) final {
if (!vn->checked_type_->IsInstance<TensorTypeNode>()) {
reshape_only = false;
}
}

const Op& reshape_;
const Op& contr_reshape_;
const Op& dyn_reshape_;
bool reshape_only{true};
};

// Check if the primitive function contains only reshape ops.
bool IsReshapeOnly(const Expr& expr) {
auto check = CheckReshapeOnly();
check.VisitExpr(expr);
return check.reshape_only;
if (auto* func = expr.as<FunctionNode>()) {
return func->HasNonzeroAttr(attr::kReshapeOnly);
}
return false;
}

class DialectRewriter : public ExprMutator {
Expand Down
47 changes: 47 additions & 0 deletions tests/python/relay/test_backend_graph_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np

import tvm
import json
from tvm import relay
from tvm.contrib import graph_executor
from tvm.relay.op import add
Expand Down Expand Up @@ -146,6 +147,51 @@ def test_plan_memory():
assert len(device_types) == 1


def test_reshape_nop():
# test that reshape can be turned into nop
x = relay.var("x", shape=(10, 4))
xx = relay.abs(x)
y = relay.expand_dims(xx, axis=1)
t0 = relay.reshape(y, (1, 40))
t1 = relay.abs(y)

z0 = relay.reshape(t0, (2, 20))
z1 = relay.sqrt(t1)
z2 = relay.reshape(t1, (1, 40))

func = relay.Function([x], relay.Tuple([z0, z1, z2]))
x_data = np.random.rand(10, 4).astype("float32")
graph = relay.build(tvm.IRModule.from_expr(func), "llvm")
graph_json_str = graph.get_json()

graph_json = json.loads(graph_json_str)

# reshape must force sharing memory
storage_ids = graph_json["attrs"]["storage_id"][1]
assert tuple(storage_ids) == (0, 1, 1, 2, 3, 2)
assert graph_json["nodes"][2]["attrs"]["func_name"] == "__nop"
assert graph_json["nodes"][5]["attrs"]["func_name"] == "__nop"

gmod = graph_executor.GraphModule(graph["default"](tvm.cpu(0)))

gmod.set_input(x=x_data)
gmod.run()
z0_np = x_data.reshape(2, 20)
z1_np = np.sqrt(
np.abs(
x_data.reshape(
10,
1,
4,
)
)
)
z2_np = np.abs(x_data).reshape(1, 40)
tvm.testing.assert_allclose(gmod.get_output(0).asnumpy(), z0_np)
tvm.testing.assert_allclose(gmod.get_output(1).asnumpy(), z1_np)
tvm.testing.assert_allclose(gmod.get_output(2).asnumpy(), z2_np)


@tvm.testing.uses_gpu
def test_gru_like():
def unit(rnn_dim):
Expand Down Expand Up @@ -231,6 +277,7 @@ def test_graph_executor_nested_tuples():


if __name__ == "__main__":
test_reshape_nop()
test_plan_memory()
test_with_params()
test_add_op_scalar()
Expand Down

0 comments on commit b2c4f1c

Please sign in to comment.