Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Improve the backward mirroring implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ArmageddonKnight committed May 4, 2020
1 parent 586c8ab commit 2509aaf
Show file tree
Hide file tree
Showing 15 changed files with 910 additions and 249 deletions.
11 changes: 5 additions & 6 deletions docs/static_site/src/pages/api/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,13 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
- The maximum size of an NDArray slice in terms of number of parameters.
- This parameter is used to slice an NDArray before synchronizing through P3Store (dist_p3).

## Memonger
## Memory Optimizations

* MXNET_BACKWARD_DO_MIRROR
* MXNET_MEMORY_OPT
- Values: 0(false) or 1(true) ```(default=0)```
- MXNet uses mirroring concept to save memory. Normally backward pass needs some forward input and it is stored in memory but you can choose to release this saved input and recalculate it in backward pass when needed. This basically trades off the computation for memory consumption.
- This parameter decides whether to do `mirror` during training for saving device memory.
- When set to `1`, during forward propagation, graph executor will `mirror` some layer's feature map and drop others, but it will re-compute this dropped feature maps when needed.
- `MXNET_BACKWARD_DO_MIRROR=1` will save 30%~50% of device memory, but retains about 95% of running speed.
- When set to `1`, MXNet will adopt various approaches to reduce the memory consumption of the model. For example, it uses the mirroring concept to save memory: Normally the backward pass needs some forward inputs to compute the gradients. Those inputs have to be stashed in memory and persistent throughout the traianing process. However, you can choose to release those saved inputs and recalculate them in the backward pass when needed. This basically trades off the computation for memory consumption. When set to `1`, during forward propagation, the graph executor will `mirror` some layers' feature maps and drop others, but it will re-compute this dropped feature maps when needed.
- This parameter decides whether to do `mirror` and/or data encodings during training for saving device memory.
- `MXNET_MEMORY_OPT=1` will save 30%~50% of device memory, but retains about 95% of running speed.
- One extension of `mirror` in MXNet is called [memonger technology](https://arxiv.org/abs/1604.06174), it will only use O(sqrt(N)) memory at 75% running speed. Checkout the code [here](https://github.com/dmlc/mxnet-memonger).

## Control the profiler
Expand Down
11 changes: 6 additions & 5 deletions example/image-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,12 @@ An over sized batch size may result in out of GPU memory. The common error
message is `cudaMalloc failed: out of memory`. Now we can
- Reduce the batch size
- Set the environment variable `MXNET_BACKWARD_DO_MIRROR` to 1. It trades off
computation for memory consumption. For example, with batch size 64,
inception-v3 uses 10G memory and trains 30 image/sec on a single K80 GPU. When
mirroring is enabled, with 10G GPU memory consumption, we can run inception-v3
using batch size 128. The cost is that the speed reduces to 27 images/sec.
- Set the environment variable `MXNET_MEMORY_OPT=1` to perform a series of
memory optimizations (e.g., trades off computation for memory consumption).
For example, with batch size 64, inception-v3 uses 10G memory and trains 30
image/sec on a single K80 GPU. When mirroring is enabled, with 10G GPU memory
consumption, we can run inception-v3 using batch size 128. The cost is that
the speed reduces to 27 images/sec.
## History
Expand Down
5 changes: 5 additions & 0 deletions python/mxnet/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,11 @@ def __call__(self, inputs, states):
name='%so'%name)
next_c = symbol._internal._plus(forget_gate * states[1], in_gate * in_transform,
name='%sstate'%name)
next_c._set_attr({'force_mirroring' : '0'})
# Cell states are excluded from being mirrored. The reason is because
# they do not pass through the fully-connected layers and will
# significantly increase the overall mirroring depth, incurring large
# performance overhead.
next_h = symbol._internal._mul(out_gate, symbol.Activation(next_c, act_type="tanh"),
name='%sout'%name)

