Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Improve the backward mirroring implementation (#18228)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArmageddonKnight authored May 21, 2020
1 parent 5343aef commit 4827de8
Show file tree
Hide file tree
Showing 17 changed files with 1,009 additions and 250 deletions.
6 changes: 6 additions & 0 deletions ci/windows/test_py3_cpu.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,9 @@ if ($LastExitCode -ne 0) { Throw ("Error running serial train tests, python exit
$env:MXNET_SAFE_ACCUMULATION=1
C:\Python37\python.exe -m pytest -v --durations=50 --cov-report xml:tests_unittest.xml --cov-append tests\python\unittest\test_operator.py::test_norm
if ($LastExitCode -ne 0) { Throw ("Error running unittest, python exited with status code " + ('{0:X}' -f $LastExitCode)) }

# Similar to the MXNET_SAFE_ACCUMULATION test case above. Need to explicitly
# set the environment variable for MXNET_MEMORY_OPT.
$env:MXNET_MEMORY_OPT=1
C:\Python37\python.exe -m pytest -v --durations=50 --cov-report xml:tests_unittest.xml --cov-append tests\python\unittest\test_memory_opt.py
if ($LastExitCode -ne 0) { Throw ("Error running unittest, python exited with status code " + ('{0:X}' -f $LastExitCode)) }
7 changes: 7 additions & 0 deletions ci/windows/test_py3_gpu.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,16 @@ C:\Python37\python.exe -m pytest -v -m 'not serial' -n 4 --durations=50 --cov-re
if ($LastExitCode -ne 0) { Throw ("Error running parallel tests, python exited with status code " + ('{0:X}' -f $LastExitCode)) }
C:\Python37\python.exe -m pytest -v -m 'serial' --durations=50 --cov-report xml:tests_train.xml --cov-append tests\python\train
if ($LastExitCode -ne 0) { Throw ("Error running serial tests, python exited with status code " + ('{0:X}' -f $LastExitCode)) }

# Adding this extra test since it's not possible to set env var on the fly in Windows.
$env:MXNET_SAFE_ACCUMULATION=1
C:\Python37\python.exe -m pytest -v --durations=50 --cov-report xml:tests_operator.xml --cov-append tests\python\gpu\test_operator_gpu.py::test_norm
if ($LastExitCode -ne 0) { Throw ("Error running tests, python exited with status code " + ('{0:X}' -f $LastExitCode)) }
C:\Python37\python.exe -m pytest -v --durations=50 --cov-report xml:tests_tvm_op.xml tests\python\gpu\test_tvm_op_gpu.py
if ($LastExitCode -ne 0) { Throw ("Error running TVM op tests, python exited with status code " + ('{0:X}' -f $LastExitCode)) }

# Similar to the MXNET_SAFE_ACCUMULATION test case above. Need to explicitly
# set the environment variable for MXNET_MEMORY_OPT.
$env:MXNET_MEMORY_OPT=1
C:\Python37\python.exe -m pytest -v --durations=50 --cov-report xml:tests_unittest.xml --cov-append tests\python\unittest\test_memory_opt.py
if ($LastExitCode -ne 0) { Throw ("Error running memory optimization tests, python exited with status code " + ('{0:X}' -f $LastExitCode)) }
6 changes: 5 additions & 1 deletion docs/static_site/src/pages/api/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
- The maximum size of an NDArray slice in terms of number of parameters.
- This parameter is used to slice an NDArray before synchronizing through P3Store (dist_p3).

## Memonger
## Memory Optimizations

* MXNET_BACKWARD_DO_MIRROR
- Values: 0(false) or 1(true) ```(default=0)```
Expand All @@ -199,6 +199,10 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
- `MXNET_BACKWARD_DO_MIRROR=1` will save 30%~50% of device memory, but retains about 95% of running speed.
- One extension of `mirror` in MXNet is called [memonger technology](https://arxiv.org/abs/1604.06174), it will only use O(sqrt(N)) memory at 75% running speed. Checkout the code [here](https://github.com/dmlc/mxnet-memonger).

* MXNET_MEMORY_OPT
- Values: 0(no optimizations) or 1(highest optimization level) ```(default=0)```
- If set to '1', various optimizations on memory consumption will be enabled.

## Control the profiler

The following environments can be used to profile the application without changing code. Execution options may affect the granularity of profiling result. If you need profiling result of every operator, please set `MXNET_EXEC_BULK_EXEC_INFERENCE`, `MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN` and `MXNET_EXEC_BULK_EXEC_TRAIN` to 0.
Expand Down
11 changes: 6 additions & 5 deletions example/image-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,12 @@ An over sized batch size may result in out of GPU memory. The common error
message is `cudaMalloc failed: out of memory`. Now we can
- Reduce the batch size
- Set the environment variable `MXNET_BACKWARD_DO_MIRROR` to 1. It trades off
computation for memory consumption. For example, with batch size 64,
inception-v3 uses 10G memory and trains 30 image/sec on a single K80 GPU. When
mirroring is enabled, with 10G GPU memory consumption, we can run inception-v3
using batch size 128. The cost is that the speed reduces to 27 images/sec.
- Set the environment variable `MXNET_MEMORY_OPT=1` to perform a series of
memory optimizations (e.g., trades off computation for memory consumption).
For example, with batch size 64, inception-v3 uses 10G memory and trains 30
image/sec on a single K80 GPU. When mirroring is enabled, with 10G GPU memory
consumption, we can run inception-v3 using batch size 128. The cost is that
the speed reduces to 27 images/sec.
## History
Expand Down
5 changes: 5 additions & 0 deletions python/mxnet/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,11 @@ def __call__(self, inputs, states):
name='%so'%name)
next_c = symbol._internal._plus(forget_gate * states[1], in_gate * in_transform,
name='%sstate'%name)
next_c._set_attr(force_mirroring='0')
# Cell states are excluded from being mirrored. The reason is because
# they do not pass through the fully-connected layers and will
# significantly increase the overall mirroring depth, incurring large
# performance overhead.
next_h = symbol._internal._mul(out_gate, symbol.Activation(next_c, act_type="tanh"),
name='%sout'%name)

Expand Down
37 changes: 19 additions & 18 deletions src/executor/exec_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,16 +273,17 @@ namespace pass {
/*!
* \brief Get the gradient graph whose outputs are gradients of xs wrt to ys.
* \param graph The input graph.
* \param ys The entries we want to take gradient from.
* \param xs The input to take gradient with respect to.
* \param ys_out_grad The symbol for additional gradient to be propagate back to y.
* \param aggregate_fun Aggregation function applied to aggregate the inputs.
* \param mirror_fun Optional mirror function to do mirror optimization and save memory.
* \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same as like.
* \param zero_ops Optional, list of operators that outputs a single zero array. The first one
* must be zeros_like.
* \param copy_op_str Optional, name of the copy operation required to handle duplicates
* on the edge of the graph
* \param ys The entries to take gradient from.
* \param xs The entries to take gradient with respect to.
* \param ys_out_grad The output gradients of ys.
* \param aggregate_fun The aggregation function used for summing gradients.
* \param mirror_fun The backward mirroring function that does mirroring to save memory.
* \param zero_ops The list of operators that output a single zero array, used
* for generating zero gradient nodes. The first operator must
* be zero_like.
* \param copy_op_str The name of the copy operator that handle gradient duplicates.
* \param in_arg_shapes The shapes of input arguments, used for shape inference.
* \param in_arg_dtpyes The data types of input arguments, used for data type inference.
* \return A new graph, whose outputs correspond to inputs of xs.
*/
inline Graph MXGradient(
Expand All @@ -292,27 +293,27 @@ inline Graph MXGradient(
std::vector<NodeEntry> ys_out_grad,
std::function<NodeEntry(std::vector<NodeEntry>&& inputs)> aggregate_fun = nullptr,
std::function<int(const Node& node)> mirror_fun = nullptr,
std::function<NodeEntry(const NodeEntry& src, const NodeEntry &like)>
attr_hint_fun = nullptr,
std::vector<const Op*> zero_ops = std::vector<const Op*>(),
std::string copy_op_str = std::string()) {
std::string copy_op_str = std::string(),
mxnet::ShapeVector in_arg_shapes = mxnet::ShapeVector(),
DTypeVector in_arg_dtypes = DTypeVector()) {
graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys));
graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs));
graph.attrs["grad_ys_out_grad"] = std::make_shared<any>(std::move(ys_out_grad));
graph.attrs["in_arg_shapes"] = std::make_shared<any>(std::move(in_arg_shapes));
graph.attrs["in_arg_dtypes"] = std::make_shared<any>(std::move(in_arg_dtypes));

if (aggregate_fun != nullptr) {
graph.attrs["grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun);
}
if (mirror_fun != nullptr) {
graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun);
}
if (attr_hint_fun != nullptr) {
graph.attrs["attr_hint_fun"] = std::make_shared<any>(attr_hint_fun);
graph.attrs["mirror_fun"] = std::make_shared<any>(mirror_fun);
}
if (zero_ops.size()) {
graph.attrs["zero_ops"] = std::make_shared<any>(std::move(zero_ops));
}
if (copy_op_str != std::string()) {
graph.attrs["copy_op"] = std::make_shared<any>(std::move(copy_op_str));
graph.attrs["copy_op_str"] = std::make_shared<any>(std::move(copy_op_str));
}
return ApplyPass(std::move(graph), "MXGradient");
}
Expand Down
128 changes: 93 additions & 35 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,28 +302,15 @@ nnvm::NodeEntry AggregateGradient(std::vector<nnvm::NodeEntry>&& v) {
}
}

