Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
inefficient implementation for csr copy
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed May 15, 2017
1 parent 0ee7f1d commit 59e7d80
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
39 changes: 36 additions & 3 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,39 @@ class NDArray {
*/
void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0);

// Make a copy of a CSR NDArray
template<typename from_xpu, typename to_xpu>
inline void CopyFromToCsrImpl(const NDArray from, NDArray *to, RunContext ctx) {
using namespace mshadow;
CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type";
// if source storage is not initialized, fill destination with zeros
auto s = ctx.get_stream<to_xpu>();
if (!from.storage_initialized()) {
LOG(FATAL) << "To be implemented";
// TODO(haibin) implement FillZerosCsrImpl
//op::FillZerosCsrImpl<to_xpu>(s, to);
return;
}
// Allocate storage
to->CheckAndAllocAuxData(csr::kIndPtr, from.aux_shape(csr::kIndPtr));
to->CheckAndAllocAuxData(csr::kIdx, from.aux_shape(csr::kIdx));
to->CheckAndAllocData(from.aux_shape(csr::kIdx));
// FIXME This is a naive implementation for CSR copy. It, however, is
// not efficient when the source CSR is sliced. In that case, we're copying
// a superset of values and indices of the slice.
// Ideally, we should truncate the values and indices array, and adjust indptr
// accordingly.
TBlob val = to->data();
TBlob indptr = to->aux_data(csr::kIndPtr);
TBlob idx = to->aux_data(csr::kIdx);
ndarray::Copy<from_xpu, to_xpu>(from.data(), &val,
from.ctx(), to->ctx(), ctx);
ndarray::Copy<from_xpu, to_xpu>(from.aux_data(csr::kIndPtr), &indptr,
from.ctx(), to->ctx(), ctx);
ndarray::Copy<from_xpu, to_xpu>(from.aux_data(csr::kIdx), &idx,
from.ctx(), to->ctx(), ctx);
}

// Make a copy of a row-sparse NDArray
template<typename from_xpu, typename to_xpu>
inline void CopyFromToRspImpl(const NDArray from, NDArray *to, RunContext ctx) {
Expand Down Expand Up @@ -829,10 +862,10 @@ void CopyFromToImpl(const NDArray from, NDArray *to, RunContext ctx) {
CopyFromToDnsImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else if (to_stype == kRowSparseStorage) {
CopyFromToRspImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else if (to_stype == kCSRStorage) {
CopyFromToCsrImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else {
// TODO(haibin) support csr copy. For sliced csr, we want to only copy the related
// indices and values instead of the superset.
LOG(FATAL) << "Not implemented yet";
LOG(FATAL) << "unknown storage type" << to_stype;
}
if (is_same<from_xpu, mshadow::gpu>::value || is_same<to_xpu, mshadow::gpu>::value) {
// Wait GPU kernel to complete
Expand Down
2 changes: 1 addition & 1 deletion src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ It updates the weights using::
.set_attr_parser(ParamParser<SGDParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
// TODO(haibin) write FCompute for sparse sgd
// TODO(haibin) implement FCompute for sparse sgd
// .set_attr<FCompute>("FCompute<cpu>", SGDUpdate<cpu>)
.set_attr<FComputeEx>(FCOMP_EX_CPU, SparseSGDUpdateEx<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def check_sparse_nd_copy(from_stype, to_stype):
check_sparse_nd_copy('row_sparse', 'row_sparse')
check_sparse_nd_copy('row_sparse', 'default')
check_sparse_nd_copy('default', 'row_sparse')
check_sparse_nd_copy('default', 'csr')

def check_sparse_nd_prop_rsp():
storage_type = 'row_sparse'
Expand Down

0 comments on commit 59e7d80

Please sign in to comment.