Skip to content

Commit

Permalink
[RELAY][PASS] FoldScaleAxis Backward (apache#2024)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Oct 30, 2018
1 parent 4823d55 commit feca27e
Show file tree
Hide file tree
Showing 6 changed files with 667 additions and 35 deletions.
6 changes: 3 additions & 3 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ class ExprVisitor
void VisitExpr_(const TupleGetItemNode* op) override;
virtual void VisitType(const Type& t);

private:
// internal visited flag.
std::unordered_set<const Node*> visited_;
protected:
// Internal visiting counter
std::unordered_map<const Node*, size_t> visit_counter_;
};

/*!
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,29 @@ def infer_type(expr, env=None):
return _ir_pass.infer_type(expr, env)


def backward_fold_scale_axis(expr):
"""Backward fold axis scaling into weights of conv2d/dense.
Parameters
----------
expr : tvm.relay.Expr
The input expression, we expect that expr's types
should be fully inferred by infer_type.
Returns
-------
folded_expr : tvm.relay.Expr
The folded expression after transformation.
Note
----
It is recommended to call backward_fold_scale_axis
before using forward_fold_scale_axis.
As backward folding targets common conv-bn pattern.
"""
return _ir_pass.backward_fold_scale_axis(expr)


def forward_fold_scale_axis(expr):
"""Fold the scaling of axis into weights of conv2d/dense.
Expand All @@ -44,6 +67,12 @@ def forward_fold_scale_axis(expr):
-------
folded_expr : tvm.relay.Expr
The folded expression after transformation.
Note
----
It is recommended to call backward_fold_scale_axis
before using forward_fold_scale_axis.
As backward folding targets common conv-bn pattern.
"""
return _ir_pass.forward_fold_scale_axis(expr)

Expand Down
12 changes: 8 additions & 4 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,14 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
Type ExprMutator::VisitType(const Type& t) { return t; }

void ExprVisitor::VisitExpr(const Expr& expr) {
if (visited_.count(expr.get())) return;
using TParent = ExprFunctor<void(const Expr&)>;
TParent::VisitExpr(expr);
visited_.insert(expr.get());
auto it = visit_counter_.find(expr.get());
if (it != visit_counter_.end()) {
++it->second;
} else {
using TParent = ExprFunctor<void(const Expr&)>;
TParent::VisitExpr(expr);
visit_counter_.insert({expr.get(), 1});
}
}

void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
Expand Down
Loading

0 comments on commit feca27e

Please sign in to comment.