Skip to content

Commit

Permalink
Fm example (apache#45)
Browse files Browse the repository at this point in the history
* update csr slice logic to avoid confusion. add more exmaples.

* add hint to module.update

* more testcases(fallback) for sparse_nd

* add to_csr() and to_rsp() method. More unit test (fallback now)

* add fm test. fix lint

* register sparse sgd under Optim.SGD

* update dmlc-core submoduel

* change indptr to _indptr temporarily. add const ref to fname
  • Loading branch information
eric-haibin-lin authored May 24, 2017
1 parent b5bcdd6 commit 77132fc
Show file tree
Hide file tree
Showing 28 changed files with 557 additions and 319 deletions.
13 changes: 13 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,19 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle,
mx_uint slice_begin,
mx_uint slice_end,
NDArrayHandle *out);

/*!
* \brief Slice the NDArray with non-default storage along axis 0.
* \param handle the handle to the NDArray
* \param slice_begin The beginning index of slice
* \param slice_end The ending index of slice
* \param out The NDArrayHandle of sliced NDArray
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArraySliceEx(NDArrayHandle handle,
mx_uint slice_begin,
mx_uint slice_end,
NDArrayHandle out);
/*!
* \brief Index the NDArray along axis 0.
* \param handle the handle to the NDArray
Expand Down
28 changes: 15 additions & 13 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ class NDArray {
/*!
* \return the shape of aux data at ith index. If it doesn't exist, return an empty one.
*/
// TODO(haibin) CamelCase
inline const TShape aux_shape(size_t i) const {
CHECK(storage_type() != kDefaultStorage);
return ptr_->aux_shapes[i];
Expand Down Expand Up @@ -239,9 +238,7 @@ class NDArray {
auto dptr = static_cast<DType*>(ptr_->shandle.dptr);
if (stype == kDefaultStorage) {
dptr += offset_;
} else if (stype == kCSRStorage) {
shape = storage_shape();
} else if (stype == kRowSparseStorage) {
} else if (stype == kCSRStorage || stype == kRowSparseStorage) {
shape = storage_shape();
} else {
LOG(FATAL) << "unknown storage type " << stype;
Expand All @@ -263,13 +260,8 @@ class NDArray {
auto type = aux_type(i);
MSHADOW_TYPE_SWITCH(type, DType, {
auto dptr = static_cast<DType*>(ptr_->aux_handles[i].dptr);
if (stype == kRowSparseStorage) {
if (stype == kRowSparseStorage || stype == kCSRStorage) {
CHECK_EQ(offset_, 0);
} else if (stype == kCSRStorage) {
if (i == csr::kIndPtr) {
dptr += offset_;
shape[0] = shape_[0] + 1;
}
} else {
LOG(FATAL) << "Unexpected storage type";
}
Expand Down Expand Up @@ -472,6 +464,14 @@ class NDArray {
* \return sliced NDArray
*/
NDArray Slice(index_t begin, index_t end) const;

/*!
* \brief Slice a NDArray with non-default storage
* \param begin begin index in first dim (inclusive)
* \param end end index in first dim (exclusive)
* \return sliced NDArray
*/
void SliceEx(index_t begin, index_t end, NDArray *dst) const;
/*!
* \brief Index a NDArray
* \param idx the index
Expand All @@ -480,14 +480,14 @@ class NDArray {
NDArray At(index_t idx) const;
// Wrap the tblob of aux data into an NDArray which shares the same variable with the
// current one.
inline const NDArray AuxNDArray(size_t i) const {
inline const NDArray aux_ndarray(size_t i) const {
CHECK_NE(storage_type(), kDefaultStorage);
CHECK(i < ptr_->aux_shapes.size());
return NDArray(aux_data(i), ctx().dev_id, var());
}
// Wrap the tblob of data into an NDArray which shares the same variable with the
// current one.
inline const NDArray DataNDArray() const {
inline const NDArray data_ndarray() const {
CHECK_NE(storage_type(), kDefaultStorage);
return NDArray(data(), ctx().dev_id, var());
}
Expand Down Expand Up @@ -606,6 +606,9 @@ class NDArray {
// \brief skip the deletion of var handle. Usually set when shared_var is present.
bool skip_delete_var = false;

/*! \brief default cosntructor */
Chunk() : static_data(true), delay_alloc(false) {}

/*! \brief construct a new chunk */
Chunk(TShape shape, Context ctx_, bool delay_alloc_, int dtype)
: static_data(false), delay_alloc(true), ctx(ctx_) {
Expand Down Expand Up @@ -779,7 +782,6 @@ inline void CopyFromToCsrImpl(const NDArray from, NDArray *to, RunContext ctx) {
// if source storage is not initialized, fill destination with zeros
auto s = ctx.get_stream<to_xpu>();
if (!from.storage_initialized()) {
LOG(FATAL) << "To be implemented";
// TODO(haibin) implement FillZerosCsrImpl
// op::FillZerosCsrImpl<to_xpu>(s, to);
return;
Expand Down
24 changes: 10 additions & 14 deletions python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..base import check_call
from ..ndarray_doc import _build_doc

_ndarray_cls_map = {}
_ndarray_cls = None

class NDArrayBase(object):
"""Base data structure for ndarray"""
Expand All @@ -27,6 +27,8 @@ def __init__(self, handle, writable=True):
----------
handle : NDArrayHandle
NDArray handle of C API
writable: bool
Whether the NDArrayBase could be modified
"""
if handle is not None:
assert isinstance(handle, NDArrayHandle)
Expand Down Expand Up @@ -177,14 +179,8 @@ def %s(%s):
c_array(ctypes.c_char_p, [c_str(val) for val in vals])))
if original_output is not None:
return original_output
ret_list = []
for i in range(num_output.value):
storage_type = ctypes.c_int(0)
check_call(_LIB.MXNDArrayGetStorageType(ctypes.cast(output_vars[i], NDArrayHandle),
ctypes.byref(storage_type)))
ret_list.append(_ndarray_cls_map[storage_type.value](ctypes.cast(output_vars[i], \
NDArrayHandle)))
return ret_list if num_output.value > 1 else ret_list[0]
ret = [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle)) for i in range(num_output.value)]
return ret if num_output.value > 1 else ret[0]
"""%handle.value)

local = {}
Expand All @@ -196,16 +192,16 @@ def %s(%s):
return ndarray_function


def _set_storage_nd_map(storage_nd_map):
def _set_ndarray_cls(ndarray_cls):
"""Set the symbolic class to be cls"""
global _ndarray_cls_map
_ndarray_cls_map = storage_nd_map
global _ndarray_cls
_ndarray_cls = ndarray_cls


# pylint: enable=too-many-locals, invalid-name
def _init_ndarray_module(storage_nd_map, root_namespace):
def _init_ndarray_module(ndarray_cls, root_namespace):
"""List and add all the ndarray functions to current module."""
_set_storage_nd_map(storage_nd_map)
_set_ndarray_cls(ndarray_cls)
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()

Expand Down
7 changes: 4 additions & 3 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def updater_handle(key, lhs_handle, rhs_handle, _):

class KVStore(object):
"""A key-value store for synchronization of values, over multiple devices."""
def __init__(self, handle):
def __init__(self, handle, name2idx=None):
"""Initializes a new KVStore.
Parameters
Expand All @@ -58,6 +58,7 @@ def __init__(self, handle):
"""
assert isinstance(handle, KVStoreHandle)
self.handle = handle
self.name2idx = name2idx if name2idx is not None else {}
self._updater = None
self._updater_func = None

Expand Down Expand Up @@ -395,7 +396,7 @@ def _send_command_to_servers(self, head, body):
check_call(_LIB.MXKVStoreSendCommmandToServers(
self.handle, mx_uint(head), c_str(body)))

def create(name='local'):
def create(name='local', name2idx=None):
"""Creates a new KVStore.
For single machine training, there are two commonly used types:
Expand Down Expand Up @@ -435,4 +436,4 @@ def create(name='local'):
handle = KVStoreHandle()
check_call(_LIB.MXKVStoreCreate(c_str(name),
ctypes.byref(handle)))
return KVStore(handle)
return KVStore(handle, name2idx=name2idx)
39 changes: 32 additions & 7 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
'eval_metric',
'locals'])

def _create_kvstore(kvstore, num_device, arg_params):
def _create_kvstore(kvstore, num_device, arg_params, name2idx=None):
"""Create kvstore
This function select and create a proper kvstore if given the kvstore type.
Expand All @@ -61,7 +61,7 @@ def _create_kvstore(kvstore, num_device, arg_params):
# no need to use kv for single device and single machine
kv = None
else:
kv = kvs.create(kvstore)
kv = kvs.create(kvstore, name2idx=name2idx)
if kvstore is 'local':
# automatically select a proper local
max_size = max(np.prod(param.shape) for param in
Expand All @@ -85,25 +85,50 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names,
if update_on_kvstore:
kvstore.pull(idx, param_on_devs, priority=-idx)

def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore):
"""Perform update of param_arrays from grad_arrays on kvstore."""
for index, pair in enumerate(zip(param_arrays, grad_arrays)):
def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore,
stype_dict=None, param_names=None):
"""Perform update of param_arrays from grad_arrays on kvstore.
If `param_names` is None or kvstore doesn't have a `name2idx` dictionary,
the index of a param is determined by the order it appears in `param_arrays`. """
stype_dict = {} if stype_dict is None else stype_dict
for i, pair in enumerate(zip(param_arrays, grad_arrays)):
arg_list, grad_list = pair
if grad_list[0] is None:
continue
index = i
if param_names is not None:
name = param_names[i]
index = index if name not in kvstore.name2idx else kvstore.name2idx[name]
# cast storage type if stype doesn't match
if name in stype_dict:
for i, grad in enumerate(grad_list):
stype = stype_dict[name]
if grad_list[i].storage_type != stype:
grad_list[i] = nd.cast_storage(grad, stype)
# push gradient, priority is negative index
kvstore.push(index, grad_list, priority=-index)
# pull back the weights
kvstore.pull(index, arg_list, priority=-index)

def _update_params(param_arrays, grad_arrays, updater, num_device,
kvstore=None):
kvstore=None, stype_dict=None, param_names=None):
"""Perform update of param_arrays from grad_arrays not on kvstore."""
for index, pair in enumerate(zip(param_arrays, grad_arrays)):
stype_dict = {} if stype_dict is None else stype_dict
for i, pair in enumerate(zip(param_arrays, grad_arrays)):
arg_list, grad_list = pair
if grad_list[0] is None:
continue
# cast storage type if stype doesn't match
if param_names is not None and param_names[i] in stype_dict:
for i, grad in enumerate(grad_list):
stype = stype_dict[param_names[i]]
if grad_list[i].storage_type != stype:
grad_list[i] = nd.cast_storage(grad, stype)
index = i
if kvstore:
if param_names is not None:
name = param_names
index = index if name not in kvstore.name2idx else kvstore.name2idx[name]
# push gradient, priority is negative index
kvstore.push(index, grad_list, priority=-index)
# pull back the sum gradients, to the same locations.
Expand Down
12 changes: 10 additions & 2 deletions python/mxnet/module/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,9 +838,17 @@ def get_input_grads(self, merge_multi_context=True):
"""
raise NotImplementedError()

def update(self):
def update(self, storage_type_dict=None):
"""Update parameters according to the installed optimizer and the gradients computed
in the previous forward-backward batch.
in the previous forward-backward batch. The storage type of parameters is casted according
to `storage_type_dict`, if provided.
Parameters
----------
storage_type_dict: dict of str to str
Defaults to ``None``. Desired storage types of parameters for parameter update. If the
parameter gradient is not of desired storage type, its storage type will be casted
before the update.
Examples
--------
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/module/bucketing_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,13 +394,13 @@ def backward(self, out_grads=None):
assert self.binded and self.params_initialized
self._curr_module.backward(out_grads=out_grads)

def update(self):
def update(self, storage_type_dict=None):
"""Update parameters according to installed optimizer and the gradient computed
in the previous forward-backward cycle.
"""
assert self.binded and self.params_initialized and self.optimizer_initialized
self._params_dirty = True
self._curr_module.update()
self._curr_module.update(storage_type_dict=storage_type_dict)

def get_outputs(self, merge_multi_context=True):
"""Get outputs from a previous forward computation.
Expand Down
12 changes: 9 additions & 3 deletions python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,12 @@ def init_optimizer(self, kvstore='local', optimizer='sgd',

if self._params_dirty:
self._sync_params_from_devices()
name2idx = {}
for idx, name in enumerate(self._exec_group.param_names):
name2idx[name] = idx

(kvstore, update_on_kvstore) = \
_create_kvstore(kvstore, len(self._context), self._arg_params)
_create_kvstore(kvstore, len(self._context), self._arg_params, name2idx=name2idx)

batch_size = self._exec_group.batch_size
if kvstore and 'dist' in kvstore.type and '_sync' in kvstore.type:
Expand Down Expand Up @@ -550,7 +554,7 @@ def backward(self, out_grads=None):
assert self.binded and self.params_initialized
self._exec_group.backward(out_grads=out_grads)

def update(self):
def update(self, storage_type_dict=None):
"""Update parameters according to the installed optimizer and the gradients computed
in the previous forward-backward batch.
"""
Expand All @@ -560,7 +564,9 @@ def update(self):
if self._update_on_kvstore:
_update_params_on_kvstore(self._exec_group.param_arrays,
self._exec_group.grad_arrays,
self._kvstore)
self._kvstore,
stype_dict=storage_type_dict,
param_names=self._param_names)
else:
_update_params(self._exec_group.param_arrays,
self._exec_group.grad_arrays,
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/module/python_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=Non
"""
pass

def update(self):
def update(self, storage_type_dict=None):
"""Update parameters according to the installed optimizer and the gradients computed
in the previous forward-backward batch. Currently we do nothing here. Subclass should
override this method if contains parameters.
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/module/sequential_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,14 +342,14 @@ def backward(self, out_grads=None):

out_grads = module.get_input_grads()

def update(self):
def update(self, storage_type_dict=None):
"""Update parameters according to installed optimizer and the gradient computed
in the previous forward-backward cycle.
"""
assert self.binded and self.params_initialized and self.optimizer_initialized

for module in self._modules:
module.update()
module.update(storage_type_dict=storage_type_dict)

def get_outputs(self, merge_multi_context=True):
"""Get outputs from a previous forward computation.
Expand Down
Loading

0 comments on commit 77132fc

Please sign in to comment.