From 2509aaf5af26179d42a9d36b7a07b70c55ffda73 Mon Sep 17 00:00:00 2001 From: ArmageddonKnight Date: Mon, 20 Apr 2020 01:31:50 -0400 Subject: [PATCH] Improve the backward mirroring implementation --- docs/static_site/src/pages/api/faq/env_var.md | 11 +- example/image-classification/README.md | 11 +- python/mxnet/rnn/rnn_cell.py | 5 + src/executor/exec_pass.h | 37 +- src/executor/graph_executor.cc | 127 +++- src/executor/graph_executor.h | 8 +- src/imperative/cached_op.h | 2 +- src/imperative/imperative.cc | 2 +- src/nnvm/gradient.cc | 709 ++++++++++++++---- src/nnvm/plan_memory.cc | 15 +- src/operator/nn/activation-inl.h | 9 +- src/operator/nn/activation.cc | 50 +- src/operator/nn/activation.cu | 46 +- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 2 +- tests/python/unittest/test_gradient.py | 125 +++ 15 files changed, 910 insertions(+), 249 deletions(-) create mode 100644 tests/python/unittest/test_gradient.py diff --git a/docs/static_site/src/pages/api/faq/env_var.md b/docs/static_site/src/pages/api/faq/env_var.md index 75255210933d..506ead707b7a 100644 --- a/docs/static_site/src/pages/api/faq/env_var.md +++ b/docs/static_site/src/pages/api/faq/env_var.md @@ -189,14 +189,13 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 - The maximum size of an NDArray slice in terms of number of parameters. - This parameter is used to slice an NDArray before synchronizing through P3Store (dist_p3). -## Memonger +## Memory Optimizations -* MXNET_BACKWARD_DO_MIRROR +* MXNET_MEMORY_OPT - Values: 0(false) or 1(true) ```(default=0)``` - - MXNet uses mirroring concept to save memory. Normally backward pass needs some forward input and it is stored in memory but you can choose to release this saved input and recalculate it in backward pass when needed. This basically trades off the computation for memory consumption. - - This parameter decides whether to do `mirror` during training for saving device memory. - - When set to `1`, during forward propagation, graph executor will `mirror` some layer's feature map and drop others, but it will re-compute this dropped feature maps when needed. - - `MXNET_BACKWARD_DO_MIRROR=1` will save 30%~50% of device memory, but retains about 95% of running speed. + - When set to `1`, MXNet will adopt various approaches to reduce the memory consumption of the model. For example, it uses the mirroring concept to save memory: Normally the backward pass needs some forward inputs to compute the gradients. Those inputs have to be stashed in memory and persistent throughout the traianing process. However, you can choose to release those saved inputs and recalculate them in the backward pass when needed. This basically trades off the computation for memory consumption. When set to `1`, during forward propagation, the graph executor will `mirror` some layers' feature maps and drop others, but it will re-compute this dropped feature maps when needed. + - This parameter decides whether to do `mirror` and/or data encodings during training for saving device memory. + - `MXNET_MEMORY_OPT=1` will save 30%~50% of device memory, but retains about 95% of running speed. - One extension of `mirror` in MXNet is called [memonger technology](https://arxiv.org/abs/1604.06174), it will only use O(sqrt(N)) memory at 75% running speed. Checkout the code [here](https://github.com/dmlc/mxnet-memonger). ## Control the profiler diff --git a/example/image-classification/README.md b/example/image-classification/README.md index 78ea94eeb440..4b4a48b33ae4 100644 --- a/example/image-classification/README.md +++ b/example/image-classification/README.md @@ -366,11 +366,12 @@ An over sized batch size may result in out of GPU memory. The common error message is `cudaMalloc failed: out of memory`. Now we can - Reduce the batch size -- Set the environment variable `MXNET_BACKWARD_DO_MIRROR` to 1. It trades off - computation for memory consumption. For example, with batch size 64, - inception-v3 uses 10G memory and trains 30 image/sec on a single K80 GPU. When - mirroring is enabled, with 10G GPU memory consumption, we can run inception-v3 - using batch size 128. The cost is that the speed reduces to 27 images/sec. +- Set the environment variable `MXNET_MEMORY_OPT=1` to perform a series of + memory optimizations (e.g., trades off computation for memory consumption). + For example, with batch size 64, inception-v3 uses 10G memory and trains 30 + image/sec on a single K80 GPU. When mirroring is enabled, with 10G GPU memory + consumption, we can run inception-v3 using batch size 128. The cost is that + the speed reduces to 27 images/sec. ## History diff --git a/python/mxnet/rnn/rnn_cell.py b/python/mxnet/rnn/rnn_cell.py index ceb33d7dcf0a..82cdc5dcf82c 100644 --- a/python/mxnet/rnn/rnn_cell.py +++ b/python/mxnet/rnn/rnn_cell.py @@ -459,6 +459,11 @@ def __call__(self, inputs, states): name='%so'%name) next_c = symbol._internal._plus(forget_gate * states[1], in_gate * in_transform, name='%sstate'%name) + next_c._set_attr({'force_mirroring' : '0'}) + # Cell states are excluded from being mirrored. The reason is because + # they do not pass through the fully-connected layers and will + # significantly increase the overall mirroring depth, incurring large + # performance overhead. next_h = symbol._internal._mul(out_gate, symbol.Activation(next_c, act_type="tanh"), name='%sout'%name) diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index e3d2fa459bc3..270c546f0f49 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -273,16 +273,17 @@ namespace pass { /*! * \brief Get the gradient graph whose outputs are gradients of xs wrt to ys. * \param graph The input graph. - * \param ys The entries we want to take gradient from. - * \param xs The input to take gradient with respect to. - * \param ys_out_grad The symbol for additional gradient to be propagate back to y. - * \param aggregate_fun Aggregation function applied to aggregate the inputs. - * \param mirror_fun Optional mirror function to do mirror optimization and save memory. - * \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same as like. - * \param zero_ops Optional, list of operators that outputs a single zero array. The first one - * must be zeros_like. - * \param copy_op_str Optional, name of the copy operation required to handle duplicates - * on the edge of the graph + * \param ys The entries to take gradient from. + * \param xs The entries to take gradient with respect to. + * \param ys_out_grad The output gradients of ys. + * \param aggregate_fun The aggregation function used for summing gradients. + * \param mirror_fun The backward mirroring function that does mirroring to save memory. + * \param zero_ops The list of operators that output a single zero array, used + * for generating zero gradient nodes. The first operator must + * be zero_like. + * \param copy_op_str The name of the copy operator that handle gradient duplicates. + * \param in_arg_shapes The shapes of input arguments, used for shape inference. + * \param in_arg_dtpyes The data types of input arguments, used for data type inference. * \return A new graph, whose outputs correspond to inputs of xs. */ inline Graph MXGradient( @@ -292,27 +293,27 @@ inline Graph MXGradient( std::vector ys_out_grad, std::function&& inputs)> aggregate_fun = nullptr, std::function mirror_fun = nullptr, - std::function - attr_hint_fun = nullptr, std::vector zero_ops = std::vector(), - std::string copy_op_str = std::string()) { + std::string copy_op_str = std::string(), + mxnet::ShapeVector in_arg_shapes = mxnet::ShapeVector(), + DTypeVector in_arg_dtypes = DTypeVector()) { graph.attrs["grad_ys"] = std::make_shared(std::move(ys)); graph.attrs["grad_xs"] = std::make_shared(std::move(xs)); graph.attrs["grad_ys_out_grad"] = std::make_shared(std::move(ys_out_grad)); + graph.attrs["in_arg_shapes"] = std::make_shared(std::move(in_arg_shapes)); + graph.attrs["in_arg_dtypes"] = std::make_shared(std::move(in_arg_dtypes)); + if (aggregate_fun != nullptr) { graph.attrs["grad_aggregate_fun"] = std::make_shared(aggregate_fun); } if (mirror_fun != nullptr) { - graph.attrs["grad_mirror_fun"] = std::make_shared(mirror_fun); - } - if (attr_hint_fun != nullptr) { - graph.attrs["attr_hint_fun"] = std::make_shared(attr_hint_fun); + graph.attrs["mirror_fun"] = std::make_shared(mirror_fun); } if (zero_ops.size()) { graph.attrs["zero_ops"] = std::make_shared(std::move(zero_ops)); } if (copy_op_str != std::string()) { - graph.attrs["copy_op"] = std::make_shared(std::move(copy_op_str)); + graph.attrs["copy_op_str"] = std::make_shared(std::move(copy_op_str)); } return ApplyPass(std::move(graph), "MXGradient"); } diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 02ce818d7fa8..4366c6060f9b 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -302,28 +302,15 @@ nnvm::NodeEntry AggregateGradient(std::vector&& v) { } } -template -inline ValueType get_node_attr( - const nnvm::Node& node, - const std::string& key, ValueType default_value) { - auto it = node.attrs.dict.find(key); - if (it == node.attrs.dict.end()) { - return default_value; - } else { - ValueType ret; - dmlc::parameter::FieldEntry e; - e.Init(key, &ret, ret); - e.Set(&ret, it->second); - return ret; - } -} /*! * \brief Create the graph for backward pass. * This is triggered by both simple_bind and bind flows. */ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, - const std::vector& grad_req_types) { + const std::vector& grad_req_types, + const ShapeVector& in_arg_shapes, + const nnvm::DTypeVector& in_arg_dtypes) { using nnvm::ObjectPtr; using nnvm::NodeEntry; // initial information @@ -356,19 +343,28 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, } } - int do_mirror = dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0); - auto need_mirror = [do_mirror](const nnvm::Node& node) -> int { - if (node.is_variable()) return 0; - const std::string& type = node.attrs.op->name; - if (type == "Dropout") return false; - if (get_node_attr(node, "__force_mirroring__", false)) return true; - if (do_mirror == 0) return false; - if (type == "Convolution") return false; - if (type == "FullyConnected") return false; - if (type == "Concat") return false; - if (type == "SoftmaxOutput") return false; - return true; - }; + std::function need_mirror = + [](const nnvm::Node& node) -> int { + if (node.is_variable()) return false; + const std::string& type = node.attrs.op->name; + if (type == "Dropout") return false; + // We follow the hidden key attribute "force_mirroring" if it is + // explicitly set. + auto iter = node.attrs.dict.find("__force_mirroring__"); + if (iter != node.attrs.dict.end()) { + bool do_mirror; + dmlc::parameter::FieldEntry e; + e.Init("__force_mirroring__", &do_mirror, do_mirror); + e.Set(&do_mirror, iter->second); + return do_mirror; + } + if (type == "Embedding") return false; + if (type == "Convolution") return false; + if (type == "FullyConnected") return false; + if (type == "Concat") return false; + if (type == "SoftmaxOutput") return false; + return true; + }; std::vector zero_ops; zero_ops.push_back(nnvm::Op::Get("zeros_like")); @@ -377,8 +373,11 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, // take gradient nnvm::Graph g_grad = nnvm::pass::MXGradient( g, symbol.outputs, xs, head_grad_entry_, - AggregateGradient, need_mirror, nullptr, - zero_ops, "_copy"); + AggregateGradient, + dmlc::GetEnv("MXNET_MEMORY_OPT", 0) ? need_mirror : nullptr, + zero_ops, "_copy", + in_arg_shapes, in_arg_dtypes); + CHECK_EQ(g_grad.outputs.size(), xs.size()); for (const auto &e : g_grad.outputs) { g.outputs.push_back(e); @@ -414,8 +413,37 @@ void GraphExecutor::Init(nnvm::Symbol symbol, std::vector aux_state_ctxes(aux_states.size()); std::transform(aux_states.begin(), aux_states.end(), aux_state_ctxes.begin(), get_ctx1); + // Record the shapes and data types of the input arguments in the source graph + // (i.e., the graph prior to the Gradient pass). Such information is need by + // the backward mirroring algorithm for shape and data type inference. + nnvm::Graph src; + src.outputs = symbol.outputs; + const nnvm::IndexedGraph& src_idx = src.indexed_graph(); + const std::unordered_set& src_mutable_nodes = src_idx.mutable_input_nodes(); + size_t src_arg_top = 0, src_aux_top = 0; + ShapeVector src_arg_shapes; + nnvm::DTypeVector src_arg_dtypes; + const size_t src_num_forward_inputs = symbol.ListInputs(nnvm::Symbol::kAll).size(); + + for (size_t i = 0; i < src_num_forward_inputs; ++i) { + const uint32_t nid = src_idx.input_nodes().at(i); + + if (src_mutable_nodes.count(nid)) { + CHECK_LT(src_aux_top, aux_states.size()); + src_arg_shapes.push_back(aux_states[src_aux_top].shape()); + src_arg_dtypes.push_back(aux_states[src_aux_top].dtype()); + ++src_aux_top; + } else { + CHECK_LT(src_arg_top, in_args.size()); + src_arg_shapes.push_back(in_args[src_arg_top].shape()); + src_arg_dtypes.push_back(in_args[src_arg_top].dtype()); + ++src_arg_top; + } + } + nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, - arg_grad_ctxes, aux_state_ctxes, grad_req_types); + arg_grad_ctxes, aux_state_ctxes, grad_req_types, + src_arg_shapes, src_arg_dtypes); // create arg_shapes and arg_dtypes for shape and type inferences const auto& idx = g.indexed_graph(); @@ -811,8 +839,34 @@ void GraphExecutor::Init(nnvm::Symbol symbol, std::unordered_map* shared_buffer, Executor* shared_exec, const nnvm::NodeEntryMap& feed_dict) { + // Record the shapes and data types of the input arguments in the source graph + // (i.e., the graph prior to the Gradient pass). Such information is need by + // the backward mirroring algorithm for shape and data type inference. + nnvm::Graph src; + src.outputs = symbol.outputs; + const nnvm::IndexedGraph& src_idx = src.indexed_graph(); + ShapeVector src_arg_shapes(src_idx.input_nodes().size(), TShape()); + nnvm::DTypeVector src_arg_dtypes(src_idx.input_nodes().size(), -1); + const size_t src_num_forward_inputs = symbol.ListInputs(nnvm::Symbol::kAll).size(); + + for (size_t i = 0; i < src_num_forward_inputs; ++i) { + const uint32_t nid = src_idx.input_nodes().at(i); + const std::string& name = src_idx[nid].source->attrs.name; + std::unordered_map::const_iterator + arg_shape_iter = arg_shape_map.find(name); + std::unordered_map::const_iterator + arg_dtype_iter = arg_dtype_map.find(name); + if (arg_shape_iter != arg_shape_map.end()) { + src_arg_shapes[i] = arg_shape_iter->second; + } + if (arg_dtype_iter != arg_dtype_map.end()) { + src_arg_dtypes[i] = arg_dtype_iter->second; + } + } + nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes, - aux_state_ctxes, grad_req_types); + aux_state_ctxes, grad_req_types, + src_arg_shapes, src_arg_dtypes); // The following code of shape and dtype inferences and argument // initialization is for simple_bind only. Regular bind operation @@ -1007,9 +1061,12 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, - const std::vector& grad_req_types) { + const std::vector& grad_req_types, + const ShapeVector& in_arg_shapes, + const nnvm::DTypeVector& in_arg_dtypes) { // setup gradient - nnvm::Graph g = InitFullGraph(symbol, grad_req_types); + nnvm::Graph g = InitFullGraph(symbol, grad_req_types, + in_arg_shapes, in_arg_dtypes); #if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC && !defined(_WIN32) if (default_ctx.dev_mask() == Context::kGPU && dmlc::GetEnv("MXNET_USE_FUSION", true)) { diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index 4164bb758376..ed6eeaa11f4f 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -188,10 +188,14 @@ class GraphExecutor : public Executor { const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, - const std::vector& grad_req_types); + const std::vector& grad_req_types, + const ShapeVector& in_arg_shapes, + const nnvm::DTypeVector& in_arg_dtypes); // intialize the full graph for simple bind, including gradient Graph InitFullGraph(nnvm::Symbol symbol, - const std::vector& grad_req_types); + const std::vector& grad_req_types, + const ShapeVector& in_arg_shapes, + const nnvm::DTypeVector& in_arg_dtypes); // initialize the cached operator void InitCachedOps(); // initialize the opr segments for bulk exec diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 731ba2efa082..c1ef8b82483f 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -167,7 +167,7 @@ void CreateBackwardGraph(nnvm::Graph* fwd_graph, *grad_graph = pass::MXGradient( *fwd_graph, fwd_graph->outputs, xs, *ograd_entries, - exec::AggregateGradient, nullptr, nullptr, + exec::AggregateGradient, nullptr, zero_ops, "_copy"); } diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index 14fedc93351c..c12fcee1910e 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -455,7 +455,7 @@ std::vector Imperative::Backward( Graph g_graph = pass::MXGradient( graph, graph.outputs, xs, ograd_entries, - exec::AggregateGradient, nullptr, nullptr, + exec::AggregateGradient, nullptr, zero_ops, "_copy"); CHECK_EQ(g_graph.outputs.size(), xs.size()); for (const auto& e : g_graph.outputs) { diff --git a/src/nnvm/gradient.cc b/src/nnvm/gradient.cc index 74cec1623800..a39c10dc5c99 100644 --- a/src/nnvm/gradient.cc +++ b/src/nnvm/gradient.cc @@ -23,178 +23,604 @@ * \brief Passes that takes gradient of the graph * This code code was modified based on mxnet codebase by Min Lin */ +#include #include #include #include + #include +#include +#include #include +#include +#include +#include +#include + +#include "../executor/exec_pass.h" + namespace nnvm { namespace pass { -namespace { -// default aggregate gradient function -// require operator zeros and elemwise_sum to be presented. -NodeEntry DefaultAggregateGradient(std::vector&& v) { - if (v.size() == 1) { - return std::move(v[0]); - } else if (v.size() == 0) { - ObjectPtr zero_node = Node::Create(); - zero_node->attrs.op = Op::Get("zeros"); - zero_node->attrs.name = "zero_grad"; - zero_node->attrs.op->attr_parser(&(zero_node->attrs)); - return NodeEntry{zero_node, 0, 0}; - } else { - ObjectPtr sum_node = Node::Create(); - sum_node->attrs.op = Op::Get("elemwise_sum"); - sum_node->inputs = std::move(v); - sum_node->attrs.name = "grad_sum"; - sum_node->attrs.dict["num_args"] = std::to_string(sum_node->inputs.size()); - sum_node->attrs.op->attr_parser(&(sum_node->attrs)); - return NodeEntry{sum_node, 0, 0}; - } -} +extern size_t MXGetDTypeSize(const int type_flag); // defined in plan_memory.cc -bool CheckGradAllZero(const std::vector& grads, - const std::vector& zero_ops) { - if (!grads.size() || !zero_ops.size()) return false; - for (const auto& g : grads) { - bool found = false; - for (const auto& op : zero_ops) { - if (g.node->op() == op) { - found = true; - break; - } - } - if (!found) return false; - } - return true; -} +namespace { -// helper entry + +/*! Auxiliary Data Structure for Gradient Entries */ struct GradEntry { -#ifdef _MSC_VER - NodeEntry sum = NodeEntry{nullptr, 0, 0}; -#else - NodeEntry sum{nullptr, 0, 0}; -#endif + NodeEntry sum = NodeEntry(nullptr, 0, 0); std::vector grads; - bool need_attr_hint{true}; }; -Graph Gradient(Graph src) { - using nnvm::FGradient; - using MirrorFun = std::function; - using AttrHintFun = std::function; +/*! + * \brief Build the backward graph from the mirror map. This function will be + * invoked twice if backward mirroring has been enabled. + */ +Graph BuildGradientGraph( + const Graph& src, + const std::vector& xs, + const std::vector& topo_order, + std::unordered_map > output_grads, + std::function mirror_fun, + const std::unordered_map& mirror_map); + +/*! + * \brief Auxiliary function that maps the forward node of the source graph to + * its corresponding node on the mirror path. + */ +inline const ObjectPtr& MapFwdNodeToMirrorPath( + const ObjectPtr& n, + const std::unordered_map& mirror_map) { + auto iter = mirror_map.find(n.get()); + if (iter == mirror_map.end() || + iter->second == nullptr) { + return n; + } + return iter->second; +} + + +Graph Gradient(Graph src) { CHECK_NE(src.attrs.count("grad_ys"), 0U) << "Gradient require grad_ys to be presented."; - CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0U) - << "Gradient require grad_ys_out_grad to be presented."; CHECK_NE(src.attrs.count("grad_xs"), 0U) << "Gradient require grad_xs to be presented."; + CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0U) + << "Gradient require grad_ys_out_grad to be presented."; + const std::vector& xs = + src.GetAttr >("grad_xs"); const std::vector& ys = src.GetAttr >("grad_ys"); const std::vector& ys_out_grad = src.GetAttr >("grad_ys_out_grad"); - const std::vector& xs = - src.GetAttr >("grad_xs"); - using AggFun = std::function&& inputs)>; - AggFun agg_fun = DefaultAggregateGradient; - if (src.attrs.count("grad_aggregate_fun") != 0) { - agg_fun = src.GetAttr("grad_aggregate_fun"); - } - MirrorFun mirror_fun = nullptr; - if (src.attrs.count("grad_mirror_fun") != 0) { - mirror_fun = src.GetAttr("grad_mirror_fun"); - } - AttrHintFun attr_hint_fun = nullptr; - if (src.attrs.count("attr_hint_fun") != 0) { - attr_hint_fun = src.GetAttr("attr_hint_fun"); - } - std::vector zero_ops; - if (src.attrs.count("zero_ops") != 0) { - zero_ops = src.GetAttr >("zero_ops"); - } - const Op* copy_op = (src.attrs.count("copy_op") != 0) ? - Op::Get(src.GetAttr("copy_op")) : - nullptr; + CHECK_EQ(ys.size(), ys_out_grad.size()); - // topo sort + // initialize a topological order of the graph nodes and `output_grads` + // that maps every operator node to its gradient entries std::vector topo_order; - std::unordered_map > output_grads; + std::unordered_map > output_grads; - DFSVisit(ys, [&](const ObjectPtr& node) { - if (output_grads.count(node.get()) == 0) { - output_grads[node.get()].resize(node->num_outputs()); - } - topo_order.push_back(node); - }); + DFSVisit(ys, + [&](const ObjectPtr& node) { + if (output_grads.count(node.get()) == 0) { + output_grads[node.get()].resize(node->num_outputs()); + } + topo_order.push_back(node); + }); - CHECK_EQ(ys.size(), ys_out_grad.size()); for (size_t i = 0; i < ys.size(); ++i) { - NodeEntry ograd = ys_out_grad[i]; - output_grads[ys[i].node.get()][ys[i].index].grads = { ograd }; + output_grads[ys[i].node.get()][ys[i].index].grads = {ys_out_grad[i]}; } - // Check that all xs are reachable from ys + // check that all xs are reachable from ys for (size_t i = 0; i < xs.size(); ++i) { CHECK(output_grads.find(xs[i].node.get()) != output_grads.end()) - << "Cannot differentiate with respect to the " << i+1 << "-th variable " + << "Cannot differentiate with respect to the " + << (i + 1) << "-th variable " << "because it is unreachable from the outputs."; } - // construct mirror as memory reduction strategy if needed - std::unordered_map mirror_map; - if (mirror_fun != nullptr) { - for (const ObjectPtr& node_ptr : topo_order) { - if (mirror_fun(*node_ptr)) { - ObjectPtr new_node = Node::Create(); - *new_node = *node_ptr; - new_node->attrs.name += "_mirror"; - for (auto& e : new_node->inputs) { - e.node = mirror_map.at(e.node.get()); + using MirrorFun = std::function; + MirrorFun mirror_fun = nullptr; + if (src.attrs.count("mirror_fun") != 0) { + mirror_fun = src.GetAttr("mirror_fun"); + } + std::unordered_map mirror_map; + + // complete the backward graph of the src, but without backward mirroring + nnvm::Graph gsrc = BuildGradientGraph(src, xs, topo_order, + output_grads, + nullptr, mirror_map); + if (mirror_fun == nullptr) { + return gsrc; // Gradient pass without mirroring ends here. + } + const IndexedGraph& idx = src.indexed_graph(), + & gidx = gsrc.indexed_graph(); + // =========================================================================== + // ----- Gradient Pass w/ Backward Mirroring ----- + // =========================================================================== + // Record, for each node entry ∈ gsrc, the nodes that reference it as inputs. + // It is important to note that since the node entry reference mapping is + // constructed from gradient graph, it can only be indexed using gidx entry ID. + std::vector > node_entry_ref_map( + gidx.num_node_entries()); + static const auto& fignore_inputs = Op::GetAttr("FIgnoreInputs"); + for (uint32_t gnid = 0; gnid < gidx.num_nodes(); ++gnid) { + const IndexedGraph::Node& inode = gidx[gnid]; + if (inode.source->is_variable()) { + continue; + } + for (uint32_t i = 0; i < inode.inputs.size(); ++i) { + if (fignore_inputs.count(inode.source->op()) != 0) { + std::vector ignore_inputs = + fignore_inputs[inode.source->op()](inode.source->attrs); + if (std::find(ignore_inputs.begin(), ignore_inputs.end(), i) + != ignore_inputs.end()) { + continue; } - for (auto& n : new_node->control_deps) { - n = mirror_map.at(n.get()); + } + node_entry_ref_map[gidx.entry_id(inode.inputs[i])].insert(inode.source); + } + } // for (gnid ∈ gidx.num_nodes()) + // Inference the shapes and data types of the gradient graphs. Those + // information is needed in later stages to determine whether putting a node + // on the mirror path can be beneficial or not. + using mxnet::ShapeVector; + ShapeVector in_arg_shapes = std::move(src.GetAttr("in_arg_shapes")); + DTypeVector in_arg_dtypes = std::move(src.GetAttr("in_arg_dtypes")); + src = mxnet::exec::InferShape(std::move(src), std::move(in_arg_shapes), "__shape__"); + src = mxnet::exec::InferType(std::move(src), std::move(in_arg_dtypes), "__dtype__"); + CHECK(src.GetAttr("shape_num_unknown_nodes") == 0U); + CHECK(src.GetAttr("dtype_num_unknown_nodes") == 0U); + const ShapeVector& src_shapes = src.GetAttr("shape"); + const DTypeVector& src_dtypes = src.GetAttr("dtype"); + + std::queue worklist; + // initialize the worklist to the output nodes + for (const NodeEntry& e : src.outputs) { + worklist.push(e.node.get()); + } + for (; !worklist.empty(); worklist.pop()) { + const Node* const workitem = worklist.front(); + // skip the current node if it has already been recorded in the mirror map + if (mirror_map.find(workitem) != mirror_map.end()) { + continue; + } + + // subgraph and its frontier and topological-sorted view + std::unordered_set subgraph; + // The associated boolean variable is used for marking forward propagation. + std::unordered_map subgraph_frontier; + std::deque subgraph_topo_order; + // ========================================================================= + // --- Backward Pass --- + // ========================================================================= + // The sub-worklist is used for constructing the subgraph. It is initialized + // to have the current workitem node. + std::queue subworklist; + subworklist.push(workitem); + // Local auxiliary function that does backpropagation on the subworklist + // items to construct the subgraph. E.g., + // A subworklist = {A} + // ↑ + // B + // After invoking this function. `subgraph` will become {A, B}. + // Note that this function will be invoked multiple times. + auto subworklist_backprop = [&subworklist, &subgraph, + &subgraph_topo_order, + &mirror_fun, &worklist]() { + std::deque subworklist_topo_order; + for (; !subworklist.empty(); subworklist.pop()) { + const Node* const subworkitem = subworklist.front(); + if (subgraph.find(subworkitem) == subgraph.end()) { + subgraph.insert(subworkitem); + subworklist_topo_order.push_front(subworkitem); + } + for (const NodeEntry& e : subworkitem->inputs) { + if (!mirror_fun(*(e.node))) { + worklist.push(e.node.get()); + } else { + subworklist.push(e.node.get()); + } + } + for (const ObjectPtr& n : subworkitem->control_deps) { + if (!mirror_fun(*n)) { + worklist.push(n.get()); + } else { + subworklist.push(n.get()); + } + } + } // for (subworkitem ∈ subworklist) + // please refer to later comments for why the topological order of the + // subworklist should be directly appended to that of the subgraph + subgraph_topo_order.insert(subgraph_topo_order.end(), + subworklist_topo_order.begin(), + subworklist_topo_order.end()); + }; + // Start propagating from the current workitem node backward until the + // mirroring function returns false (indicating that a compute-heavy layer + // has been hit), in which case we put the node that fails the mirroring + // function into the worklist as the new head. During the traversal, we + // build up the subgraph and its topological order at the same time. + subworklist_backprop(); + + // Forward propagate the subgraph nodes in topological order and make sure + // that all the node entries that are part of the forward propagation belong + // to the same subgraph. This process continues until all the node entries + // have been included, in which case we say that the subgraph has converged. + // + // The reason why this step is needed is because, consider the example below: + // A B C subworklist = {A} + // ↑ ↑ ↑ + // ↖ ↑ ↗ + // D + // Without loss of generality, suppose that the previous backpropagation + // starts from node A, then the subgraph will only contain branch D → A. + // However, we want to include branch D → B adn D → C as well since all + // three branches share the same node entries (i.e., the outputs of D) and + // hence they are all affected by the decision on whether D should be put + // onto the mirror path or not. + bool has_subgraph_converged; + do { + has_subgraph_converged = true; + for (const Node* const subgraph_node : subgraph_topo_order) { + for (const NodeEntry& subgraph_node_entry : + subgraph_node->inputs) { + const std::unordered_set ref_nodes = + node_entry_ref_map[gidx.entry_id(subgraph_node_entry)]; + + for (const Node* const ref_node : ref_nodes) { + // If there are other nodes that reference the node entry and that + // node satisfies the following conditions: + // (1) belongs to the forward graph, and + // (2) is not part of the subgraph + // We add that node to the subgraph and adjust the topological order + // accordingly. + if (ref_node != subgraph_node && idx.exist(ref_node) && + subgraph.find(ref_node) == subgraph.end()) { + // Forward propagate from the reference node until the mirroring + // function returns false. This indicates that the head of the + // branch has been reached (i.e., B or C in our previously + // illustrated example), and we add it to the subworklist for + // another backpropagation. + std::queue ref_node_heads; + ref_node_heads.push(ref_node); + for (; !ref_node_heads.empty(); ref_node_heads.pop()) { + const Node* const ref_node_head = ref_node_heads.front(); + bool is_ref_node_head_output = false; + for (const NodeEntry& y : ys) { + if (ref_node_head == y.node.get()) { + is_ref_node_head_output = true; + } + } + if (!mirror_fun(*ref_node_head) || is_ref_node_head_output) { + subworklist.push(ref_node_head); + continue; + } + + uint32_t gnid = gidx.node_id(ref_node_head); + for (uint32_t oid = 0; oid < ref_node_head->num_outputs(); ++oid) { + uint32_t geid = gidx.entry_id(gnid, oid); + for (const Node* const n : node_entry_ref_map[geid]) { + if (idx.exist(n)) { + ref_node_heads.push(n); + } + } + } // for (oid ∈ [0, ref_node_head->num_outputs())) + } // for (ref_node_head ∈ ref_node_heads) + // Do the backpropagation again. The topological order of the + // subworklist can be directly appended to the end of the existing + // order. E,g, in our previous example, we expect to have + // `subgraph_topo_order` = {D, A} + {B} + {C} + subworklist_backprop(); + // indicate that the subgraph has not changed the quit the loop + has_subgraph_converged = false; + break; + } // if (ref_node != subgraph_node && idx.exist(ref_node) && + // subgraph.find(ref_node) == subgraph.end() + } // for (ref_node ∈ ref_nodes) + if (!has_subgraph_converged) { + break; + } + } // for (subgraph_node_entry ∈ subgraph_node->inputs) + if (!has_subgraph_converged) { + break; } - mirror_map[node_ptr.get()] = std::move(new_node); - } else { - mirror_map[node_ptr.get()] = node_ptr; + } // for (subgraph_node ∈ subgraph_topo_order) + } while (!has_subgraph_converged); + // ========================================================================= + // --- Forward Pass --- + // ========================================================================= + // Now that the subgraph is complete, we start by assuming that all the + // nodes in the subgraph can be mirrored, and forward propagate starting + // from the subgraph frontier. The propagation is successful if the amount + // of storage released by removing the frontier nodes off the mirror path is + // greater or equal to the storage allocated. + do { + has_subgraph_converged = true; + // Obtain the subgraph frontier. The subgraph frontier denotes a group of + // nodes whose inputs satisfy the following conditions: + // (1) fails the mirroring function, or + // (2) has been marked as NOT on the mirror path, i.e., + // `mirror_map[input_node] == nullptr` + // E.g., consider the subgraph below: + // A + // ↑ + // B + // ↑ + // C + // The subgraph frontier in this example is {C}. As C is the place where + // the mirror path (and hence the forward propagation) starts. + subgraph_frontier.clear(); + for (const Node* const subgraph_node : subgraph) { + if (!mirror_fun(*subgraph_node)) { + mirror_map[subgraph_node] = nullptr; + continue; + } + if (mirror_map.find(subgraph_node) != mirror_map.end()) { + continue; + } + bool is_frontier = true; + for (const NodeEntry& e : subgraph_node->inputs) { + auto iter = mirror_map.find(e.node.get()); + if (mirror_fun(*(e.node)) && + !(iter != mirror_map.end() && iter->second == nullptr)) { + is_frontier = false; + } + } + for (const ObjectPtr& n : subgraph_node->control_deps) { + auto iter = mirror_map.find(n.get()); + if (mirror_fun(*n) && + !(iter != mirror_map.end() && iter->second == nullptr)) { + is_frontier = false; + } + } + if (is_frontier) { + subgraph_frontier.emplace(subgraph_node, false); + } + } // for (subgraph_node ∈ subgraph) + for (std::pair& frontier_node : subgraph_frontier) { + if (frontier_node.second) { + // If the frontier node has been marked as true, then this indicates + // that the node has been forward propagated before (by other nodes + // that share the same input). + continue; + } + // As we do the forward propagation, we not only propagate the current + // frontier node individually, but all the frontier nodes that share the + // same input with the current one. This is a recursive progress because + // it is possible for A to share the same input with B and B, at the + // same time, to share the same input with C, like in the graph below: + // D E + // ↑ ↑ + // ↗ ↖ ↗ ↖ + // A B C + std::unordered_set forward_candidates{frontier_node.first}; + frontier_node.second = true; + bool has_forward_candidates_converged; + do { + has_forward_candidates_converged = true; + for (const Node* const candidate : forward_candidates) { + for (const NodeEntry candidate_input : candidate->inputs) { + uint32_t geid = gidx.entry_id(candidate_input); + const std::unordered_set& ref_nodes = node_entry_ref_map[geid]; + for (const Node* const ref_node : ref_nodes) { + if (ref_node != frontier_node.first && + subgraph_frontier.find(ref_node) != subgraph_frontier.end() && + forward_candidates.find(ref_node) == forward_candidates.end()) { + subgraph_frontier[ref_node] = true; + forward_candidates.insert(ref_node); + has_forward_candidates_converged = false; + } + } // for (ref_node ∈ ref_nodes) + if (!has_forward_candidates_converged) { + break; + } + } // for (candidate_input ∈ candidate->inputs) + if (!has_forward_candidates_converged) { + break; + } + } // for (candidate ∈ forward_candidates) + } while (!has_forward_candidates_converged); + // Record the node entries that are newly allocated and those that are + // released. A node entry can be released if all its referencing nodes + // are part of the subgraph frontier. Otherwise, it is in the allocated set. + std::unordered_set newly_allocated_node_entries, + released_node_entries; + for (const Node* const candidate : forward_candidates) { + uint32_t nid = idx.node_id(candidate), + gnid = gidx.node_id(candidate); + for (uint32_t oid = 0; oid < candidate->num_outputs(); ++oid) { + uint32_t eid = idx.entry_id(nid, oid), + geid = gidx.entry_id(gnid, oid); + if (node_entry_ref_map[geid].size() != 0) { + newly_allocated_node_entries.insert(eid); + } + } + for (const NodeEntry& candidate_input : candidate->inputs) { + uint32_t eid = idx.entry_id(candidate_input), + geid = gidx.entry_id(candidate_input); + const std::unordered_set& ref_nodes = node_entry_ref_map[geid]; + bool can_be_released = true; + for (const Node* const ref_node : ref_nodes) { + if (subgraph_frontier.find(ref_node) == subgraph_frontier.end()) { + newly_allocated_node_entries.insert(eid); + can_be_released = false; + } + } + if (can_be_released) { + released_node_entries.insert(eid); + } + } // for (candidate_input ∈ candidate->input) + } // for (candidate ∈ forward_candidates) + + // Now, compare the total amount of newly allocated storage versus the + // released storage, if the latter is greater or equal to the former, + // then we remove the current node from the frontier. Otherwise all the + // forward candidate nodes are marked as on the mirror path. + size_t newly_allocated_storage = 0, released_storage = 0; + for (const uint32_t eid : newly_allocated_node_entries) { + newly_allocated_storage += src_shapes[eid].Size() * + MXGetDTypeSize(src_dtypes[eid]); + } + for (const uint32_t eid : released_node_entries) { + released_storage += src_shapes[eid].Size() * MXGetDTypeSize(src_dtypes[eid]); + } + if (released_storage >= newly_allocated_storage) { + for (const Node* const candidate : forward_candidates) { + CHECK(subgraph_frontier.find(candidate) != subgraph_frontier.end()); + subgraph_frontier.erase(candidate); + mirror_map[candidate] = nullptr; + } + has_subgraph_converged = false; + break; + } // if (released_storage >= newly_allocated_storage) + } // for (frontier_node ∈ subgraph_frontier) + } while (!has_subgraph_converged); + + // Finally, mark all the remaining nodes of the subgraph as on the mirror path. + for (const Node* const subgraph_node : subgraph_topo_order) { + if (mirror_map.find(subgraph_node) != mirror_map.end()) { + continue; + } + ObjectPtr subgraph_node_mirror = Node::Create(); + *subgraph_node_mirror = *subgraph_node; + subgraph_node_mirror->attrs.name += "_mirror"; + for (NodeEntry& e : subgraph_node_mirror->inputs) { + e.node = MapFwdNodeToMirrorPath(e.node, mirror_map); } + for (ObjectPtr& n : subgraph_node_mirror->control_deps) { + n = MapFwdNodeToMirrorPath(n, mirror_map); + } + mirror_map[subgraph_node] = subgraph_node_mirror; } + } // for (workitem ∈ worklist) + DFSVisit(ys, + [&](const ObjectPtr& node) { + if (mirror_map.at(node.get()) != nullptr) { + node->attrs.dict["__mirror_stage__"] = "1"; + } else { + node->attrs.dict["__mirror_stage__"] = "0"; + } + }); + return BuildGradientGraph(src, xs, topo_order, + output_grads, + mirror_fun, mirror_map); +} + + +/*! + * \brief Auxiliary function that checks whether all the gradients are zero or not. + */ +inline bool CheckGradAllZero(const std::vector& grads, + const std::vector& zero_ops) { + if (!grads.size() || !zero_ops.size()) return false; + for (const auto& g : grads) { + bool found = false; + for (const auto& op : zero_ops) { + if (g.node->op() == op) { + found = true; + break; + } + } + if (!found) return false; + } + return true; +} + + +Graph BuildGradientGraph( + const Graph& src, + const std::vector& xs, + const std::vector& topo_order, + std::unordered_map > output_grads, + std::function mirror_fun, + const std::unordered_map& mirror_map) { + static auto& grad_fun_map = Op::GetAttr("FGradient"); + + // gradient aggregation function + using AggFun = std::function&&)>; + AggFun agg_fun = [](std::vector&& v)->NodeEntry { + if (v.size() == 1) { + return std::move(v[0]); + } else if (v.size() == 0) { + ObjectPtr zero_grad_node = Node::Create(); + zero_grad_node->attrs.op = Op::Get("zeros"); + zero_grad_node->attrs.name = "zero_grad"; + zero_grad_node->attrs.op->attr_parser(&(zero_grad_node->attrs)); + return NodeEntry{zero_grad_node, 0, 0}; + } else { + ObjectPtr grad_sum_node = Node::Create(); + grad_sum_node->attrs.op = Op::Get("elemwise_sum"); + grad_sum_node->inputs = std::move(v); + grad_sum_node->attrs.name = "grad_sum"; + grad_sum_node->attrs.dict["num_args"] = + std::to_string(grad_sum_node->inputs.size()); + grad_sum_node->attrs.op->attr_parser(&(grad_sum_node->attrs)); + return NodeEntry{grad_sum_node, 0, 0}; + } + }; + if (src.attrs.count("grad_aggregate_fun") != 0) { + agg_fun = src.GetAttr("grad_aggregate_fun"); } - // traverse backward - static auto& grad_fun_map = Op::GetAttr("FGradient"); - static auto& finfer_shape = Op::GetAttr("FInferShape"); + // zero and copy operators + std::vector zero_ops; + if (src.attrs.count("zero_ops") != 0) { + zero_ops = src.GetAttr >("zero_ops"); + } + const Op* copy_op = (src.attrs.count("copy_op_str") != 0) ? + Op::Get(src.GetAttr("copy_op_str")) : nullptr; std::vector out_agg_grads; - for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) { - const ObjectPtr& ptr = *rit; - if (ptr->is_variable()) continue; + for (auto topo_order_rit = topo_order.rbegin(); + topo_order_rit != topo_order.rend(); ++topo_order_rit) { + const ObjectPtr& src_fwd_node = *topo_order_rit; + if (src_fwd_node->is_variable()) continue; + + // gather all the output gradient entries and apply the aggregation function out_agg_grads.clear(); - auto& out_grad_vec = output_grads.at(ptr.get()); + auto& out_grad_vec = output_grads.at(src_fwd_node.get()); for (uint32_t i = 0; i < out_grad_vec.size(); ++i) { GradEntry& e = out_grad_vec[i]; e.sum = agg_fun(std::move(e.grads)); - if (e.need_attr_hint && attr_hint_fun != nullptr) { - e.sum = attr_hint_fun(e.sum, NodeEntry{ptr, 0, i}); - } out_agg_grads.push_back(e.sum); } - if ((*rit)->inputs.size() != 0) { - ObjectPtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get())); + if (src_fwd_node->inputs.size() != 0) { + // If the current node has inputs, the gradients need to be further + // propagated backward. + ObjectPtr fwd_node = MapFwdNodeToMirrorPath(src_fwd_node, mirror_map); + // calculate the input gradients std::vector input_grads; - // Check for FGradient - if (grad_fun_map.contains(ptr->op())) { - input_grads = grad_fun_map[ptr->op()](fwd_node, out_agg_grads); - CHECK_EQ((*rit)->inputs.size(), input_grads.size()) - << "Gradient function not returning enough gradient"; + if (grad_fun_map.count(src_fwd_node->op())) { + input_grads = grad_fun_map[src_fwd_node->op()](fwd_node, out_agg_grads); + CHECK_EQ(src_fwd_node->inputs.size(), input_grads.size()) + << "The Gradient function is not returning enough gradients."; + // If the operator node fails the mirror function, it is however still + // possible for its feature maps to be recomputed without incurring + // significant runtime overhead. The reason is because some operators + // have their feature maps sit on the inputs rather than the outputs. + // E.g., the fully-connected layer (Y=XW^T), whose gradients are given + // by dX = dYW, dW = dY^TX and hence have no data dependency on Y. + if (mirror_fun != nullptr && !mirror_fun(*fwd_node)) { + for (NodeEntry& input_grad : input_grads) { + for (NodeEntry& grad_input : input_grad.node->inputs) { + const ObjectPtr& grad_input_node_mirrored = MapFwdNodeToMirrorPath( + grad_input.node, mirror_map); + grad_input = NodeEntry( + grad_input_node_mirrored, + grad_input.index, + grad_input.version); + } // for (grad_input ∈ input_grad.node->inputs) + } // for (input_grad ∈ input_grads) + } // if (mirror_fun != nullptr && !mirror_fun(*fwd_node)) } else if (CheckGradAllZero(out_agg_grads, zero_ops)) { - for (size_t i = 0; i < fwd_node->num_inputs(); ++i) { + for (size_t i = 0; i < src_fwd_node->num_inputs(); ++i) { std::ostringstream os; - if (1 == fwd_node->num_inputs()) { + if (1 == src_fwd_node->num_inputs()) { os << fwd_node->attrs.name << "_backward"; } else { os << fwd_node->attrs.name << "_in" << i << "_backward"; @@ -208,25 +634,25 @@ Graph Gradient(Graph src) { p->op()->attr_parser(&(p->attrs)); } input_grads.emplace_back(p, 0, 0); - } + } // for (i ∈ src_fwd_node->num_inputs()) } else { - LOG(FATAL) << "Operator " << fwd_node->op()->name << " is non-differentiable " + LOG(FATAL) << "Operator " << src_fwd_node->op()->name << " is non-differentiable " << "because it didn't register FGradient attribute."; } - for (const auto& nodeEntry : input_grads) - CHECK(nodeEntry.node); - auto git = input_grads.begin(); - CHECK((*rit)->inputs.size() <= input_grads.size()); - for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) { - auto& output_grad_entry = output_grads[it->node.get()][it->index]; - // if any of the backward op can do shape inference, the hint is not necessary. - if (finfer_shape.contains(git->node->op())) { - output_grad_entry.need_attr_hint = false; - } - output_grad_entry.grads.emplace_back(std::move(*git)); + for (const auto& e : input_grads) { + CHECK(e.node); } - } - } + auto input_grad_iter = input_grads.begin(); + CHECK(src_fwd_node->inputs.size() <= input_grads.size()); + for (auto input_iter = src_fwd_node->inputs.begin(); + input_iter != src_fwd_node->inputs.end(); + ++input_iter, ++input_grad_iter) { + // propagate the input gradients to the output gradients of the input nodes + output_grads[input_iter->node.get()][input_iter->index] + .grads.emplace_back(std::move(*input_grad_iter)); + } + } // if (src_fwd_node->inputs.size() != 0) + } // for (topo_order_rit ∈ reverse(topo_order)) // take out the xs' grads Graph ret; ret.outputs.resize(xs.size()); @@ -237,9 +663,6 @@ Graph Gradient(Graph src) { // aggregate sum if there haven't been if (entry.sum.node.get() == nullptr) { entry.sum = agg_fun(std::move(entry.grads)); - if (entry.need_attr_hint && attr_hint_fun != nullptr) { - entry.sum = attr_hint_fun(entry.sum, e); - } } if (copy_op != nullptr) { auto kv = unique_grads.find(entry.sum); @@ -254,15 +677,16 @@ Graph Gradient(Graph src) { copy_node->attrs.name = os.str(); copy_node->inputs.emplace_back(entry.sum); if (copy_node->attrs.op->attr_parser != nullptr) { - copy_node->attrs.op->attr_parser(&(copy_node->attrs)); + copy_node->attrs.op->attr_parser(&(copy_node->attrs)); } - unique_grads.emplace(NodeEntry{std::move(copy_node), 0, 0}, std::make_pair(1, counter)); + unique_grads.emplace(NodeEntry{std::move(copy_node), 0, 0}, + std::make_pair(1, counter)); } } else { - ret.outputs[counter] = entry.sum; + ret.outputs[counter] = entry.sum; } ++counter; - } + } // for (e ∈ xs) if (copy_op != nullptr) { for (const auto& kv : unique_grads) { ret.outputs[kv.second.second] = kv.first; @@ -271,6 +695,7 @@ Graph Gradient(Graph src) { return ret; } + // register pass NNVM_REGISTER_PASS(MXGradient) .describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]") @@ -278,6 +703,8 @@ NNVM_REGISTER_PASS(MXGradient) .set_change_graph(true) .depend_graph_attr("grad_ys") .depend_graph_attr("grad_xs") +.depend_graph_attr("in_arg_shapes") +.depend_graph_attr("in_arg_dtypes") .depend_graph_attr("grad_ys_out_grad"); } // namespace diff --git a/src/nnvm/plan_memory.cc b/src/nnvm/plan_memory.cc index 3815f239f88c..804a596cd8b7 100644 --- a/src/nnvm/plan_memory.cc +++ b/src/nnvm/plan_memory.cc @@ -34,9 +34,10 @@ namespace nnvm { namespace pass { -namespace { -// Return bytes of data flag. -static int MXGetDTypeSize(int type_flag) { +/*! + * \brief Return the storage in bytes for the corresponding data flag. + */ +size_t MXGetDTypeSize(const int type_flag) { switch (type_flag) { case mshadow::kUint8: case mshadow::kInt8: @@ -61,6 +62,8 @@ static int MXGetDTypeSize(int type_flag) { } } +namespace { + // simple graph based allocator. class MXGraphAllocator { public: @@ -78,8 +81,7 @@ class MXGraphAllocator { StorageID Request(int dev_id, int dtype, mxnet::TShape shape, uint32_t node_id) { if (!mxnet::shape_is_known(shape)) return kBadStorageID; // search memory block in [size / match_range_, size * match_range_) - // TODO(tqchen) add size of the dtype, assume 4 bytes for now - size_t size = shape.Size() * 4; + size_t size = shape.Size() * MXGetDTypeSize(dtype); if (match_range_ == 0) return this->Alloc(dev_id, size); auto begin = free_.lower_bound(size / match_range_); auto mid = free_.lower_bound(size); @@ -373,7 +375,8 @@ Graph MXPlanMemory(Graph ret) { size_t min_allocated_bytes = -1; size_t max_match_range = dmlc::GetEnv("NNVM_EXEC_MATCH_RANGE", 16); size_t min_match_range = - dmlc::GetEnv("NNVM_AUTO_SEARCH_MATCH_RANGE", false) ? 1 : max_match_range; + dmlc::GetEnv("MXNET_MEMORY_OPT", false) || + dmlc::GetEnv("NNVM_AUTO_SEARCH_MATCH_RANGE", false) ? 1 : max_match_range; for (size_t match_range = min_match_range; match_range <= max_match_range; match_range *= 2) { // Make a copy of related fields StorageVector storage_vec(storage); diff --git a/src/operator/nn/activation-inl.h b/src/operator/nn/activation-inl.h index 1d8e4c2b6cda..06ff1fe1bedb 100644 --- a/src/operator/nn/activation-inl.h +++ b/src/operator/nn/activation-inl.h @@ -176,8 +176,13 @@ void ActivationGradComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext &ct ctx, inputs[0], inputs[1], req[0], outputs[0]); break; case activation::kSoftSign: - ActivationBackward( - ctx, inputs[0], inputs[2], req[0], outputs[0]); + if (dmlc::GetEnv("MXNET_MEMORY_OPT", 0)) { + ActivationBackward( + ctx, inputs[0], inputs[1], req[0], outputs[0]); + } else { + ActivationBackward( + ctx, inputs[0], inputs[2], req[0], outputs[0]); + } break; default: LOG(FATAL) << "unknown activation type"; diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc index 1259ceb7d9b3..622d3464371f 100644 --- a/src/operator/nn/activation.cc +++ b/src/operator/nn/activation.cc @@ -41,16 +41,19 @@ namespace activation { int GradNumInputs(int act_type) { // check activation.cu \sa ActivationGradCompute + if (dmlc::GetEnv("MXNET_MEMORY_OPT", 0)) { + return 2; + } switch (act_type) { - case kReLU: - return 2; - case kSoftReLU: - case kSoftSign: - case kTanh: - case kSigmoid: - return 3; - default: - CHECK(false) << "missing activation type"; + case kReLU: + return 2; + case kSoftReLU: + case kSoftSign: + case kTanh: + case kSigmoid: + return 3; + default: + CHECK(false) << "missing activation type"; } // unreachable return -1; @@ -65,27 +68,34 @@ struct ActivationGrad { const char *op_name; std::vector operator()(const nnvm::ObjectPtr& n, const std::vector& ograds) const { - // ograds, output... + // ograds std::vector heads(ograds.begin(), ograds.end()); - heads.emplace_back(n, activation::kOut, 0); - const NodeAttrs& attrs = n->attrs; using namespace activation; int act_type = dmlc::get(attrs.parsed).act_type; - // for ReLU, no need to pass input data. This enables inplace optimization during the - // forward pass. - // check activation.cu \sa ActivationGradCompute - switch (act_type) { + + if (dmlc::GetEnv("MXNET_MEMORY_OPT", 0)) { + if (act_type == kSoftSign) { + heads.push_back(n->inputs[activation::kData]); + } else { + heads.emplace_back(n, activation::kOut, 0); + } + } else { + heads.emplace_back(n, activation::kOut, 0); // output + // for ReLU, no need to pass input data. This enables inplace optimization + // during the forward pass. check activation.cu \sa ActivationGradCompute + switch (act_type) { case kReLU: - break; + break; case kSoftReLU: case kSoftSign: case kTanh: case kSigmoid: - heads.push_back(n->inputs[activation::kData]); - break; + heads.push_back(n->inputs[activation::kData]); + break; default: - CHECK(false) << "missing activation type"; + CHECK(false) << "missing activation type"; + } } return MakeGradNode(op_name, n, heads, n->attrs.dict); } diff --git a/src/operator/nn/activation.cu b/src/operator/nn/activation.cu index ec7db844b100..1116cf20165b 100644 --- a/src/operator/nn/activation.cu +++ b/src/operator/nn/activation.cu @@ -82,24 +82,48 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); + bool do_memory_opt = dmlc::GetEnv("MXNET_MEMORY_OPT", 0); + // both SoftReLU and SoftSign not supported by CUDNN yet if (act_type == activation::kSoftReLU) { ActivationBackward( ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]); } else if (act_type == activation::kSoftSign) { - ActivationBackward( - ctx, inputs.at(0), inputs.at(2), req[0], outputs[0]); + if (do_memory_opt) { + ActivationBackward( + ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]); + } else { + ActivationBackward( + ctx, inputs.at(0), inputs.at(2), req[0], outputs[0]); + } } else if (act_type == activation::kReLU) { - MSHADOW_REAL_TYPE_SWITCH(inputs.at(0).type_flag_, DType, { - // XXX: for y = relu(x), y is passed as "in_data" to Backward() - get_cudnn_op(param).Backward(ctx, inputs.at(0), inputs.at(1), - inputs.at(1), req[0], outputs[0]); - }); + if (do_memory_opt) { + ActivationBackward( + ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]); + } else { + MSHADOW_REAL_TYPE_SWITCH(inputs.at(0).type_flag_, DType, { + // XXX: for y = relu(x), y is passed as "in_data" to Backward() + get_cudnn_op(param).Backward(ctx, inputs.at(0), inputs.at(1), + inputs.at(1), req[0], outputs[0]); + }); + } } else { - MSHADOW_REAL_TYPE_SWITCH(inputs.at(0).type_flag_, DType, { - get_cudnn_op(param).Backward(ctx, inputs.at(0), inputs.at(2), - inputs.at(1), req[0], outputs[0]); - }); + if (do_memory_opt) { + if (act_type == activation::kTanh) { + ActivationBackward( + ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]); + } else if (act_type == activation::kSigmoid) { + ActivationBackward( + ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]); + } else { + LOG(FATAL) << "unknown activation type"; + } + } else { + MSHADOW_REAL_TYPE_SWITCH(inputs.at(0).type_flag_, DType, { + get_cudnn_op(param).Backward(ctx, inputs.at(0), inputs.at(2), + inputs.at(1), req[0], outputs[0]); + }); + } // if (do_memory_opt) } } #endif diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 881d3d2247da..95e879acb75d 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -128,7 +128,7 @@ class CuDNNBatchNormOp { // which further indicates that we are in the backward mirroring mode, // and therefore update to the auxiliary states is disabled. // This is done by setting the `momentum` to `1` (or `factor` to `0`). - float factor = (dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0) && internal_aux_states_lock_) ? + float factor = (dmlc::GetEnv("MXNET_MEMORY_OPT", 0) && internal_aux_states_lock_) ? 0 : (1 - param_.momentum); CUDNN_CALL(cudnnBatchNormalizationForwardTraining(s->dnn_handle_, mode, diff --git a/tests/python/unittest/test_gradient.py b/tests/python/unittest/test_gradient.py new file mode 100644 index 000000000000..49d19ef7ee46 --- /dev/null +++ b/tests/python/unittest/test_gradient.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: skip-file +import mxnet as mx +import os + + +num_hidden = 4096 + + +def test_rnn_cell(): + # x →→→ + →→→ tanh ⇒⇒⇒ + # ↑ + # y →→→→ + # + # ⇒⇒⇒ : Backward Dependency + # In this example, there is no benefit in mirroring the elementwise-add + # operator and the tanh operator. + os.environ["MXNET_MEMORY_OPT"] = '1' + x = mx.sym.Variable("x") + x = mx.sym.FullyConnected(x, num_hidden=num_hidden) + y = mx.sym.Variable("y") + y = mx.sym.FullyConnected(y, num_hidden=num_hidden) + tmp = mx.sym._internal._plus(x, y) + z = mx.sym.Activation(tmp, act_type='tanh') + exec = z.simple_bind(mx.cpu(), 'write', x=(num_hidden,), y=(num_hidden,)) + exec_debug_str = exec.debug_str().split('\n') + op_checklist = 0 + for i, line in enumerate(exec_debug_str): + if "Op:elemwise_add" in line: + op_checklist += 1 + assert exec_debug_str[i + 5] == "\t__mirror_stage__=0" + if "Op:Activation" in line: + op_checklist += 1 + assert exec_debug_str[i + 4] == "\t__mirror_stage__=0" + assert op_checklist == 2, \ + "Not all operator nodes have been verified on the mirror stage" + os.environ["MXNET_MEMORY_OPT"] = '0' + + +def test_mlp_attn(): + # x →→→ + →→→ tanh ⇒⇒⇒ + # ↑ + →→→ tanh ⇒⇒⇒ + # y_1 →→ ↑ + →→→ tanh ⇒⇒⇒ + # y_2 →→→→ ↑ ⋱ + # y_3 →→→→→→ + →→→ tanh ⇒⇒⇒ + # ↑ + # y_n →→→→→→→→→→ + os.environ["MXNET_MEMORY_OPT"] = '1' + x = mx.sym.Variable("x") + tmp, z = [], [] + num_steps = 5 + in_arg_shapes = {'x': (num_steps, num_hidden,)} + for i in range(num_steps): + y = mx.sym.Variable("y_t%d"%i) + tmp.append(mx.sym.broadcast_add(x, y, name="broadcast_add%d"%i)) + z.append(mx.sym.Activation(tmp[-1], act_type='tanh', + name="activation%d"%i)) + in_arg_shapes["y_t%d"%i] = (1, num_hidden,) + z = mx.sym.Group(z) + exec = z.simple_bind(mx.cpu(), 'write', **in_arg_shapes) + exec_debug_str = exec.debug_str().split('\n') + op_checklist = 0 + for i, line in enumerate(exec_debug_str): + for t in range(num_steps): + if line == "Op:broadcast_add, Name=broadcast_add%d"%t: + op_checklist += 1 + assert exec_debug_str[i + 5] == "\t__mirror_stage__=1" + if line == "Op:Activation, Name=activation%d"%t: + op_checklist += 1 + assert exec_debug_str[i + 4] == "\t__mirror_stage__=1" + assert op_checklist == 2 * num_steps, \ + "Not all operator nodes have been verified on the mirror stage" + os.environ["MXNET_MEMORY_OPT"] = '0' + + +def test_fc(): + # x →→→ tanh ⇒⇒⇒ tanh ⇒⇒⇒ FC + # →→→ tanh_ →→→ + # ↓ + # FC' + os.environ["MXNET_MEMORY_OPT"] = '1' + x = mx.sym.Variable("x") + y = mx.sym.Activation(x, act_type='tanh', name='y') + z = mx.sym.Activation(y, act_type='tanh', name='z') + z = mx.sym.FullyConnected(z, num_hidden=num_hidden) + exec = z.simple_bind(mx.cpu(), 'write', x=(num_hidden,)) + exec_debug_str = exec.debug_str().split('\n') + op_checklist = 0 + for i, line in enumerate(exec_debug_str): + if line == "Op:Activation, Name=y": + op_checklist += 1 + assert exec_debug_str[i + 4] == "\t__mirror_stage__=0" + if line == "Op:Activation, Name=z": + op_checklist += 1 + assert exec_debug_str[i + 4] == "\t__mirror_stage__=1" + if "Op:FullyConnected" in line: + op_checklist += 1 + assert exec_debug_str[i + 6] == "\t__mirror_stage__=0" + if "Op:_backward_FullyConnected" in line: + op_checklist += 1 + assert exec_debug_str[i + 3] == "\targ[1]=z_mirror(0)" + assert op_checklist == 4, \ + "Not all operator nodes have been verified on the mirror stage" + os.environ["MXNET_MEMORY_OPT"] = '0' + + +if __name__ == "__main__": + import nose + nose.runmodule()