template<typename ValueType>
inline ValueType get_node_attr(
const nnvm::Node& node,
const std::string& key, ValueType default_value) {
auto it = node.attrs.dict.find(key);
if (it == node.attrs.dict.end()) {
return default_value;
} else {
ValueType ret;
dmlc::parameter::FieldEntry<ValueType> e;
e.Init(key, &ret, ret);
e.Set(&ret, it->second);
return ret;
}
}

/*!
* \brief Create the graph for backward pass.
* This is triggered by both simple_bind and bind flows.
*/
nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
const std::vector<OpReqType>& grad_req_types) {
const std::vector<OpReqType>& grad_req_types,
const ShapeVector& in_arg_shapes,
const nnvm::DTypeVector& in_arg_dtypes) {
using nnvm::ObjectPtr;
using nnvm::NodeEntry;
// initial information
Expand Down Expand Up @@ -356,19 +343,28 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
}
}

int do_mirror = dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0);
auto need_mirror = [do_mirror](const nnvm::Node& node) -> int {
if (node.is_variable()) return 0;
const std::string& type = node.attrs.op->name;
if (type == "Dropout") return false;
if (get_node_attr(node, "__force_mirroring__", false)) return true;
if (do_mirror == 0) return false;
if (type == "Convolution") return false;
if (type == "FullyConnected") return false;
if (type == "Concat") return false;
if (type == "SoftmaxOutput") return false;
return true;
};
std::function<int(const nnvm::Node&)> need_mirror =
[](const nnvm::Node& node) -> int {
if (node.is_variable()) return false;
const std::string& type = node.attrs.op->name;
if (type == "Dropout") return false;
// We follow the hidden key attribute "force_mirroring" if it is
// explicitly set.
auto iter = node.attrs.dict.find("__force_mirroring__");
if (iter != node.attrs.dict.end()) {
bool do_mirror;
dmlc::parameter::FieldEntry<bool> e;
e.Init("__force_mirroring__", &do_mirror, do_mirror);
e.Set(&do_mirror, iter->second);
return do_mirror;
}
if (type == "Embedding") return false;
if (type == "Convolution") return false;
if (type == "FullyConnected") return false;
if (type == "Concat") return false;
if (type == "SoftmaxOutput") return false;
return true;
};

