From 294c66750721f044f49a3df120b385429ca8d6dd Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 21 Feb 2019 11:08:47 +0800 Subject: [PATCH] [RELAY][PASS] Common subexpression elimination --- python/tvm/relay/ir_pass.py | 21 ++++++ src/relay/pass/eliminate_common_subexpr.cc | 71 +++++++++++++++++++ src/relay/pass/pattern_util.h | 15 ++++ .../test_pass_eliminate_common_subexpr.py | 63 ++++++++++++++++ 4 files changed, 170 insertions(+) create mode 100644 src/relay/pass/eliminate_common_subexpr.cc create mode 100644 tests/python/relay/test_pass_eliminate_common_subexpr.py diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 561c5d3887884..7e9b550bde827 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -534,6 +534,7 @@ def gradient(expr, mod=None): """ return _ir_pass.first_order_gradient(expr, mod) + def get_total_mac_number(expr): """ Count the number of MACs (multiply-accumulate) of a model @@ -549,3 +550,23 @@ def get_total_mac_number(expr): The number of MACs (multiply-accumulate) of a model """ return _ir_pass.GetTotalMacNumber(expr) + + +def eliminate_common_subexpr(expr, fskip=None): + """ + Eliminate common subexpressions. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + fskip: function + The callback function that decides whether an expression should be skipped. + + Returns + ------- + expr : tvm.relay.Expr + The output expression. + """ + return _ir_pass.eliminate_common_subexpr(expr, fskip) diff --git a/src/relay/pass/eliminate_common_subexpr.cc b/src/relay/pass/eliminate_common_subexpr.cc new file mode 100644 index 0000000000000..914627dc4aaf8 --- /dev/null +++ b/src/relay/pass/eliminate_common_subexpr.cc @@ -0,0 +1,71 @@ +/*! + * Copyright (c) 2019 by Contributors + * + * \file eliminate_common_subexpr.cc + * \brief Combine common subexpressions. + * + * This is an optimization pass that eliminates common subexpressions. During the pass, it tries + * to replace an expression with a previously appeared expression with the same input and + * attributes. The fskip callback argument allows us to skip specific expressions. + */ +#include +#include +#include +#include "./pattern_util.h" + +namespace tvm { +namespace relay { + +class CommonSubexprEliminator : public ExprMutator { + public: + explicit CommonSubexprEliminator(runtime::TypedPackedFunc fskip): fskip_(fskip) {} + + Expr VisitExpr_(const CallNode* call) final { + static auto op_stateful = Op::GetAttr("TOpIsStateful"); + Expr new_expr = ExprMutator::VisitExpr_(call); + const CallNode* new_call = new_expr.as(); + CHECK(new_call); + const OpNode* op = new_call->op.as(); + AttrsEqual attrs_equal; + + if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef(op), false)) { + return new_expr; + } + if (fskip_ != nullptr && fskip_(new_expr)) { + return new_expr; + } + + auto it = expr_map_.find(new_call->args[0]); + if (it != expr_map_.end()) { + for (const CallNode* candidate : it->second) { + bool is_equivalent = true; + if (!new_call->op.same_as(candidate->op)) continue; + for (size_t i = 0; i < new_call->args.size(); i++) { + if (!new_call->args[i].same_as(candidate->args[i]) && + !IsEqualScalar(new_call->args[i], candidate->args[i]) && + !attrs_equal(new_call->attrs, candidate->attrs)) { + is_equivalent = false; + break; + } + } + if (!is_equivalent) continue; + return GetRef(candidate); + } + } + expr_map_[new_call->args[0]].push_back(new_call); + return new_expr; + } + + std::unordered_map, NodeHash, NodeEqual> expr_map_; + runtime::TypedPackedFunc fskip_; +}; + +Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) { + return CommonSubexprEliminator(callback)(expr); +} + +TVM_REGISTER_API("relay._ir_pass.eliminate_common_subexpr") +.set_body_typed(EliminateCommonSubexpr); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 08fc017f41eb3..8433fc1508054 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -192,6 +192,21 @@ inline Constant MakeConstantScalar(DataType dtype, T value) { return ConstantNode::make(arr); } +/*! + * \brief Check if two expressions are equal scalars. + * \param a The expression to be checked. + * \param b The expression to be checked + * \return Whether two expressions are equal scalars. + */ +inline bool IsEqualScalar(const Expr& a, const Expr& b) { + const auto* constant_a = a.as(); + const auto* constant_b = b.as(); + if (!constant_a || !constant_b || !constant_a->is_scalar() || !constant_b->is_scalar()) { + return false; + } + return AlphaEqual(a, b); +} + inline Expr GetField(Expr t, size_t i) { return TupleGetItemNode::make(t, i); } diff --git a/tests/python/relay/test_pass_eliminate_common_subexpr.py b/tests/python/relay/test_pass_eliminate_common_subexpr.py new file mode 100644 index 0000000000000..381a54a3d3245 --- /dev/null +++ b/tests/python/relay/test_pass_eliminate_common_subexpr.py @@ -0,0 +1,63 @@ +"""Test eliminate common subexpr pass""" +from tvm import relay +from tvm.relay.op import register_alter_op_layout +from tvm.relay import ir_pass + + +def test_simple(): + def before(): + x = relay.var("x", shape=(1, 16)) + y1 = relay.nn.relu(x) + y2 = relay.nn.relu(x) + y1 = relay.add(y1, relay.const(1.0, "float32")) + y2 = relay.add(y2, relay.const(1.0, "float32")) + y = relay.add(y1, y2) + f = relay.Function([x], y) + return f + + def expected(): + x = relay.var("x", shape=(1, 16)) + y = relay.nn.relu(x) + y = relay.add(y, relay.const(1.0, "float32")) + y = relay.add(y, y) + f = relay.Function([x], y) + return f + + z = before() + z = ir_pass.eliminate_common_subexpr(z) + assert ir_pass.alpha_equal(z, expected()) + + +def test_callback(): + def before(): + x = relay.var("x", shape=(1, 16)) + y1 = relay.nn.relu(x) + y2 = relay.nn.relu(x) + y1 = relay.add(y1, relay.const(1.0, "float32")) + y2 = relay.add(y2, relay.const(1.0, "float32")) + y = relay.add(y1, y2) + f = relay.Function([x], y) + return f + + def expected(): + x = relay.var("x", shape=(1, 16)) + y = relay.nn.relu(x) + y1 = relay.add(y, relay.const(1.0, "float32")) + y2 = relay.add(y, relay.const(1.0, "float32")) + y = relay.add(y1, y2) + f = relay.Function([x], y) + return f + + def fskip(expr): + if isinstance(expr, relay.expr.Call) and expr.op.name == 'add': + return True + return False + + z = before() + z = ir_pass.eliminate_common_subexpr(z, fskip) + assert ir_pass.alpha_equal(z, expected()) + + +if __name__ == "__main__": + test_simple() + test_callback()