Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Fix PE #3482

Merged
merged 2 commits into from
Jul 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct Module;
* The functional style allows users to construct custom
* environments easily, for example each thread can store
* a Module while auto-tuning.
* */
*/

class ModuleNode : public RelayNode {
public:
Expand Down
26 changes: 20 additions & 6 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -18,14 +18,15 @@
*/

/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
* \file src/tvm/relay/expr_mutator.cc
* \brief A wrapper around ExprFunctor which functionally updates the AST.
*
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "type_functor.h"

namespace tvm {
Expand Down Expand Up @@ -353,7 +354,7 @@ TVM_REGISTER_API("relay._analysis.post_order_visit")
});

// Implement bind.
class ExprBinder : public ExprMutator {
class ExprBinder : public ExprMutator, PatternMutator {
public:
explicit ExprBinder(const tvm::Map<Var, Expr>& args_map)
: args_map_(args_map) {
Expand Down Expand Up @@ -383,13 +384,26 @@ class ExprBinder : public ExprMutator {
}
}

Pattern VisitPattern(const Pattern& p) final {
return PatternMutator::VisitPattern(p);
}

Clause VisitClause(const Clause& c) final {
Pattern pat = VisitPattern(c->lhs);
return ClauseNode::make(pat, VisitExpr(c->rhs));
}

Var VisitVar(const Var& v) final {
return Downcast<Var>(VisitExpr(v));
}

private:
const tvm::Map<Var, Expr>& args_map_;
};

Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
if (const FunctionNode* func = expr.as<FunctionNode>()) {
Expr new_body = ExprBinder(args_map).Mutate(func->body);
Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
Array<Var> new_params;
for (Var param : func->params) {
if (!args_map.count(param)) {
Expand All @@ -406,7 +420,7 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
func->type_params,
func->attrs);
} else {
return ExprBinder(args_map).Mutate(expr);
return ExprBinder(args_map).VisitExpr(expr);
}
}

Expand Down
6 changes: 5 additions & 1 deletion src/relay/ir/type_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) {
}
}

Type TypeMutator::VisitType(const Type& t) {
return t.defined() ? TypeFunctor<Type(const Type&)>::VisitType(t) : t;
}