Expand Down
37 changes: 19 additions & 18 deletions src/executor/exec_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,16 +273,17 @@ namespace pass {
/*!
* \brief Get the gradient graph whose outputs are gradients of xs wrt to ys.
* \param graph The input graph.
* \param ys The entries we want to take gradient from.
* \param xs The input to take gradient with respect to.
* \param ys_out_grad The symbol for additional gradient to be propagate back to y.
* \param aggregate_fun Aggregation function applied to aggregate the inputs.
* \param mirror_fun Optional mirror function to do mirror optimization and save memory.
* \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same as like.
* \param zero_ops Optional, list of operators that outputs a single zero array. The first one
* must be zeros_like.
* \param copy_op_str Optional, name of the copy operation required to handle duplicates
* on the edge of the graph
* \param ys The entries to take gradient from.
* \param xs The entries to take gradient with respect to.
* \param ys_out_grad The output gradients of ys.
* \param aggregate_fun The aggregation function used for summing gradients.
* \param mirror_fun The backward mirroring function that does mirroring to save memory.
* \param zero_ops The list of operators that output a single zero array, used
* for generating zero gradient nodes. The first operator must
* be zero_like.
* \param copy_op_str The name of the copy operator that handle gradient duplicates.
* \param in_arg_shapes The shapes of input arguments, used for shape inference.
* \param in_arg_dtpyes The data types of input arguments, used for data type inference.
* \return A new graph, whose outputs correspond to inputs of xs.
*/
inline Graph MXGradient(
Expand All @@ -292,27 +293,27 @@ inline Graph MXGradient(
std::vector<NodeEntry> ys_out_grad,
std::function<NodeEntry(std::vector<NodeEntry>&& inputs)> aggregate_fun = nullptr,
std::function<int(const Node& node)> mirror_fun = nullptr,
std::function<NodeEntry(const NodeEntry& src, const NodeEntry &like)>
attr_hint_fun = nullptr,
std::vector<const Op*> zero_ops = std::vector<const Op*>(),
std::string copy_op_str = std::string()) {
std::string copy_op_str = std::string(),
mxnet::ShapeVector in_arg_shapes = mxnet::ShapeVector(),
DTypeVector in_arg_dtypes = DTypeVector()) {
graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys));
graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs));
graph.attrs["grad_ys_out_grad"] = std::make_shared<any>(std::move(ys_out_grad));
graph.attrs["in_arg_shapes"] = std::make_shared<any>(std::move(in_arg_shapes));
graph.attrs["in_arg_dtypes"] = std::make_shared<any>(std::move(in_arg_dtypes));

if (aggregate_fun != nullptr) {
graph.attrs["grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun);
}
if (mirror_fun != nullptr) {
graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun);
}
if (attr_hint_fun != nullptr) {
graph.attrs["attr_hint_fun"] = std::make_shared<any>(attr_hint_fun);
graph.attrs["mirror_fun"] = std::make_shared<any>(mirror_fun);
}
if (zero_ops.size()) {
graph.attrs["zero_ops"] = std::make_shared<any>(std::move(zero_ops));
}
if (copy_op_str != std::string()) {
graph.attrs["copy_op"] = std::make_shared<any>(std::move(copy_op_str));
graph.attrs["copy_op_str"] = std::make_shared<any>(std::move(copy_op_str));
}
return ApplyPass(std::move(graph), "MXGradient");
}
Expand Down
127 changes: 92 additions & 35 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,28 +302,15 @@ nnvm::NodeEntry AggregateGradient(std::vector<nnvm::NodeEntry>&& v) {
}
}

template<typename ValueType>
inline ValueType get_node_attr(
const nnvm::Node& node,
const std::string& key, ValueType default_value) {
auto it = node.attrs.dict.find(key);
if (it == node.attrs.dict.end()) {
return default_value;
} else {
ValueType ret;
dmlc::parameter::FieldEntry<ValueType> e;
e.Init(key, &ret, ret);
e.Set(&ret, it->second);
return ret;
}
}

/*!
* \brief Create the graph for backward pass.
* This is triggered by both simple_bind and bind flows.
*/
nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
const std::vector<OpReqType>& grad_req_types) {
const std::vector<OpReqType>& grad_req_types,
const ShapeVector& in_arg_shapes,
const nnvm::DTypeVector& in_arg_dtypes) {
using nnvm::ObjectPtr;
using nnvm::NodeEntry;
// initial information
Expand Down Expand Up @@ -356,19 +343,28 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
}
}

