From 1252c48b1908de80e5054fbdebc5cc716ba46c62 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Wed, 24 May 2017 18:17:43 +0000 Subject: [PATCH] change indptr to _indptr temporarily. add const ref to fname --- python/mxnet/sparse_ndarray.py | 8 +++++--- src/operator/operator_common.h | 2 +- tests/python/unittest/test_sparse_ndarray.py | 14 ++++++-------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/python/mxnet/sparse_ndarray.py b/python/mxnet/sparse_ndarray.py index 82ff42d4ad66..41e8e2b7ed83 100644 --- a/python/mxnet/sparse_ndarray.py +++ b/python/mxnet/sparse_ndarray.py @@ -288,7 +288,7 @@ def _aux_type(self, i): return _DTYPE_MX_TO_NP[aux_type.value] @property - def values(self): + def _values(self): """The values array of the SparseNDArray. This is a read-only view of the values array. They reveal internal implementation details and should be used with care. @@ -300,7 +300,7 @@ def values(self): return self._data() @property - def indices(self): + def _indices(self): """The indices array of the SparseNDArray. This is a read-only view of the indices array. They reveal internal implementation details and should be used with care. @@ -317,7 +317,7 @@ def indices(self): raise Exception("unknown storage type " + stype) @property - def indptr(self): + def _indptr(self): """The indptr array of the SparseNDArray with `csr` storage type. This is a read-only view of the indptr array. They reveal internal implementation details and should be used with care. @@ -602,6 +602,8 @@ def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None): def _ndarray_cls(handle): stype = _storage_type(handle) + # TODO(haibin) in the long run, we want to have CSRNDArray and RowSparseNDArray which + # inherit from SparseNDArray return NDArray(handle) if stype == 'default_storage' else SparseNDArray(handle) _init_ndarray_module(_ndarray_cls, "mxnet") diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index a996afad5ef1..6e0bc2ad5ba6 100755 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -325,7 +325,7 @@ void FCompExFallback(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs, FCompute fcompute, - const std::string fname) { + const std::string& fname) { std::vector in_blobs, out_blobs; std::vector tmps; common::GetInputBlobs(inputs, &in_blobs, &tmps, ctx, true); diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index af506cb94c13..224a5e008b3b 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -89,9 +89,9 @@ def check_sparse_nd_prop_rsp(): shape = rand_shape_2d() nd, (v, idx) = rand_sparse_ndarray(shape, storage_type) assert(nd._num_aux == 1) - assert(nd.indices.dtype == np.int32) + assert(nd._indices.dtype == np.int32) assert(nd.storage_type == 'row_sparse') - assert_almost_equal(nd.indices.asnumpy(), idx) + assert_almost_equal(nd._indices.asnumpy(), idx) def test_sparse_nd_basic(): def check_rsp_creation(values, indices, shape): @@ -101,13 +101,13 @@ def check_rsp_creation(values, indices, shape): dns[3] = mx.nd.array(values[1]) assert_almost_equal(rsp.asnumpy(), dns.asnumpy()) indices = mx.nd.array(indices).asnumpy() - assert_almost_equal(rsp.indices.asnumpy(), indices) + assert_almost_equal(rsp._indices.asnumpy(), indices) def check_csr_creation(shape): csr, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr') - assert_almost_equal(csr.indptr.asnumpy(), indptr) - assert_almost_equal(csr.indices.asnumpy(), indices) - assert_almost_equal(csr.values.asnumpy(), values) + assert_almost_equal(csr._indptr.asnumpy(), indptr) + assert_almost_equal(csr._indices.asnumpy(), indices) + assert_almost_equal(csr._values.asnumpy(), values) shape = (4,2) values = np.random.rand(2,2) @@ -147,8 +147,6 @@ def check_sparse_nd_csr_slice(shape): A2 = A.asnumpy() start = rnd.randint(0, shape[0] - 1) end = rnd.randint(start + 1, shape[0]) - values = A[start:end].values - indptr = A[start:end].indptr assert same(A[start:end].asnumpy(), A2[start:end]) shape = (rnd.randint(2, 10), rnd.randint(1, 10))