Skip to content

Commit

Permalink
add csr python interface. fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed Apr 23, 2017
1 parent dbb304b commit 69aecce
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 21 deletions.
1 change: 1 addition & 0 deletions python/mxnet/contrib/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import functools
from ..base import _LIB, check_call, string_types
from ..base import mx_uint, NDArrayHandle, c_array
# pylint: disable= unused-import
from ..sparse_ndarray import SparseNDArray
from ..ndarray import NDArray, zeros_like
from ..symbol import _GRAD_REQ_MAP
Expand Down
12 changes: 6 additions & 6 deletions python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ def _get_outputs(self):
num_output = out_size.value
outputs = []
for i in xrange(num_output):
storage_type = ctypes.c_int(0)
check_call(_LIB.MXNDArrayGetStorageType(ctypes.cast(handles[i], NDArrayHandle),
ctypes.byref(storage_type)))
output = NDArray(NDArrayHandle(handles[i])) if storage_type.value == 1 \
else SparseNDArray(NDArrayHandle(handles[i]))
outputs.append(output)
storage_type = ctypes.c_int(0)
check_call(_LIB.MXNDArrayGetStorageType(ctypes.cast(handles[i], NDArrayHandle),
ctypes.byref(storage_type)))
output = NDArray(NDArrayHandle(handles[i])) if storage_type.value == 1 \
else SparseNDArray(NDArrayHandle(handles[i]))
outputs.append(output)
return outputs

def forward(self, is_train=False, **kwargs):
Expand Down
37 changes: 30 additions & 7 deletions python/mxnet/sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@
# When possible, use cython to speedup part of computation.
try:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
from ._ctypes.ndarray import NDArrayBase, _init_ndarray_module
from ._ctypes.ndarray import _init_ndarray_module
elif _sys.version_info >= (3, 0):
from ._cy3.ndarray import NDArrayBase, _init_ndarray_module
from ._cy3.ndarray import _init_ndarray_module
else:
from ._cy2.ndarray import NDArrayBase, _init_ndarray_module
from ._cy2.ndarray import _init_ndarray_module
except ImportError:
if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0:
raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1")
from ._ctypes.ndarray import NDArrayBase, _init_ndarray_module
from ._ctypes.ndarray import _init_ndarray_module

# pylint: enable= no-member
_STORAGE_TYPE_ID_TO_STR = {
Expand All @@ -62,7 +62,7 @@
}

def _new_alloc_handle(storage_type, shape, ctx, delay_alloc=True,
dtype=mx_real_t, aux_types=[]):
dtype=mx_real_t, aux_types=None):
"""Return a new handle with specified shape and context.
Empty handle is only used to hold results
Expand Down Expand Up @@ -202,12 +202,35 @@ def as_in_context(self, context):
def to_dense(self):
return to_dense(self)

#TODO(haibin) also add aux_types. Not tested yet.
#We need a to_dense method to test it
def csr(values, indptr, idx, shape, ctx=Context.default_ctx, dtype=mx_real_t):
''' constructor '''
hdl = NDArrayHandle()
#TODO currently only supports NDArray input
assert(isinstance(values, NDArray))
assert(isinstance(index, NDArray))
indices = c_array(NDArrayHandle, [idx.handle, indptr.handle])
num_aux = mx_uint(2)
check_call(_LIB.MXNDArrayCreateSparse(
values.handle, num_aux, indices,
c_array(mx_uint, shape),
mx_uint(len(shape)),
ctypes.c_int(_STORAGE_TYPE_STR_TO_ID['csr']),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(False)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
ctypes.byref(hdl)))
return SparseNDArray(hdl)

# pylint: enable= no-member
#TODO(haibin) also specify aux_types
def row_sparse(values, index, shape, ctx=Context.default_ctx, dtype=mx_real_t):
''' constructor '''
hdl = NDArrayHandle()
assert(isinstance(values, NDArrayBase))
assert(isinstance(index, NDArrayBase))
assert(isinstance(values, NDArray))
assert(isinstance(index, NDArray))
indices = c_array(NDArrayHandle, [index.handle])
num_aux = mx_uint(1)
check_call(_LIB.MXNDArrayCreateSparse(
Expand Down
3 changes: 2 additions & 1 deletion src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ int MXNDArrayCreateSparse(NDArrayHandle data,
NDArray* nd_aux_ptr = reinterpret_cast<NDArray*>(aux_vec[i]);
aux_ndarrays.push_back(*nd_aux_ptr);
}
*out = new NDArray(*data_ptr, aux_ndarrays, ctx, kRowSparseStorage, TShape(shape, shape + ndim));
NDArrayStorageType stype = (NDArrayStorageType) storage_type;
*out = new NDArray(*data_ptr, aux_ndarrays, ctx, stype, TShape(shape, shape + ndim));
API_END();
}

Expand Down
1 change: 0 additions & 1 deletion tests/python/unittest/test_infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def test_incomplete_infer_concat():

def test_fc_infer_type():
mx_real_t = mx.base.mx_real_t
# Build MLP
data = mx.symbol.Variable('data')
out = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=1000)

Expand Down
3 changes: 0 additions & 3 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2906,9 +2906,6 @@ def test_where_numeric_gradient(shape, same_shape):
test_where_numeric_gradient((5, 9), False)
test_where_numeric_gradient((5, 7, 9), True)
test_where_numeric_gradient((5, 7, 9), False)
<<<<<<< HEAD
<<<<<<< HEAD


def test_new_softmax():
for ndim in range(1, 5):
Expand Down
17 changes: 15 additions & 2 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_ndarray_elementwise_fallback():
sparse_plus_sparse = mx.nd.add_n(sparse_nd1, sparse_nd1)
assert_almost_equal(sparse_plus_sparse.asnumpy(), sparse_np1 + sparse_np1)

def test_ndarray_conversion():
def check_conversion_row_sparse():
val = np.array([5, 10])
idx = np.array([1])
sparse_val = np.array([[0, 0], [5, 10], [0, 0], [0, 0], [0, 0]])
Expand All @@ -83,6 +83,20 @@ def test_ndarray_conversion():
f = mx.sparse_nd.to_dense(d)
assert_almost_equal(f.asnumpy(), sparse_val)

def check_conversion_csr():
val = mx.nd.array([1, 2, 3, 4, 5, 6])
indices = mx.nd.array([0, 2, 2, 0, 1, 2], dtype=np.int32)
indptr = mx.nd.array([0, 2, 3, 6], dtype=np.int32)
shape = (3, 3)
#sparse_val = np.array([[0, 0], [5, 10], [0, 0], [0, 0], [0, 0]])
d = mx.sparse_nd.csr(val, indices, indptr, (5,2))
#f = mx.sparse_nd.to_dense(d)
#assert_almost_equal(f.asnumpy(), sparse_val)

def test_ndarray_conversion():
check_conversion_row_sparse()
#TODO check_conversion_csr()

def test_ndarray_zeros():
zero = mx.nd.zeros((2,2))
sparse_zero = mx.sparse_nd.zeros((2,2), 'row_sparse')
Expand All @@ -99,4 +113,3 @@ def test_ndarray_copyto():
test_ndarray_zeros()
test_ndarray_copyto()
test_ndarray_elementwise_fallback()

1 change: 0 additions & 1 deletion tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,3 @@ def test_elemwise_add_sparse_sparse():
test_elemwise_add_dense()
test_elemwise_add_dense_sparse()
test_elemwise_add_sparse_sparse()
print("done")

0 comments on commit 69aecce

Please sign in to comment.