Skip to content

Commit

Permalink
Non-blocking row_sparse_pull. Fix incorrect indices generated by devi…
Browse files Browse the repository at this point in the history
…ce kvstore.row_sparse_pull (apache#9887)

* nonblocking Kvstore (apache#195)

* draft

* rm use_copy. fix dist kvstore. TODO: fix dtype

* fix dtype, shape

* remove reshape

* cleanup

* fix compilation

* rsp draft

* update param name

* doc update and small refactoring

* minor updates

* enhance test case with 2-D rowids

* update gpu tests

* rewrite gpu unique kernels

* update gpu tests

* update reshape test/

* fix lint

* update test for py3
  • Loading branch information
eric-haibin-lin authored and Jin Huang committed Mar 30, 2018
1 parent 65a6e86 commit bf455ef
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 240 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
other pull actions.
row_ids : NDArray or list of NDArray
The row_ids for which to pull for each value. Each row_id is an 1D NDArray \
The row_ids for which to pull for each value. Each row_id is an 1-D NDArray \
whose values don't have to be unique nor sorted.
Examples
Expand Down
196 changes: 55 additions & 141 deletions src/kvstore/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ class Comm {

/**
* \brief broadcast src to dst[i] with target row_ids for every i
* \param key the identifier key for the stored ndarray
* \param src the source row_sparse ndarray to broadcast
* \param dst a list of destination row_sparse NDArray and its target row_ids to broadcast,
where the row_ids are expected to be unique and sorted
* \param use_copy if set to true, directly copy src to dst[i] without looking up the
provided row_ids
where the row_ids are expected to be unique and sorted in row_id.data()
* \param priority the priority of the operation
*/
virtual void BroadcastRowSparse(int key, const NDArray& src,
const std::vector<std::pair<NDArray*, NDArray>>& dst,
const bool use_copy,
const int priority) = 0;

/**
Expand Down Expand Up @@ -209,7 +209,6 @@ class CommCPU : public Comm {

void BroadcastRowSparse(int key, const NDArray& src,
const std::vector<std::pair<NDArray*, NDArray>>& dst,
const bool use_copy,
const int priority) override {
using namespace mshadow;
CHECK_EQ(src.storage_type(), kRowSparseStorage)
Expand All @@ -219,107 +218,30 @@ class CommCPU : public Comm {
for (size_t i = 0; i < dst.size(); ++i) {
NDArray* out = dst[i].first;
NDArray row_id = dst[i].second;
if (use_copy) {
CopyFromTo(src, out, priority);
} else {
CHECK_EQ(out->storage_type(), kRowSparseStorage)
<< "BroadcastRowSparse expects row_sparse dst NDArray";
CHECK_EQ(row_id.ctx().dev_mask(), Context::kCPU)
<< "BroadcastRowSparse with row_indices on gpu context not supported";
// retain according to unique indices
const bool use_sparse_retain = (src.shape()[0] != src.storage_shape()[0])
|| (row_id.dtype() != out->aux_type(rowsparse::kIdx))
|| (out->ctx().dev_mask() != Context::kGPU);
if (use_sparse_retain) { // use sparse_retain op
const bool is_to_gpu = out->ctx().dev_mask() == Context::kGPU;
NDArray out_cpu = is_to_gpu? NDArray(kRowSparseStorage, src.shape(),
src.ctx(), true, src.dtype(), src.aux_types()) : *out;
Engine::Get()->PushAsync(
[=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
const TBlob& indices = row_id.data();
NDArray temp = out_cpu; // get rid of const qualifier
op::SparseRetainOpForwardRspImpl<cpu>(rctx.get_stream<cpu>(),
src, indices, kWriteTo,
&temp);
on_complete();
}, Context::CPU(), {src.var(), row_id.var()}, {out_cpu.var()},
FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
if (is_to_gpu) {
CopyFromTo(out_cpu, out, priority);
}
} else { // direct copy rows
Engine::Get()->PushAsync(
[=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
CopyRetainedRowsToGPU(rctx.get_stream<cpu>(), rctx.get_stream<gpu>(),
src, row_id, out);
// wait for GPU operations to complete
rctx.get_stream<gpu>()->Wait();
on_complete();
}, out->ctx(), {src.var(), row_id.var()}, {out->var()},
FnProperty::kCopyToGPU, priority, PROFILER_MESSAGE("KVStoreCopyRetainedRowsToGPU"));
}
}
CHECK_EQ(out->storage_type(), kRowSparseStorage)
<< "BroadcastRowSparse expects row_sparse dst NDArray";
CHECK_EQ(row_id.ctx().dev_mask(), Context::kCPU)
<< "BroadcastRowSparse with row_indices on gpu context not supported";
// retain according to unique indices
const bool is_to_gpu = out->ctx().dev_mask() == Context::kGPU;
NDArray retained_cpu = is_to_gpu ? NDArray(kRowSparseStorage, src.shape(),
src.ctx(), true, src.dtype(), src.aux_types()) : *out;
Engine::Get()->PushAsync(
[=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
const TBlob& indices = row_id.data();
NDArray temp = retained_cpu; // get rid the of const qualifier
op::SparseRetainOpForwardRspImpl<cpu>(rctx.get_stream<cpu>(),
src, indices, kWriteTo,
&temp);
on_complete();
}, Context::CPU(), {src.var(), row_id.var()}, {retained_cpu.var()},
FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
// if retained_cpu == out, CopyFromTo will ignore the copy operation
CopyFromTo(retained_cpu, out, priority);
}
}

private:
/*!
* \brief When src is a rsp with full rows,
* simply copy retained rows directly from cpu to gpu
* without invoking sparse_retain op.
*/
void CopyRetainedRowsToGPU(mshadow::Stream<cpu>* cpu_stream,
mshadow::Stream<gpu>* gpu_stream,
const NDArray& src,
const NDArray& indices,
NDArray* dst) {
#if MXNET_USE_CUDA == 1
CHECK_EQ(src.storage_type(), kRowSparseStorage)
<< "CopyRetainedRowsToGPU expects row-sparse src NDArray";
CHECK_EQ(src.ctx().dev_mask(), Context::kCPU)
<< "CopyRetainedRowsToGPU with src on gpu context not supported";
CHECK_EQ(src.storage_shape()[0], src.shape()[0])
<< "CopyRetainedRowsToGPU only supports src rsp with full rows";
CHECK_EQ(indices.storage_type(), kDefaultStorage);
CHECK_EQ(indices.ctx().dev_mask(), Context::kCPU);
CHECK_EQ(dst->storage_type(), kRowSparseStorage);
CHECK_EQ(dst->ctx().dev_mask(), Context::kGPU);
CHECK_EQ(indices.dtype(), dst->aux_type(rowsparse::kIdx))
<< "CopyRetainedRowsToGPU only supports same data type for idx array and dst aux_data(0)";
if (!src.storage_initialized() || indices.data().Size() == 0U) {
op::FillZerosRspImpl(gpu_stream, *dst);
return;
}
using namespace mshadow;

const TBlob& src_data = src.data();
const TBlob& idx_data = indices.data();
const size_t row_length = src.shape().ProdShape(1, src.shape().ndim());
const size_t num_rows_retained = idx_data.Size();
dst->CheckAndAlloc({Shape1(num_rows_retained)});
TBlob dst_data = dst->data();
TBlob dst_idx_data = dst->aux_data(rowsparse::kIdx);
MSHADOW_TYPE_SWITCH(src.dtype(), DType, {
MSHADOW_IDX_TYPE_SWITCH(indices.dtype(), IType, {
// copy idx array
Tensor<gpu, 1, IType> dst_idx_tensor = dst_idx_data.FlatTo1D<gpu, IType>(gpu_stream);
const Tensor<cpu, 1, IType> idx_tensor = idx_data.FlatTo1D<cpu, IType>(cpu_stream);
Copy(dst_idx_tensor, idx_tensor, gpu_stream);
// copy src data
const Tensor<cpu, 2, DType> src_data_tensor = src_data.get_with_shape<cpu, 2, DType>(
Shape2(src_data.shape_[0], row_length), cpu_stream);
Tensor<gpu, 2, DType> dst_data_tensor = dst_data.get_with_shape<gpu, 2, DType>(
Shape2(dst_data.shape_[0], row_length), gpu_stream);
for (size_t i = 0; i < num_rows_retained; ++i) {
Copy(dst_data_tensor[i], src_data_tensor[idx_tensor[i]], gpu_stream);
}
})
})
#else
LOG(FATAL) << "GPU not enabled";
#endif
}

// reduce sum into val[0]
inline void ReduceSumCPU(const std::vector<NDArray> &in_data) {
MSHADOW_TYPE_SWITCH(in_data[0].dtype(), DType, {
Expand Down Expand Up @@ -632,54 +554,46 @@ class CommDevice : public Comm {

void BroadcastRowSparse(int key, const NDArray& src,
const std::vector<std::pair<NDArray*, NDArray>>& dst,
const bool use_copy,
const int priority) override {
CHECK_EQ(src.storage_type(), kRowSparseStorage)
<< "BroadcastRowSparse expects row-sparse src NDArray";

for (size_t i = 0; i < dst.size(); ++i) {
NDArray* out = dst[i].first;
NDArray row_id = dst[i].second;
if (use_copy) {
CopyFromTo(src, out, priority);
} else {
CHECK_EQ(out->storage_type(), kRowSparseStorage)
<< "BroadcastRowSparse expects row_sparse dst NDArray";

const bool is_diff_ctx = out->ctx() != src.ctx();
NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(),
src.ctx(), true, out->dtype(), out->aux_types()) : *out;

CHECK_EQ(row_id.ctx(), src.ctx())
<< "row_id and src are expected to be on the same context";

Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
NDArray temp = out_gpu;
const TBlob& indices = row_id.data();
switch (temp.ctx().dev_mask()) {
case cpu::kDevMask: {
mxnet::common::SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(),
src, indices, kWriteTo, &temp);
break;
}
CHECK_EQ(out->storage_type(), kRowSparseStorage)
<< "BroadcastRowSparse expects row_sparse dst NDArray";
CHECK_EQ(row_id.ctx(), src.ctx())
<< "row_id and src are expected to be on the same context";
// retain according to indices
const bool is_diff_ctx = out->ctx() != src.ctx();
NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(),
src.ctx(), true, out->dtype(), out->aux_types()) : *out;
Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
const TBlob& indices = row_id.data();
using namespace mxnet::common;
NDArray temp = out_gpu;
switch (temp.ctx().dev_mask()) {
case cpu::kDevMask: {
SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(),
src, indices, kWriteTo, &temp);
break;
}
#if MXNET_USE_CUDA
case gpu::kDevMask: {
mxnet::common::SparseRetainOpForwardRspWrapper<gpu>(rctx.get_stream<gpu>(),
src, indices, kWriteTo, &temp);
// wait for GPU operations to complete
rctx.get_stream<gpu>()->Wait();
break;
}
#endif
default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
case gpu::kDevMask: {
SparseRetainOpForwardRspWrapper<gpu>(rctx.get_stream<gpu>(),
src, indices, kWriteTo, &temp);
// wait for GPU operations to complete
rctx.get_stream<gpu>()->Wait();
break;
}
on_complete();
}, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()},
FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
if (is_diff_ctx) {
CopyFromTo(out_gpu, out, priority);
}
}
#endif
default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
}
on_complete();
}, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()},
FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
CopyFromTo(out_gpu, out, priority);
}
}

