Skip to content

Commit

Permalink
Add inferstorage to graph executor
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed May 17, 2017
1 parent 00235b5 commit e1920f2
Show file tree
Hide file tree
Showing 8 changed files with 392 additions and 72 deletions.
3 changes: 3 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,9 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const mx_uint num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
mx_uint* shared_buffer_len,
Expand Down
1 change: 1 addition & 0 deletions include/mxnet/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class Executor {
const std::vector<Context>& aux_state_ctxes,
const std::unordered_map<std::string, TShape>& arg_shape_map,
const std::unordered_map<std::string, int>& arg_dtype_map,
const std::unordered_map<std::string, int>& arg_stype_map,
const std::vector<OpReqType>& grad_req_types,
const std::unordered_set<std::string>& param_names,
std::vector<NDArray>* in_args,
Expand Down
2 changes: 1 addition & 1 deletion include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ class NDArray {
if (skip_free == false) {
Storage::Get()->Free(h);
for (size_t i = 0; i < aux_h.size(); i++) {
Storage::Get()->Free(aux_h[i]);
if (aux_h[i].size > 0) Storage::Get()->Free(aux_h[i]);
}
}
}, shandle.ctx, var);
Expand Down
133 changes: 129 additions & 4 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from .base import NDArrayHandle, ExecutorHandle, SymbolHandle
from .base import check_call, MXNetError
from .context import Context, cpu
from .ndarray import _STORAGE_TYPE_ID_TO_STR, _STORAGE_TYPE_STR_TO_ID
from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
from .sparse_ndarray import SparseNDArray
from .executor import Executor
from . import _symbol_internal as _internal
from .attribute import AttrScope
Expand Down Expand Up @@ -715,6 +717,89 @@ def list_auxiliary_states(self):
self.handle, ctypes.byref(size), ctypes.byref(sarr)))
return [py_str(sarr[i]) for i in range(size.value)]

def infer_storage_type(self, *args, **kwargs):
"""Infer the storage type of outputs and arguments of given known types of arguments.
User can either pass in the known types in positional way or keyword argument way.
Tuple of Nones is returned if there is not enough information passed in.
An error will be raised if there is inconsistency found in the known types passed in.
Parameters
----------
*args :
Provide type of arguments in a positional way.
Unknown type can be marked as None
**kwargs :
Provide keyword arguments of known types.
Returns
-------
arg_storage_types : list of numpy.dtype or None
List of types of arguments.
The order is in the same order as list_arguments()
out_storage_types : list of numpy.dtype or None
List of types of outputs.
The order is in the same order as list_outputs()
aux_storage_types : list of numpy.dtype or None
List of types of outputs.
The order is in the same order as list_auxiliary_states()
"""
# pylint: disable=too-many-locals
if len(args) != 0 and len(kwargs) != 0:
raise ValueError('Can only specify known argument \
types either by positional or kwargs way.')
sdata = []
if len(args) != 0:
keys = None
for s in args:
if s is not None:
if s not in _STORAGE_TYPE_STR_TO_ID or not isinstance(s, basestring):
raise TypeError('Argument need to be one of '+str(_STORAGE_TYPE_STR_TO_ID))
sdata.append(_STORAGE_TYPE_STR_TO_ID[s])
else:
sdata.append(_STORAGE_TYPE_STR_TO_ID['undefined'])
else:
keys = []
for k, v in kwargs.items():
if v in _STORAGE_TYPE_STR_TO_ID:
keys.append(c_str(k))
sdata.append(_STORAGE_TYPE_STR_TO_ID[v])
arg_storage_type_size = mx_uint()
arg_storage_type_data = ctypes.POINTER(ctypes.c_int)()
out_storage_type_size = mx_uint()
out_storage_type_data = ctypes.POINTER(ctypes.c_int)()
aux_storage_type_size = mx_uint()
aux_storage_type_data = ctypes.POINTER(ctypes.c_int)()
complete = ctypes.c_int()
check_call(_LIB.MXSymbolInferStorageType(
self.handle,
mx_uint(len(sdata)),
c_array(ctypes.c_char_p, keys),
c_array(ctypes.c_int, sdata),
ctypes.byref(arg_storage_type_size),
ctypes.byref(arg_storage_type_data),
ctypes.byref(out_storage_type_size),
ctypes.byref(out_storage_type_data),
ctypes.byref(aux_storage_type_size),
ctypes.byref(aux_storage_type_data),
ctypes.byref(complete)))
if complete.value != 0:
arg_storage_types = [
_STORAGE_TYPE_ID_TO_STR[arg_storage_type_data[i]] \
for i in range(arg_storage_type_size.value)]
out_storage_types = [
_STORAGE_TYPE_ID_TO_STR[out_storage_type_data[i]] \
for i in range(out_storage_type_size.value)]
aux_storage_types = [
_STORAGE_TYPE_ID_TO_STR[aux_storage_type_data[i]] \
for i in range(aux_storage_type_size.value)]
return (arg_storage_types, out_storage_types, aux_storage_types)
else:
return (None, None, None)
# pylint: enable=too-many-locals


