Skip to content

Commit

Permalink
[Relay] Higher order reverse mode automatic differentiation that work…
Browse files Browse the repository at this point in the history
… with control flow (apache#2496)

add test

remove dead code

stash

do it

add more test
  • Loading branch information
MarisaKirisame authored and wweic committed Mar 12, 2019
1 parent 8ce998e commit 1f04aed
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 56 deletions.
21 changes: 17 additions & 4 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,9 +530,11 @@ def to_graph_normal_form(expr):
return _ir_pass.to_graph_normal_form(expr)


def gradient(expr, mod=None):
def gradient(expr, mod=None, mode='higher_order'):
"""
Transform a function to return original result paired with gradient of input.
Transform the input function,
returning a function that calculate the original result,
paired with gradient of the input.
Parameters
----------
Expand All @@ -541,12 +543,23 @@ def gradient(expr, mod=None):
mod : Optional[tvm.relay.Module]
mode : Optional[String]
The mode of the automatic differentiation algorithm.
'first_order' only work on first order code, but will not produce reference nor closure.
'higher_order' work on all code using reference and closure.
Returns
-------
expr : tvm.relay.Expr
The output expression.
The transformed expression.
"""
return _ir_pass.first_order_gradient(expr, mod)
if mode == 'first_order':
return _ir_pass.first_order_gradient(expr, mod)
elif mode == 'higher_order':
return _ir_pass.gradient(expr, mod)
else:
raise Exception('unknown mode')



def get_total_mac_number(expr):
Expand Down
1 change: 1 addition & 0 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
}

