Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Static memory allocation for cached_op #10817

Merged
merged 28 commits into from
May 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -987,11 +987,6 @@ MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle,
int num_flags,
const char** keys,
const char** vals,
int num_inputs,
const char** input_names,
int num_params,
const char** param_names,
NDArrayHandle* params,
CachedOpHandle *out);
/*!
* \brief free cached operator
Expand Down
89 changes: 0 additions & 89 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,6 @@
#include "./ndarray.h"

namespace mxnet {
/*! \brief CachedOp Parameters */
struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
uint32_t inline_limit;
uint32_t forward_bulk_size;
uint32_t backward_bulk_size;
DMLC_DECLARE_PARAMETER(CachedOpConfig) {
DMLC_DECLARE_FIELD(inline_limit)
.set_default(2)
.describe("Maximum number of operators that can be inlined.");
DMLC_DECLARE_FIELD(forward_bulk_size)
.set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
.describe("Segment size of bulk execution during forward pass.");
DMLC_DECLARE_FIELD(backward_bulk_size)
.set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
.describe("Segment size of bulk execution during backward pass.");
}
};
/*! \brief runtime functions for NDArray */
class Imperative {
public:
Expand Down Expand Up @@ -94,67 +77,6 @@ class Imperative {
&& info.out_grads.size() == 1;
}
};
class CachedOp {
public:
CachedOp(
const nnvm::Symbol& sym,
const std::vector<std::pair<std::string, std::string> >& flags,
const std::vector<std::string> arg_names,
const std::unordered_map<std::string, std::vector<NDArray> >& params);
uint32_t num_inputs() {
return fwd_graph_.indexed_graph().input_nodes().size();
}
uint32_t num_outputs() {
return fwd_graph_.outputs.size();
}
uint32_t num_backward_inputs() {
return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
}
std::vector<bool>& save_inputs() {
return save_inputs_;
}
std::vector<bool>& save_outputs() {
return save_outputs_;
}
const std::unordered_set<uint32_t>& mutable_input_nodes() {
return fwd_graph_.indexed_graph().mutable_input_nodes();
}
nnvm::Graph GetForwardGraph(const bool recording,
const std::vector<NDArray*>& inputs);
nnvm::Graph GetBackwardGraph(const OpStatePtr& state,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& inputs);
std::vector<nnvm::NodeEntry> Gradient(const nnvm::NodePtr& node,
const std::vector<nnvm::NodeEntry>& ograds);
void Forward(const std::shared_ptr<CachedOp>& op_ptr,
const std::vector<NDArray*>& args,
const std::vector<NDArray*>& outputs);
void Backward(const bool retain_graph,
const OpStatePtr& state,
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);

private:
struct CachedOpState {
std::vector<NDArray> buff;
std::vector<OpStatePtr> states;
};
std::mutex mutex_;
CachedOpConfig config_;
nnvm::Graph fwd_graph_;
nnvm::Graph grad_graph_;
nnvm::Graph full_graph_;
std::unordered_map<Context, std::vector<NDArray> > params_;
bool inlining_;
std::vector<nnvm::NodeEntry> ograd_entries_;
std::vector<bool> curr_grad_req_;
std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
std::vector<uint32_t> fwd_args_idx_;
std::vector<uint32_t> fwd_params_idx_;
std::vector<uint32_t> bwd_input_eid_;
std::vector<bool> save_inputs_, save_outputs_;
};
/*! \brief whether operator recording is on. */
bool is_training() const {
return is_train_;
Expand Down Expand Up @@ -222,15 +144,6 @@ class Imperative {
uint32_t num_inputs, uint32_t num_outputs,
std::vector<bool> *p_save_inputs,
std::vector<bool> *p_save_outputs);
void RunGraph(
const bool retain_graph,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
std::vector<OpStatePtr> *p_states,
const DispatchModeVector& dispatch_modes);
/*! \brief indicate whether is training. */
#if DMLC_CXX11_THREAD_LOCAL
static thread_local bool is_train_;
Expand All @@ -247,7 +160,5 @@ class Imperative {
int backward_bulk_size_{0};
};

using CachedOpPtr = std::shared_ptr<Imperative::CachedOp>;

} // namespace mxnet
#endif // MXNET_IMPERATIVE_H_
8 changes: 8 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,14 @@ class NDArray {
return byte_offset_ > 0 || shape() != ptr_->storage_shape;
}

