Skip to content

Commit

Permalink
Improve copy sparse tensors (apache#7003)
Browse files Browse the repository at this point in the history
* Use cast_storage when copying ndarrays of different stypes on same context

* Relaunch test
  • Loading branch information
reminisce authored and eric-haibin-lin committed Jul 20, 2017
1 parent 68cbcc1 commit f158c08
Showing 1 changed file with 32 additions and 20 deletions.
52 changes: 32 additions & 20 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,29 +406,41 @@ void CopyFromToImpl(const NDArray from, NDArray *to, RunContext ctx) {
// if storage type doesn't match, cast the storage first
auto from_stype = from.storage_type();
auto to_stype = to->storage_type();
NDArray casted_nd;
if (from_stype != to_stype) {
TShape shape = from.shape();
auto from_ctx = from.ctx();
auto s = ctx.get_stream<from_xpu>();
// TODO(haibin) inplace conversion
CHECK(from_stype == kDefaultStorage
|| to_stype == kDefaultStorage
|| from_stype == to_stype)
<< "Copying ndarray of stype = " << from_stype
<< " to stype = " << to_stype << " is not supported";
const auto from_ctx = from.ctx();
const auto to_ctx = to->ctx();
auto s = ctx.get_stream<from_xpu>();
if (from_ctx == to_ctx && from_stype != to_stype) {
// same ctx, different stypes, use cast op directly without copying
common::CastStorageDispatch<from_xpu>(s, from, *to);
} else {
NDArray casted_nd; // an intermediate result before copying from to to
if (from_stype == to_stype) {
casted_nd = from; // same stype, no need to cast from
} else { // different stypes on different ctx needs an temporary casted_nd
TShape shape = from.shape();
if (to_stype == kDefaultStorage) {
casted_nd = NDArray(shape, from_ctx);
} else {
casted_nd = NDArray(to_stype, shape, from_ctx);
}
// convert from_nd to the same stype as to_nd
common::CastStorageDispatch<from_xpu>(s, from, casted_nd);
}

if (to_stype == kDefaultStorage) {
casted_nd = NDArray(shape, from_ctx);
CopyFromToDnsImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else if (to_stype == kRowSparseStorage) {
CopyFromToRspImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else if (to_stype == kCSRStorage) {
CopyFromToCsrImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else {
casted_nd = NDArray(to_stype, shape, from_ctx);
LOG(FATAL) << "unknown storage type" << to_stype;
}
common::CastStorageDispatch<from_xpu>(s, from, casted_nd);
} else {
casted_nd = from;
}
if (to_stype == kDefaultStorage) {
CopyFromToDnsImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else if (to_stype == kRowSparseStorage) {
CopyFromToRspImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else if (to_stype == kCSRStorage) {
CopyFromToCsrImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else {
LOG(FATAL) << "unknown storage type" << to_stype;
}
if (is_same<from_xpu, mshadow::gpu>::value || is_same<to_xpu, mshadow::gpu>::value) {
// Wait GPU kernel to complete
Expand Down

0 comments on commit f158c08

Please sign in to comment.