Skip to content

Commit

Permalink
[Relay] use unordered_map instead of map in ANF (apache#3024)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame authored and Wei Chen committed May 13, 2019
1 parent 6254550 commit a976268
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions src/relay/pass/to_a_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
namespace tvm {
namespace relay {

Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv);
Expr ToANormalForm(const Expr& e,
const Module& m,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv);

struct ScopeNode;
using Scope = std::shared_ptr<ScopeNode>;
Expand Down Expand Up @@ -104,7 +106,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
const Module& m,
const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
std::set<GlobalVar>* gv) {
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
Fill fi(m, dg, node_scope, gv);
return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
}
Expand All @@ -113,13 +115,13 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Module mod_;
const DependencyGraph& dg_;
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope_;
std::set<GlobalVar>* visited_;
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited_;
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo;

Fill(Module mod,
const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
std::set<GlobalVar>* visited) :
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited) :
mod_(mod),
dg_(dg),
node_scope_(node_scope),
Expand Down Expand Up @@ -273,7 +275,9 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
}
};

Expr ToANormalFormAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
Expr ToANormalFormAux(const Expr& e,
const Module& m,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
/* When you lift a lambda, what is inside is also being lift.
*
* So we must determine the scope of the lambda before determining the scope of it's body.
Expand All @@ -299,12 +303,14 @@ Expr ToANormalFormAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
return Fill::ToANormalForm(e, m, dg, &node_scope, gv);
}

Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
Expr ToANormalForm(const Expr& e,
const Module& m,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
return TransformF([&](const Expr& e) { return ToANormalFormAux(e, m, gv); }, e);
}

Expr ToANormalForm(const Expr& e, const Module& m) {
std::set<GlobalVar> gv;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> gv;
return ToANormalForm(e, m, &gv);
}

Expand Down

0 comments on commit a976268

Please sign in to comment.