std::vector<const nnvm::Op*> zero_ops;
zero_ops.push_back(nnvm::Op::Get("zeros_like"));
Expand All @@ -377,8 +373,12 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
// take gradient
nnvm::Graph g_grad = nnvm::pass::MXGradient(
g, symbol.outputs, xs, head_grad_entry_,
AggregateGradient, need_mirror, nullptr,
zero_ops, "_copy");
AggregateGradient,
(dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0) ||
dmlc::GetEnv("MXNET_MEMORY_OPT", 0)) ? need_mirror : nullptr,
zero_ops, "_copy",
in_arg_shapes, in_arg_dtypes);

CHECK_EQ(g_grad.outputs.size(), xs.size());
for (const auto &e : g_grad.outputs) {
g.outputs.push_back(e);
Expand Down Expand Up @@ -414,8 +414,37 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
std::vector<Context> aux_state_ctxes(aux_states.size());
std::transform(aux_states.begin(), aux_states.end(), aux_state_ctxes.begin(), get_ctx1);

// Record the shapes and data types of the input arguments in the source graph
// (i.e., the graph prior to the Gradient pass). Such information is need by
// the backward mirroring algorithm for shape and data type inference.
nnvm::Graph src;
src.outputs = symbol.outputs;
const nnvm::IndexedGraph& src_idx = src.indexed_graph();
const std::unordered_set<uint32_t>& src_mutable_nodes = src_idx.mutable_input_nodes();
size_t src_arg_top = 0, src_aux_top = 0;
ShapeVector src_arg_shapes;
nnvm::DTypeVector src_arg_dtypes;
const size_t src_num_forward_inputs = symbol.ListInputs(nnvm::Symbol::kAll).size();

for (size_t i = 0; i < src_num_forward_inputs; ++i) {
const uint32_t nid = src_idx.input_nodes().at(i);

if (src_mutable_nodes.count(nid)) {
CHECK_LT(src_aux_top, aux_states.size());
src_arg_shapes.push_back(aux_states[src_aux_top].shape());
src_arg_dtypes.push_back(aux_states[src_aux_top].dtype());
++src_aux_top;
} else {
CHECK_LT(src_arg_top, in_args.size());
src_arg_shapes.push_back(in_args[src_arg_top].shape());
src_arg_dtypes.push_back(in_args[src_arg_top].dtype());
++src_arg_top;
}
}

nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes,
arg_grad_ctxes, aux_state_ctxes, grad_req_types);
arg_grad_ctxes, aux_state_ctxes, grad_req_types,
src_arg_shapes, src_arg_dtypes);

