Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Higher order reverse mode automatic differentiation that work with control flow #2496

Merged
merged 1 commit into from
Mar 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,9 +530,11 @@ def to_graph_normal_form(expr):
return _ir_pass.to_graph_normal_form(expr)


def gradient(expr, mod=None):
def gradient(expr, mod=None, mode='higher_order'):
MarisaKirisame marked this conversation as resolved.
Show resolved Hide resolved
"""
Transform a function to return original result paired with gradient of input.
Transform the input function,
returning a function that calculate the original result,
paired with gradient of the input.

Parameters
----------
Expand All @@ -541,12 +543,23 @@ def gradient(expr, mod=None):

mod : Optional[tvm.relay.Module]

mode : Optional[String]
The mode of the automatic differentiation algorithm.
'first_order' only work on first order code, but will not produce reference nor closure.
'higher_order' work on all code using reference and closure.
MarisaKirisame marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
expr : tvm.relay.Expr
The output expression.
The transformed expression.
"""
return _ir_pass.first_order_gradient(expr, mod)
if mode == 'first_order':
return _ir_pass.first_order_gradient(expr, mod)
elif mode == 'higher_order':
return _ir_pass.gradient(expr, mod)
MarisaKirisame marked this conversation as resolved.
Show resolved Hide resolved
else:
raise Exception('unknown mode')



def get_total_mac_number(expr):
Expand Down
1 change: 1 addition & 0 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
}