def infer_type(self, *args, **kwargs):
"""Infers the type of all arguments and all outputs, given the known types
for some arguments.
Expand Down Expand Up @@ -1114,8 +1199,9 @@ def _get_ndarray_inputs(arg_key, args, arg_names, allow_missing):
raise TypeError('Only accept list of NDArrays or dict of str to NDArray')
return c_array(NDArrayHandle, arg_handles), arg_arrays

def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
shared_arg_names=None, shared_exec=None, shared_buffer=None, **kwargs):
def simple_bind(self, ctx, grad_req='write', type_dict=None, storage_type_dict=None,
group2ctx=None, shared_arg_names=None, shared_exec=None,
shared_buffer=None, **kwargs):
"""Bind current symbol to get an executor, allocate all the arguments needed.
Allows specifying data types.
Expand Down Expand Up @@ -1157,6 +1243,9 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
type_dict : Dict of str->numpy.dtype
Input type dictionary, name->dtype
storage_type_dict : Dict of str->str
Input storage type dictionary, name->storage_type
group2ctx : Dict of string to mx.Context
The dict mapping the `ctx_group` attribute to the context assignment.
Expand All @@ -1171,7 +1260,8 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
shared_buffer : Dict of string to `NDArray`
The dict mapping argument names to the `NDArray` that can be reused for initializing
the current executor. This buffer will be checked for reuse if one argument name
of the current executor is not found in `shared_arg_names`.
of the current executor is not found in `shared_arg_names`. The `NDArray`s are
expected have default storage type.
kwargs : Dict of str->shape
Input shape dictionary, name->shape
Expand All @@ -1181,6 +1271,7 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
executor : mxnet.Executor
The generated executor
"""
# data types
num_provided_arg_types = 0
provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names
provided_arg_type_data = ctypes.POINTER(mx_uint)() # provided types
Expand All @@ -1196,6 +1287,22 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
provided_arg_type_names = c_array(ctypes.c_char_p, provided_arg_type_names)
provided_arg_type_data = c_array(ctypes.c_int, provided_arg_type_data)

# storage types
num_provided_arg_stypes = 0
# provided storage type argument names
provided_arg_stype_names = ctypes.POINTER(ctypes.c_char_p)()
provided_arg_stype_data = ctypes.POINTER(mx_uint)() # provided storage types
if storage_type_dict is not None:
provided_arg_stype_names = []
provided_arg_stype_data = []
for k, v in storage_type_dict.items():
if v in _STORAGE_TYPE_STR_TO_ID:
provided_arg_stype_names.append(c_str(k))
provided_arg_stype_data.append(ctypes.c_int(_STORAGE_TYPE_STR_TO_ID[v]))
num_provided_arg_stypes = mx_uint(len(provided_arg_stype_names))
provided_arg_stype_names = c_array(ctypes.c_char_p, provided_arg_stype_names)
provided_arg_stype_data = c_array(ctypes.c_int, provided_arg_stype_data)

provided_arg_shape_data = [] # shape data
# argument shape index in sdata,
# e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first arg
Expand Down Expand Up @@ -1269,6 +1376,8 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
shared_buffer_names = []
shared_buffer_handles = []
for k, v in shared_buffer.items():
assert(v.storage_type == 'default'), \
"shared_buffer is expected to only contain NDArrays with default storage"
shared_buffer_names.append(c_str(k))
shared_buffer_handles.append(v.handle)
shared_buffer_names = c_array(ctypes.c_char_p, shared_buffer_names)
Expand Down Expand Up @@ -1305,6 +1414,9 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
num_provided_arg_stypes,
provided_arg_stype_names,
provided_arg_stype_data,
mx_uint(len(shared_arg_name_list)),
c_array(ctypes.c_char_p, shared_arg_name_list),
ctypes.byref(shared_buffer_len),
Expand Down Expand Up @@ -1335,6 +1447,16 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
aux_arrays = [NDArray(NDArrayHandle(aux_state_handles[i]))
for i in range(num_aux_states.value)]

# redefine NDArray class based on storage types
def check_storage_type(ndarrays):
for idx, array in enumerate(ndarrays):
if array is not None and array.storage_type != 'default':
ndarrays[idx].__class__ = SparseNDArray
return ndarrays
arg_arrays = check_storage_type(arg_arrays)
grad_arrays = check_storage_type(grad_arrays)
aux_arrays = check_storage_type(aux_arrays)

executor = Executor(exe_handle, self, ctx, grad_req, group2ctx)
executor.arg_arrays = arg_arrays
executor.grad_arrays = grad_arrays
Expand Down Expand Up @@ -1583,7 +1705,8 @@ def reshape(self, shape):
"""
return reshape(self, shape=shape)

