Skip to content

Commit

Permalink
Add TupleGetItem to CSE (apache#5931)
Browse files Browse the repository at this point in the history
* Add TupleGetItem to CSE

* rename a local variable
  • Loading branch information
Matthew Brookhart authored and Trevor Morris committed Jul 10, 2020
1 parent c269afb commit 1da674e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 13 deletions.
51 changes: 38 additions & 13 deletions src/relay/transforms/eliminate_common_subexpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,27 +58,52 @@ class CommonSubexprEliminator : public ExprMutator {

auto it = expr_map_.find(new_call->op);
if (it != expr_map_.end()) {
for (const CallNode* candidate : it->second) {
bool is_equivalent = true;
if (!attrs_equal(new_call->attrs, candidate->attrs)) {
continue;
for (const Expr& candidate_expr : it->second) {
if (const CallNode* candidate = candidate_expr.as<CallNode>()) {
bool is_equivalent = true;
if (!attrs_equal(new_call->attrs, candidate->attrs)) {
continue;
}
for (size_t i = 0; i < new_call->args.size(); i++) {
if (!new_call->args[i].same_as(candidate->args[i]) &&
!IsEqualScalar(new_call->args[i], candidate->args[i])) {
is_equivalent = false;
break;
}
}
if (!is_equivalent) continue;
return GetRef<Call>(candidate);
}
for (size_t i = 0; i < new_call->args.size(); i++) {
if (!new_call->args[i].same_as(candidate->args[i]) &&
!IsEqualScalar(new_call->args[i], candidate->args[i])) {
is_equivalent = false;
break;
}
}
expr_map_[new_call->op].push_back(new_expr);
return new_expr;
}

Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr new_expr = ExprMutator::VisitExpr_(op);
const TupleGetItemNode* new_tuple_item = new_expr.as<TupleGetItemNode>();
CHECK(new_tuple_item);

if (fskip_ != nullptr && fskip_(new_expr)) {
return new_expr;
}

auto it = expr_map_.find(new_tuple_item->tuple);
if (it != expr_map_.end()) {
for (const Expr& candidate_expr : it->second) {
if (const TupleGetItemNode* candidate = candidate_expr.as<TupleGetItemNode>()) {
if (new_tuple_item->index == candidate->index) {
return GetRef<Expr>(candidate);
}
}
if (!is_equivalent) continue;
return GetRef<Call>(candidate);
}
}
expr_map_[new_call->op].push_back(new_call);
expr_map_[new_tuple_item->tuple].push_back(new_expr);
return new_expr;
}

std::unordered_map<Expr, std::vector<const CallNode*>, ObjectPtrHash, ObjectPtrEqual> expr_map_;
std::unordered_map<Expr, std::vector<Expr>, ObjectPtrHash, ObjectPtrEqual> expr_map_;
runtime::TypedPackedFunc<bool(Expr)> fskip_;
};

Expand Down
29 changes: 29 additions & 0 deletions tests/python/relay/test_pass_eliminate_common_subexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,35 @@ def fskip(expr):
z = run_opt_pass(z, transform.EliminateCommonSubexpr(fskip))
assert tvm.ir.structural_equal(z, expected())

def test_tuple_get_time():
def before():
x = relay.var('x', shape=(1, 16, 1, 1))
var = relay.var('var', shape=(16,))
mean = relay.var('mean', shape=(16,))
beta = relay.var('beta', shape=(16,))
gamma = relay.var('gamma', shape=(16,))
BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)
T1 = BN[0]
T2 = BN[0]
add = T1 + T2
f = relay.Function([x, var, mean, beta, gamma], add)
return f

def expected():
x = relay.var('x', shape=(1, 16, 1, 1))
var = relay.var('var', shape=(16,))
mean = relay.var('mean', shape=(16,))
beta = relay.var('beta', shape=(16,))
gamma = relay.var('gamma', shape=(16,))
BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)
T1 = BN[0]
add = T1 + T1
f = relay.Function([x, var, mean, beta, gamma], add)
return run_opt_pass(f, transform.InferType())

z = before()
z = run_opt_pass(z, transform.EliminateCommonSubexpr())
assert tvm.ir.structural_equal(z, expected())

if __name__ == "__main__":
test_simple()
Expand Down

0 comments on commit 1da674e

Please sign in to comment.