diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 372a7bc8ea0a..6e1e939c035f 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -771,6 +771,39 @@ class NDArray { */ void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0); +// Make a copy of a CSR NDArray +template +inline void CopyFromToCsrImpl(const NDArray from, NDArray *to, RunContext ctx) { + using namespace mshadow; + CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type"; + // if source storage is not initialized, fill destination with zeros + auto s = ctx.get_stream(); + if (!from.storage_initialized()) { + LOG(FATAL) << "To be implemented"; + // TODO(haibin) implement FillZerosCsrImpl + //op::FillZerosCsrImpl(s, to); + return; + } + // Allocate storage + to->CheckAndAllocAuxData(csr::kIndPtr, from.aux_shape(csr::kIndPtr)); + to->CheckAndAllocAuxData(csr::kIdx, from.aux_shape(csr::kIdx)); + to->CheckAndAllocData(from.aux_shape(csr::kIdx)); + // FIXME This is a naive implementation for CSR copy. It, however, is + // not efficient when the source CSR is sliced. In that case, we're copying + // a superset of values and indices of the slice. + // Ideally, we should truncate the values and indices array, and adjust indptr + // accordingly. + TBlob val = to->data(); + TBlob indptr = to->aux_data(csr::kIndPtr); + TBlob idx = to->aux_data(csr::kIdx); + ndarray::Copy(from.data(), &val, + from.ctx(), to->ctx(), ctx); + ndarray::Copy(from.aux_data(csr::kIndPtr), &indptr, + from.ctx(), to->ctx(), ctx); + ndarray::Copy(from.aux_data(csr::kIdx), &idx, + from.ctx(), to->ctx(), ctx); +} + // Make a copy of a row-sparse NDArray template inline void CopyFromToRspImpl(const NDArray from, NDArray *to, RunContext ctx) { @@ -829,10 +862,10 @@ void CopyFromToImpl(const NDArray from, NDArray *to, RunContext ctx) { CopyFromToDnsImpl(casted_nd, to, ctx); } else if (to_stype == kRowSparseStorage) { CopyFromToRspImpl(casted_nd, to, ctx); + } else if (to_stype == kCSRStorage) { + CopyFromToCsrImpl(casted_nd, to, ctx); } else { - // TODO(haibin) support csr copy. For sliced csr, we want to only copy the related - // indices and values instead of the superset. - LOG(FATAL) << "Not implemented yet"; + LOG(FATAL) << "unknown storage type" << to_stype; } if (is_same::value || is_same::value) { // Wait GPU kernel to complete diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 82389918f680..0429f4de797d 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -46,7 +46,7 @@ It updates the weights using:: .set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<2, 1>) .set_attr("FInferType", ElemwiseType<2, 1>) -// TODO(haibin) write FCompute for sparse sgd +// TODO(haibin) implement FCompute for sparse sgd // .set_attr("FCompute", SGDUpdate) .set_attr(FCOMP_EX_CPU, SparseSGDUpdateEx) .add_argument("weight", "NDArray-or-Symbol", "Weight") diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index e47c2b2a75a1..d3d185b35ceb 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -68,6 +68,7 @@ def check_sparse_nd_copy(from_stype, to_stype): check_sparse_nd_copy('row_sparse', 'row_sparse') check_sparse_nd_copy('row_sparse', 'default') check_sparse_nd_copy('default', 'row_sparse') + check_sparse_nd_copy('default', 'csr') def check_sparse_nd_prop_rsp(): storage_type = 'row_sparse'