Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
reuse memory pool for cachedop
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong committed May 23, 2018
1 parent f655ca9 commit b1bf748
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 26 deletions.
11 changes: 8 additions & 3 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ struct CachedOp::CachedOpState {
std::vector<OpStatePtr> op_states;
std::vector<std::shared_ptr<exec::OpExecutor> > execs;
std::vector<imperative::EngineOprSeg> opr_segs;
std::multimap<size_t, NDArray> fwd_reuse_pool;
std::multimap<size_t, NDArray> bwd_reuse_pool;
};

CachedOp::CachedOp(
Expand Down Expand Up @@ -495,9 +497,10 @@ void CachedOp::StaticAllocMemory(
}
}

imperative::AllocateMemory(
auto& reuse_pool = keep_fwd ? state.bwd_reuse_pool : state.fwd_reuse_pool;
reuse_pool = imperative::AllocateMemory(
g, idx, default_ctx, start_eid, end_eid, mem_plan,
state.arrays, &state.array_reqs);
state.arrays, &state.array_reqs, std::move(reuse_pool));

state.recording = recording;
if (keep_fwd) {
Expand Down Expand Up @@ -929,7 +932,9 @@ void CachedOp::StaticBackward(
const auto& idx = g.indexed_graph();
auto num_forward_nodes = state.info.fwd_graph.indexed_graph().num_nodes();

if (!state.bwd_alloc || !match) StaticAllocMemory(state_ptr, true, true);
if (!state.bwd_alloc || !match) {
StaticAllocMemory(state_ptr, true, true);
}

if (config_.static_shape) {
for (size_t i = 0; i < config_.param_indices.ndim(); ++i) {
Expand Down
49 changes: 31 additions & 18 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <utility>
#include <algorithm>
#include <vector>
#include <map>
#include <string>
#include "../executor/graph_executor.h"
#include "../executor/exec_pass.h"
Expand Down Expand Up @@ -776,18 +777,22 @@ inline MemoryPlanVector PlanMemory(
}


inline void AllocateMemory(const nnvm::Graph& g,
const nnvm::IndexedGraph& idx,
const Context& default_ctx,
const uint32_t entry_start, const uint32_t entry_end,
const MemoryPlanVector& mem_plan,
const std::vector<NDArray*>& arrays,
std::vector<OpReqType> *array_reqs) {
inline std::multimap<size_t, NDArray> AllocateMemory(
const nnvm::Graph& g,
const nnvm::IndexedGraph& idx,
const Context& default_ctx,
const uint32_t entry_start, const uint32_t entry_end,
const MemoryPlanVector& mem_plan,
const std::vector<NDArray*>& arrays,
std::vector<OpReqType> *array_reqs,
std::multimap<size_t, NDArray>&& pool = std::multimap<size_t, NDArray>()) {
using namespace nnvm;
const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
const auto& shapes = g.GetAttr<ShapeVector>("shape");
const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");

std::multimap<size_t, NDArray> new_pool;

for (uint32_t i = entry_start; i < entry_end; ++i) {
if (mem_plan[i].storage_id == exec::kExternalStorageID) continue;
CHECK(arrays[i]->is_none());
Expand All @@ -796,22 +801,30 @@ inline void AllocateMemory(const nnvm::Graph& g,
shapes[i], default_ctx, true, dtypes[i]);
continue;
}

CHECK_EQ(stypes[i], kDefaultStorage);
if (mem_plan[i].root == i) {
CHECK_GT(mem_plan[i].size, 0);
NDArray buff(TShape({static_cast<nnvm::dim_t>(mem_plan[i].size)}),
default_ctx, true, mshadow::kUint8);
*arrays[i] = buff.AsArray(shapes[i], dtypes[i]);
continue;
}

CHECK_GE(mem_plan[mem_plan[i].root].storage_id, 0);
*arrays[i] = arrays[mem_plan[i].root]->AsArray(shapes[i], dtypes[i]);
if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) {
array_reqs->at(i) = kWriteInplace;
auto iter = pool.lower_bound(mem_plan[i].size);
if (iter != pool.end()) {
*arrays[i] = iter->second.AsArray(shapes[i], dtypes[i]);
new_pool.insert(*iter);
pool.erase(iter);
} else {
NDArray buff(TShape({static_cast<nnvm::dim_t>(mem_plan[i].size)}),
default_ctx, true, mshadow::kUint8);
*arrays[i] = buff.AsArray(shapes[i], dtypes[i]);
new_pool.insert({mem_plan[i].size, buff});
}
} else {
CHECK_GE(mem_plan[mem_plan[i].root].storage_id, 0);
*arrays[i] = arrays[mem_plan[i].root]->AsArray(shapes[i], dtypes[i]);
if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) {
array_reqs->at(i) = kWriteInplace;
}
}
}

return new_pool;
}

inline void SetupOpExec(
Expand Down
19 changes: 14 additions & 5 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,15 +1069,15 @@ def test_zero_grad():
grad = net.collect_params()['test_zero_grad_weight'].grad()
assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0)

def test_hybrid_static_memory():
def check_hybrid_static_memory(**kwargs):
x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
x.attach_grad()

net1 = gluon.model_zoo.vision.get_resnet(
1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context())
net2 = gluon.model_zoo.vision.get_resnet(
1, 18, pretrained=True, prefix='net_', ctx=mx.context.current_context())
net2.hybridize(use_static_memory=True)
net2.hybridize(**kwargs)
net1(x)
net2(x)

Expand All @@ -1097,23 +1097,32 @@ def test(net, x):
for key in grads1:
assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5)

def test_hybrid_static_memory():
check_hybrid_static_memory()
check_hybrid_static_memory(static_alloc=True)
check_hybrid_static_memory(static_alloc=True, static_shape=True)

def test_hybrid_static_memory_switching():
x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
def check_hybrid_static_memory_switching(**kwargs):
net = gluon.model_zoo.vision.get_resnet(
1, 18, pretrained=True, ctx=mx.context.current_context())
net.hybridize(use_static_memory=True)
net.hybridize(**kwargs)

x = mx.nd.random.uniform(shape=(4, 3, 32, 32))
net(x)
with mx.autograd.record():
y = net(x)
y.backward()
x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
net(x)
with mx.autograd.record():
y = net(x)
y.backward()
mx.nd.waitall()

def test_hybrid_static_memory_switching():
check_hybrid_static_memory_switching()
check_hybrid_static_memory_switching(static_alloc=True)
check_hybrid_static_memory_switching(static_alloc=True, static_shape=True)

@with_seed()
def test_hook():
Expand Down

0 comments on commit b1bf748

Please sign in to comment.