Skip to content

Commit

Permalink
Back Propagation. Ready for PR. (apache#49)
Browse files Browse the repository at this point in the history
* commit

* do wlist

* commit

* fix lint

* fix lint

* do unaryop

* fix lint

* implement more for evaluator

* commit

* finish rev for binop

* accidentally deleted free_vars.h, force push.

* fix compile error

* fix lint

* impl more anf

* test passed. working on impl.

* impl body of anf

* add skeleton for lambda_backprop

* add node type according to lambda backprop

* fix lint

* commit whatever

* commit!

* remove reverse node, we dont need them

* add relay expr for ref, getref, setref

* Revert "remove reverse node, we dont need them"

This reverts commit fd7a93b1973c17321b6f0e108e3c5fc4ec8357d3.
  • Loading branch information
MarisaKirisame authored and jroesch committed Aug 16, 2018
1 parent 6983913 commit d6da89e
Show file tree
Hide file tree
Showing 15 changed files with 690 additions and 112 deletions.
74 changes: 74 additions & 0 deletions relay/include/relay/anf.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*!
* Copyright (c) 2018 by Contributors
* \file anf.h
* \brief Convert expression to A Normal Form.
*/
#ifndef NNVM_RELAY_ANF_H_
#define NNVM_RELAY_ANF_H_

#include <unordered_map>
#include <string>
#include <vector>
#include <utility>
#include "environment.h"
#include "expr_functor.h"
#include "ir.h"
#include "gen_fresh.h"

namespace nnvm {
namespace relay {

using std::vector;
using std::pair;

struct PartialLocalId {
LocalId lid;
explicit PartialLocalId(const LocalId & lid) : lid(lid) { }
PartialLocalId() {
throw;
}
};

struct LetList {
GenFresh gf;
vector<pair<PartialLocalId, Expr>> lets;
PartialLocalId let(const Expr & expr) {
PartialLocalId id(gf.fresh());
lets.push_back({id, expr});
return id;
}
Expr plug(const Expr & expr) const {
Expr ret = expr;
for (auto rit = lets.rbegin(); rit != lets.rend(); rit++) {
ret = LetNode::make(rit->first.lid, rit->second, ret);
}
return ret;
}
};

struct ANF : public ExprFunctor<PartialLocalId(const Expr& n)> {
LetList ll;
ANF() {}
static Expr ToANF(const Expr& expr);
PartialLocalId Convert(const Expr& expr);
PartialLocalId VisitExpr_(const LocalIdNode* op) override;
PartialLocalId VisitExpr_(const GlobalIdNode* op) override;
PartialLocalId VisitExpr_(const IntrinsicIdNode* op) override;
PartialLocalId VisitExpr_(const FloatLitNode* op) override;
PartialLocalId VisitExpr_(const BoolLitNode* op) override;
PartialLocalId VisitExpr_(const IntLitNode* op) override;
PartialLocalId VisitExpr_(const TensorLitNode* op) override;
PartialLocalId VisitExpr_(const ProductLitNode* op) override;
PartialLocalId VisitExpr_(const CastNode* op) override;
PartialLocalId VisitExpr_(const ParamNode* op) override;
PartialLocalId VisitExpr_(const FunctionNode* op) override;
PartialLocalId VisitExpr_(const CallNode* op) override;
PartialLocalId VisitExpr_(const DebugNode* op) override;
PartialLocalId VisitExpr_(const UnaryOpNode* op) override;
PartialLocalId VisitExpr_(const BinaryOpNode* op) override;
PartialLocalId VisitExpr_(const LetNode* op) override;
};

} // namespace relay
} // namespace nnvm
#endif // NNVM_RELAY_ANF_H_
6 changes: 6 additions & 0 deletions relay/include/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ReverseNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ReverseInverseNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ZeroSNode*op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const PlusSNode*op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const AccumulateNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ZeroNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -160,6 +163,9 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(BinaryOpNode);
RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ReverseNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ReverseInverseNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ZeroSNode);
RELAY_EXPR_FUNCTOR_DISPATCH(PlusSNode);
RELAY_EXPR_FUNCTOR_DISPATCH(AccumulateNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ZeroNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ProjectionNode);
Expand Down
2 changes: 1 addition & 1 deletion relay/include/relay/gen_fresh.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace relay {
struct GenFresh {
int i = 0;
LocalId fresh() {
return LocalIdNode::make("fresh_" + i++);
return LocalIdNode::make("fresh_" + std::to_string(++i));
}
};

Expand Down
23 changes: 23 additions & 0 deletions relay/include/relay/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,27 @@
#include "nnvm/relay/ir/type.h"
#include "nnvm/relay/ir/value.h"

namespace nnvm {
namespace relay {

Expr GetField(const Expr & x, int index);

Expr Pair(const Expr & l, const Expr & r);

Expr Neg(const Expr & x);

Expr Square(const Expr & x);

Expr Plus(const Expr & l, const Expr & r);

Expr Sub(const Expr & l, const Expr & r);

Expr Mul(const Expr & l, const Expr & r);

Expr Div(const Expr & l, const Expr & r);

Expr Float(double d);

} // namespace relay
} // namespace nnvm
#endif // NNVM_RELAY_IR_H_
104 changes: 104 additions & 0 deletions relay/include/relay/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,59 @@ class ReverseNode : public ExprNode {

RELAY_DEFINE_EXPR(Reverse, ReverseNode);

class ReverseInverse;

struct ReverseInverseNode : ExprNode {
Expr node;

ReverseInverseNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("node", &node); }

TVM_DLL static ReverseInverse make(Expr node);

static constexpr const char* _type_key = "nnvm.ReverseInverse";
TVM_DECLARE_NODE_TYPE_INFO(ReverseInverseNode, ExprNode);
};

RELAY_DEFINE_EXPR(ReverseInverse, ReverseInverseNode);

class PlusS;

struct PlusSNode : ExprNode {
Expr left;
Expr right;
PlusSNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("left", &left);
v->Visit("right", &right);
}

TVM_DLL static PlusS make(Expr left, Expr right);

static constexpr const char* _type_key = "nnvm.PlusS";
TVM_DECLARE_NODE_TYPE_INFO(PlusSNode, ExprNode);
};

