Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Use cast_storage when copying ndarrays of different stypes on same co…
Browse files Browse the repository at this point in the history
…ntext
  • Loading branch information
reminisce committed Jul 11, 2017
1 parent 038fd31 commit 628439c
Showing 1 changed file with 30 additions and 20 deletions.
50 changes: 30 additions & 20 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,29 +398,39 @@ void CopyFromToImpl(const NDArray from, NDArray *to, RunContext ctx) {
// if storage type doesn't match, cast the storage first
auto from_stype = from.storage_type();
auto to_stype = to->storage_type();
NDArray casted_nd;
if (from_stype != to_stype) {
TShape shape = from.shape();
auto from_ctx = from.ctx();
auto s = ctx.get_stream<from_xpu>();
// TODO(haibin) inplace conversion
CHECK(from_stype == kDefaultStorage || to_stype == kDefaultStorage)
<< "Copying ndarray of stype = " << from_stype
<< " to stype = " << to_stype << " is not supported";
const auto from_ctx = from.ctx();
const auto to_ctx = to->ctx();
if (from_ctx == to_ctx && from_stype != to_stype) {
// same ctx, different stypes, use cast op directly without copying
common::CastStorageDispatch<from_xpu>(s, from, *to);
} else {
NDArray casted_nd; // an intermediate result before copying from to to
if (from_stype == to_stype) {
casted_nd = from; // same stype, no need to cast from
} else { // different stypes on different ctx needs an temporary casted_nd
TShape shape = from.shape();
auto s = ctx.get_stream<from_xpu>();
if (to_stype == kDefaultStorage) {
casted_nd = NDArray(shape, from_ctx);
} else {
casted_nd = NDArray(to_stype, shape, from_ctx);
}
// convert from_nd to the same stype as to_nd
common::CastStorageDispatch<from_xpu>(s, from, casted_nd);
}

if (to_stype == kDefaultStorage) {
casted_nd = NDArray(shape, from_ctx);
CopyFromToDnsImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else if (to_stype == kRowSparseStorage) {
CopyFromToRspImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else if (to_stype == kCSRStorage) {
CopyFromToCsrImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else {
casted_nd = NDArray(to_stype, shape, from_ctx);
LOG(FATAL) << "unknown storage type" << to_stype;
}
common::CastStorageDispatch<from_xpu>(s, from, casted_nd);
} else {
casted_nd = from;
}
if (to_stype == kDefaultStorage) {
CopyFromToDnsImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else if (to_stype == kRowSparseStorage) {
CopyFromToRspImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else if (to_stype == kCSRStorage) {
CopyFromToCsrImpl<from_xpu, to_xpu>(casted_nd, to, ctx);
} else {
LOG(FATAL) << "unknown storage type" << to_stype;
}
if (is_same<from_xpu, mshadow::gpu>::value || is_same<to_xpu, mshadow::gpu>::value) {
// Wait GPU kernel to complete
Expand Down

0 comments on commit 628439c

Please sign in to comment.