From f9e866525dffc5d52c9a6374615ccd3168aed9e1 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Fri, 11 Jan 2019 20:37:03 -0800 Subject: [PATCH] revert change --- 3rdparty/HalideIR | 2 +- 3rdparty/dmlc-core | 2 +- include/tvm/relay/pass.h | 3 +-- python/tvm/relay/backend/interpreter.py | 2 +- python/tvm/relay/build_module.py | 2 +- python/tvm/relay/ir_pass.py | 7 ++----- src/relay/pass/fold_constant.cc | 2 +- src/relay/pass/fuse_ops.cc | 28 +++---------------------- 8 files changed, 11 insertions(+), 37 deletions(-) diff --git a/3rdparty/HalideIR b/3rdparty/HalideIR index 6e7c1f046fda5..a08e26e5a97f4 160000 --- a/3rdparty/HalideIR +++ b/3rdparty/HalideIR @@ -1 +1 @@ -Subproject commit 6e7c1f046fda536562dc80977e93324fee2324bd +Subproject commit a08e26e5a97f4ef4d566a42f6c78704b3f9c7b8a diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index d07fb7a443b5d..519d013a213c0 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit d07fb7a443b5db8a89d65a15a024af6a425615a5 +Subproject commit 519d013a213c0c447a971f51219473ef564d2348 diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index ad27a69bf0eed..a48db14daba9a 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -154,11 +154,10 @@ Expr FoldConstant(const Expr& expr); /*! * \brief Fuse operations into expr into seperate functions. * \param expr The expression. - * \param mod The global module. * \param fuse_opt_level Optimization level. * \return The optimized expression. */ -Expr FuseOps(const Expr& expr, const Module& mod, int fuse_opt_level); +Expr FuseOps(const Expr& expr, int fuse_opt_level); /*! * \brief Apply rewrite rules to rewrite the expr in post DFS order. diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 0d5bf2582d465..88e4f89a3dfce 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -218,7 +218,7 @@ def optimize(self, expr): """ # TODO: We need to move this optimization code into the optimizer/pass manager ck_expr = ir_pass.infer_type(expr, mod=self.mod) - fused_expr = ir_pass.fuse_ops(ck_expr, mod=self.mod) + fused_expr = ir_pass.fuse_ops(ck_expr) ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod) return ck_fused diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index c58dfa8f3128e..c18367a070c33 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -234,7 +234,7 @@ def build(func, func = optimize(func, target, params) # Fuse ops before running code gen func = ir_pass.infer_type(func) - func = ir_pass.fuse_ops(func, mod=None, opt_level=cfg.opt_level) + func = ir_pass.fuse_ops(func, cfg.opt_level) # Graph code generation func = ir_pass.infer_type(func) graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 77d430e4ab38e..7c0b9138f1bce 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -305,7 +305,7 @@ def fold_constant(expr): return _ir_pass.FoldConstant(expr) -def fuse_ops(expr, mod=None, opt_level=1): +def fuse_ops(expr, opt_level=1): """Fuse operators in expr together. Parameters @@ -313,9 +313,6 @@ def fuse_ops(expr, mod=None, opt_level=1): expr : tvm.relay.Expr The input expression. - mod : Optional[tvm.relay.Module] - The global module. - opt_level : int The level of fuse optimization. @@ -324,7 +321,7 @@ def fuse_ops(expr, mod=None, opt_level=1): transformed_expr : tvm.relay.Expr Transformed expression, containing fused result. """ - return _ir_pass.FuseOps(expr, mod, opt_level) + return _ir_pass.FuseOps(expr, opt_level) def combine_parallel_conv2d(expr): diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index b319c5108d1ec..60994cdd6ca92 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -128,7 +128,7 @@ class ConstantFolder : public ExprMutator { // Constant evaluate a expression. Expr ConstEvaluate(Expr expr) { expr = InferType(expr, Module(nullptr)); - expr = FuseOps(expr, Module(nullptr), 0); + expr = FuseOps(expr, 0); expr = InferType(expr, Module(nullptr)); return ValueToExpr(executor_(expr)); } diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 3ad61cfd81c79..b2b35c51a1ca2 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -650,8 +650,6 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) { return std::move(groups_); } -Expr FuseOps(const Expr& expr, const Module& m, int fuse_opt_level, std::set* gv); - class FuseMutator : private ExprMutator { public: // Run the transform @@ -669,13 +667,8 @@ class FuseMutator : private ExprMutator { return this->Mutate(body); } - FuseMutator(const Module& mod, int fuse_opt_level, std::set* visited) : - mod_(mod), fuse_opt_level_(fuse_opt_level), visited_(visited) { } private: - Module mod_; - int fuse_opt_level_; - std::set* visited_; /*! \brief Temporary information from each group. */ struct GroupInfo { public: @@ -758,16 +751,6 @@ class FuseMutator : private ExprMutator { return new_tuple; } - Expr VisitExpr_(const GlobalVarNode* node) { - GlobalVar gv = GetRef(node); - if (visited_->count(gv) == 0) { - visited_->insert(gv); - mod_->Update(gv, - Downcast(FuseOps(mod_->Lookup(gv), mod_, fuse_opt_level_, visited_))); - } - return gv; - } - Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) { const GroupInfo& ginfo = ginfo_[group]; auto func = FunctionNode::make(ginfo.params, body, ret_type, {}); @@ -807,22 +790,17 @@ class FuseMutator : private ExprMutator { }; -Expr FuseOps(const Expr& expr, const Module& m, int fuse_opt_level) { - std::set gv; - return FuseOps(expr, m, fuse_opt_level, &gv); -} - -Expr FuseOps(const Expr& expr, const Module& m, int fuse_opt_level, std::set* gv) { +Expr FuseOps(const Expr& expr, int fuse_opt_level) { // First we convert all chains of fusable ops into // abstracted functions which we mark as primtive // then we convert these primtive functions into // new operators. - return FuseMutator(m, fuse_opt_level, gv).Transform(expr, fuse_opt_level); + return FuseMutator().Transform(expr, fuse_opt_level); } TVM_REGISTER_API("relay._ir_pass.FuseOps") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = FuseOps(args[0], args[1], args[2]); + *ret = FuseOps(args[0], args[1]); }); } // namespace relay } // namespace tvm