def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, init=None, **kwargs):
def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None,
init=None, storage_type=None, **kwargs):
"""Creates a symbolic variable with specified name.
Example usage:
Expand Down Expand Up @@ -1637,6 +1760,8 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, ini
if not isinstance(init, string_types):
init = init.dumps()
attr['__init__'] = init
if storage_type is not None:
attr['__storage_type__'] = str(_STORAGE_TYPE_STR_TO_ID[storage_type])
for k, v in kwargs.items():
if k.startswith('__') and k.endswith('__'):
attr[k] = str(v)
Expand Down
30 changes: 27 additions & 3 deletions src/c_api/c_api_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ int MXExecutorBindEX(SymbolHandle symbol_handle,
* \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes
* \param provided_arg_dtype_names argument name list of provided dtypes
* \param provided_arg_dtypes data of provided dtypes
* \param num_provided_arg_stypes number of user provided in_arg and axu_state storage types
* \param provided_arg_stype_names argument name list of provided storage types
* \param provided_arg_stypes data of provided storage types
* \param num_shared_arg_names number of parameter names passed from _bind_ith_exec
* \param shared_arg_name_list parameter name list passed from _bind_ith_exec
* \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec
Expand Down Expand Up @@ -203,6 +206,9 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const mx_uint num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
mx_uint* shared_buffer_len,
Expand Down Expand Up @@ -251,6 +257,23 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
}
}

// setup arg_stype_map
std::unordered_map<std::string, int> arg_stype_map;
if (nullptr == provided_arg_stypes) { // use attr_dict
for (const auto& arg_name : in_arg_names) {
const auto it = attr_dict.find(arg_name);
if (it == attr_dict.end() || !it->second.count("__storage_type__")) {
arg_stype_map[arg_name] = kDefaultStorage;
}
}
} else { // use user input type_dict
// create stype map for in_args and aux_states
arg_stype_map.reserve(num_provided_arg_stypes);
for (mx_uint i = 0; i < num_provided_arg_stypes; ++i) {
arg_stype_map[provided_arg_stype_names[i]] = provided_arg_stypes[i];
}
}

// create default ctx
Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);
// create ctx map
Expand Down Expand Up @@ -391,9 +414,10 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
std::vector<NDArray> aux_state_vec;

*out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec,
aux_state_ctx_vec, arg_shape_map, arg_dtype_map, grad_req_type_vec,
shared_arg_name_set, &in_arg_vec, &arg_grad_vec, &aux_state_vec,
use_shared_buffer? &shared_buffer_map : nullptr,
aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map,
grad_req_type_vec, shared_arg_name_set, &in_arg_vec,
&arg_grad_vec, &aux_state_vec,
use_shared_buffer ? &shared_buffer_map : nullptr,
reinterpret_cast<Executor*>(shared_exec_handle));

// copy ndarray ptrs to ret->handles so that front end
Expand Down
4 changes: 2 additions & 2 deletions src/executor/attach_op_execs_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ Graph AttachOpExecs(Graph g) {
const auto& vctx = g.GetAttr<ContextVector>("context");
const auto& saved_opr = g.GetAttr<
std::unordered_map<const nnvm::Node*, std::shared_ptr<Operator>>>("saved_opr");
const auto& dispatch_stypes = g.GetAttr<StorageTypeVector>("dispatch_storage_types");
const auto& dispatch_stypes = g.GetAttr<StorageTypeVector>("dispatch_stypes");

// get the graph
const auto& idx = g.indexed_graph();
Expand All @@ -254,7 +254,7 @@ Graph AttachOpExecs(Graph g) {
FComputeEx fcompute_ex =
common::GetFComputeEx(inode.source->op(), vctx[i], dispatch_stypes[i]);
#if EXEC_ATTACH_OP_DEBUG
LOG(INFO) << "dispatch type = " << dispatch_stypes[i];
LOG(INFO) << "dispatch storage type = " << dispatch_stypes[i];
#endif
if (fcreate_layer_op.count(inode.source->op())) {
std::vector<TShape> ishape;
Expand Down
Loading

0 comments on commit e1920f2

Please sign in to comment.