int do_mirror = dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0);
auto need_mirror = [do_mirror](const nnvm::Node& node) -> int {
if (node.is_variable()) return 0;
const std::string& type = node.attrs.op->name;
if (type == "Dropout") return false;
if (get_node_attr(node, "__force_mirroring__", false)) return true;
if (do_mirror == 0) return false;
if (type == "Convolution") return false;
if (type == "FullyConnected") return false;
if (type == "Concat") return false;
if (type == "SoftmaxOutput") return false;
return true;
};
std::function<int(const nnvm::Node&)> need_mirror =
[](const nnvm::Node& node) -> int {
if (node.is_variable()) return false;
const std::string& type = node.attrs.op->name;
if (type == "Dropout") return false;
// We follow the hidden key attribute "force_mirroring" if it is
// explicitly set.
auto iter = node.attrs.dict.find("__force_mirroring__");
if (iter != node.attrs.dict.end()) {
bool do_mirror;
dmlc::parameter::FieldEntry<bool> e;
e.Init("__force_mirroring__", &do_mirror, do_mirror);
e.Set(&do_mirror, iter->second);
return do_mirror;
}
if (type == "Embedding") return false;
if (type == "Convolution") return false;
if (type == "FullyConnected") return false;
if (type == "Concat") return false;
if (type == "SoftmaxOutput") return false;
return true;
};

std::vector<const nnvm::Op*> zero_ops;
zero_ops.push_back(nnvm::Op::Get("zeros_like"));
Expand All @@ -377,8 +373,11 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
// take gradient
nnvm::Graph g_grad = nnvm::pass::MXGradient(
g, symbol.outputs, xs, head_grad_entry_,
AggregateGradient, need_mirror, nullptr,
zero_ops, "_copy");
AggregateGradient,
dmlc::GetEnv("MXNET_MEMORY_OPT", 0) ? need_mirror : nullptr,
zero_ops, "_copy",
in_arg_shapes, in_arg_dtypes);

CHECK_EQ(g_grad.outputs.size(), xs.size());
for (const auto &e : g_grad.outputs) {
g.outputs.push_back(e);
Expand Down Expand Up @@ -414,8 +413,37 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
std::vector<Context> aux_state_ctxes(aux_states.size());
std::transform(aux_states.begin(), aux_states.end(), aux_state_ctxes.begin(), get_ctx1);

// Record the shapes and data types of the input arguments in the source graph
// (i.e., the graph prior to the Gradient pass). Such information is need by
// the backward mirroring algorithm for shape and data type inference.
nnvm::Graph src;
src.outputs = symbol.outputs;
const nnvm::IndexedGraph& src_idx = src.indexed_graph();
const std::unordered_set<uint32_t>& src_mutable_nodes = src_idx.mutable_input_nodes();
size_t src_arg_top = 0, src_aux_top = 0;
ShapeVector src_arg_shapes;
nnvm::DTypeVector src_arg_dtypes;
const size_t src_num_forward_inputs = symbol.ListInputs(nnvm::Symbol::kAll).size();

for (size_t i = 0; i < src_num_forward_inputs; ++i) {
const uint32_t nid = src_idx.input_nodes().at(i);

if (src_mutable_nodes.count(nid)) {
CHECK_LT(src_aux_top, aux_states.size());
src_arg_shapes.push_back(aux_states[src_aux_top].shape());
src_arg_dtypes.push_back(aux_states[src_aux_top].dtype());
++src_aux_top;
} else {
CHECK_LT(src_arg_top, in_args.size());
src_arg_shapes.push_back(in_args[src_arg_top].shape());
src_arg_dtypes.push_back(in_args[src_arg_top].dtype());
++src_arg_top;
}
}

nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes,
arg_grad_ctxes, aux_state_ctxes, grad_req_types);
arg_grad_ctxes, aux_state_ctxes, grad_req_types,
src_arg_shapes, src_arg_dtypes);

