diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index b08e6c659c1f..1ff6cf9fc410 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -398,29 +398,39 @@ void CopyFromToImpl(const NDArray from, NDArray *to, RunContext ctx) { // if storage type doesn't match, cast the storage first auto from_stype = from.storage_type(); auto to_stype = to->storage_type(); - NDArray casted_nd; - if (from_stype != to_stype) { - TShape shape = from.shape(); - auto from_ctx = from.ctx(); - auto s = ctx.get_stream(); - // TODO(haibin) inplace conversion + CHECK(from_stype == kDefaultStorage || to_stype == kDefaultStorage) + << "Copying ndarray of stype = " << from_stype + << " to stype = " << to_stype << " is not supported"; + const auto from_ctx = from.ctx(); + const auto to_ctx = to->ctx(); + if (from_ctx == to_ctx && from_stype != to_stype) { + // same ctx, different stypes, use cast op directly without copying + common::CastStorageDispatch(s, from, *to); + } else { + NDArray casted_nd; // an intermediate result before copying from to to + if (from_stype == to_stype) { + casted_nd = from; // same stype, no need to cast from + } else { // different stypes on different ctx needs an temporary casted_nd + TShape shape = from.shape(); + auto s = ctx.get_stream(); + if (to_stype == kDefaultStorage) { + casted_nd = NDArray(shape, from_ctx); + } else { + casted_nd = NDArray(to_stype, shape, from_ctx); + } + // convert from_nd to the same stype as to_nd + common::CastStorageDispatch(s, from, casted_nd); + } + if (to_stype == kDefaultStorage) { - casted_nd = NDArray(shape, from_ctx); + CopyFromToDnsImpl(casted_nd, to, ctx); + } else if (to_stype == kRowSparseStorage) { + CopyFromToRspImpl(casted_nd, to, ctx); + } else if (to_stype == kCSRStorage) { + CopyFromToCsrImpl(casted_nd, to, ctx); } else { - casted_nd = NDArray(to_stype, shape, from_ctx); + LOG(FATAL) << "unknown storage type" << to_stype; } - common::CastStorageDispatch(s, from, casted_nd); - } else { - casted_nd = from; - } - if (to_stype == kDefaultStorage) { - CopyFromToDnsImpl(casted_nd, to, ctx); - } else if (to_stype == kRowSparseStorage) { - CopyFromToRspImpl(casted_nd, to, ctx); - } else if (to_stype == kCSRStorage) { - CopyFromToCsrImpl(casted_nd, to, ctx); - } else { - LOG(FATAL) << "unknown storage type" << to_stype; } if (is_same::value || is_same::value) { // Wait GPU kernel to complete