From a04af16649d0c2e995300ea730133fe61f6dcf12 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Tue, 2 Jul 2019 19:19:14 -0700 Subject: [PATCH 1/2] do fix some error fix remove cout retrigger build --- include/tvm/relay/module.h | 2 +- src/relay/ir/expr_functor.cc | 26 ++++- src/relay/pass/partial_eval.cc | 114 ++++++++++--------- src/relay/pass/util.cc | 10 ++ tests/python/relay/test_pass_partial_eval.py | 4 +- 5 files changed, 92 insertions(+), 64 deletions(-) diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 638f75968fd3..4a3ff0b6eb19 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -55,7 +55,7 @@ struct Module; * The functional style allows users to construct custom * environments easily, for example each thread can store * a Module while auto-tuning. - * */ + */ class ModuleNode : public RelayNode { public: diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 36692c5c571b..0434e2ac59c6 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * \file src/tvm/relay/expr_mutator.cc * \brief A wrapper around ExprFunctor which functionally updates the AST. * @@ -26,6 +26,7 @@ * the cost of using functional updates. */ #include +#include #include "type_functor.h" namespace tvm { @@ -353,7 +354,7 @@ TVM_REGISTER_API("relay._analysis.post_order_visit") }); // Implement bind. -class ExprBinder : public ExprMutator { +class ExprBinder : public ExprMutator, PatternMutator { public: explicit ExprBinder(const tvm::Map& args_map) : args_map_(args_map) { @@ -383,13 +384,26 @@ class ExprBinder : public ExprMutator { } } + Pattern VisitPattern(const Pattern& p) final { + return PatternMutator::VisitPattern(p); + } + + Clause VisitClause(const Clause& c) final { + Pattern pat = VisitPattern(c->lhs); + return ClauseNode::make(pat, VisitExpr(c->rhs)); + } + + Var VisitVar(const Var& v) final { + return Downcast(VisitExpr(v)); + } + private: const tvm::Map& args_map_; }; Expr Bind(const Expr& expr, const tvm::Map& args_map) { if (const FunctionNode* func = expr.as()) { - Expr new_body = ExprBinder(args_map).Mutate(func->body); + Expr new_body = ExprBinder(args_map).VisitExpr(func->body); Array new_params; for (Var param : func->params) { if (!args_map.count(param)) { @@ -406,7 +420,7 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { func->type_params, func->attrs); } else { - return ExprBinder(args_map).Mutate(expr); + return ExprBinder(args_map).VisitExpr(expr); } } diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 6887c7a60322..4efa3844111c 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -64,7 +64,7 @@ * 3: The generated code reuses bindings (although they are not shadowed), * so we have to deduplicate them. * - * 4: In the generated code, multiple VarNode might have same Id. + * 4: In the generated code, as it call TypeSubst, multiple VarNode might have same Id. * While it is permitted, most pass use NodeHash for Var, * and having multiple VarNode for same Id break them. * Thus we remap them to a single Id for now. @@ -216,9 +216,9 @@ Static MkSRef() { } using Func = std::function&, - const Attrs&, - const Array&, - LetList*)>; + const Attrs&, + const Array&, + LetList*)>; struct SFuncNode : StaticNode { Func func; @@ -256,6 +256,7 @@ class Environment { void Insert(const Var& v, const PStatic& ps) { CHECK(ps.defined()); + CHECK_EQ(env_.back().locals.count(v), 0); env_.back().locals[v] = ps; } @@ -287,12 +288,17 @@ class Environment { /*! * \brief As our store require rollback, we implement it as a frame. - * every time we need to copy the store, a new frame is insert. - * every time we roll back, a frame is popped. + * + * Every time we need to copy the store, a new frame is insert. + * Every time we roll back, a frame is popped. */ struct StoreFrame { std::unordered_map store; - /*! \brief on unknown effect, history_valid is set to true to signal above frame is outdated */ + /*! + * \brief On unknown effect, history_valid is set to true to signal above frame is outdated. + * + * It only outdate the frame above it, but not the current frame. + */ bool history_valid = true; explicit StoreFrame(const std::unordered_map& store) : store(store) { } StoreFrame() = default; @@ -310,6 +316,7 @@ class Store { } void Insert(const SRefNode* r, const PStatic& ps) { + CHECK(r); store_.back().store[r] = ps; } @@ -317,19 +324,21 @@ class Store { PStatic Lookup(const SRefNode* r) { auto rit = store_.rbegin(); while (rit != store_.rend()) { - if (!rit->history_valid) { - return PStatic(); - } if (rit->store.find(r) != rit->store.end()) { return rit->store.find(r)->second; } + if (!rit->history_valid) { + return PStatic(); + } ++rit; } return PStatic(); } void Invalidate() { - store_.back().history_valid = false; + StoreFrame sf; + sf.history_valid = false; + store_.push_back(sf); } private: @@ -341,6 +350,10 @@ class Store { store_->store_.push_back(StoreFrame()); } ~StoreFrameContext() { + // push one history valid frame off. + while (!store_->store_.back().history_valid) { + store_->store_.pop_back(); + } store_->store_.pop_back(); } }; @@ -442,13 +455,7 @@ Function AsFunc(const Expr& e) { class PartialEvaluator : public ExprFunctor, public PatternFunctor { public: - PartialEvaluator(const tvm::Array& free_vars, - const Module& mod) : - mod_(mod) { - for (const Var& v : free_vars) { - env_.Insert(v, NoStatic(v)); - } - } + PartialEvaluator(const Module& mod) : mod_(mod) { } PStatic VisitExpr(const Expr& e, LetList* ll) final { PStatic ret = ExprFunctor::VisitExpr(e, ll); @@ -484,23 +491,23 @@ class PartialEvaluator : public ExprFunctor return env_.Lookup(GetRef(op)); } - PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final { - GlobalVar gv = GetRef(op); + PStatic VisitGlobalVar(const GlobalVar& gv) { + CHECK(mod_.defined()); if (gv_map_.count(gv) == 0) { - if (mod_.defined()) { - Function func = mod_->Lookup(gv); - InitializeFuncId(func); - Func f = VisitFuncStatic(func, gv); - gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); - func = AsFunc(PostProcess(VisitFuncDynamic(func, f))); - mod_->Update(gv, func); - } else { - gv_map_.insert({gv, NoStatic(gv)}); - } + Function func = mod_->Lookup(gv); + InitializeFuncId(func); + Func f = VisitFuncStatic(func, gv); + gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); + func = AsFunc(PostProcess(VisitFuncDynamic(func, f))); + mod_->Update(gv, func); } return gv_map_.at(gv); } + PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final { + return VisitGlobalVar(GetRef(op)); + } + PStatic VisitExpr_(const LetNode* op, LetList* ll) final { env_.Insert(op->var, VisitExpr(op->value, ll)); return VisitExpr(op->body, ll); @@ -672,22 +679,22 @@ class PartialEvaluator : public ExprFunctor }; } - Expr VisitFuncDynamic(const Function& func, const Func& f) { return store_.Extend([&]() { - store_.Invalidate(); - return FunctionNode::make(func->params, LetList::With([&](LetList* ll) { - std::vector pv; - for (const auto& v : func->params) { - pv.push_back(NoStatic(v)); - } - tvm::Array type_args; - for (const auto& tp : func->type_params) { - type_args.push_back(tp); - } - return f(pv, Attrs(), type_args, ll)->dynamic; - }), func->ret_type, func->type_params, func->attrs); - }); + store_.Invalidate(); + return FunctionNode::make(func->params, + LetList::With([&](LetList* ll) { + std::vector pv; + for (const auto& v : func->params) { + pv.push_back(NoStatic(v)); + } + tvm::Array type_args; + for (const auto& tp : func->type_params) { + type_args.push_back(tp); + } + return f(pv, Attrs(), type_args, ll)->dynamic; + }), func->ret_type, func->type_params, func->attrs); + }); } PStatic VisitFunc(const Function& func, LetList* ll) { @@ -1012,17 +1019,14 @@ Expr PostProcess(const Expr& e) { Module PartialEval(const Module& m) { CHECK(m->entry_func.defined()); - auto func = m->Lookup(m->entry_func); - Expr ret = - TransformF([&](const Expr& e) { - return LetList::With([&](LetList* ll) { - relay::partial_eval::PartialEvaluator pe(FreeVars(e), m); - pe.InitializeFuncId(e); - return relay::partial_eval::PostProcess(pe.VisitExpr(e, ll)->dynamic); - }); - }, func); - CHECK(ret->is_type()); - m->Update(m->entry_func, Downcast(ret)); + relay::partial_eval::PartialEvaluator pe(m); + std::vector gvs; + for (const auto& p : m->functions) { + gvs.push_back(p.first); + } + for (const auto& gv : gvs) { + pe.VisitGlobalVar(gv); + } return m; } diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 2497197ffbe5..e2b71570bd2f 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -425,6 +425,16 @@ Expr TypeSubst(const Expr& expr, const tvm::Map& subst_map) { Var VisitVar(const Var& v) final { return Downcast(VisitExpr(v)); } + + Pattern VisitPattern(const Pattern& p) final { + return PatternMutator::VisitPattern(p); + } + + Clause VisitClause(const Clause& c) final { + Pattern pat = VisitPattern(c->lhs); + return ClauseNode::make(pat, VisitExpr(c->rhs)); + } + private: const tvm::Map& subst_map_; }; diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index 6a7f59c91daa..8855b089ba8c 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -307,10 +307,10 @@ def test_double(): if __name__ == '__main__': - test_empty_ad() + test_ref() test_tuple() + test_empty_ad() test_const_inline() - test_ref() test_ad() test_if_ref() test_function_invalidate() From 004caf509f6cb312ed47d2bbfac4b0b8abec4304 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 3 Jul 2019 02:01:42 -0700 Subject: [PATCH 2/2] fix it! --- src/relay/ir/type_functor.cc | 6 +++++- src/relay/ir/type_functor.h | 1 + src/relay/pass/let_list.h | 2 +- src/relay/pass/partial_eval.cc | 2 +- src/relay/pass/type_infer.cc | 1 + 5 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index 516f4c875b20..cde68c50daef 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -92,6 +92,10 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) { } } +Type TypeMutator::VisitType(const Type& t) { + return t.defined() ? TypeFunctor::VisitType(t) : t; +} + // Type Mutator. Array TypeMutator::MutateArray(Array arr) { // The array will do copy on write @@ -221,7 +225,7 @@ class TypeBinder : public TypeMutator { }; Type Bind(const Type& type, const tvm::Map& args_map) { - return type.defined() ? TypeBinder(args_map).VisitType(type) : type; + return TypeBinder(args_map).VisitType(type); } } // namespace relay diff --git a/src/relay/ir/type_functor.h b/src/relay/ir/type_functor.h index 27ac288fe48d..c3ee14eedd48 100644 --- a/src/relay/ir/type_functor.h +++ b/src/relay/ir/type_functor.h @@ -139,6 +139,7 @@ class TypeVisitor : public TypeFunctor { // Mutator that transform a type to another one. class TypeMutator : public TypeFunctor { public: + Type VisitType(const Type& t) override; Type VisitType_(const TypeVarNode* op) override; Type VisitType_(const TensorTypeNode* op) override; Type VisitType_(const IncompleteTypeNode* op) override; diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h index 1b422d2a878f..73c5fe3abc22 100644 --- a/src/relay/pass/let_list.h +++ b/src/relay/pass/let_list.h @@ -48,7 +48,7 @@ class LetList { public: ~LetList() { if (lets_.size() > 0 && !used_) { - std::cout << "Warning: letlist not used" << std::endl; + LOG(WARNING) << "letlist not used"; } } /*! diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 4efa3844111c..b7f12b65751d 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -636,7 +636,7 @@ class PartialEvaluator : public ExprFunctor subst.Set(func->type_params[i], type_args[i]); } for (size_t i = type_args.size(); i < func->type_params.size(); ++i) { - subst.Set(func->type_params[i], Type()); + subst.Set(func->type_params[i], IncompleteTypeNode::make(kType)); } std::vector