From 5456d56c2883eb10f558ba2af65a0fc02b48aa6d Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Mon, 23 Mar 2020 20:09:20 -0700 Subject: [PATCH] increase code readability --- include/tvm/relay/transform.h | 2 +- python/tvm/relay/transform/transform.py | 8 +- ...gradient_cell.cc => lazy_gradient_init.cc} | 235 +++++++++--------- ...ell.py => test_pass_lazy_gradient_init.py} | 72 +++--- 4 files changed, 162 insertions(+), 155 deletions(-) rename src/relay/transforms/{gradient_cell.cc => lazy_gradient_init.cc} (57%) rename tests/python/relay/{test_pass_gradient_cell.py => test_pass_lazy_gradient_init.py} (85%) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 3f5be9694649d..deb084c65d546 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -89,7 +89,7 @@ TVM_DLL Pass DeadCodeElimination(bool inline_once = false); * * \return the pass */ -TVM_DLL Pass GradientCell(); +TVM_DLL Pass LazyGradientInit(); /*! * \brief Fold constant expressions. diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 3edc1bf39aa05..b406459d0a16e 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -219,8 +219,8 @@ def DeadCodeElimination(inline_once=False): """ return _ffi_api.DeadCodeElimination(inline_once) -def GradientCell(): - """Reduces memory usage of tensors with all 0s or 1s +def LazyGradientInit(): + """Reduces memory usage of gradient tensors Parameters ---------- @@ -228,9 +228,9 @@ def GradientCell(): Returns ------- ret: tvm.relay.Pass - The registered pass that delays or reduces memory allocation + A pass which delays and/or reduces memory allocation, by lazily allocating 0 or one filled tensors. """ - return _ffi_api.GradientCell() + return _ffi_api.LazyGradientInit() def FoldConstant(): """Fold the constant expressions in a Relay program. diff --git a/src/relay/transforms/gradient_cell.cc b/src/relay/transforms/lazy_gradient_init.cc similarity index 57% rename from src/relay/transforms/gradient_cell.cc rename to src/relay/transforms/lazy_gradient_init.cc index 2c21c751a0e7a..ba6ca05663bbf 100644 --- a/src/relay/transforms/gradient_cell.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -19,12 +19,14 @@ /*! * - * \file gradient_cell.cc + * \file lazy_gradient_init.cc * - * \brief Convert all tensors to a Gradient Cell + * \brief Lazily instantiate 0-filled or 1-filled tensors. + * This pass should be used after reverse-mode ad so that gradient tensors + * are not instantiated until after the forward pass. * * This pass delays or removes memory allocation by converting tensors into - * GradCell, an algebraic data type defined in gradient.rly + * GradCell, an algebraic data type defined in gradient.rly. * * This will delay or decrease memory usage. All calls to * ones, ones_like, zeros, zeros_like will call the One or Zero constructor @@ -67,13 +69,28 @@ namespace tvm { namespace relay { /*! -* \brief Visitor to wrap inputs +* \brief Visitor appropriately wraps tensors with Raw constructor +* +* Recursively looks at the type of the expression (TensorType or TupleType are only supported for now) +* and either call the GradCell constructor if TensorType +* or unfold and recursively visit if TupleType */ class InputVisitor: public ExprFunctor { public: explicit InputVisitor(IRModule module): module_(module) {} - Expr wrapExpr(const Expr expr, const Type& type) { + Expr VisitExpr_(const VarNode* op, const Type& t) final { + std::cout << op->type_annotation << std::endl; + return WrapExpr(GetRef(op), op->type_annotation); + } + + Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { + return WrapExpr(GetRef(op), t); + } + private: + IRModule module_; + + Expr WrapExpr(const Expr expr, const Type& type) { if (type.as()) { return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); @@ -89,27 +106,30 @@ class InputVisitor: public ExprFunctor { return expr; } - - Expr VisitExpr_(const VarNode* op, const Type& t) final { - std::cout << op->type_annotation << std::endl; - return wrapExpr(GetRef(op), op->type_annotation); - } - - Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { - return wrapExpr(GetRef(op), t); - } - private: - IRModule module_; }; /*! -* \brief Visitor to unwrap output +* \brief Visitor appropriately unwraps expressions with GradCell type into Tensors +* +* Recursively looks at the type of the expression +* and either use the FromGradCell function if TypeCall to GradCell +* or unfold and recursively visit if TupleType */ class OutputVisitor: public ExprFunctor { public: explicit OutputVisitor(IRModule module): module_(module) {} - Expr unwrapExpr(const Expr expr, const Type& type) { + Expr VisitExpr_(const CallNode* op, const Type& t) final { + return UnwrapExpr(GetRef(op), t); + } + + Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { + return UnwrapExpr(GetRef(op), t); + } + private: + IRModule module_; + + Expr UnwrapExpr(const Expr expr, const Type& type) { if (auto* type_call = type.as()) { if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) { return Call(module_->GetGlobalVar("FromGradCell"), {expr}); @@ -127,32 +147,22 @@ class OutputVisitor: public ExprFunctor { return expr; } - - Expr VisitExpr_(const CallNode* op, const Type& t) final { - return unwrapExpr(GetRef(op), t); - } - - Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { - return unwrapExpr(GetRef(op), t); - } - private: - IRModule module_; }; -class GradientCellTransform: public ExprMutator, public TypeMutator { +class LazyGradientInitializer: public ExprMutator, public TypeMutator { public: - explicit GradientCellTransform(IRModule module): + explicit LazyGradientInitializer(IRModule module): module_(module) { module_->ImportFromStd("gradient.rly"); } /*! - * \brief apply GradientCell transformation and wrap function + * \brief apply LazyGradientInit transformation and wrap function * so that function type stays the same * * input/output types should only be a combination of TupleTypes and TensorTypes */ - Expr transform(const Expr& e) { + Expr Transform(const Expr& e) { auto* f = (e).as(); auto* transformed = this->Mutate(e).as(); @@ -179,90 +189,46 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { } Expr VisitExpr_(const CallNode* call_node) final { - // optimize operators if (auto* op = (call_node->op).as()) { Expr op_expr = GetRef(op); - if (op_expr == Op::Get("add") && call_node->args.size() == 2 && - AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { - // case: "add" between two tensors of the same size - const auto addFunc = module_->GetGlobalVar("AddGradCell"); - tvm::Array args; - // create add function - Type paramType = call_node->args[0]->checked_type(); - tvm::Array params = {Var("lhs", paramType), - Var("rhs", paramType)}; - Expr callAdd = Call(Op::Get("add"), {params[0], params[1]}); - Expr addTensorsFunc = Function(params, callAdd, paramType, - Array()); - - // pass add function and tensors into arguments - args.push_back(addTensorsFunc); - for (Expr expr : call_node->args) { - args.push_back(VisitExpr(expr)); - } - return Call(addFunc, args, Attrs(), {paramType}); - } else if (op_expr == Op::Get("multiply") && call_node->args.size() == 2 && - AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { - // case: "multiply" between two tensors of the same size - const auto multFunc = module_->GetGlobalVar("MultiplyGradCell"); - // create multiply function - tvm::Array args; - Type paramType = call_node->args[0]->checked_type(); - tvm::Array params = {Var("lhs", paramType), - Var("rhs", paramType)}; - Expr callMultiply = Call(Op::Get("multiply"), - {params[0], params[1]}); - Expr multTensorsFunc = Function(params, callMultiply, paramType, - Array()); - - // pass multiply function and tensors into arguments - args.push_back(multTensorsFunc); - for (Expr expr : call_node->args) { - args.push_back(VisitExpr(expr)); - } - return Call(multFunc, args, Attrs(), {paramType}); - } else if (op_expr == Op::Get("ones")) { - // ones operator, use One constructor of GradCell - Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, - {call_node->checked_type()}, {}); - return Call(module_->GetConstructor("GradCell", "One"), - {func}, Attrs(), {call_node->checked_type()}); - } else if (op_expr == Op::Get("zeros")) { - // zeros operator, use Zero constructor of GradCell - Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, - {call_node->checked_type()}, {}); - return Call(module_->GetConstructor("GradCell", "Zero"), - {func}, Attrs(), {call_node->checked_type()}); + + if (op_expr == Op::Get("add")) { + return CallGradCellFunction(call_node, module_->GetGlobalVar("AddGradCell")); } - // handle other ops + zeros_like + ones_like - // we put zeros_like and ones_like here to make use of - // code converting the arguments of CallNode into Tensor - const auto fromFunc = module_->GetGlobalVar("FromGradCell"); - tvm::Array args; - // use FromGradCell to convert args to Tensor - for (Expr expr : call_node->args) { - args.push_back(Call(fromFunc, - {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); + if (op_expr == Op::Get("multiply")) { + return CallGradCellFunction(call_node, module_->GetGlobalVar("MultiplyGradCell")); } - const Expr tensorRes = Call(call_node->op, args); + if (op_expr == Op::Get("ones") || op_expr == Op::Get("zeros")) { + // fn() -> T, function returns result of the operation + Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, + {call_node->checked_type()}, {}); + // call appropriate GradCell constructor + std::string constructor_name = op_expr == Op::Get("ones") ? "One" : "Zero"; + return Call(module_->GetConstructor("GradCell", constructor_name), + {func}, Attrs(), {call_node->checked_type()}); + } - if (op_expr == Op::Get("ones_like")) { - Expr onesFunction = Function({}, tensorRes, + if (op_expr == Op::Get("ones_like") || op_expr == Op::Get("zeros_like")) { + // ones_like and zeros_like need TensorType input + Expr result = CallPrimitiveOp(call_node); + // fn() -> T, function returns result of operation + Expr func = Function({}, result, {call_node->checked_type()}, Array()); + // call appropriate GradCell constructor + std::string constructor_name = op_expr == Op::Get("ones_like") ? "One" : "Zero"; return Call(module_->GetConstructor("GradCell", "One"), - {onesFunction}, Attrs(), {call_node->checked_type()}); - } else if (op_expr == Op::Get("zeros_like")) { - Expr zerosFunction = Function({}, tensorRes, - {call_node->checked_type()}, Array()); - return Call(module_->GetConstructor("GradCell", "Zero"), - {zerosFunction}, Attrs(), {call_node->checked_type()}); + {func}, Attrs(), {call_node->checked_type()}); } - return Call(module_->GetConstructor("GradCell", "Raw"), {tensorRes}, + + // handle all other ops + Expr result = CallPrimitiveOp(call_node); + // wrap result with Raw constructor + return Call(module_->GetConstructor("GradCell", "Raw"), {result}, Attrs(), {call_node->checked_type()}); } - // call-> op is not a relay op + // not an op return ExprMutator::VisitExpr_(call_node); } @@ -280,23 +246,70 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { private: // Module IRModule module_; + + /*! + * \brief Convert call_node to add/multiply op to use overloaded functions for GradCell type + */ + Expr CallGradCellFunction(const CallNode* call_node, GlobalVar overloaded_op) { + // can only use overloaded functions if 2 arguments of same type + if (call_node->args.size() != 2 || + !AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { + Expr result = CallPrimitiveOp(call_node); + return Call(module_->GetConstructor("GradCell", "Raw"), {result}, + Attrs(), {call_node->checked_type()}); + } + + tvm::Array args; + // create "fallback" function for overloaded function + Type paramType = call_node->args[0]->checked_type(); + tvm::Array params = {Var("lhs", paramType), + Var("rhs", paramType)}; + // use primitive op in this case + Expr callOp = Call(call_node->op, {params[0], params[1]}); + Expr func = Function(params, callOp, paramType, + Array()); + + // pass "fallback" function and tensors as arguments + args.push_back(func); + for (Expr expr : call_node->args) { + args.push_back(VisitExpr(expr)); + } + // return new call to overloaded function + return Call(overloaded_op, args, Attrs(), {paramType}); + } + + /*! + * \brief Convert calls to other ops by converting args into TensorType + * \return call expr returning result of op + */ + Expr CallPrimitiveOp(const CallNode* call_node) { + const auto fromFunc = module_->GetGlobalVar("FromGradCell"); + tvm::Array args; + // use FromGradCell to convert args to Tensor + for (Expr expr : call_node->args) { + args.push_back(Call(fromFunc, + {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); + } + // result of operation + return Call(call_node->op, args); + } }; -Expr GradientCell(const Expr& e, IRModule mod) { - return GradientCellTransform(mod).transform(e); +Expr LazyGradientInit(const Expr& e, IRModule mod) { + return LazyGradientInitializer(mod).Transform(e); } namespace transform { -Pass GradientCell() { +Pass LazyGradientInit() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(GradientCell(f, m)); + return Downcast(LazyGradientInit(f, m)); }; - return CreateFunctionPass(pass_func, 2, "GradientCell", {}); + return CreateFunctionPass(pass_func, 2, "LazyGradientInit", {}); } -TVM_REGISTER_GLOBAL("relay._transform.GradientCell") -.set_body_typed(GradientCell); +TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit") +.set_body_typed(LazyGradientInit); } // namespace transform diff --git a/tests/python/relay/test_pass_gradient_cell.py b/tests/python/relay/test_pass_lazy_gradient_init.py similarity index 85% rename from tests/python/relay/test_pass_gradient_cell.py rename to tests/python/relay/test_pass_lazy_gradient_init.py index 2055771ba9eb5..f9c762e5f9055 100644 --- a/tests/python/relay/test_pass_gradient_cell.py +++ b/tests/python/relay/test_pass_lazy_gradient_init.py @@ -24,7 +24,7 @@ import pytest def test_tc(): - # test typechecks + """Simple testcase, check that transformation typechecks.""" mod = tvm.IRModule() shape = (20, 20) @@ -37,13 +37,13 @@ def test_tc(): y = relay.Function([x1, x2], (x1 - x2) * x2) mod["main"] = y - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) # function input/output types should remain the same assert mod["main"].checked_type == relay.FuncType([t, t], t) def test_add(): - # test simple add + """Simple add testcase. Check types and semantic equivalence.""" mod = tvm.IRModule() shape = (10, 10) @@ -55,7 +55,7 @@ def test_add(): y = relay.Function([x], x+x) mod["main"] = y - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], t) @@ -66,7 +66,7 @@ def test_add(): assert_allclose(y.asnumpy(), x.asnumpy() + x.asnumpy()) def test_add_tuple(): - # test input tuple and add items + """Add elements of tuple. Check types and semantic equivalence.""" mod = tvm.IRModule() shape = (10, 10) @@ -79,7 +79,7 @@ def test_add_tuple(): y = relay.Function([x], relay.TupleGetItem(x, 0) + relay.TupleGetItem(x, 1)) mod["main"] = y - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) mod = transform.PrintIR(show_meta_data=True)(mod) y = mod["main"] @@ -91,7 +91,7 @@ def test_add_tuple(): assert_allclose(y.asnumpy(), x[0].asnumpy() + x[1].asnumpy()) def test_mult(): - # test simple add + """Simple multiplication testcase. Check types and semantic equivalence.""" mod = tvm.IRModule() shape = (15, 15) @@ -103,7 +103,7 @@ def test_mult(): y = relay.Function([x], x * x) mod["main"] = y - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], t) @@ -114,7 +114,7 @@ def test_mult(): assert_allclose(y.asnumpy(), x.asnumpy() * x.asnumpy()) def test_ret_tuple(): - # test return tuple + """Test tuple return type. Check types and semantic equivalence.""" mod = tvm.IRModule() shape = (10, 10) @@ -127,7 +127,7 @@ def test_ret_tuple(): func = run_infer_type(func) mod["main"] = func - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) func = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], relay.TupleType([t, t])) @@ -138,8 +138,8 @@ def test_ret_tuple(): assert_allclose(y[0].asnumpy(), x.asnumpy()) assert_allclose(y[1].asnumpy(), x.asnumpy() * 2.0) -def test_broadcast(): - # test broadcast add +def test_add_broadcast(): + """Test adding matrices of different size. Check types and semantic equivalence.""" mod = tvm.IRModule() shape1 = (3, 4, 1) @@ -152,30 +152,25 @@ def test_broadcast(): x2 = relay.var("x2", t2) func = relay.Function([x1,x2], x1 + x2) func = run_infer_type(func) - back_func = transform.gradient(func) - back_func = run_infer_type(back_func) - mod["main"] = back_func - mod = transform.GradientCell()(mod) - back_func = mod["main"] + mod["main"] = func + mod = transform.LazyGradientInit()(mod) + func = mod["main"] x1_np = rand(dtype, *shape1).asnumpy() x2_np = rand(dtype, *shape2).asnumpy() expected_forward = x1_np + x2_np expected_forward_type = relay.TensorType(expected_forward.shape, dtype) - assert mod["main"].checked_type == relay.FuncType([t1, t2], - relay.TupleType([expected_forward_type, relay.TupleType([t1, t2])])) + assert mod["main"].checked_type == relay.FuncType([t1, t2], expected_forward_type) ex = create_executor(mod=mod) - (forward), (grad_x1, grad_x2, ) = ex.evaluate(back_func)(x1_np, x2_np) + forward = ex.evaluate(func)(x1_np, x2_np) assert_allclose(forward.asnumpy(), expected_forward) - assert_allclose(grad_x1.asnumpy(), np.ones_like(expected_forward).sum(axis=2, keepdims=True)) - assert_allclose(grad_x2.asnumpy(), np.ones_like(expected_forward).sum(axis=(0,1), keepdims=True).squeeze(axis=0)) def test_reverse_ad_identity(): - # test correctness after reverse mode ad + """Simple test with reverse mode ad.""" # of f(x) = x mod = tvm.IRModule() @@ -191,7 +186,7 @@ def test_reverse_ad_identity(): back_func = run_infer_type(back_func) mod["main"] = back_func - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) back_func = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], @@ -204,8 +199,7 @@ def test_reverse_ad_identity(): assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy())) def test_multivar_reverse_ad(): - # test correctness after reverse mode ad - # of multivariate function + """Simple test with multivariate reverse mode ad.""" mod = tvm.IRModule() shape = (10, 10) @@ -221,7 +215,7 @@ def test_multivar_reverse_ad(): back_func = run_infer_type(back_func) mod["main"] = back_func - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) back_func = mod["main"] assert mod["main"].checked_type == relay.FuncType([t, t], @@ -236,7 +230,7 @@ def test_multivar_reverse_ad(): assert_allclose(grad_y.asnumpy(), x.asnumpy()) def test_after_partial_eval(): - # test GradientCell transformation after PartialEval + """Test transformation following reverse mode ad and PartialEval""" mod = tvm.IRModule() shape = (10, 10) @@ -256,7 +250,7 @@ def test_after_partial_eval(): seq = transform.Sequential([ transform.PartialEvaluate(), - transform.GradientCell(), + transform.LazyGradientInit(), transform.DeadCodeElimination() ]) @@ -274,7 +268,7 @@ def test_after_partial_eval(): assert_allclose(grad_y.asnumpy(), x.asnumpy()) def test_before_partial_eval(): - # test GradientCell transformation before PartialEval + """Test transformation before PartialEval""" mod = tvm.IRModule() shape = (10, 10) @@ -291,7 +285,7 @@ def test_before_partial_eval(): mod["main"] = back_func seq = transform.Sequential([ - transform.GradientCell(), + transform.LazyGradientInit(), transform.PartialEvaluate(), transform.DeadCodeElimination() ]) @@ -310,7 +304,7 @@ def test_before_partial_eval(): assert_allclose(grad_y.asnumpy(), x.asnumpy()) def test_zeros(): - # test with zeros operator + """Simple test using "zeros" op""" mod = tvm.IRModule() shape = (10, 10) @@ -321,7 +315,7 @@ def test_zeros(): y = relay.Function([x], x + relay.zeros(shape, dtype)) mod["main"] = y - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], t) @@ -332,7 +326,7 @@ def test_zeros(): assert_allclose(y.asnumpy(), x.asnumpy()) def test_ones(): - # test with ones operator + """Simple test using "ones" op""" mod = tvm.IRModule() shape = (10, 10) @@ -343,7 +337,7 @@ def test_ones(): y = relay.Function([x], x + relay.ones(shape, dtype)) mod["main"] = y - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], t) @@ -354,7 +348,7 @@ def test_ones(): assert_allclose(y.asnumpy(), x.asnumpy() + np.ones_like(x.asnumpy())) def test_zeros_like(): - # test with zeros_like operator + """Simple test using "zeros_like" op""" mod = tvm.IRModule() shape = (10, 10) @@ -365,7 +359,7 @@ def test_zeros_like(): y = relay.Function([x], x + relay.zeros_like(x)) mod["main"] = y - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], t) @@ -376,7 +370,7 @@ def test_zeros_like(): assert_allclose(y.asnumpy(), x.asnumpy()) def test_ones_like(): - # test with ones_like operator + """Simple test using "ones_like" op""" mod = tvm.IRModule() shape = (10, 10) @@ -387,7 +381,7 @@ def test_ones_like(): y = relay.Function([x], x + relay.ones_like(x)) mod["main"] = y - mod = transform.GradientCell()(mod) + mod = transform.LazyGradientInit()(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], t)