Expand Down
34 changes: 16 additions & 18 deletions src/kvstore/kvstore_dist.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,24 +279,20 @@ class KVStoreDist : public KVStoreLocal {
}
auto &target_val_rowids = grouped_val_rowids[i];
const size_t num_vals = target_val_rowids.size();
size_t num_rows = 0;
// TODO(haibin) refactor this for loop
for (size_t i = 0; i < num_vals; i++) {
auto &row_id = target_val_rowids[i].second;
NDArray indices(row_id.shape(), pinned_ctx_, false, mshadow::kInt64);
CopyFromTo(row_id, &indices, 0);
Unique(&indices, priority);
target_val_rowids[i].second = indices;
num_rows += indices.shape().Size();
}
if (num_vals > 1) {
// TODO(haibin) aggregate over all unique indices
LOG(FATAL) << "RowSparsePull with multiple values is not implemented yet";
} else {
auto& indices = target_val_rowids[0].second;
PullRowSparse_(key, recv_buf, indices, priority);
comm_->BroadcastRowSparse(key, recv_buf, grouped_val_rowid, num_vals == 1, priority);
target_val_rowids[i].second = Unique(row_id, pinned_ctx_, 0);
}
CHECK_EQ(num_vals, 1) << "RowSparsePull with multiple values is not supported yet";
NDArray& indices = target_val_rowids[0].second;
PullRowSparse_(key, recv_buf, indices, priority);
// The recv_buf contains values pulled from remote server with unique indices.
// Directly broadcast w/o rowids if num_vals == 1
auto get_val = [](const std::pair<NDArray*, NDArray>& p) { return p.first; };
std::vector<NDArray*> grouped_val(grouped_val_rowid.size());
std::transform(grouped_val_rowid.begin(), grouped_val_rowid.end(),
grouped_val.begin(), get_val);
comm_->Broadcast(key, recv_buf, grouped_val, priority);
}
}

