diff --git a/src/relay/transforms/eliminate_common_subexpr.cc b/src/relay/transforms/eliminate_common_subexpr.cc index 8f7375c9dd35..dc3f77e3180c 100644 --- a/src/relay/transforms/eliminate_common_subexpr.cc +++ b/src/relay/transforms/eliminate_common_subexpr.cc @@ -58,27 +58,52 @@ class CommonSubexprEliminator : public ExprMutator { auto it = expr_map_.find(new_call->op); if (it != expr_map_.end()) { - for (const CallNode* candidate : it->second) { - bool is_equivalent = true; - if (!attrs_equal(new_call->attrs, candidate->attrs)) { - continue; + for (const Expr& candidate_expr : it->second) { + if (const CallNode* candidate = candidate_expr.as()) { + bool is_equivalent = true; + if (!attrs_equal(new_call->attrs, candidate->attrs)) { + continue; + } + for (size_t i = 0; i < new_call->args.size(); i++) { + if (!new_call->args[i].same_as(candidate->args[i]) && + !IsEqualScalar(new_call->args[i], candidate->args[i])) { + is_equivalent = false; + break; + } + } + if (!is_equivalent) continue; + return GetRef(candidate); } - for (size_t i = 0; i < new_call->args.size(); i++) { - if (!new_call->args[i].same_as(candidate->args[i]) && - !IsEqualScalar(new_call->args[i], candidate->args[i])) { - is_equivalent = false; - break; + } + } + expr_map_[new_call->op].push_back(new_expr); + return new_expr; + } + + Expr VisitExpr_(const TupleGetItemNode* op) final { + Expr new_expr = ExprMutator::VisitExpr_(op); + const TupleGetItemNode* new_tuple_item = new_expr.as(); + CHECK(new_tuple_item); + + if (fskip_ != nullptr && fskip_(new_expr)) { + return new_expr; + } + + auto it = expr_map_.find(new_tuple_item->tuple); + if (it != expr_map_.end()) { + for (const Expr& candidate_expr : it->second) { + if (const TupleGetItemNode* candidate = candidate_expr.as()) { + if (new_tuple_item->index == candidate->index) { + return GetRef(candidate); } } - if (!is_equivalent) continue; - return GetRef(candidate); } } - expr_map_[new_call->op].push_back(new_call); + expr_map_[new_tuple_item->tuple].push_back(new_expr); return new_expr; } - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> expr_map_; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> expr_map_; runtime::TypedPackedFunc fskip_; }; diff --git a/tests/python/relay/test_pass_eliminate_common_subexpr.py b/tests/python/relay/test_pass_eliminate_common_subexpr.py index 7af524d3ae01..45d21a472cd1 100644 --- a/tests/python/relay/test_pass_eliminate_common_subexpr.py +++ b/tests/python/relay/test_pass_eliminate_common_subexpr.py @@ -84,6 +84,35 @@ def fskip(expr): z = run_opt_pass(z, transform.EliminateCommonSubexpr(fskip)) assert tvm.ir.structural_equal(z, expected()) +def test_tuple_get_time(): + def before(): + x = relay.var('x', shape=(1, 16, 1, 1)) + var = relay.var('var', shape=(16,)) + mean = relay.var('mean', shape=(16,)) + beta = relay.var('beta', shape=(16,)) + gamma = relay.var('gamma', shape=(16,)) + BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5) + T1 = BN[0] + T2 = BN[0] + add = T1 + T2 + f = relay.Function([x, var, mean, beta, gamma], add) + return f + + def expected(): + x = relay.var('x', shape=(1, 16, 1, 1)) + var = relay.var('var', shape=(16,)) + mean = relay.var('mean', shape=(16,)) + beta = relay.var('beta', shape=(16,)) + gamma = relay.var('gamma', shape=(16,)) + BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5) + T1 = BN[0] + add = T1 + T1 + f = relay.Function([x, var, mean, beta, gamma], add) + return run_opt_pass(f, transform.InferType()) + + z = before() + z = run_opt_pass(z, transform.EliminateCommonSubexpr()) + assert tvm.ir.structural_equal(z, expected()) if __name__ == "__main__": test_simple()