// create arg_shapes and arg_dtypes for shape and type inferences
const auto& idx = g.indexed_graph();
Expand Down Expand Up @@ -811,8 +839,34 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
std::unordered_map<std::string, NDArray>* shared_buffer,
Executor* shared_exec,
const nnvm::NodeEntryMap<NDArray>& feed_dict) {
// Record the shapes and data types of the input arguments in the source graph
// (i.e., the graph prior to the Gradient pass). Such information is need by
// the backward mirroring algorithm for shape and data type inference.
nnvm::Graph src;
src.outputs = symbol.outputs;
const nnvm::IndexedGraph& src_idx = src.indexed_graph();
ShapeVector src_arg_shapes(src_idx.input_nodes().size(), TShape());
nnvm::DTypeVector src_arg_dtypes(src_idx.input_nodes().size(), -1);
const size_t src_num_forward_inputs = symbol.ListInputs(nnvm::Symbol::kAll).size();

for (size_t i = 0; i < src_num_forward_inputs; ++i) {
const uint32_t nid = src_idx.input_nodes().at(i);
const std::string& name = src_idx[nid].source->attrs.name;
std::unordered_map<std::string, TShape>::const_iterator
arg_shape_iter = arg_shape_map.find(name);
std::unordered_map<std::string, int>::const_iterator
arg_dtype_iter = arg_dtype_map.find(name);
if (arg_shape_iter != arg_shape_map.end()) {
src_arg_shapes[i] = arg_shape_iter->second;
}
if (arg_dtype_iter != arg_dtype_map.end()) {
src_arg_dtypes[i] = arg_dtype_iter->second;
}
}

nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes,
aux_state_ctxes, grad_req_types);
aux_state_ctxes, grad_req_types,
src_arg_shapes, src_arg_dtypes);

// The following code of shape and dtype inferences and argument
// initialization is for simple_bind only. Regular bind operation
Expand Down Expand Up @@ -1007,9 +1061,12 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::vector<OpReqType>& grad_req_types) {
const std::vector<OpReqType>& grad_req_types,
const ShapeVector& in_arg_shapes,
const nnvm::DTypeVector& in_arg_dtypes) {
// setup gradient
nnvm::Graph g = InitFullGraph(symbol, grad_req_types);
nnvm::Graph g = InitFullGraph(symbol, grad_req_types,
in_arg_shapes, in_arg_dtypes);

#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC && !defined(_WIN32)
if (default_ctx.dev_mask() == Context::kGPU && dmlc::GetEnv("MXNET_USE_FUSION", true)) {
Expand Down
8 changes: 6 additions & 2 deletions src/executor/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,14 @@ class GraphExecutor : public Executor {
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::vector<OpReqType>& grad_req_types);
const std::vector<OpReqType>& grad_req_types,
const ShapeVector& in_arg_shapes,
const nnvm::DTypeVector& in_arg_dtypes);
// intialize the full graph for simple bind, including gradient
Graph InitFullGraph(nnvm::Symbol symbol,
const std::vector<OpReqType>& grad_req_types);
const std::vector<OpReqType>& grad_req_types,
const ShapeVector& in_arg_shapes,
const nnvm::DTypeVector& in_arg_dtypes);
// initialize the cached operator
void InitCachedOps();
// initialize the opr segments for bulk exec
Expand Down
2 changes: 1 addition & 1 deletion src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ void CreateBackwardGraph(nnvm::Graph* fwd_graph,

*grad_graph = pass::MXGradient(
*fwd_graph, fwd_graph->outputs, xs, *ograd_entries,
exec::AggregateGradient, nullptr, nullptr,
exec::AggregateGradient, nullptr,
zero_ops, "_copy");
}

Expand Down
2 changes: 1 addition & 1 deletion src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ std::vector<NDArray*> Imperative::Backward(

Graph g_graph = pass::MXGradient(
graph, graph.outputs, xs, ograd_entries,
exec::AggregateGradient, nullptr, nullptr,
exec::AggregateGradient, nullptr,
zero_ops, "_copy");
CHECK_EQ(g_graph.outputs.size(), xs.size());
for (const auto& e : g_graph.outputs) {
Expand Down
Loading

0 comments on commit 2509aaf

Please sign in to comment.