diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 04b92ba68e3b..2d8e99ae8b25 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -530,9 +530,11 @@ def to_graph_normal_form(expr): return _ir_pass.to_graph_normal_form(expr) -def gradient(expr, mod=None): +def gradient(expr, mod=None, mode='higher_order'): """ - Transform a function to return original result paired with gradient of input. + Transform the input function, + returning a function that calculate the original result, + paired with gradient of the input. Parameters ---------- @@ -541,12 +543,23 @@ def gradient(expr, mod=None): mod : Optional[tvm.relay.Module] + mode : Optional[String] + The mode of the automatic differentiation algorithm. + 'first_order' only work on first order code, but will not produce reference nor closure. + 'higher_order' work on all code using reference and closure. + Returns ------- expr : tvm.relay.Expr - The output expression. + The transformed expression. """ - return _ir_pass.first_order_gradient(expr, mod) + if mode == 'first_order': + return _ir_pass.first_order_gradient(expr, mod) + elif mode == 'higher_order': + return _ir_pass.gradient(expr, mod) + else: + raise Exception('unknown mode') + def get_total_mac_number(expr): diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 11f96c48a311..66ff9caf4ae4 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -225,6 +225,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } node->pattern = op_pattern; + this->Update(call->op, nullptr, kOpaque); const auto* rtype = call->checked_type().as(); // pass the message back to all the children it references. for (size_t i = 0; i < call->args.size(); ++i) { diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 780490a45b0a..d564e02b5596 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -85,10 +85,10 @@ using ADValue = std::shared_ptr; /*! \brief AD over a program which generates a tensor output. */ struct ADTensor : ADValueNode { - Expr foward; + Expr forward; mutable Expr reverse; // must be a variable to avoid duplication - ADTensor(LetList* ll, const Expr& foward) : - foward(ll->Push(foward)), reverse(ll->Push(ZeroLike(this->foward))) { } + ADTensor(LetList* ll, const Expr& forward) : + forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { } }; /*! \brief A staged representation of the program, we reflect @@ -105,14 +105,14 @@ struct ADFunction : ADValueNode { func(func) { } }; -struct ReverseAD : ExprFunctor { +struct FirstOrderReverseAD : ExprFunctor { const OpMap rev_map = Op::GetAttr("FPrimalGradient"); std::vector> backprop_actions; // we assume no closure so no need for lexical scoping std::unordered_map env; LetList* ll; - ReverseAD(LetList* ll) : ll(ll) { } + FirstOrderReverseAD(LetList* ll) : ll(ll) { } ADValue VisitExpr_(const OpNode* op) final { Op op_ref = GetRef(op); @@ -121,21 +121,22 @@ struct ReverseAD : ExprFunctor { return std::make_shared([this, op_ref](const std::vector& args, const Attrs& attrs, const tvm::Array& type_args) { - std::vector call_args; - for (const ADValue& adval : args) { - call_args.push_back(adval->get().foward); + std::vector call_args; + for (const ADValue& adval : args) { + call_args.push_back(adval->get().forward); + } + auto orig = CallNode::make(op_ref, call_args, attrs, type_args); + auto ret = std::make_shared(ll, orig); + backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { + tvm::Array rev = rev_map[op_ref](orig, ret->reverse); + CHECK(args.size() == rev.size()); + for (size_t i = 0; i < args.size(); ++i) { + args[i]->get().reverse = + ll->Push(Add(args[i]->get().reverse, rev[i])); } - auto orig = CallNode::make(op_ref, call_args, attrs, type_args); - auto ret = std::make_shared(ll, orig); - backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { - tvm::Array rev = rev_map[op_ref](orig, ret->reverse); - for (size_t i = 0; i < args.size(); ++i) { - args[i]->get().reverse = - ll->Push(Add(args[i]->get().reverse, rev[i])); - } - }); - return ret; }); + return ret; + }); } ADValue VisitExpr_(const ConstantNode* op) final { @@ -172,6 +173,23 @@ struct ReverseAD : ExprFunctor { } }; +Type GradRetType(const Function& f) { + // if type annotations are provided, we will construct a ret type; + // otherwise, leave it to be inferred + if (!f->ret_type.defined()) { + return Type(); + } + std::vector vt; + for (const auto& p : f->params) { + if (!p->type_annotation.defined()) { + return Type(); + } + vt.push_back(p->type_annotation); + } + + return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)}); +} + Expr FirstOrderGradient(const Expr& re, const Module& mod) { // Currently we first remove any global functions for the first // order case. @@ -182,7 +200,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { // We will then build a sequence of lets which implement reverse mode. Expr body = LetList::With([&](LetList* ll) { - ReverseAD reverse_ad(ll); + FirstOrderReverseAD reverse_ad(ll); ADValue rev = reverse_ad(e); std::vector args; for (const auto& p : f->params) { @@ -191,46 +209,131 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { auto c = rev->get().func(args, Attrs(), {}); const auto& res = c->get(); Expr grad = LetList::With([&](LetList* ll) { - res.reverse = OneLike(res.foward); - for (auto it = reverse_ad.backprop_actions.rbegin(); - it != reverse_ad.backprop_actions.rend(); - ++it) { - (*it)(ll); + res.reverse = OnesLike(res.forward); + for (auto it = reverse_ad.backprop_actions.rbegin(); + it != reverse_ad.backprop_actions.rend(); + ++it) { + (*it)(ll); + } + std::vector grad_res; + for (const auto& a : args) { + grad_res.push_back(a->get().reverse); + } + return TupleNode::make(grad_res); + }); + return Pair(res.forward, grad); + }); + + return FunctionNode::make(f->params, body, GradRetType(GetRef(f)), {}); +} + +TVM_REGISTER_API("relay._ir_pass.first_order_gradient") +.set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args.size(), 2); + *ret = FirstOrderGradient(args[0], args[1]); +}); + +struct ReverseADType : TypeMutator { + Type VisitType_(const TensorTypeNode* ttn) final { + Type t = GetRef(ttn); + return TupleTypeNode::make({t, RefTypeNode::make(t)}); + } +}; + +struct ReverseAD : ExprMutator { + Var bp; + const OpMap rev_map = Op::GetAttr("FPrimalGradient"); + + ReverseAD(const Var& bp) : bp(bp) { } + + Expr VisitExpr_(const OpNode* op) final { + LOG(FATAL) << "op should only be inside call"; + throw; + } + + Expr VisitExpr_(const CallNode* op) final { + if (const OpNode* op_node = op->op.as()) { + Op op_ref = GetRef(op_node); + CHECK(rev_map.count(op_ref)) + << op_node->name << " does not have reverse mode defined"; + return LetList::With([&](LetList* ll) { + std::vector args; + for (const auto& arg : op->args) { + args.push_back(ll->Push(VisitExpr(arg))); } - std::vector grad_res; - for (const auto& a : args) { - grad_res.push_back(a->get().reverse); + std::vector orig_args; + for (const auto& arg : args) { + orig_args.push_back(GetField(VisitExpr(arg), 0)); } - return TupleNode::make(grad_res); + Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args); + Var orig_var = ll->Push(orig); + auto ref = ll->Push(RefCreateNode::make(ZerosLike(orig_var))); + auto bpv = ll->Push(RefReadNode::make(bp)); + Expr nbp = FunctionNode::make( + {}, + LetList::With([&](LetList* ll) { + tvm::Array rev = rev_map[op_ref](orig, ll->Push(RefReadNode::make(ref))); + CHECK(args.size() == rev.size()); + for (size_t i = 0; i < args.size(); ++i) { + ll->Push(RefWriteNode::make(GetField(args[i], 1), + Add(ll->Push(RefReadNode::make(GetField(args[i], 1))), + rev[i]))); + } + return CallNode::make(bpv, {}); + }), + TupleTypeNode::make({}), + {}); + ll->Push(RefWriteNode::make(bp, nbp)); + return Pair(orig_var, ref); }); - return Pair(res.foward, grad); - }); - - // if type annotations are provided, we will construct a ret type; - // otherwise, leave it to be inferred - Type ret_type = Type(); - std::vector vt; - bool missing = !f->ret_type.defined(); - for (const auto& p : f->params) { - if (missing || !p->type_annotation.defined()) { - missing = true; - break; } - vt.push_back(p->type_annotation); + return ExprMutator::VisitExpr_(op); + } + + Expr VisitExpr_(const ConstantNode* op) final { + Expr e = GetRef(op); + return Pair(e, RefCreateNode::make(ZerosLike(e))); } - if (!missing) { - ret_type = TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)}); + Type VisitType(const Type& t) final { + return t.defined() ? ReverseADType()(t) : t; } +}; - return FunctionNode::make(f->params, body, ret_type, {}); +Expr BPEmpty() { + Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {}); + return RefCreateNode::make(unitF); } -TVM_REGISTER_API("relay._ir_pass.first_order_gradient") - .set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args.size(), 2); - *ret = FirstOrderGradient(args[0], args[1]); - }); +Expr Gradient(const Expr& re, const Module& mod) { + auto e = DeGlobal(mod, re); + auto f = e.as(); + CHECK(f) << "input need to be a function"; + CHECK(f->type_params.size() == 0) << "no polymorphism supported for now"; + Expr body = LetList::With([&](LetList* ll) { + Var bp = ll->Push(BPEmpty()); + Expr rev = ReverseAD(bp)(e); + std::vector args; + for (const auto& p : f->params) { + args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p))))); + } + auto c = ll->Push(CallNode::make(rev, args)); + ll->Push(RefWriteNode::make(GetField(c, 1), OnesLike(GetField(c, 0)))); + ll->Push(CallNode::make(RefReadNode::make(bp), {})); + std::vector ret; + for (const auto& a : args) { + ret.push_back(RefReadNode::make(GetField(a, 1))); + } + return Pair(GetField(c, 0), TupleNode::make(ret)); + }); + return FunctionNode::make(f->params, body, GradRetType(GetRef(f)), {}); +} + +TVM_REGISTER_API("relay._ir_pass.gradient") +.set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args.size(), 2); + *ret = Gradient(args[0], args[1]); +}); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index e59efa958310..96038745474e 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -299,12 +299,12 @@ inline Expr Divide(Expr lhs, Expr rhs) { return CallNode::make(op, {lhs, rhs}, Attrs(), {}); } -inline Expr ZeroLike(Expr e) { +inline Expr ZerosLike(Expr e) { static const Op& op = Op::Get("zeros_like"); return CallNode::make(op, {e}); } -inline Expr OneLike(Expr e) { +inline Expr OnesLike(Expr e) { static const Op& op = Op::Get("ones_like"); return CallNode::make(op, {e}); } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 8dd02f39adce..ea6b9a95da50 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -53,7 +53,7 @@ bool TupleGetItemRel(const Array& types, const auto* param = attrs.as(); CHECK(param != nullptr); CHECK_GE(param->index, 0); - CHECK_LT(param->index, data->fields.size()); + CHECK_LT(param->index, data->fields.size()); reporter->Assign(types[1], data->fields[param->index]); return true; } diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 6b5d0e776934..400941f12617 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -2,6 +2,7 @@ from tvm import relay from tvm.relay.ir_pass import free_vars, free_type_vars, gradient from tvm.relay import create_executor +from tvm.relay.prelude import Prelude import numpy as np @@ -123,6 +124,72 @@ def test_broadcast_subtract(): -np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0)) +def test_tuple(): + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + x = relay.var("x", t) + y = relay.var("y", t) + z = relay.var("z", t) + tup = relay.Var("tup") + func = relay.Function([x, y, z], relay.Let(tup, relay.Tuple([x, y, z]), + relay.TupleGetItem(tup, 0) + + relay.TupleGetItem(tup, 1) - + relay.TupleGetItem(tup, 2))) + back_func = relay.ir_pass.infer_type(gradient(func)) + assert back_func.checked_type == relay.FuncType([t, t, t], relay.TupleType([t, relay.TupleType([t, t, t])])) + x_nd = rand(dtype, *shape) + y_nd = rand(dtype, *shape) + z_nd = rand(dtype, *shape) + x_np = x_nd.asnumpy() + y_np = y_nd.asnumpy() + z_np = z_nd.asnumpy() + expected_forward = x_np + y_np - z_np + ex = create_executor() + forward, (grad_x, grad_y, grad_z) = ex.evaluate(back_func)(x_nd, y_nd, z_nd) + np.testing.assert_allclose(forward.asnumpy(), expected_forward) + np.testing.assert_allclose(grad_x.asnumpy(), np.ones_like(grad_x.asnumpy())) + np.testing.assert_allclose(grad_y.asnumpy(), np.ones_like(grad_y.asnumpy())) + np.testing.assert_allclose(grad_z.asnumpy(), -1 * np.ones_like(grad_z.asnumpy())) + + +def test_pow(): + mod = relay.Module() + p = Prelude(mod) + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + x = relay.var("x", t) + double = relay.Function([x], x + x) + i = relay.var("i", t) + func = relay.Function([i], relay.Call(p.iterate(double, p.s(p.s(p.s(p.z())))), [i])) + back_func = relay.ir_pass.infer_type(gradient(func, mod=mod), mod=mod) + assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) + i_nd = rand(dtype, *shape) + ex = create_executor(mod=mod) + forward, (grad_i,) = ex.evaluate(back_func)(i_nd) + np.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy()) + np.testing.assert_allclose(grad_i.asnumpy(), 8 * np.ones_like(grad_i.asnumpy())) + +def test_ref(): + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + x = relay.var("x", t) + r = relay.Var("r") + u = relay.Var("u") + body = relay.RefRead(r) + body = relay.Let(u, relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r)), body) + body = relay.Let(r, relay.RefCreate(x), body) + func = relay.Function([x], body) + back_func = relay.ir_pass.infer_type(gradient(func)) + assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) + x_nd = rand(dtype, *shape) + ex = create_executor() + forward, (grad_x,) = ex.evaluate(back_func)(x_nd) + np.testing.assert_allclose(forward.asnumpy(), 2 * x_nd.asnumpy()) + np.testing.assert_allclose(grad_x.asnumpy(), 2 * np.ones_like(grad_x.asnumpy())) + if __name__ == "__main__": test_id() test_add() @@ -130,3 +197,6 @@ def test_broadcast_subtract(): test_sub() test_broadcast_add() test_broadcast_subtract() + test_tuple() + test_pow() + test_ref()