Skip to content

Commit

Permalink
bug fix for initializing module with row_sparse weight (#81)
Browse files Browse the repository at this point in the history
* bug fix for initializing module with row_sparse weight

* update log message
  • Loading branch information
eric-haibin-lin authored Jun 11, 2017
1 parent 16a6d7f commit 87bb1f7
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 11 deletions.
3 changes: 2 additions & 1 deletion python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .. import context as ctx
from .. import ndarray as nd
from .. import sparse_ndarray as sparse_nd
from .. import optimizer as opt

from .executor_group import DataParallelExecutorGroup
Expand Down Expand Up @@ -398,7 +399,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
else:
assert self._arg_params is None and self._aux_params is None
param_arrays = [
nd.zeros(x[0].shape, dtype=x[0].dtype)
sparse_nd.zeros(x[0].storage_type, x[0].shape, dtype=x[0].dtype)
for x in self._exec_group.param_arrays
]
self._arg_params = {name:arr for name, arr in zip(self._param_names, param_arrays)}
Expand Down
17 changes: 16 additions & 1 deletion python/mxnet/sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,22 @@ def asnumpy(self):
return self.to_dense().asnumpy()

def astype(self, dtype):
raise Exception('Not implemented for SparseND yet!')
"""Returns a copy of the array after casting to a specified type.
Parameters
----------
dtype : numpy.dtype or str
The type of the returned array.
Examples
--------
>>> x = mx.sparse_nd.zeros('row_sparse', (2,3), dtype='float32')
>>> y = x.astype('int32')
>>> y.dtype
<type 'numpy.int32'>
"""
res = zeros(self.storage_type, self.shape, ctx=self.context, dtype=dtype)
self.copyto(res)
return res


def copyto(self, other):
"""Copies the value of this array to another array.
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,

// attr_dict for setting up type_dict and arg/aux ctx
std::unordered_map<std::string, std::unordered_map<std::string, std::string>> attr_dict;
if (nullptr == provided_arg_dtypes || nullptr != g2c_keys) {
if (nullptr == provided_arg_dtypes || nullptr != g2c_keys || nullptr == provided_arg_stypes) {
std::vector<std::tuple<std::string, std::string, std::string>> attrs =
sym->ListAttrsRecursive();
attr_dict.reserve(attrs.size());
Expand Down
12 changes: 8 additions & 4 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,10 @@ inline void SGDUpdateRspRspImpl(const SGDParam& param,
TBlob out_blob = out->data();
SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data(), grad, out_req, &out_blob);
} else {
LOG(FATAL) << "SGDUpdate for RowSparse weights is only implemented when "
<< "weights.values.shape == weights.shape";
LOG(FATAL) << "SGDUpdate for RowSparse weights is only implemented for "
<< "RowSparse weights with all rows containing non-zeros. "
<< "Expects weights.values.shape[0] (" << weight.storage_shape()[0]
<< ") == weights.shape[0] (" << weight.shape()[0] << ").";
}
}

Expand Down Expand Up @@ -360,8 +362,10 @@ inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param,
SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
mom.data(), out_req, &out_blob);
} else {
LOG(FATAL) << "SGDUpdate for RowSparse weights is only implemented when "
<< "weights.values.shape == weights.shape";
LOG(FATAL) << "SGDUpdate for RowSparse weights is only implemented for "
<< "RowSparse weights with all rows containing non-zeros. "
<< "Expects weights.values.shape[0] (" << weight.storage_shape()[0]
<< ") == weights.shape[0] (" << weight.shape()[0] << ").";
}
}

Expand Down
6 changes: 4 additions & 2 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,10 @@ void SparseEmbeddingForwardRspImpl(const nnvm::NodeAttrs& attrs,
bool transpose_a = false;
DotCsrRspDnsImpl<xpu>(ctx, data, weight, req, transpose_a, &out_blob);
} else {
LOG(FATAL) << "SparseEmbedding for RowSparse weights is only implemented when "
<< "weights.values.shape == weights.shape";
LOG(FATAL) << "SparseEmbedding for RowSparse weights is only implemented for "
<< "RowSparse weights with all rows containing non-zeros. "
<< "Expects weights.values.shape[0] (" << weight.storage_shape()[0]
<< ") == weights.shape[0] (" << weight.shape()[0] << ").";
}
}

Expand Down
10 changes: 8 additions & 2 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,10 @@ void DotCsrRspDnsImpl(const OpContext& ctx,
// reuse csr dns implementation when storage_shape == shape for rhs
DotCsrDnsDnsImpl<xpu>(ctx, lhs, rhs.data(), req, trans_lhs, ret);
} else {
LOG(FATAL) << "Dot for RowSparse rhs is only implemented for rhs.values.shape == rhs.shape";
LOG(FATAL) << "Dot for RowSparse rhs is only implemented for "
<< "RowSparse rhs with all rows containing non-zeros. "
<< "Expects rhs.values.shape[0] (" << rhs.storage_shape()[0]
<< ") == rhs.shape[0] (" << rhs.shape()[0] << ").";
}
}

Expand Down Expand Up @@ -739,7 +742,10 @@ void DotBackwardCsrRspDns(const nnvm::NodeAttrs& attrs,
TBlob ret = outputs[1].data();
DotCsrDnsDnsImpl<xpu>(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret);
} else {
LOG(FATAL) << "Dot for RowSparse rhs is only implemented for rhs.values.shape == rhs.shape";
LOG(FATAL) << "Dot for RowSparse rhs is only implemented for "
<< "RowSparse rhs with all rows containing non-zeros. "
<< "Expects rhs.values.shape[0] (" << rhs.storage_shape()[0]
<< ") == rhs.shape[0] (" << rhs.shape()[0] << ").";
}
}

Expand Down
26 changes: 26 additions & 0 deletions tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,31 @@ def fm_model(k, feature_dim):
# print('Epoch %d, Training %s' % (epoch, metric.get()))
assert(metric.get()[1] < 0.2)

def test_module_initializer():
def regression_model(m):
x = mx.symbol.var("data", storage_type='csr')
v = mx.symbol.var("v", shape=(m, 1), init=mx.init.Uniform(scale=.1),
storage_type='row_sparse')
model = mx.symbol.dot(lhs=x, rhs=v)
y = mx.symbol.Variable("label")
model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out")
return model

n, m = 128, 100
model = regression_model(m)

data = mx.sparse_nd.zeros('csr', (n, m))
label = mx.nd.zeros((n, 1))
iterator = mx.io.NDArrayIter(data=data, label={'label':label}, batch_size=n)

# create module
mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['label'])
mod.bind(data_shapes=iterator.provide_data, label_shapes=iterator.provide_label)
mod.init_params()
v = mod._arg_params['v']
assert(v.storage_type == 'row_sparse')
assert(np.sum(v.asnumpy()) != 0)

if __name__ == '__main__':
test_module_dtype()
test_module_input_grads()
Expand All @@ -453,3 +478,4 @@ def fm_model(k, feature_dim):
test_monitor()
test_executor_group()
test_module_fm()
test_module_initializer()
7 changes: 7 additions & 0 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,13 @@ def test_sparse_nd_output_fallback():
mx.nd.random_normal(shape=shape, out=out)
assert(np.sum(out.asnumpy()) != 0)

def test_sparse_nd_astype():
stypes = ['row_sparse', 'csr']
for stype in stypes:
x = mx.sparse_nd.zeros(stype, rand_shape_2d(), dtype='float32')
y = x.astype('int32')
assert(y.dtype == np.int32), y.dtype

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 87bb1f7

Please sign in to comment.