Skip to content

Commit

Permalink
Add unoptimized symbol to executor for sharing (apache#16798)
Browse files Browse the repository at this point in the history
* Add unoptimized symbol to executor for sharing

* Copy the symbol in Reshape

* Added test for multiple reshapes
  • Loading branch information
ptrendx authored Nov 20, 2019
1 parent f1c6880 commit 61c8baf
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 9 deletions.
20 changes: 12 additions & 8 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ static const std::string GetDefaultSubgraphBackend() {
#endif
}

GraphExecutor::GraphExecutor() {
GraphExecutor::GraphExecutor(const nnvm::Symbol& symbol) {
log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false);
need_grad_ = false;
is_dynamic_ = false;
Expand All @@ -60,6 +60,7 @@ GraphExecutor::GraphExecutor() {
LOG(INFO) << "MXNET_SUBGRAPH_BACKEND=NONE is detected, subgraph backend is not in use";
}
engine_ref_ = Engine::_GetSharedRef();
symbol_ = symbol.Copy();
}

GraphExecutor::~GraphExecutor() {
Expand Down Expand Up @@ -888,10 +889,9 @@ Executor* GraphExecutor::Reshape(const bool partial_shaping,
std::vector<NDArray>* arg_grads,
std::vector<NDArray>* aux_states) {
nnvm::Graph g;
g.outputs = std::vector<nnvm::NodeEntry>(graph_.outputs.begin(),
graph_.outputs.begin() + num_forward_outputs_);
nnvm::Symbol symbol;
symbol.outputs = g.outputs;
symbol.outputs = symbol_.outputs;
g.outputs = symbol_.outputs;
const nnvm::IndexedGraph& idx = g.indexed_graph();
mxnet::ShapeVector arg_shapes(idx.input_nodes().size(), mxnet::TShape());
for (size_t i = 0; i < num_forward_inputs_; ++i) {
Expand Down Expand Up @@ -975,8 +975,8 @@ Executor* GraphExecutor::Reshape(const bool partial_shaping,
}
}
}
auto exec = new GraphExecutor();
exec->Init(symbol, default_ctx, ctx_map,
auto exec = new GraphExecutor(symbol);
exec->Init(symbol.Copy(), default_ctx, ctx_map,
*in_args, *arg_grads, grad_req_types, *aux_states,
this);
return exec;
Expand Down Expand Up @@ -1967,7 +1967,7 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol,
std::vector<NDArray>* aux_states,
std::unordered_map<std::string, NDArray>* shared_buffer,
Executor* shared_exec) {
auto exec = new exec::GraphExecutor();
auto exec = new exec::GraphExecutor(symbol);
bool init = false;
if (!exec->subgraph_property().empty()) {
static int verbose = dmlc::GetEnv("MXNET_SUBGRAPH_VERBOSE", 1);
Expand All @@ -1987,6 +1987,8 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol,
symbol = exec::BuildSubgraph(symbol, backend, arg_shape_map, arg_dtype_map, arg_stype_map,
default_ctx, group2ctx, &tmp_in_arg_ctxes, &tmp_arg_grad_ctxes,
&tmp_grad_req_types, &tmp_aux_state_ctxes, verbose);
// Subgraph cannot be recreated from unoptimized symbol
exec = new exec::GraphExecutor(symbol);
exec->Init(symbol.Copy(), default_ctx, group2ctx, tmp_in_arg_ctxes, tmp_arg_grad_ctxes,
tmp_aux_state_ctxes, arg_shape_map, arg_dtype_map, arg_stype_map,
tmp_grad_req_types, shared_arg_names, &tmp_in_args, &tmp_arg_grads,
Expand Down Expand Up @@ -2041,7 +2043,7 @@ Executor *Executor::Bind(nnvm::Symbol symbol,
const std::vector<OpReqType> &grad_req_type,
const std::vector<NDArray> &aux_states,
Executor* shared_exec) {
auto exec = new exec::GraphExecutor();
auto exec = new exec::GraphExecutor(symbol);
static int verbose = dmlc::GetEnv("MXNET_SUBGRAPH_VERBOSE", 1);
std::vector<NDArray> tmp_in_args = in_args;
std::vector<NDArray> tmp_arg_grad_store = arg_grad_store;
Expand All @@ -2056,6 +2058,8 @@ Executor *Executor::Bind(nnvm::Symbol symbol,
symbol = exec::BuildSubgraph(symbol, backend, default_ctx, group2ctx, &tmp_in_args,
&tmp_arg_grad_store, &tmp_grad_req_type, &tmp_aux_states,
verbose);
// Subgraph cannot be recreated from unoptimized symbol
exec = new exec::GraphExecutor(symbol);
}
}
exec->Init(symbol.Copy(), default_ctx, group2ctx, tmp_in_args, tmp_arg_grad_store,
Expand Down
5 changes: 4 additions & 1 deletion src/executor/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class GraphExecutor : public Executor {
public:
using Executor::MonitorCallback;

GraphExecutor();
explicit GraphExecutor(const nnvm::Symbol& symbol);
virtual ~GraphExecutor();
void Forward(bool is_train) override;
void PartialForward(bool is_train, int step, int *step_left) override;
Expand Down Expand Up @@ -267,6 +267,9 @@ class GraphExecutor : public Executor {
std::string subgraph_property_;
// ref of engine
std::shared_ptr<Engine> engine_ref_;
// Unoptimized copy of the symbol for sharing with
// child executors
nnvm::Symbol symbol_;
};

} // namespace exec
Expand Down
25 changes: 25 additions & 0 deletions tests/python/gpu/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,31 @@ def test_fusion_compiler_cache():
check_fused_symbol(a+b, ctx=mx.gpu(1), a=arr1, b=arr2)


@with_seed()
def test_fusion_reshape_executor():
a = mx.sym.Variable("data1")
b = mx.sym.Variable("data2")
c = a + b + 1
sym = mx.sym.relu(c)
orig_shape = (10,10)
e = sym.simple_bind(ctx=mx.gpu(), data1=orig_shape, data2=orig_shape)
data = mx.nd.zeros(orig_shape, ctx=mx.gpu())
out = e.forward(is_train=False)
assert out[0].sum().asscalar() == 100
changed_shape = (80, 2)
new_shape = {'data1': changed_shape, 'data2': changed_shape}
data = mx.nd.zeros(new_shape['data1'], ctx=mx.gpu())
f = e.reshape(allow_up_sizing=True, **new_shape)
out = f.forward(is_train=False, data1=data, data2=data)
assert out[0].sum().asscalar() == 160
# Reshape again
changed_shape = (30, 5)
new_shape = {'data1': changed_shape, 'data2': changed_shape}
data = mx.nd.zeros(new_shape['data1'], ctx=mx.gpu())
f = e.reshape(allow_up_sizing=True, **new_shape)
out = f.forward(is_train=False, data1=data, data2=data)
assert out[0].sum().asscalar() == 150

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 61c8baf

Please sign in to comment.