Skip to content

Commit

Permalink
revert change
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame committed Jan 12, 2019
1 parent bf57ea5 commit f9e8665
Show file tree
Hide file tree
Showing 8 changed files with 11 additions and 37 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/HalideIR
2 changes: 1 addition & 1 deletion 3rdparty/dmlc-core
3 changes: 1 addition & 2 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,10 @@ Expr FoldConstant(const Expr& expr);
/*!
* \brief Fuse operations into expr into seperate functions.
* \param expr The expression.
* \param mod The global module.
* \param fuse_opt_level Optimization level.
* \return The optimized expression.
*/
Expr FuseOps(const Expr& expr, const Module& mod, int fuse_opt_level);
Expr FuseOps(const Expr& expr, int fuse_opt_level);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def optimize(self, expr):
"""
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, mod=self.mod)
fused_expr = ir_pass.fuse_ops(ck_expr, mod=self.mod)
fused_expr = ir_pass.fuse_ops(ck_expr)
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
return ck_fused

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def build(func,
func = optimize(func, target, params)
# Fuse ops before running code gen
func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func, mod=None, opt_level=cfg.opt_level)
func = ir_pass.fuse_ops(func, cfg.opt_level)
# Graph code generation
func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
Expand Down
7 changes: 2 additions & 5 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,17 +305,14 @@ def fold_constant(expr):
return _ir_pass.FoldConstant(expr)


def fuse_ops(expr, mod=None, opt_level=1):
def fuse_ops(expr, opt_level=1):
"""Fuse operators in expr together.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The global module.
opt_level : int
The level of fuse optimization.
Expand All @@ -324,7 +321,7 @@ def fuse_ops(expr, mod=None, opt_level=1):
transformed_expr : tvm.relay.Expr
Transformed expression, containing fused result.
"""
return _ir_pass.FuseOps(expr, mod, opt_level)
return _ir_pass.FuseOps(expr, opt_level)


def combine_parallel_conv2d(expr):
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class ConstantFolder : public ExprMutator {
// Constant evaluate a expression.
Expr ConstEvaluate(Expr expr) {
expr = InferType(expr, Module(nullptr));
expr = FuseOps(expr, Module(nullptr), 0);
expr = FuseOps(expr, 0);
expr = InferType(expr, Module(nullptr));
return ValueToExpr(executor_(expr));
}
Expand Down
28 changes: 3 additions & 25 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -650,8 +650,6 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) {
return std::move(groups_);
}

Expr FuseOps(const Expr& expr, const Module& m, int fuse_opt_level, std::set<GlobalVar>* gv);

class FuseMutator : private ExprMutator {
public:
// Run the transform
Expand All @@ -669,13 +667,8 @@ class FuseMutator : private ExprMutator {
return this->Mutate(body);
}

FuseMutator(const Module& mod, int fuse_opt_level, std::set<GlobalVar>* visited) :
mod_(mod), fuse_opt_level_(fuse_opt_level), visited_(visited) { }

private:
Module mod_;
int fuse_opt_level_;
std::set<GlobalVar>* visited_;
/*! \brief Temporary information from each group. */
struct GroupInfo {
public:
Expand Down Expand Up @@ -758,16 +751,6 @@ class FuseMutator : private ExprMutator {
return new_tuple;
}

Expr VisitExpr_(const GlobalVarNode* node) {
GlobalVar gv = GetRef<GlobalVar>(node);
if (visited_->count(gv) == 0) {
visited_->insert(gv);
mod_->Update(gv,
Downcast<Function>(FuseOps(mod_->Lookup(gv), mod_, fuse_opt_level_, visited_)));
}
return gv;
}

Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
const GroupInfo& ginfo = ginfo_[group];
auto func = FunctionNode::make(ginfo.params, body, ret_type, {});
Expand Down Expand Up @@ -807,22 +790,17 @@ class FuseMutator : private ExprMutator {
};


Expr FuseOps(const Expr& expr, const Module& m, int fuse_opt_level) {
std::set<GlobalVar> gv;
return FuseOps(expr, m, fuse_opt_level, &gv);
}

Expr FuseOps(const Expr& expr, const Module& m, int fuse_opt_level, std::set<GlobalVar>* gv) {
Expr FuseOps(const Expr& expr, int fuse_opt_level) {
// First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// new operators.
return FuseMutator(m, fuse_opt_level, gv).Transform(expr, fuse_opt_level);
return FuseMutator().Transform(expr, fuse_opt_level);
}

TVM_REGISTER_API("relay._ir_pass.FuseOps")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FuseOps(args[0], args[1], args[2]);
*ret = FuseOps(args[0], args[1]);
});
} // namespace relay
} // namespace tvm

0 comments on commit f9e8665

Please sign in to comment.