node->pattern = op_pattern;
this->Update(call->op, nullptr, kOpaque);
MarisaKirisame marked this conversation as resolved.
Show resolved Hide resolved
const auto* rtype = call->checked_type().as<TensorTypeNode>();
// pass the message back to all the children it references.
for (size_t i = 0; i < call->args.size(); ++i) {
Expand Down
201 changes: 152 additions & 49 deletions src/relay/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ using ADValue = std::shared_ptr<ADValueNode>;

/*! \brief AD over a program which generates a tensor output. */
Copy link
Member

@merrymercy merrymercy Feb 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the program generates a tuple of tensor as output?
ADFunction and ADTensor cannot cover this case.

struct ADTensor : ADValueNode {
Expr foward;
Expr forward;
mutable Expr reverse; // must be a variable to avoid duplication
ADTensor(LetList* ll, const Expr& foward) :
foward(ll->Push(foward)), reverse(ll->Push(ZeroLike(this->foward))) { }
ADTensor(LetList* ll, const Expr& forward) :
forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { }
};

/*! \brief A staged representation of the program, we reflect
Expand All @@ -105,14 +105,14 @@ struct ADFunction : ADValueNode {
func(func) { }
};

struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> {
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
std::vector<std::function<void(LetList* ll)>> backprop_actions;
// we assume no closure so no need for lexical scoping
std::unordered_map<Var, ADValue, NodeHash, NodeEqual> env;
LetList* ll;

ReverseAD(LetList* ll) : ll(ll) { }
FirstOrderReverseAD(LetList* ll) : ll(ll) { }

ADValue VisitExpr_(const OpNode* op) final {
Op op_ref = GetRef<Op>(op);
Expand All @@ -121,21 +121,22 @@ struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
return std::make_shared<ADFunction>([this, op_ref](const std::vector<ADValue>& args,
const Attrs& attrs,
const tvm::Array<Type>& type_args) {
std::vector<Expr> call_args;
for (const ADValue& adval : args) {
call_args.push_back(adval->get<ADTensor>().foward);
std::vector<Expr> call_args;
for (const ADValue& adval : args) {
call_args.push_back(adval->get<ADTensor>().forward);
}
auto orig = CallNode::make(op_ref, call_args, attrs, type_args);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to use the real original node instead of a reconstruction? Reconstructing a node may lead to losing some information, e.g. the inferred type checked_type_.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe, but it will require big change in code structure. if such a case come up i will do it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need checked_type_ in the integration with the tensor expression ad, mostly for finding out the number of the outputs of the original operation. However, I think I can get this information from other sources. Would passing and reassigning just checked_type_ be dangerous in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgrechanik-h can i just rerun type infer? right now every pass will destroy checked_type_ and rebuild from type infer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarisaKirisame Not sure what you mean, but rerunning type inference sounds like a bit of an overkill, and I'm not sure it can be done before calling the FPrimalGradient attribute. If the checked_type_ must be reset after running the differentiation pass, then one of the solutions could be setting it before calling FPrimalGradient to the original value and then resetting it to nullptr after FPrimalGradient has finished, but this feels kinda hacky.

(Also currently I think that in my particular case the proper solution would be to fix the signature of FTVMCompute so that it accept input types, not only the out_type. And this is not connected to the automatic differentiation pass.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgrechanik-h all pass (FuseOps, AD, ANF, GNF, DeadCodeElimination, FoldScaleAxis) remove the type annotation and rerun it AFAIK. I am not sure why it is an AD-specific issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarisaKirisame I think some passes may benefit from using type information, and, of course, they should use it before erasing it (or recreating the node, I don't think checked_type_ gets literally erased anywhere). In the case of the code we are currently discussing the node is recreated (and thus type information is erased) before calling to FPrimalGradient function which could use type information if it was still there. I don't insist on fixing it if it's difficult or unnatural, because I have only one case where this might be useful, moreover in this single case it would be better to fix a completely different part of Relay.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My other passes use type info too. But we just rerun type infer, and we are encoding (rerunning type infer) into pass manager too.

auto ret = std::make_shared<ADTensor>(ll, orig);
backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
args[i]->get<ADTensor>().reverse =
ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i]));
}
auto orig = CallNode::make(op_ref, call_args, attrs, type_args);
auto ret = std::make_shared<ADTensor>(ll, orig);
backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
for (size_t i = 0; i < args.size(); ++i) {
args[i]->get<ADTensor>().reverse =
ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i]));
}
});
return ret;
});
return ret;
});
}

ADValue VisitExpr_(const ConstantNode* op) final {
Expand Down Expand Up @@ -172,6 +173,23 @@ struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
}
};

Type GradRetType(const Function& f) {
// if type annotations are provided, we will construct a ret type;
// otherwise, leave it to be inferred
if (!f->ret_type.defined()) {
return Type();
}
std::vector<Type> vt;
for (const auto& p : f->params) {
if (!p->type_annotation.defined()) {
return Type();
}
vt.push_back(p->type_annotation);
}

return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
}

Expr FirstOrderGradient(const Expr& re, const Module& mod) {
// Currently we first remove any global functions for the first
// order case.
Expand All @@ -182,7 +200,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {

// We will then build a sequence of lets which implement reverse mode.
Expr body = LetList::With([&](LetList* ll) {
ReverseAD reverse_ad(ll);
FirstOrderReverseAD reverse_ad(ll);
ADValue rev = reverse_ad(e);
std::vector<ADValue> args;
for (const auto& p : f->params) {
Expand All @@ -191,46 +209,131 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
auto c = rev->get<ADFunction>().func(args, Attrs(), {});
const auto& res = c->get<ADTensor>();
Expr grad = LetList::With([&](LetList* ll) {
res.reverse = OneLike(res.foward);
for (auto it = reverse_ad.backprop_actions.rbegin();
it != reverse_ad.backprop_actions.rend();
++it) {
(*it)(ll);
res.reverse = OnesLike(res.forward);
for (auto it = reverse_ad.backprop_actions.rbegin();
it != reverse_ad.backprop_actions.rend();
++it) {
(*it)(ll);
}
std::vector<Expr> grad_res;
for (const auto& a : args) {
grad_res.push_back(a->get<ADTensor>().reverse);
}
return TupleNode::make(grad_res);
});
return Pair(res.forward, grad);
});

return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
}

TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 2);
*ret = FirstOrderGradient(args[0], args[1]);
});

struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final {
Type t = GetRef<Type>(ttn);
return TupleTypeNode::make({t, RefTypeNode::make(t)});
}
};

struct ReverseAD : ExprMutator {
Var bp;
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");

ReverseAD(const Var& bp) : bp(bp) { }

Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call";
throw;
MarisaKirisame marked this conversation as resolved.
Show resolved Hide resolved
MarisaKirisame marked this conversation as resolved.
Show resolved Hide resolved
}

Expr VisitExpr_(const CallNode* op) final {
if (const OpNode* op_node = op->op.as<OpNode>()) {
Op op_ref = GetRef<Op>(op_node);
CHECK(rev_map.count(op_ref))
<< op_node->name << " does not have reverse mode defined";
return LetList::With([&](LetList* ll) {
std::vector<Var> args;
for (const auto& arg : op->args) {
args.push_back(ll->Push(VisitExpr(arg)));
}
std::vector<Expr> grad_res;
for (const auto& a : args) {
grad_res.push_back(a->get<ADTensor>().reverse);
std::vector<Expr> orig_args;
for (const auto& arg : args) {
orig_args.push_back(GetField(VisitExpr(arg), 0));
}
return TupleNode::make(grad_res);
Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args);
Var orig_var = ll->Push(orig);
auto ref = ll->Push(RefCreateNode::make(ZerosLike(orig_var)));
auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make(
{},
LetList::With([&](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, ll->Push(RefReadNode::make(ref)));
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
ll->Push(RefWriteNode::make(GetField(args[i], 1),
Add(ll->Push(RefReadNode::make(GetField(args[i], 1))),
rev[i])));
}
return CallNode::make(bpv, {});
}),
TupleTypeNode::make({}),
{});
ll->Push(RefWriteNode::make(bp, nbp));
return Pair(orig_var, ref);
});
return Pair(res.foward, grad);
});

// if type annotations are provided, we will construct a ret type;
// otherwise, leave it to be inferred
Type ret_type = Type();
std::vector<Type> vt;
bool missing = !f->ret_type.defined();
for (const auto& p : f->params) {
if (missing || !p->type_annotation.defined()) {
missing = true;
break;
}
vt.push_back(p->type_annotation);
return ExprMutator::VisitExpr_(op);
}

Expr VisitExpr_(const ConstantNode* op) final {
Expr e = GetRef<Expr>(op);
return Pair(e, RefCreateNode::make(ZerosLike(e)));
}

if (!missing) {
ret_type = TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
Type VisitType(const Type& t) final {
return t.defined() ? ReverseADType()(t) : t;
}
};

return FunctionNode::make(f->params, body, ret_type, {});
Expr BPEmpty() {
Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
return RefCreateNode::make(unitF);
}

TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 2);
*ret = FirstOrderGradient(args[0], args[1]);
});
Expr Gradient(const Expr& re, const Module& mod) {
auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>();
CHECK(f) << "input need to be a function";
CHECK(f->type_params.size() == 0) << "no polymorphism supported for now";
Expr body = LetList::With([&](LetList* ll) {
Var bp = ll->Push(BPEmpty());
Expr rev = ReverseAD(bp)(e);
std::vector<Expr> args;
for (const auto& p : f->params) {
args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p)))));
}
auto c = ll->Push(CallNode::make(rev, args));
ll->Push(RefWriteNode::make(GetField(c, 1), OnesLike(GetField(c, 0))));
ll->Push(CallNode::make(RefReadNode::make(bp), {}));
std::vector<Expr> ret;
for (const auto& a : args) {
ret.push_back(RefReadNode::make(GetField(a, 1)));
}
return Pair(GetField(c, 0), TupleNode::make(ret));
});
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
}

TVM_REGISTER_API("relay._ir_pass.gradient")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 2);
*ret = Gradient(args[0], args[1]);
});

} // namespace relay
} // namespace tvm
4 changes: 2 additions & 2 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,12 @@ inline Expr Divide(Expr lhs, Expr rhs) {
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}

inline Expr ZeroLike(Expr e) {
inline Expr ZerosLike(Expr e) {
static const Op& op = Op::Get("zeros_like");
return CallNode::make(op, {e});
}

inline Expr OneLike(Expr e) {
inline Expr OnesLike(Expr e) {
static const Op& op = Op::Get("ones_like");
return CallNode::make(op, {e});
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ bool TupleGetItemRel(const Array<Type>& types,
const auto* param = attrs.as<TupleGetItemAttrs>();
CHECK(param != nullptr);
CHECK_GE(param->index, 0);
CHECK_LT(param->index, data->fields.size());
CHECK_LT(param->index, data->fields.size());
reporter->Assign(types[1], data->fields[param->index]);
return true;
}
Expand Down
Loading