forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Back Propagation. Ready for PR. (apache#49)
* commit * do wlist * commit * fix lint * fix lint * do unaryop * fix lint * implement more for evaluator * commit * finish rev for binop * accidentally deleted free_vars.h, force push. * fix compile error * fix lint * impl more anf * test passed. working on impl. * impl body of anf * add skeleton for lambda_backprop * add node type according to lambda backprop * fix lint * commit whatever * commit! * remove reverse node, we dont need them * add relay expr for ref, getref, setref * Revert "remove reverse node, we dont need them" This reverts commit fd7a93b1973c17321b6f0e108e3c5fc4ec8357d3.
- Loading branch information
1 parent
6983913
commit d6da89e
Showing
15 changed files
with
690 additions
and
112 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
/*! | ||
* Copyright (c) 2018 by Contributors | ||
* \file anf.h | ||
* \brief Convert expression to A Normal Form. | ||
*/ | ||
#ifndef NNVM_RELAY_ANF_H_ | ||
#define NNVM_RELAY_ANF_H_ | ||
|
||
#include <unordered_map> | ||
#include <string> | ||
#include <vector> | ||
#include <utility> | ||
#include "environment.h" | ||
#include "expr_functor.h" | ||
#include "ir.h" | ||
#include "gen_fresh.h" | ||
|
||
namespace nnvm { | ||
namespace relay { | ||
|
||
using std::vector; | ||
using std::pair; | ||
|
||
struct PartialLocalId { | ||
LocalId lid; | ||
explicit PartialLocalId(const LocalId & lid) : lid(lid) { } | ||
PartialLocalId() { | ||
throw; | ||
} | ||
}; | ||
|
||
struct LetList { | ||
GenFresh gf; | ||
vector<pair<PartialLocalId, Expr>> lets; | ||
PartialLocalId let(const Expr & expr) { | ||
PartialLocalId id(gf.fresh()); | ||
lets.push_back({id, expr}); | ||
return id; | ||
} | ||
Expr plug(const Expr & expr) const { | ||
Expr ret = expr; | ||
for (auto rit = lets.rbegin(); rit != lets.rend(); rit++) { | ||
ret = LetNode::make(rit->first.lid, rit->second, ret); | ||
} | ||
return ret; | ||
} | ||
}; | ||
|
||
struct ANF : public ExprFunctor<PartialLocalId(const Expr& n)> { | ||
LetList ll; | ||
ANF() {} | ||
static Expr ToANF(const Expr& expr); | ||
PartialLocalId Convert(const Expr& expr); | ||
PartialLocalId VisitExpr_(const LocalIdNode* op) override; | ||
PartialLocalId VisitExpr_(const GlobalIdNode* op) override; | ||
PartialLocalId VisitExpr_(const IntrinsicIdNode* op) override; | ||
PartialLocalId VisitExpr_(const FloatLitNode* op) override; | ||
PartialLocalId VisitExpr_(const BoolLitNode* op) override; | ||
PartialLocalId VisitExpr_(const IntLitNode* op) override; | ||
PartialLocalId VisitExpr_(const TensorLitNode* op) override; | ||
PartialLocalId VisitExpr_(const ProductLitNode* op) override; | ||
PartialLocalId VisitExpr_(const CastNode* op) override; | ||
PartialLocalId VisitExpr_(const ParamNode* op) override; | ||
PartialLocalId VisitExpr_(const FunctionNode* op) override; | ||
PartialLocalId VisitExpr_(const CallNode* op) override; | ||
PartialLocalId VisitExpr_(const DebugNode* op) override; | ||
PartialLocalId VisitExpr_(const UnaryOpNode* op) override; | ||
PartialLocalId VisitExpr_(const BinaryOpNode* op) override; | ||
PartialLocalId VisitExpr_(const LetNode* op) override; | ||
}; | ||
|
||
} // namespace relay | ||
} // namespace nnvm | ||
#endif // NNVM_RELAY_ANF_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
/*! | ||
* Copyright (c) 2018 by Contributors | ||
* \file lambda_backprop.h | ||
* \brief Reverse Mode Automatic Differentiation for the Relay IR | ||
*/ | ||
#ifndef NNVM_RELAY_LAMBDA_BACKPROP_H_ | ||
#define NNVM_RELAY_LAMBDA_BACKPROP_H_ | ||
|
||
#include <unordered_map> | ||
#include <string> | ||
#include <vector> | ||
#include <utility> | ||
#include <map> | ||
#include "environment.h" | ||
#include "expr_functor.h" | ||
#include "ir.h" | ||
#include "gen_fresh.h" | ||
#include "anf.h" | ||
|
||
// Do we really need a J-> and J->-1? | ||
// Maybe we can just ask J-> to store original version as a Product? | ||
namespace nnvm { | ||
namespace relay { | ||
|
||
struct Transformed { | ||
Expr expr; | ||
explicit Transformed(const Expr & expr) : expr(expr) { } | ||
}; | ||
|
||
struct Transformer : ExprFunctor<Transformed(const Expr& n)> { | ||
Transformed Transform(const Expr & expr) { | ||
return (*this)(expr); | ||
} | ||
Transformed VisitExpr_(const LocalIdNode* op) override; | ||
Transformed VisitExpr_(const GlobalIdNode* op) override; | ||
Transformed VisitExpr_(const IntrinsicIdNode* op) override; | ||
Transformed VisitExpr_(const FloatLitNode* op) override; | ||
Transformed VisitExpr_(const BoolLitNode* op) override; | ||
Transformed VisitExpr_(const IntLitNode* op) override; | ||
Transformed VisitExpr_(const TensorLitNode* op) override; | ||
Transformed VisitExpr_(const ProductLitNode* op) override; | ||
Transformed VisitExpr_(const CastNode* op) override; | ||
Transformed VisitExpr_(const ParamNode* op) override; | ||
Transformed VisitExpr_(const FunctionNode* op) override; | ||
Transformed VisitExpr_(const CallNode* op) override; | ||
Transformed VisitExpr_(const DebugNode* op) override; | ||
Transformed VisitExpr_(const UnaryOpNode* op) override; | ||
Transformed VisitExpr_(const BinaryOpNode* op) override; | ||
Transformed VisitExpr_(const LetNode* op) override; | ||
}; | ||
|
||
struct UnTransform : ExprFunctor<Expr(const Expr & n)> { | ||
Expr UnTransform(const Expr & expr) { | ||
return (*this)(expr); | ||
} | ||
Expr VisitExpr_(const LocalIdNode* op) override; | ||
Expr VisitExpr_(const GlobalIdNode* op) override; | ||
Expr VisitExpr_(const IntrinsicIdNode* op) override; | ||
Expr VisitExpr_(const FloatLitNode* op) override; | ||
Expr VisitExpr_(const BoolLitNode* op) override; | ||
Expr VisitExpr_(const IntLitNode* op) override; | ||
Expr VisitExpr_(const TensorLitNode* op) override; | ||
Expr VisitExpr_(const ProductLitNode* op) override; | ||
Expr VisitExpr_(const CastNode* op) override; | ||
Expr VisitExpr_(const ParamNode* op) override; | ||
Expr VisitExpr_(const FunctionNode* op) override; | ||
Expr VisitExpr_(const CallNode* op) override; | ||
Expr VisitExpr_(const DebugNode* op) override; | ||
Expr VisitExpr_(const UnaryOpNode* op) override; | ||
Expr VisitExpr_(const BinaryOpNode* op) override; | ||
Expr VisitExpr_(const LetNode* op) override; | ||
}; | ||
|
||
Expr UnTransform(const Transformed & t) { | ||
throw; | ||
} | ||
|
||
} // namespace relay | ||
} // namespace nnvm | ||
#endif // NNVM_RELAY_LAMBDA_BACKPROP_H_ |
Oops, something went wrong.