node->pattern = op_pattern;
this->Update(call->op, nullptr, kOpaque);
const auto* rtype = call->checked_type().as<TensorTypeNode>();
// pass the message back to all the children it references.
for (size_t i = 0; i < call->args.size(); ++i) {
Expand Down
201 changes: 152 additions & 49 deletions src/relay/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ using ADValue = std::shared_ptr<ADValueNode>;

/*! \brief AD over a program which generates a tensor output. */
struct ADTensor : ADValueNode {
Expr foward;
Expr forward;
mutable Expr reverse; // must be a variable to avoid duplication
ADTensor(LetList* ll, const Expr& foward) :
foward(ll->Push(foward)), reverse(ll->Push(ZeroLike(this->foward))) { }
ADTensor(LetList* ll, const Expr& forward) :
forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { }
};

/*! \brief A staged representation of the program, we reflect
Expand All @@ -105,14 +105,14 @@ struct ADFunction : ADValueNode {
func(func) { }
};

struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> {
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
std::vector<std::function<void(LetList* ll)>> backprop_actions;
// we assume no closure so no need for lexical scoping
std::unordered_map<Var, ADValue, NodeHash, NodeEqual> env;
LetList* ll;

ReverseAD(LetList* ll) : ll(ll) { }
FirstOrderReverseAD(LetList* ll) : ll(ll) { }

ADValue VisitExpr_(const OpNode* op) final {
Op op_ref = GetRef<Op>(op);
Expand All @@ -121,21 +121,22 @@ struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
return std::make_shared<ADFunction>([this, op_ref](const std::vector<ADValue>& args,
const Attrs& attrs,
const tvm::Array<Type>& type_args) {
std::vector<Expr> call_args;
for (const ADValue& adval : args) {
call_args.push_back(adval->get<ADTensor>().foward);
std::vector<Expr> call_args;
for (const ADValue& adval : args) {
call_args.push_back(adval->get<ADTensor>().forward);
}
auto orig = CallNode::make(op_ref, call_args, attrs, type_args);
auto ret = std::make_shared<ADTensor>(ll, orig);
backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
args[i]->get<ADTensor>().reverse =
ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i]));
}
auto orig = CallNode::make(op_ref, call_args, attrs, type_args);
auto ret = std::make_shared<ADTensor>(ll, orig);
backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
for (size_t i = 0; i < args.size(); ++i) {
args[i]->get<ADTensor>().reverse =
ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i]));
}
});
return ret;
});
return ret;
});
}

ADValue VisitExpr_(const ConstantNode* op) final {
Expand Down Expand Up @@ -172,6 +173,23 @@ struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
}
};

Type GradRetType(const Function& f) {
// if type annotations are provided, we will construct a ret type;
// otherwise, leave it to be inferred
if (!f->ret_type.defined()) {
return Type();
}
std::vector<Type> vt;
for (const auto& p : f->params) {
if (!p->type_annotation.defined()) {
return Type();
}
vt.push_back(p->type_annotation);
}

return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
}

Expr FirstOrderGradient(const Expr& re, const Module& mod) {
// Currently we first remove any global functions for the first
// order case.
Expand All @@ -182,7 +200,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {

// We will then build a sequence of lets which implement reverse mode.
Expr body = LetList::With([&](LetList* ll) {
ReverseAD reverse_ad(ll);
FirstOrderReverseAD reverse_ad(ll);
ADValue rev = reverse_ad(e);
std::vector<ADValue> args;
for (const auto& p : f->params) {
Expand All @@ -191,46 +209,131 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
auto c = rev->get<ADFunction>().func(args, Attrs(), {});
const auto& res = c->get<ADTensor>();
Expr grad = LetList::With([&](LetList* ll) {
res.reverse = OneLike(res.foward);
for (auto it = reverse_ad.backprop_actions.rbegin();
it != reverse_ad.backprop_actions.rend();
++it) {
(*it)(ll);
res.reverse = OnesLike(res.forward);
for (auto it = reverse_ad.backprop_actions.rbegin();
it != reverse_ad.backprop_actions.rend();
++it) {
(*it)(ll);
}
std::vector<Expr> grad_res;
for (const auto& a : args) {
grad_res.push_back(a->get<ADTensor>().reverse);
}
return TupleNode::make(grad_res);
});
return Pair(res.forward, grad);
});

return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
}

TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 2);
*ret = FirstOrderGradient(args[0], args[1]);
});

struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final {
Type t = GetRef<Type>(ttn);
return TupleTypeNode::make({t, RefTypeNode::make(t)});
}
};

struct ReverseAD : ExprMutator {
Var bp;
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");

ReverseAD(const Var& bp) : bp(bp) { }

Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call";
throw;
}

Expr VisitExpr_(const CallNode* op) final {
if (const OpNode* op_node = op->op.as<OpNode>()) {
Op op_ref = GetRef<Op>(op_node);
CHECK(rev_map.count(op_ref))
<< op_node->name << " does not have reverse mode defined";
return LetList::With([&](LetList* ll) {
std::vector<Var> args;
for (const auto& arg : op->args) {
args.push_back(ll->Push(VisitExpr(arg)));
}
std::vector<Expr> grad_res;
for (const auto& a : args) {
grad_res.push_back(a->get<ADTensor>().reverse);
std::vector<Expr> orig_args;
for (const auto& arg : args) {
orig_args.push_back(GetField(VisitExpr(arg), 0));
}
return TupleNode::make(grad_res);
Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args);
Var orig_var = ll->Push(orig);
auto ref = ll->Push(RefCreateNode::make(ZerosLike(orig_var)));
auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make(
{},
LetList::With([&](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, ll->Push(RefReadNode::make(ref)));
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
ll->Push(RefWriteNode::make(GetField(args[i], 1),
Add(ll->Push(RefReadNode::make(GetField(args[i], 1))),
rev[i])));
}
return CallNode::make(bpv, {});
}),
TupleTypeNode::make({}),
{});
ll->Push(RefWriteNode::make(bp, nbp));
return Pair(orig_var, ref);
});
return Pair(res.foward, grad);
});

// if type annotations are provided, we will construct a ret type;
// otherwise, leave it to be inferred
Type ret_type = Type();
std::vector<Type> vt;
bool missing = !f->ret_type.defined();
for (const auto& p : f->params) {
if (missing || !p->type_annotation.defined()) {
missing = true;
break;
}
vt.push_back(p->type_annotation);
return ExprMutator::VisitExpr_(op);
}

Expr VisitExpr_(const ConstantNode* op) final {
Expr e = GetRef<Expr>(op);
return Pair(e, RefCreateNode::make(ZerosLike(e)));
}

if (!missing) {
ret_type = TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
Type VisitType(const Type& t) final {
return t.defined() ? ReverseADType()(t) : t;
}
};

return FunctionNode::make(f->params, body, ret_type, {});
Expr BPEmpty() {
Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
return RefCreateNode::make(unitF);
}

TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 2);
*ret = FirstOrderGradient(args[0], args[1]);
});
Expr Gradient(const Expr& re, const Module& mod) {
auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>();
CHECK(f) << "input need to be a function";
CHECK(f->type_params.size() == 0) << "no polymorphism supported for now";
Expr body = LetList::With([&](LetList* ll) {
Var bp = ll->Push(BPEmpty());
Expr rev = ReverseAD(bp)(e);
std::vector<Expr> args;
for (const auto& p : f->params) {
args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p)))));
}
auto c = ll->Push(CallNode::make(rev, args));
ll->Push(RefWriteNode::make(GetField(c, 1), OnesLike(GetField(c, 0))));
ll->Push(CallNode::make(RefReadNode::make(bp), {}));
std::vector<Expr> ret;
for (const auto& a : args) {
ret.push_back(RefReadNode::make(GetField(a, 1)));
}
return Pair(GetField(c, 0), TupleNode::make(ret));
});
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
}

TVM_REGISTER_API("relay._ir_pass.gradient")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 2);
*ret = Gradient(args[0], args[1]);
});

} // namespace relay
} // namespace tvm
4 changes: 2 additions & 2 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,12 @@ inline Expr Divide(Expr lhs, Expr rhs) {
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}

inline Expr ZeroLike(Expr e) {
inline Expr ZerosLike(Expr e) {
static const Op& op = Op::Get("zeros_like");
return CallNode::make(op, {e});
}

inline Expr OneLike(Expr e) {
inline Expr OnesLike(Expr e) {
static const Op& op = Op::Get("ones_like");
return CallNode::make(op, {e});
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ bool TupleGetItemRel(const Array<Type>& types,
const auto* param = attrs.as<TupleGetItemAttrs>();
CHECK(param != nullptr);
CHECK_GE(param->index, 0);
CHECK_LT(param->index, data->fields.size());
CHECK_LT(param->index, data->fields.size());
reporter->Assign(types[1], data->fields[param->index]);
return true;
}
Expand Down
Loading

0 comments on commit 1f04aed

Please sign in to comment.