RELAY_DEFINE_EXPR(PlusS, PlusSNode);

class ZeroS;

struct ZeroSNode : ExprNode {
Expr node;
ZeroSNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("node", &node); }

TVM_DLL static ZeroS make(Expr node);

static constexpr const char* _type_key = "nnvm.ZeroS";
TVM_DECLARE_NODE_TYPE_INFO(ZeroSNode, ExprNode);
};

RELAY_DEFINE_EXPR(ZeroS, ZeroSNode);

class Accumulate;

class AccumulateNode : public ExprNode {
Expand Down Expand Up @@ -431,6 +484,57 @@ class ZeroNode : public ExprNode {

RELAY_DEFINE_EXPR(Zero, ZeroNode);

class Ref;

struct RefNode : ExprNode {
Expr expr;
RefNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("expr", &expr); }

TVM_DLL static Ref make(Expr expr);

static constexpr const char * _type_key = "nnvm.Ref";
TVM_DECLARE_NODE_TYPE_INFO(RefNode, ExprNode);
};

RELAY_DEFINE_EXPR(Ref, RefNode);

class GetRef;

struct GetRefNode : ExprNode {
Expr expr;
GetRefNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("expr", &expr); }

TVM_DLL static Ref make(Expr expr);

static constexpr const char * _type_key = "nnvm.GetRef";
TVM_DECLARE_NODE_TYPE_INFO(GetRefNode, ExprNode);
};

RELAY_DEFINE_EXPR(GetRef, GetRefNode);

class SetRef;

struct SetRefNode : ExprNode {
Expr ref, val;
SetRefNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("ref", &ref);
v->Visit("val", &val);
}

TVM_DLL static Ref make(Expr ref, Expr val);

static constexpr const char * _type_key = "nnvm.SetRef";
TVM_DECLARE_NODE_TYPE_INFO(SetRefNode, ExprNode);
};

RELAY_DEFINE_EXPR(SetRef, SetRefNode);

class Projection;

