Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong committed May 7, 2018
1 parent 21a1c26 commit 60d6544
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 29 deletions.
37 changes: 19 additions & 18 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,33 +159,42 @@ class Imperative {
std::mutex mutex;
Context context;
GraphInfo info;
// Static memory only

bool initialized = false;
bool bwd_pending = false;
bool recording = false;
std::vector<NDArray> buff;
std::vector<NDArray*> arrays;
std::vector<OpReqType> array_reqs;
std::vector<std::shared_ptr<exec::OpExecutor> > execs;
std::vector<std::unique_ptr<engine::Opr, EngineOprDeleter> > engine_oprs;

void Reset(bool bwd_only);
void ResetStaticRuntime(bool keep_fwd);
};

DeviceState* GetDeviceState(const Context& ctx);
bool SetForwardGraph(GraphInfo* info,
const bool recording,
const std::vector<NDArray*>& inputs);
bool SetBackwardGraph(GraphInfo* info,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& inputs);
OpStatePtr DynamicForward(const Context& default_ctx,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs);
void DynamicBackward(const bool retain_graph,
const OpStatePtr& state,
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);
void StaticResetState(DeviceState* dev_state,
bool recording,
bool for_bwd);
bool recording,
bool keep_fwd);
void StaticRunOps(const Context& default_ctx,
const nnvm::Graph& g,
const DeviceState* dev_state,
size_t start_nid,
size_t end_nid);
bool SetForwardGraph(GraphInfo* info,
const bool recording,
const std::vector<NDArray*>& inputs);
bool SetBackwardGraph(GraphInfo* info,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& inputs);
OpStatePtr StaticForward(const Context& default_ctx,
const std::vector<NDArray*>& args,
const std::vector<NDArray*>& outputs);
Expand All @@ -194,14 +203,6 @@ class Imperative {
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);
OpStatePtr DynamicForward(const Context& default_ctx,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs);
void DynamicBackward(const bool retain_graph,
const OpStatePtr& state,
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);

