diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index b31ae7cf0cd7..7d5c2af0f881 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -50,7 +50,7 @@ static const std::string GetDefaultSubgraphBackend() { #endif } -GraphExecutor::GraphExecutor() { +GraphExecutor::GraphExecutor(const nnvm::Symbol& symbol) { log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false); need_grad_ = false; is_dynamic_ = false; @@ -60,6 +60,7 @@ GraphExecutor::GraphExecutor() { LOG(INFO) << "MXNET_SUBGRAPH_BACKEND=NONE is detected, subgraph backend is not in use"; } engine_ref_ = Engine::_GetSharedRef(); + symbol_ = symbol.Copy(); } GraphExecutor::~GraphExecutor() { @@ -888,10 +889,9 @@ Executor* GraphExecutor::Reshape(const bool partial_shaping, std::vector* arg_grads, std::vector* aux_states) { nnvm::Graph g; - g.outputs = std::vector(graph_.outputs.begin(), - graph_.outputs.begin() + num_forward_outputs_); nnvm::Symbol symbol; - symbol.outputs = g.outputs; + symbol.outputs = symbol_.outputs; + g.outputs = symbol_.outputs; const nnvm::IndexedGraph& idx = g.indexed_graph(); mxnet::ShapeVector arg_shapes(idx.input_nodes().size(), mxnet::TShape()); for (size_t i = 0; i < num_forward_inputs_; ++i) { @@ -975,8 +975,8 @@ Executor* GraphExecutor::Reshape(const bool partial_shaping, } } } - auto exec = new GraphExecutor(); - exec->Init(symbol, default_ctx, ctx_map, + auto exec = new GraphExecutor(symbol); + exec->Init(symbol.Copy(), default_ctx, ctx_map, *in_args, *arg_grads, grad_req_types, *aux_states, this); return exec; @@ -1967,7 +1967,7 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, std::vector* aux_states, std::unordered_map* shared_buffer, Executor* shared_exec) { - auto exec = new exec::GraphExecutor(); + auto exec = new exec::GraphExecutor(symbol); bool init = false; if (!exec->subgraph_property().empty()) { static int verbose = dmlc::GetEnv("MXNET_SUBGRAPH_VERBOSE", 1); @@ -1987,6 +1987,8 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, symbol = exec::BuildSubgraph(symbol, backend, arg_shape_map, arg_dtype_map, arg_stype_map, default_ctx, group2ctx, &tmp_in_arg_ctxes, &tmp_arg_grad_ctxes, &tmp_grad_req_types, &tmp_aux_state_ctxes, verbose); + // Subgraph cannot be recreated from unoptimized symbol + exec = new exec::GraphExecutor(symbol); exec->Init(symbol.Copy(), default_ctx, group2ctx, tmp_in_arg_ctxes, tmp_arg_grad_ctxes, tmp_aux_state_ctxes, arg_shape_map, arg_dtype_map, arg_stype_map, tmp_grad_req_types, shared_arg_names, &tmp_in_args, &tmp_arg_grads, @@ -2041,7 +2043,7 @@ Executor *Executor::Bind(nnvm::Symbol symbol, const std::vector &grad_req_type, const std::vector &aux_states, Executor* shared_exec) { - auto exec = new exec::GraphExecutor(); + auto exec = new exec::GraphExecutor(symbol); static int verbose = dmlc::GetEnv("MXNET_SUBGRAPH_VERBOSE", 1); std::vector tmp_in_args = in_args; std::vector tmp_arg_grad_store = arg_grad_store; @@ -2056,6 +2058,8 @@ Executor *Executor::Bind(nnvm::Symbol symbol, symbol = exec::BuildSubgraph(symbol, backend, default_ctx, group2ctx, &tmp_in_args, &tmp_arg_grad_store, &tmp_grad_req_type, &tmp_aux_states, verbose); + // Subgraph cannot be recreated from unoptimized symbol + exec = new exec::GraphExecutor(symbol); } } exec->Init(symbol.Copy(), default_ctx, group2ctx, tmp_in_args, tmp_arg_grad_store, diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index f150165796ad..bfa6980a8e29 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -58,7 +58,7 @@ class GraphExecutor : public Executor { public: using Executor::MonitorCallback; - GraphExecutor(); + explicit GraphExecutor(const nnvm::Symbol& symbol); virtual ~GraphExecutor(); void Forward(bool is_train) override; void PartialForward(bool is_train, int step, int *step_left) override; @@ -267,6 +267,9 @@ class GraphExecutor : public Executor { std::string subgraph_property_; // ref of engine std::shared_ptr engine_ref_; + // Unoptimized copy of the symbol for sharing with + // child executors + nnvm::Symbol symbol_; }; } // namespace exec diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py index 5606eb19a9c5..693336f22496 100644 --- a/tests/python/gpu/test_fusion.py +++ b/tests/python/gpu/test_fusion.py @@ -239,6 +239,31 @@ def test_fusion_compiler_cache(): check_fused_symbol(a+b, ctx=mx.gpu(1), a=arr1, b=arr2) +@with_seed() +def test_fusion_reshape_executor(): + a = mx.sym.Variable("data1") + b = mx.sym.Variable("data2") + c = a + b + 1 + sym = mx.sym.relu(c) + orig_shape = (10,10) + e = sym.simple_bind(ctx=mx.gpu(), data1=orig_shape, data2=orig_shape) + data = mx.nd.zeros(orig_shape, ctx=mx.gpu()) + out = e.forward(is_train=False) + assert out[0].sum().asscalar() == 100 + changed_shape = (80, 2) + new_shape = {'data1': changed_shape, 'data2': changed_shape} + data = mx.nd.zeros(new_shape['data1'], ctx=mx.gpu()) + f = e.reshape(allow_up_sizing=True, **new_shape) + out = f.forward(is_train=False, data1=data, data2=data) + assert out[0].sum().asscalar() == 160 + # Reshape again + changed_shape = (30, 5) + new_shape = {'data1': changed_shape, 'data2': changed_shape} + data = mx.nd.zeros(new_shape['data1'], ctx=mx.gpu()) + f = e.reshape(allow_up_sizing=True, **new_shape) + out = f.forward(is_train=False, data1=data, data2=data) + assert out[0].sum().asscalar() == 150 + if __name__ == '__main__': import nose nose.runmodule()