From 630cf61b08b3ed365969c95302e62eb529b26375 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Thu, 27 Aug 2020 11:32:40 -0700 Subject: [PATCH] [Relay][Training] Make AutoDiff thread through global function. (#6336) * save * lint * lint * fix warning * fix test * save --- src/printer/doc.cc | 2 +- src/relay/transforms/gradient.cc | 106 ++++++++++++++++++----- tests/python/relay/test_pass_gradient.py | 41 ++++++++- 3 files changed, 124 insertions(+), 25 deletions(-) diff --git a/src/printer/doc.cc b/src/printer/doc.cc index d487e3e7aa3e..ab1eddbe7d1e 100644 --- a/src/printer/doc.cc +++ b/src/printer/doc.cc @@ -129,7 +129,7 @@ Doc Doc::Indent(int indent, Doc doc) { } Doc Doc::StrLiteral(const std::string& value, std::string quote) { - // TODO(M.K.): add escape. + // TODO(@M.K.): add escape. Doc doc; return doc << quote << value << quote; } diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index 7894c34de55d..9c472542cc91 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -72,7 +72,7 @@ Type WithGradientType(const Type&); Expr FirstOrderGradient(const Expr& e, const Optional& mod); Type WithGradientType(const Type& t) { - // TODO(M.K.): stricter checking + // TODO(@M.K.): stricter checking auto ty = t.as(); CHECK(ty) << "input should be a function"; return FuncType(ty->arg_types, TupleType({ty->ret_type, TupleType(ty->arg_types)}), {}, {}); @@ -85,7 +85,7 @@ Expr DeGlobal(const Optional& mod, const Expr& e) { if (mod.defined() && x) { BaseFunc base_func = mod.value()->Lookup(GetRef(x)); if (auto* n = base_func.as()) { - return n->body; + return GetRef(n); } else { return e; } @@ -338,11 +338,22 @@ Expr FirstOrderGradient(const Expr& re, const Optional& mod) { TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient").set_body_typed(FirstOrderGradient); +Type bpt = RelayRefType(FuncType({}, TupleType(Array()), {}, {})); + struct ReverseADType : TypeMutator { Type VisitType_(const TensorTypeNode* ttn) final { Type t = GetRef(ttn); return TupleType({t, RelayRefType(t)}); } + + Type VisitType_(const FuncTypeNode* ftn) final { + std::vector arg_types; + for (const auto& t : ftn->arg_types) { + arg_types.push_back(VisitType(t)); + } + arg_types.push_back(bpt); + return FuncType(arg_types, ftn->ret_type, ftn->type_params, ftn->type_constraints); + } }; Type ReverseType(const Type& t) { return ReverseADType()(t); } @@ -438,12 +449,18 @@ Expr BPEmpty() { struct ReverseAD : ExprMutator { using ADVarMap = std::unordered_map; - + using ADGlobalVarMap = std::unordered_map; + Optional mod; + // TODO(@M.K.) refactor AD to always use mod. Var bp; std::shared_ptr ad_vars; + std::shared_ptr ad_gvars; const OpAttrMap rev_map = Op::GetAttrMap("FPrimalGradient"); - explicit ReverseAD(const Var& bp, std::shared_ptr ad_vars) : bp(bp), ad_vars(ad_vars) {} + explicit ReverseAD(const Optional& mod, const Var& bp, + const std::shared_ptr& ad_vars, + const std::shared_ptr& ad_gvars) + : mod(mod), bp(bp), ad_vars(ad_vars), ad_gvars(ad_gvars) {} Expr VisitExpr_(const OpNode* op) final { LOG(FATAL) << "op should only be inside call"; @@ -481,9 +498,8 @@ struct ReverseAD : ExprMutator { Expr nbp = Function({}, LetList::With([&](LetList* ll) { // we need a new ReverseAD visitor to avoid clobbering the bp local var auto dup_bp = ll->Push(BPEmpty()); - ReverseAD dup_diff(dup_bp, ad_vars); - auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x))); - + auto dup_ad = + ll->Push(ReverseAD(mod, dup_bp, ad_vars, ad_gvars)(DeDup(x))); TransferGrads(call->checked_type(), ret, dup_ad, ll); ll->Push(Call(RefRead(dup_bp), {})); return Call(bpv, {}); @@ -518,22 +534,29 @@ struct ReverseAD : ExprMutator { orig_var->checked_type_ = call->checked_type(); auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll)); auto bpv = ll->Push(RefRead(bp)); - Expr nbp = Function({}, LetList::With([&](LetList* ll) { - tvm::Array rev = - rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); - CHECK(args.size() == rev.size()); - for (size_t i = 0; i < args.size(); ++i) { - UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll); - } - return Call(bpv, {}); - }), - TupleType::Empty(), {}); + Expr nbp_body = LetList::With([&](LetList* ll) { + tvm::Array rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); + CHECK(args.size() == rev.size()); + for (size_t i = 0; i < args.size(); ++i) { + UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll); + } + return Call(bpv, {}); + }); + Expr nbp = Function({}, nbp_body, TupleType::Empty(), {}); ll->Push(RefWrite(bp, transform::ToANormalForm(nbp))); // TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that. return ret; }); + } else if (call->op.as()) { + return ExprMutator::VisitExpr_(call); + } else { + std::vector args; + for (const auto& arg : call->args) { + args.push_back(VisitExpr(arg)); + } + args.push_back(bp); + return Call(VisitExpr(call->op), args); } - return ExprMutator::VisitExpr_(call); } Expr VisitExpr_(const ConstantNode* op) final { @@ -559,6 +582,39 @@ struct ReverseAD : ExprMutator { return ad_vars->at(var_ref); } + Expr VisitExpr_(const GlobalVarNode* op) final { + // todo: concatenating string to add attribute seems like a brittle hack. + // maybe get module indexed by a rose tree of string? + CHECK(mod.defined()); + auto orig_gv = GetRef(op); + if (ad_gvars->count(orig_gv) == 0) { + GlobalVar gv(op->name_hint + "_grad"); + (*ad_gvars)[orig_gv] = gv; + Function orig_f = Downcast(DeDup(mod.value()->Lookup(orig_gv))); + std::vector params; + for (const auto& p : orig_f->params) { + params.push_back(Downcast(VisitExpr(p))); + } + params.push_back(bp); + Expr body = VisitExpr(orig_f->body); + Function f(params, body, VisitType(orig_f->ret_type), orig_f->type_params, orig_f->attrs); + std::cout << "gv " << op->name_hint << ": " << AsText(f, false) << std::endl; + mod.value()->Add(gv, f); + } + return ad_gvars->at(orig_gv); + } + + Expr VisitExpr_(const FunctionNode* op) final { + std::vector params; + for (const auto& var : op->params) { + params.push_back(Downcast(VisitExpr(var))); + } + auto new_bp = Var("bp", bpt); + params.push_back(new_bp); + return Function(params, ReverseAD(mod, new_bp, ad_vars, ad_gvars)(op->body), + VisitType(op->ret_type), op->type_params, op->attrs); + } + Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; } }; @@ -604,12 +660,16 @@ Expr Gradient(const Expr& re, const Optional& mod) { } CHECK(!MissingGrad(e)) << "input has operators with missing gradients"; Expr body = LetList::With([&](LetList* ll) { - Var bp = ll->Push(BPEmpty()); - Expr rev = ReverseAD(bp, std::make_shared())(e); - std::vector args; + Var bp = ll->Push(BPEmpty(), bpt); + Expr rev = ReverseAD(mod, bp, std::make_shared(), + std::make_shared())(e); + std::vector normal_args, args; for (const auto& p : f->params) { - args.push_back(ll->Push(Pair(p, RefCreate(ZerosLike(p))))); + auto x = ll->Push(Pair(p, RefCreate(ZerosLike(p)))); + normal_args.push_back(x); + args.push_back(x); } + args.push_back(bp); auto c = ll->Push(Call(rev, args)); std::function init_grad; init_grad = [&](const Expr& e, const Type& t) { @@ -626,7 +686,7 @@ Expr Gradient(const Expr& re, const Optional& mod) { init_grad(c, f->body->checked_type()); ll->Push(Call(RefRead(bp), {})); std::vector ret; - for (const auto& a : args) { + for (const auto& a : normal_args) { ret.push_back(RefRead(GetField(a, 1))); } std::function get_final_result; diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 296d3e5e9354..b239ef4fc4a6 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -21,6 +21,7 @@ import tvm from tvm import te from tvm import relay +from tvm.relay import GlobalVar from tvm.relay.analysis import free_vars, free_type_vars from tvm.relay import create_executor, transform from tvm.relay.transform import gradient @@ -29,7 +30,7 @@ import tvm.relay.op as op -def test_id(): +def test_fo_id(): shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) @@ -44,6 +45,21 @@ def test_id(): tvm.testing.assert_allclose(forward.asnumpy(), x.asnumpy()) tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy())) +def test_id(): + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + x = relay.var("x", t) + func = relay.Function([x], x) + func = run_infer_type(func) + back_func = run_infer_type(gradient(func)) + assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) + ex = create_executor() + x = rand(dtype, *shape) + forward, (grad,) = ex.evaluate(back_func)(x) + tvm.testing.assert_allclose(forward.asnumpy(), x.asnumpy()) + tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy())) + def test_relu(): shape = (10, 10) @@ -341,5 +357,28 @@ def test_no_duplication(): counts = count_ops(gr) assert counts['nn.dense'] == 3, "We expect 3 dense (1 forward, two backward)" + +def test_global_function(): + m = tvm.IRModule() + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + x = relay.Var('x', t) + d = GlobalVar('double') + m[d] = relay.Function([x], x + x) + y = relay.Var('y', t) + q = GlobalVar('q') + m[q] = relay.Function([y], d(d(y))) + g = GlobalVar('grad') + m[g] = tvm.relay.transform.gradient(q, m) + back_func = m[g] + assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) + ex = create_executor(mod=m) + x = rand(dtype, *shape) + forward, (grad,) = ex.evaluate(back_func)(x) + tvm.testing.assert_allclose(forward.asnumpy(), 4 * x.asnumpy()) + tvm.testing.assert_allclose(grad.asnumpy(), 4 * np.ones_like(x.asnumpy())) + + if __name__ == "__main__": pytest.main([__file__])