/* \brief Check whether the two arrays are the same array */
inline bool IsSame(const NDArray& other) {
return ptr_ == other.ptr_ &&
shape_ == other.shape_ &&
byte_offset_ == other.byte_offset_ &&
dtype_ == other.dtype_;
}

/*!
* \return the shape of current NDArray.
*/
Expand Down
33 changes: 20 additions & 13 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,25 +126,36 @@ class OpStatePtr {
template<typename T, typename... Args>
static OpStatePtr Create(Args&&... args) {
OpStatePtr ret;
ret.ptr_ = std::make_shared<OpState>();
ret.ptr_->var_ = Engine::Get()->NewVariable();
ret.ptr_->state_.construct<T>(std::forward<Args>(args)...);
auto state = new T(std::forward<Args>(args)...);
auto var = Engine::Get()->NewVariable();
ret.ptr_.reset(
new OpState(var, state),
[](OpState* p) {
Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), p->var);
delete reinterpret_cast<T*>(p->state);
delete p;
});

return ret;
}
/* \brief Get engine variable associated with this state */
engine::VarHandle get_var() const {
return ptr_->var_;
return ptr_->var;
}
/* \brief Get state of type T */
template<typename T>
T& get_state() const {
return dmlc::get<T>(ptr_->state_);
return *reinterpret_cast<T*>(ptr_->state);
}
/* \brief clear state */
void reset() {
ptr_.reset();
}
/* \brief checks whether the managed object is managed only by the current
OpStatePtr instance */
bool unique() const {
return ptr_.unique();
}
/* \brief Whether state is empty */
explicit operator bool() const {
return ptr_ ? true : false;
Expand All @@ -153,16 +164,12 @@ class OpStatePtr {
private:
/* \brief state structure */
struct OpState {
OpState() {}
engine::VarHandle var;
void* state;

OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {}
OpState(const OpState& other) = delete;
OpState& operator=(const OpState& other) = delete;

~OpState() {
Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), var_);
}

engine::VarHandle var_;
dmlc::any state_;
};
/* \brief shared pointer to state */
std::shared_ptr<OpState> ptr_;
Expand Down
16 changes: 1 addition & 15 deletions python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,28 +105,14 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
class CachedOp(object):
"""Cached operator handle."""
__slots__ = ["handle"]
def __init__(self, sym, flags=(), inputs=None, params=None):
def __init__(self, sym, flags=()):
self.handle = CachedOpHandle()
param_names = []
param_arrays = []
if inputs is None:
assert params is None, "When inputs is None params must also be None."
inputs = sym.list_inputs()
elif params is not None:
for name, arrs in params.items():
param_arrays.extend(arrs)
param_names.extend([name] * len(arrs))

check_call(_LIB.MXCreateCachedOpEx(
sym.handle,
len(flags),
c_str_array([key for key, _ in flags]),
c_str_array([str(val) for _, val in flags]),
len(inputs),
c_str_array(inputs),
len(param_names),
c_str_array(param_names),
c_handle_array(param_arrays),
ctypes.byref(self.handle)))