CachedOpConfig config_;
nnvm::Graph fwd_graph_;
Expand Down
5 changes: 5 additions & 0 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ class OpStatePtr {
void reset() {
ptr_.reset();
}
/* \brief checks whether the managed object is managed only by the current
OpStatePtr instance */
bool unique() {
return ptr_.unique();
}
/* \brief Whether state is empty */
explicit operator bool() const {
return ptr_ ? true : false;
Expand Down
36 changes: 25 additions & 11 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -409,16 +409,17 @@ Imperative::CachedOp::DeviceState* Imperative::CachedOp::GetDeviceState(
return state;
}

void Imperative::CachedOp::DeviceState::Reset(bool for_bwd) {
void Imperative::CachedOp::DeviceState::ResetStaticRuntime(bool keep_fwd) {
size_t num_forward_nodes = info.fwd_graph.indexed_graph().num_nodes();
size_t num_forward_entries = info.fwd_graph.indexed_graph().num_node_entries();

for (size_t i = for_bwd ? num_forward_entries : 0; i < buff.size(); ++i) {

for (size_t i = keep_fwd ? num_forward_entries : 0; i < buff.size(); ++i) {
buff[i] = NDArray();
array_reqs[i] = kNullOp;
}
for (size_t i = 0; i < buff.size(); ++i) arrays[i] = &buff[i];
for (size_t i = for_bwd ? num_forward_nodes : 0; i < execs.size(); ++i) {
for (size_t i = keep_fwd ? num_forward_nodes : 0; i < execs.size(); ++i) {
execs[i].reset();
engine_oprs[i].reset();
}
Expand All @@ -427,19 +428,19 @@ void Imperative::CachedOp::DeviceState::Reset(bool for_bwd) {
void Imperative::CachedOp::StaticResetState(
DeviceState* dev_state,
bool recording,
bool for_bwd) {
bool keep_fwd) {
using namespace nnvm;
using namespace imperative;
nnvm::Graph& g = for_bwd ? dev_state->info.full_graph : dev_state->info.fwd_graph;
nnvm::Graph& g = keep_fwd ? dev_state->info.full_graph : dev_state->info.fwd_graph;
const auto& idx = g.indexed_graph();
const auto& vstorage_inplace = g.GetAttr<std::vector<int> >("storage_inplace_index");
const auto& mem_plan = g.GetAttr<MemoryPlanVector>(
for_bwd ? "backward_mem_plan" : (recording ? "full_mem_plan" : "forward_mem_plan"));
keep_fwd ? "backward_mem_plan" : (recording ? "full_mem_plan" : "forward_mem_plan"));
size_t start_nid =
for_bwd ? dev_state->info.fwd_graph.indexed_graph().num_nodes() : 0;
keep_fwd ? dev_state->info.fwd_graph.indexed_graph().num_nodes() : 0;
size_t end_nid = idx.num_nodes();
size_t start_eid =
for_bwd ? dev_state->info.fwd_graph.indexed_graph().num_node_entries() : 0;
keep_fwd ? dev_state->info.fwd_graph.indexed_graph().num_node_entries() : 0;
size_t end_eid = idx.num_node_entries();

for (size_t i = start_nid; i < end_nid; ++i) {
Expand Down Expand Up @@ -469,6 +470,7 @@ void Imperative::CachedOp::StaticResetState(
}

dev_state->initialized = true;
dev_state->bwd_pending = false;
dev_state->recording = recording;
}

Expand Down Expand Up @@ -538,13 +540,17 @@ OpStatePtr Imperative::CachedOp::StaticForward(
auto dev_state = GetDeviceState(default_ctx);
std::lock_guard<std::mutex> lock(dev_state->mutex);

CHECK(!dev_state->bwd_pending)
<< "Cannot forward for the second time before calling backward first "
<< "when use_static_memory=True.";

bool match = SetForwardGraph(&dev_state->info, recording, inputs);

nnvm::Graph& g = dev_state->info.fwd_graph;
const auto& idx = g.indexed_graph();

if (!(dev_state->initialized && dev_state->recording == recording && match)) {
dev_state->Reset(false);
dev_state->ResetStaticRuntime(false);

for (size_t i = 0; i < fwd_params_idx_.size(); ++i) {
auto nid = idx.input_nodes()[fwd_params_idx_[i]];
Expand Down Expand Up @@ -573,6 +579,8 @@ OpStatePtr Imperative::CachedOp::StaticForward(

StaticRunOps(default_ctx, g, dev_state, 0, idx.num_nodes());

dev_state->bwd_pending = recording;

return OpStatePtr();
}

Expand Down Expand Up @@ -785,20 +793,24 @@ void Imperative::CachedOp::StaticBackward(
using namespace nnvm;
using namespace imperative;


Context default_ctx = outputs[0]->ctx();

auto dev_state = GetDeviceState(default_ctx);
std::lock_guard<std::mutex> lock(dev_state->mutex);

CHECK(dev_state->bwd_pending)
<< "Must forward with is_recording=True before calling backward.";

bool match = SetBackwardGraph(&dev_state->info, reqs, inputs);
// TODO: check if param grads match
// TODO(eric): check if param grads match

nnvm::Graph& g = dev_state->info.full_graph;
const auto& idx = g.indexed_graph();
auto num_forward_nodes = dev_state->info.fwd_graph.indexed_graph().num_nodes();

if (!match) {
dev_state->Reset(true);
dev_state->ResetStaticRuntime(true);

for (auto i : fwd_params_idx_) {
const auto iter = fwd_input_to_grad_output_.find(i);
Expand Down Expand Up @@ -829,6 +841,8 @@ void Imperative::CachedOp::StaticBackward(
}

StaticRunOps(default_ctx, g, dev_state, num_forward_nodes, idx.num_nodes());

dev_state->bwd_pending = retain_graph;
}

void Imperative::CachedOp::Backward(
Expand Down
31 changes: 31 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,37 @@ def test_hybrid_multi_context():
net(mx.nd.zeros((1, 3, 32, 32), ctx=mx.cpu(0))).asnumpy()


def test_hybrid_static_memory():

x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
x.attach_grad()

net1 = gluon.model_zoo.vision.get_resnet(1, 18, pretrained=True, prefix='net_')
net2 = gluon.model_zoo.vision.get_resnet(1, 18, pretrained=True, prefix='net_')
net2.hybridize(use_static_memory=True)
net1(x)
net2(x)

net1.save_params('test.params')
net2.load_params('test.params')

def test(net, x):
with mx.autograd.record(False):
y = net(x)
y.backward()

grads = {k: v.grad() for k, v in net.collect_params().items() if v.grad_req != 'null'}

return y, grads

y1, grads1 = test(net1, x)
y2, grads2 = test(net2, x)

assert_allclose(y1.asnumpy(), y2.asnumpy())
for key in grads1:
assert_allclose(grads1[key].asnumpy(), grads2[key].asnumpy())


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 60d6544

Please sign in to comment.