Skip to content

Commit

Permalink
[RELAY][PASS] Common subexpression elimination (apache#2639)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored and wweic committed Mar 12, 2019
1 parent b095d70 commit fb2cd3f
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 0 deletions.
20 changes: 20 additions & 0 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,3 +564,23 @@ def get_total_mac_number(expr):
The number of MACs (multiply-accumulate) of a model
"""
return _ir_pass.GetTotalMacNumber(expr)


def eliminate_common_subexpr(expr, fskip=None):
"""
Eliminate common subexpressions.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
fskip: function
The callback function that decides whether an expression should be skipped.
Returns
-------
expr : tvm.relay.Expr
The output expression.
"""
return _ir_pass.eliminate_common_subexpr(expr, fskip)
72 changes: 72 additions & 0 deletions src/relay/pass/eliminate_common_subexpr.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*!
* Copyright (c) 2019 by Contributors
*
* \file eliminate_common_subexpr.cc
* \brief Combine common subexpressions.
*
* This is an optimization pass that eliminates common subexpressions. During the pass, it tries
* to replace an expression with a previously appeared expression with the same input and
* attributes. The fskip callback argument allows us to skip specific expressions.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <unordered_map>
#include "./pattern_util.h"

namespace tvm {
namespace relay {

class CommonSubexprEliminator : public ExprMutator {
public:
explicit CommonSubexprEliminator(runtime::TypedPackedFunc<bool(Expr)> fskip): fskip_(fskip) {}

Expr VisitExpr_(const CallNode* call) final {
static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
Expr new_expr = ExprMutator::VisitExpr_(call);
const CallNode* new_call = new_expr.as<CallNode>();
CHECK(new_call);
const OpNode* op = new_call->op.as<OpNode>();
AttrsEqual attrs_equal;

if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {
return new_expr;
}
if (fskip_ != nullptr && fskip_(new_expr)) {
return new_expr;
}

auto it = expr_map_.find(new_call->op);
if (it != expr_map_.end()) {
for (const CallNode* candidate : it->second) {
bool is_equivalent = true;
if (!attrs_equal(new_call->attrs, candidate->attrs)) {
continue;
}
for (size_t i = 0; i < new_call->args.size(); i++) {
if (!new_call->args[i].same_as(candidate->args[i]) &&
!IsEqualScalar(new_call->args[i], candidate->args[i])) {
is_equivalent = false;
break;
}
}
if (!is_equivalent) continue;
return GetRef<Call>(candidate);
}
}
expr_map_[new_call->op].push_back(new_call);
return new_expr;
}

std::unordered_map<Expr, std::vector<const CallNode*>, NodeHash, NodeEqual> expr_map_;
runtime::TypedPackedFunc<bool(Expr)> fskip_;
};

Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) {
return CommonSubexprEliminator(callback)(expr);
}

TVM_REGISTER_API("relay._ir_pass.eliminate_common_subexpr")
.set_body_typed<Expr(Expr, PackedFunc)>(EliminateCommonSubexpr);

} // namespace relay
} // namespace tvm
15 changes: 15 additions & 0 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,21 @@ inline Constant MakeConstantScalar(DataType dtype, T value) {
return ConstantNode::make(arr);
}

/*!
* \brief Check if two expressions are equal scalars.
* \param a The expression to be checked.
* \param b The expression to be checked
* \return Whether two expressions are equal scalars.
*/
inline bool IsEqualScalar(const Expr& a, const Expr& b) {
const auto* constant_a = a.as<ConstantNode>();
const auto* constant_b = b.as<ConstantNode>();
if (!constant_a || !constant_b || !constant_a->is_scalar() || !constant_b->is_scalar()) {
return false;
}
return AlphaEqual(a, b);
}

inline Expr GetField(Expr t, size_t i) {
return TupleGetItemNode::make(t, i);
}
Expand Down
63 changes: 63 additions & 0 deletions tests/python/relay/test_pass_eliminate_common_subexpr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Test eliminate common subexpr pass"""
from tvm import relay
from tvm.relay.op import register_alter_op_layout
from tvm.relay import ir_pass


def test_simple():
def before():
x = relay.var("x", shape=(1, 16))
y1 = relay.nn.relu(x)
y2 = relay.nn.relu(x)
y1 = relay.add(y1, relay.const(1.0, "float32"))
y2 = relay.add(y2, relay.const(1.0, "float32"))
y = relay.add(y1, y2)
f = relay.Function([x], y)
return f

def expected():
x = relay.var("x", shape=(1, 16))
y = relay.nn.relu(x)
y = relay.add(y, relay.const(1.0, "float32"))
y = relay.add(y, y)
f = relay.Function([x], y)
return f

z = before()
z = ir_pass.eliminate_common_subexpr(z)
assert ir_pass.alpha_equal(z, expected())


def test_callback():
def before():
x = relay.var("x", shape=(1, 16))
y1 = relay.nn.relu(x)
y2 = relay.nn.relu(x)
y1 = relay.add(y1, relay.const(1.0, "float32"))
y2 = relay.add(y2, relay.const(1.0, "float32"))
y = relay.add(y1, y2)
f = relay.Function([x], y)
return f

def expected():
x = relay.var("x", shape=(1, 16))
y = relay.nn.relu(x)
y1 = relay.add(y, relay.const(1.0, "float32"))
y2 = relay.add(y, relay.const(1.0, "float32"))
y = relay.add(y1, y2)
f = relay.Function([x], y)
return f

def fskip(expr):
if isinstance(expr, relay.expr.Call) and expr.op.name == 'add':
return True
return False

z = before()
z = ir_pass.eliminate_common_subexpr(z, fskip)
assert ir_pass.alpha_equal(z, expected())


if __name__ == "__main__":
test_simple()
test_callback()

0 comments on commit fb2cd3f

Please sign in to comment.