// create arg_shapes and arg_dtypes for shape and type inferences
const auto& idx = g.indexed_graph();
Expand Down Expand Up @@ -811,8 +840,34 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
std::unordered_map<std::string, NDArray>* shared_buffer,
Executor* shared_exec,
const nnvm::NodeEntryMap<NDArray>& feed_dict) {
// Record the shapes and data types of the input arguments in the source graph
// (i.e., the graph prior to the Gradient pass). Such information is need by
// the backward mirroring algorithm for shape and data type inference.
nnvm::Graph src;
src.outputs = symbol.outputs;
const nnvm::IndexedGraph& src_idx = src.indexed_graph();
ShapeVector src_arg_shapes(src_idx.input_nodes().size(), TShape());
nnvm::DTypeVector src_arg_dtypes(src_idx.input_nodes().size(), -1);
const size_t src_num_forward_inputs = symbol.ListInputs(nnvm::Symbol::kAll).size();

for (size_t i = 0; i < src_num_forward_inputs; ++i) {
const uint32_t nid = src_idx.input_nodes().at(i);
const std::string& name = src_idx[nid].source->attrs.name;
std::unordered_map<std::string, TShape>::const_iterator
arg_shape_iter = arg_shape_map.find(name);
std::unordered_map<std::string, int>::const_iterator
arg_dtype_iter = arg_dtype_map.find(name);
if (arg_shape_iter != arg_shape_map.end()) {
src_arg_shapes[i] = arg_shape_iter->second;
}
if (arg_dtype_iter != arg_dtype_map.end()) {
src_arg_dtypes[i] = arg_dtype_iter->second;
}
}

nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes,
aux_state_ctxes, grad_req_types);
aux_state_ctxes, grad_req_types,
src_arg_shapes, src_arg_dtypes);

// The following code of shape and dtype inferences and argument
// initialization is for simple_bind only. Regular bind operation
Expand Down Expand Up @@ -1007,9 +1062,12 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::vector<OpReqType>& grad_req_types) {
const std::vector<OpReqType>& grad_req_types,
const ShapeVector& in_arg_shapes,
const nnvm::DTypeVector& in_arg_dtypes) {
// setup gradient
nnvm::Graph g = InitFullGraph(symbol, grad_req_types);
nnvm::Graph g = InitFullGraph(symbol, grad_req_types,
in_arg_shapes, in_arg_dtypes);

#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC && !defined(_WIN32)
if (default_ctx.dev_mask() == Context::kGPU && dmlc::GetEnv("MXNET_USE_FUSION", true)) {
Expand Down
8 changes: 6 additions & 2 deletions src/executor/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,14 @@ class GraphExecutor : public Executor {
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::vector<OpReqType>& grad_req_types);
const std::vector<OpReqType>& grad_req_types,
const ShapeVector& in_arg_shapes,
const nnvm::DTypeVector& in_arg_dtypes);
// intialize the full graph for simple bind, including gradient
Graph InitFullGraph(nnvm::Symbol symbol,
const std::vector<OpReqType>& grad_req_types);
const std::vector<OpReqType>& grad_req_types,
const ShapeVector& in_arg_shapes,
const nnvm::DTypeVector& in_arg_dtypes);
// initialize the cached operator
void InitCachedOps();
// initialize the opr segments for bulk exec
Expand Down
2 changes: 1 addition & 1 deletion src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ void CreateBackwardGraph(nnvm::Graph* fwd_graph,

*grad_graph = pass::MXGradient(
*fwd_graph, fwd_graph->outputs, xs, *ograd_entries,
exec::AggregateGradient, nullptr, nullptr,
exec::AggregateGradient, nullptr,
zero_ops, "_copy");
}

Expand Down
2 changes: 1 addition & 1 deletion src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ std::vector<NDArray*> Imperative::Backward(

Graph g_graph = pass::MXGradient(
graph, graph.outputs, xs, ograd_entries,
exec::AggregateGradient, nullptr, nullptr,
exec::AggregateGradient, nullptr,
zero_ops, "_copy");
CHECK_EQ(g_graph.outputs.size(), xs.size());
for (const auto& e : g_graph.outputs) {
Expand Down
Loading

0 comments on commit 4827de8

Please sign in to comment.