class ProjectionNode : public ExprNode {
Expand Down
18 changes: 17 additions & 1 deletion relay/include/relay/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#ifndef NNVM_RELAY_IR_TYPE_H_
#define NNVM_RELAY_IR_TYPE_H_

#include <nnvm/relay/ir/type.h>
#include <tvm/api_registry.h>
#include <tvm/ir.h>
#include <tvm/node.h>
Expand Down Expand Up @@ -151,6 +150,23 @@ class Shape : public NodeRef {
using ContainerType = Node;
};

class RefType;
struct RefTypeNode : TypeNode {
Type type;
RefTypeNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("type", &type);
}

TVM_DLL static RefType make(Type type);

static constexpr const char* _type_key = "nnvm.RefType";
TVM_DECLARE_NODE_TYPE_INFO(RefTypeNode, TypeNode);
};

RELAY_DEFINE_TYPE(RefType, RefTypeNode);

class TensorType;

/*! \brief The type of Tensors. */
Expand Down
80 changes: 80 additions & 0 deletions relay/include/relay/lambda_backprop.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*!
* Copyright (c) 2018 by Contributors
* \file lambda_backprop.h
* \brief Reverse Mode Automatic Differentiation for the Relay IR
*/
#ifndef NNVM_RELAY_LAMBDA_BACKPROP_H_
#define NNVM_RELAY_LAMBDA_BACKPROP_H_

#include <unordered_map>
#include <string>
#include <vector>
#include <utility>
#include <map>
#include "environment.h"
#include "expr_functor.h"
#include "ir.h"
#include "gen_fresh.h"
#include "anf.h"

// Do we really need a J-> and J->-1?
// Maybe we can just ask J-> to store original version as a Product?
namespace nnvm {
namespace relay {

struct Transformed {
Expr expr;
explicit Transformed(const Expr & expr) : expr(expr) { }
};

struct Transformer : ExprFunctor<Transformed(const Expr& n)> {
Transformed Transform(const Expr & expr) {
return (*this)(expr);
}
Transformed VisitExpr_(const LocalIdNode* op) override;
Transformed VisitExpr_(const GlobalIdNode* op) override;
Transformed VisitExpr_(const IntrinsicIdNode* op) override;
Transformed VisitExpr_(const FloatLitNode* op) override;
Transformed VisitExpr_(const BoolLitNode* op) override;
Transformed VisitExpr_(const IntLitNode* op) override;
Transformed VisitExpr_(const TensorLitNode* op) override;
Transformed VisitExpr_(const ProductLitNode* op) override;
Transformed VisitExpr_(const CastNode* op) override;
Transformed VisitExpr_(const ParamNode* op) override;
Transformed VisitExpr_(const FunctionNode* op) override;
Transformed VisitExpr_(const CallNode* op) override;
Transformed VisitExpr_(const DebugNode* op) override;
Transformed VisitExpr_(const UnaryOpNode* op) override;
Transformed VisitExpr_(const BinaryOpNode* op) override;
Transformed VisitExpr_(const LetNode* op) override;
};

struct UnTransform : ExprFunctor<Expr(const Expr & n)> {
Expr UnTransform(const Expr & expr) {
return (*this)(expr);
}
Expr VisitExpr_(const LocalIdNode* op) override;
Expr VisitExpr_(const GlobalIdNode* op) override;
Expr VisitExpr_(const IntrinsicIdNode* op) override;
Expr VisitExpr_(const FloatLitNode* op) override;
Expr VisitExpr_(const BoolLitNode* op) override;
Expr VisitExpr_(const IntLitNode* op) override;
Expr VisitExpr_(const TensorLitNode* op) override;
Expr VisitExpr_(const ProductLitNode* op) override;
Expr VisitExpr_(const CastNode* op) override;
Expr VisitExpr_(const ParamNode* op) override;
Expr VisitExpr_(const FunctionNode* op) override;
Expr VisitExpr_(const CallNode* op) override;
Expr VisitExpr_(const DebugNode* op) override;
Expr VisitExpr_(const UnaryOpNode* op) override;
Expr VisitExpr_(const BinaryOpNode* op) override;
Expr VisitExpr_(const LetNode* op) override;
};

Expr UnTransform(const Transformed & t) {
throw;
}

} // namespace relay
} // namespace nnvm
#endif // NNVM_RELAY_LAMBDA_BACKPROP_H_
Loading

0 comments on commit d6da89e

Please sign in to comment.