Expand Down Expand Up @@ -462,10 +458,12 @@ class KVStoreDist : public KVStoreLocal {
auto pull_from_servers = [this, key, recv_buf, indices]
(RunContext rctx, Engine::CallbackOnComplete cb) {
// allocate memory for the buffer
size_t num_rows = indices.shape().Size();
CHECK_EQ(indices.dtype(), mshadow::kInt64);
const TBlob idx_data = indices.data();
size_t num_rows = idx_data.shape_.Size();
recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)});
real_t* data = recv_buf.data().dptr<real_t>();
const auto offsets = indices.data().dptr<int64_t>();
const auto offsets = idx_data.dptr<int64_t>();
const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim());
const int64_t size = num_rows * unit_len;
// convert to ps keys in row sparse format
Expand All @@ -480,7 +478,7 @@ class KVStoreDist : public KVStoreLocal {
// because after pull is done, the callback function returns and locks are released.
// at this point, later functions may access the indices variable while copy happens
mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D<cpu, int64_t>(),
indices.data().FlatTo1D<cpu, int64_t>());
idx_data.FlatTo1D<cpu, int64_t>());
CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens,
static_cast<int>(DataHandleType::kRowSparsePushPull),
[vals, cb]() { delete vals; cb(); });
Expand Down
Loading

0 comments on commit bf455ef

Please sign in to comment.