// Type Mutator.
Array<Type> TypeMutator::MutateArray(Array<Type> arr) {
// The array will do copy on write
Expand Down Expand Up @@ -221,7 +225,7 @@ class TypeBinder : public TypeMutator {
};

Type Bind(const Type& type, const tvm::Map<TypeVar, Type>& args_map) {
return type.defined() ? TypeBinder(args_map).VisitType(type) : type;
return TypeBinder(args_map).VisitType(type);
}

} // namespace relay
Expand Down
1 change: 1 addition & 0 deletions src/relay/ir/type_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> {
// Mutator that transform a type to another one.
class TypeMutator : public TypeFunctor<Type(const Type& n)> {
public:
Type VisitType(const Type& t) override;
Type VisitType_(const TypeVarNode* op) override;
Type VisitType_(const TensorTypeNode* op) override;
Type VisitType_(const IncompleteTypeNode* op) override;
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/let_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class LetList {
public:
~LetList() {
if (lets_.size() > 0 && !used_) {
std::cout << "Warning: letlist not used" << std::endl;
LOG(WARNING) << "letlist not used";
}
}
/*!
Expand Down
116 changes: 60 additions & 56 deletions src/relay/pass/partial_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
* 3: The generated code reuses bindings (although they are not shadowed),
* so we have to deduplicate them.
*
* 4: In the generated code, multiple VarNode might have same Id.
* 4: In the generated code, as it call TypeSubst, multiple VarNode might have same Id.
* While it is permitted, most pass use NodeHash for Var,
* and having multiple VarNode for same Id break them.
* Thus we remap them to a single Id for now.
Expand Down Expand Up @@ -216,9 +216,9 @@ Static MkSRef() {
}

using Func = std::function<PStatic(const std::vector<PStatic>&,
const Attrs&,
const Array<Type>&,
LetList*)>;
const Attrs&,
const Array<Type>&,
LetList*)>;

struct SFuncNode : StaticNode {
Func func;
Expand Down Expand Up @@ -256,6 +256,7 @@ class Environment {

void Insert(const Var& v, const PStatic& ps) {
CHECK(ps.defined());
CHECK_EQ(env_.back().locals.count(v), 0);
env_.back().locals[v] = ps;
}

Expand Down Expand Up @@ -287,12 +288,17 @@ class Environment {

/*!
* \brief As our store require rollback, we implement it as a frame.
* every time we need to copy the store, a new frame is insert.
* every time we roll back, a frame is popped.
*
* Every time we need to copy the store, a new frame is insert.
* Every time we roll back, a frame is popped.
*/
struct StoreFrame {
std::unordered_map<const SRefNode*, PStatic> store;
/*! \brief on unknown effect, history_valid is set to true to signal above frame is outdated */
/*!
* \brief On unknown effect, history_valid is set to true to signal above frame is outdated.
*
* It only outdate the frame above it, but not the current frame.
*/
bool history_valid = true;
explicit StoreFrame(const std::unordered_map<const SRefNode*, PStatic>& store) : store(store) { }
StoreFrame() = default;
Expand All @@ -310,26 +316,29 @@ class Store {
}

void Insert(const SRefNode* r, const PStatic& ps) {
CHECK(r);
store_.back().store[r] = ps;
}

// return null if not found
PStatic Lookup(const SRefNode* r) {
auto rit = store_.rbegin();
while (rit != store_.rend()) {
if (!rit->history_valid) {
return PStatic();
}
if (rit->store.find(r) != rit->store.end()) {
return rit->store.find(r)->second;
}
if (!rit->history_valid) {
return PStatic();
}
++rit;
}
return PStatic();
}

void Invalidate() {
store_.back().history_valid = false;
StoreFrame sf;
sf.history_valid = false;
store_.push_back(sf);
}

private:
Expand All @@ -341,6 +350,10 @@ class Store {
store_->store_.push_back(StoreFrame());
}
~StoreFrameContext() {
// push one history valid frame off.
while (!store_->store_.back().history_valid) {
store_->store_.pop_back();
}
store_->store_.pop_back();
}
};
Expand Down Expand Up @@ -442,13 +455,7 @@ Function AsFunc(const Expr& e) {
class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>,
public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> {
public:
PartialEvaluator(const tvm::Array<Var>& free_vars,
const Module& mod) :
mod_(mod) {
for (const Var& v : free_vars) {
env_.Insert(v, NoStatic(v));
}
}
PartialEvaluator(const Module& mod) : mod_(mod) { }

PStatic VisitExpr(const Expr& e, LetList* ll) final {
PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll);
Expand Down Expand Up @@ -484,23 +491,23 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
return env_.Lookup(GetRef<Var>(op));
}

PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final {
GlobalVar gv = GetRef<GlobalVar>(op);
PStatic VisitGlobalVar(const GlobalVar& gv) {
CHECK(mod_.defined());
if (gv_map_.count(gv) == 0) {
if (mod_.defined()) {
Function func = mod_->Lookup(gv);
InitializeFuncId(func);
Func f = VisitFuncStatic(func, gv);
gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
func = AsFunc(PostProcess(VisitFuncDynamic(func, f)));
mod_->Update(gv, func);
} else {
gv_map_.insert({gv, NoStatic(gv)});
}
Function func = mod_->Lookup(gv);
InitializeFuncId(func);
Func f = VisitFuncStatic(func, gv);
gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
func = AsFunc(PostProcess(VisitFuncDynamic(func, f)));
mod_->Update(gv, func);
}
return gv_map_.at(gv);
}

PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final {
return VisitGlobalVar(GetRef<GlobalVar>(op));
}

PStatic VisitExpr_(const LetNode* op, LetList* ll) final {
env_.Insert(op->var, VisitExpr(op->value, ll));
return VisitExpr(op->body, ll);
Expand Down Expand Up @@ -629,7 +636,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
subst.Set(func->type_params[i], type_args[i]);
}
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
subst.Set(func->type_params[i], Type());
subst.Set(func->type_params[i], IncompleteTypeNode::make(kType));
}
std::vector<Time> args_time;
for (const auto& v : pv) {
Expand Down Expand Up @@ -672,22 +679,22 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
};
}


Expr VisitFuncDynamic(const Function& func, const Func& f) {
return store_.Extend<Expr>([&]() {
store_.Invalidate();
return FunctionNode::make(func->params, LetList::With([&](LetList* ll) {
std::vector<PStatic> pv;
for (const auto& v : func->params) {
pv.push_back(NoStatic(v));
}
tvm::Array<Type> type_args;
for (const auto& tp : func->type_params) {
type_args.push_back(tp);
}
return f(pv, Attrs(), type_args, ll)->dynamic;
}), func->ret_type, func->type_params, func->attrs);
});
store_.Invalidate();
return FunctionNode::make(func->params,
LetList::With([&](LetList* ll) {
std::vector<PStatic> pv;
for (const auto& v : func->params) {
pv.push_back(NoStatic(v));
}
tvm::Array<Type> type_args;
for (const auto& tp : func->type_params) {
type_args.push_back(tp);
}
return f(pv, Attrs(), type_args, ll)->dynamic;
}), func->ret_type, func->type_params, func->attrs);
});
}

PStatic VisitFunc(const Function& func, LetList* ll) {
Expand Down Expand Up @@ -1012,17 +1019,14 @@ Expr PostProcess(const Expr& e) {

Module PartialEval(const Module& m) {
CHECK(m->entry_func.defined());
auto func = m->Lookup(m->entry_func);
Expr ret =
TransformF([&](const Expr& e) {
return LetList::With([&](LetList* ll) {
relay::partial_eval::PartialEvaluator pe(FreeVars(e), m);
pe.InitializeFuncId(e);
return relay::partial_eval::PostProcess(pe.VisitExpr(e, ll)->dynamic);
});
}, func);
CHECK(ret->is_type<FunctionNode>());
m->Update(m->entry_func, Downcast<Function>(ret));
relay::partial_eval::PartialEvaluator pe(m);
std::vector<GlobalVar> gvs;
for (const auto& p : m->functions) {
gvs.push_back(p.first);
}
for (const auto& gv : gvs) {
pe.VisitGlobalVar(gv);
}
return m;
}

Expand Down
1 change: 1 addition & 0 deletions src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
return it->second.checked_type;
}
Type ret = this->VisitExpr(expr);
CHECK(ret.defined());
KindCheck(ret, mod_);
ResolvedTypeInfo& rti = type_map_[expr];
rti.checked_type = ret;
Expand Down
10 changes: 10 additions & 0 deletions src/relay/pass/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,16 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
Var VisitVar(const Var& v) final {
return Downcast<Var>(VisitExpr(v));
}

Pattern VisitPattern(const Pattern& p) final {
return PatternMutator::VisitPattern(p);
}

Clause VisitClause(const Clause& c) final {
Pattern pat = VisitPattern(c->lhs);
return ClauseNode::make(pat, VisitExpr(c->rhs));
}

private:
const tvm::Map<TypeVar, Type>& subst_map_;
};
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_pass_partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,10 @@ def test_double():


if __name__ == '__main__':
test_empty_ad()
test_ref()
test_tuple()
test_empty_ad()
test_const_inline()
test_ref()
test_ad()
test_if_ref()
test_function_invalidate()
Expand Down