def __del__(self):
Expand Down
74 changes: 49 additions & 25 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,16 @@ def hybridize(self, active=True, **kwargs):
----------
active : bool, default True
Whether to turn hybrid on or off.
**kwargs : string
Additional flags for hybridized operator.
static_alloc : bool, default False
Statically allocate memory to improve speed. Memory usage may increase.
static_shape : bool, default False
Optimize for invariant input shapes between iterations. Must also
set static_alloc to True. Change of input shapes is still allowed
but slower.
forward_bulk_size : int, default 15
Segment size of bulk execution during forward pass.
backward_bulk_size : int, default 15
Segment size of bulk execution during backward pass.
"""
for cld in self._children.values():
cld.hybridize(active, **kwargs)
Expand Down Expand Up @@ -615,7 +623,7 @@ def __init__(self, prefix=None, params=None):
self._out_format = None
self._in_format = None
self._active = False
self._flags = {}
self._flags = []

def __setattr__(self, name, value):
"""Registers parameters."""
Expand All @@ -642,39 +650,43 @@ def _get_graph(self, *args):
return self._cached_graph

def _build_cache(self, *args):
inputs, out = self._get_graph(*args)
input_names = [i.name for i in inputs]

data, out = self._get_graph(*args)
data_names = {data.name : i for i, data in enumerate(data)}
params = self.collect_params()
input_names = out.list_inputs()

param_names = set(params.keys())
expected_names = set(out.list_inputs())
expected_names = set(input_names)
for name in expected_names:
assert name in param_names or name in input_names, \
assert name in param_names or name in data_names, \
"Unknown input to HybridBlock: %s"%name

used_input_names = [i for i in input_names if i in expected_names]
if len(used_input_names) != len(input_names):
unused = ', '.join(['%d-th'%i for i, name in enumerate(input_names)
used_data_names = [i for i in data_names if i in expected_names]
if len(used_data_names) != len(data_names):
unused = ', '.join(['%d-th'%i for name, i in data_names.items()
if name not in expected_names])
warnings.warn("The %s input to HybridBlock is not used by any "
"computation. Is this intended?"%unused, stacklevel=4)

used_param_names = set(i for i in param_names if i in expected_names)
used_param_names = [i for i in param_names if i in expected_names]
if len(used_param_names) != len(param_names):
unused = ', '.join(list(param_names - used_param_names))
unused = ', '.join(list(param_names - set(used_param_names)))
warnings.warn("Parameter %s is not used by any computation. "
"Is this intended?"%unused, stacklevel=4)

used_params = {k: params[k] for k in used_param_names}
try:
param_dict = {k: v.list_data() for k, v in used_params.items()}
except DeferredInitializationError:
self._deferred_infer_shape(*args)
for i in used_params.values():
i._finish_deferred_init()
param_dict = {k: v.list_data() for k, v in used_params.items()}

self._cached_op = ndarray.CachedOp(out, self._flags, input_names, param_dict)
data_indices = []
param_indices = []
self._cached_op_args = []
for i, name in enumerate(input_names):
if name in data_names:
data_indices.append(i)
self._cached_op_args.append((True, data_names[name]))
else:
param_indices.append(i)
self._cached_op_args.append((False, params[name]))
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
self._flags
self._cached_op = ndarray.CachedOp(out, flags)

def _deferred_infer_shape(self, *args):
try:
Expand All @@ -690,7 +702,19 @@ def _call_cached_op(self, *args):

args, fmt = _flatten(args, "input")
assert fmt == self._in_format, "Invalid input format"
out = self._cached_op(*args)
try:
cargs = [args[i] if is_arg else i.data()
for is_arg, i in self._cached_op_args]
except DeferredInitializationError:
self._deferred_infer_shape(*args)
cargs = []
for is_arg, i in self._cached_op_args:
if is_arg:
cargs.append(args[i])
else:
i._finish_deferred_init()
cargs.append(i.data())
out = self._cached_op(*cargs)
if isinstance(out, NDArray):
out = [out]
return _regroup(out, self._out_format)[0]
Expand All @@ -711,7 +735,7 @@ def register_child(self, block, name=None):

def hybridize(self, active=True, **kwargs):
self._active = active
self._flags = kwargs.items()
self._flags = list(kwargs.items())
self._clear_cached_op()
if active and self._forward_hooks or self._forward_pre_hooks:
warnings.warn('"{}" is being hybridized while still having forward hook/pre-hook. '
Expand Down
Loading