Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Misc fixes for sparse distributed training #8345

Merged
merged 11 commits into from
Oct 21, 2017
40 changes: 18 additions & 22 deletions src/kvstore/kvstore_dist.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ namespace kvstore {
/**
* \brief distributed kvstore
*
* for a worker node, it always guarantees that all push and pull issued from
* this worker on the same key are serialized. namely push(3) and then pull(3),
* then the data pulled is always containing the modification from the push(3).
*
* it's the server node's job to control the data consistency among all
* workers. see details on \ref ServerHandle::Start
*/
Expand Down Expand Up @@ -248,7 +244,7 @@ class KVStoreDist : public KVStoreLocal {
LOG(FATAL) << "RowSparsePull with multiple values is not implemented yet";
} else {
auto& indices = target_val_rowids[0].second;
PullRowSparse_(key, &recv_buf, indices, priority);
PullRowSparse_(key, recv_buf, indices, priority);
comm_->BroadcastRowSparse(key, recv_buf, grouped_val_rowid, num_vals == 1, priority);
}
}
Expand Down Expand Up @@ -322,24 +318,24 @@ class KVStoreDist : public KVStoreLocal {
}

// pull row sparse weight into `recv_buf` based on indices given by `indices`
void PullRowSparse_(const int key, NDArray *recv_buf, const NDArray& indices, int priority) {
void PullRowSparse_(const int key, const NDArray& recv_buf,
const NDArray& indices, int priority) {
using namespace rowsparse;
auto pull_from_servers = [this, key, recv_buf, indices]
(RunContext rctx, Engine::CallbackOnComplete cb) {
// allocate memory for the buffer
size_t num_rows = indices.shape().Size();
recv_buf->CheckAndAlloc({mshadow::Shape1(num_rows)});
recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)});
#if MKL_EXPERIMENTAL == 1
mkl_set_tblob_eager_mode(recv_buf->data());
mkl_set_tblob_eager_mode(recv_buf.data());
#endif
real_t* data = recv_buf->data().dptr<real_t>();
auto indices_data = indices.data();
const auto offsets = indices_data.dptr<int64_t>();
const auto unit_len = recv_buf->shape().ProdShape(1, recv_buf->shape().ndim());
real_t* data = recv_buf.data().dptr<real_t>();
const auto offsets = indices.data().dptr<int64_t>();
const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim());
const int64_t size = num_rows * unit_len;
// convert to ps keys in row sparse format
PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets,
unit_len, recv_buf->shape()[0]);
unit_len, recv_buf.shape()[0]);
if (this->log_verbose_) {
LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << " keys: "
<< pskv.keys << " size: " << size;
Expand All @@ -348,16 +344,16 @@ class KVStoreDist : public KVStoreLocal {
// copy indices to recv_buf. this needs to be done before ZPull
// because after pull is done, the callback function returns and locks are released.
// at this point, later functions may access the indices variable while copy happens
mshadow::Copy(recv_buf->aux_data(kIdx).FlatTo1D<cpu, int64_t>(),
indices_data.FlatTo1D<cpu, int64_t>());
mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D<cpu, int64_t>(),
indices.data().FlatTo1D<cpu, int64_t>());
CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens, kRowSparsePushPull,
[vals, cb]() { delete vals; cb(); });
};
CHECK_NOTNULL(Engine::Get())->PushAsync(
pull_from_servers,
pinned_ctx_,
{indices.var()},
{recv_buf->var()},
{recv_buf.var()},
FnProperty::kNormal,
priority,
PROFILER_MESSAGE("KVStoreDistRowSparsePull"));
Expand All @@ -366,15 +362,15 @@ class KVStoreDist : public KVStoreLocal {
// push row sparse gradient
void PushRowSparse(int key, const NDArray &send_buf, int priority) {
using namespace rowsparse;
auto push_to_servers = [this, key, &send_buf]
auto push_to_servers = [this, key, send_buf]
(RunContext rctx, Engine::CallbackOnComplete cb) {
#if MKL_EXPERIMENTAL == 1
mkl_set_tblob_eager_mode(send_buf.data());
#endif
real_t* data = send_buf.data().dptr<real_t>();
bool init = send_buf.storage_initialized();
Copy link
Member

@rahul003 rahul003 Oct 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove init, it is no longer used

const int64_t num_rows = init ? send_buf.aux_shape(kIdx)[0] : 0;
const auto offsets = init ? send_buf.aux_data(kIdx).dptr<int64_t>() : nullptr;
const int64_t num_rows = send_buf.aux_shape(kIdx)[0];
const auto offsets = send_buf.aux_data(kIdx).dptr<int64_t>();
const auto unit_len = send_buf.shape().ProdShape(1, send_buf.shape().ndim());
const int64_t size = num_rows * unit_len;

Expand Down Expand Up @@ -472,7 +468,7 @@ class KVStoreDist : public KVStoreLocal {
return pskv;
}

// TODO(haibin) this encoding method for row sparse keys doesn't allow cross-layer batching
// Note: this encoding method for row sparse keys doesn't allow cross-layer batching
inline PSKV& EncodeRowSparseKey(const int key, const int64_t size, const int64_t num_rows,
const int64_t *offsets, const size_t unit_len,
const int64_t total_num_rows) {
Expand All @@ -495,15 +491,15 @@ class KVStoreDist : public KVStoreLocal {
ps::Key master_key = krs[i].begin() + key;
pskv.keys.push_back(master_key);
pskv.lens.push_back(0);
if (offsets) {
if (offsets && size > 0) {
// calculate partition ranges
int64_t part_num_rows =
llround(static_cast<double>(total_num_rows) / num_servers * (i + 1)) -
llround(static_cast<double>(total_num_rows) / num_servers * i);
auto end_row = start_row + part_num_rows;
// search for offsets in [start_row, end_row)
auto lb = std::lower_bound(offsets, offsets + num_rows, start_row);
auto ub = std::upper_bound(offsets, offsets + num_rows, end_row - 1);

for (auto offset = lb; offset < ub; offset++) {
ps::Key ps_key = krs[i].begin() + key + (*offset - start_row);
CHECK_LT(ps_key, krs[i].end());
Expand Down
18 changes: 9 additions & 9 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,13 @@ void FillCompute(const nnvm::NodeAttrs& attrs,
});
}

struct PopulateFullIdxRspKernel {
template<typename IType>
MSHADOW_XINLINE static void Map(int i, IType* out) {
KERNEL_ASSIGN(out[i], kWriteTo, i);
}
};

// Fill in the indices and values of a RowSparse NDArray to represent a zeros NDArray,
// instead of the usual compact representation.
template<typename xpu>
Expand All @@ -192,21 +199,14 @@ inline void FillDnsZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
MSHADOW_IDX_TYPE_SWITCH(dst->aux_type(kIdx), IType, {
auto num_rows = dst->shape()[0];
dst->CheckAndAlloc({Shape1(num_rows)});
auto idx = dst->aux_data(kIdx).FlatTo1D<xpu, IType>(s);
auto idx = dst->aux_data(kIdx);
auto val = dst->data();
Kernel<set_zero, xpu>::Launch(s, val.Size(), val.dptr<DType>());
ASSIGN_DISPATCH(idx, kWriteTo, range<IType>(0, num_rows, 1, 1));
Kernel<PopulateFullIdxRspKernel, xpu>::Launch(s, num_rows, idx.dptr<IType>());
});
});
}

struct PopulateFullIdxRspKernel {
template<typename IType>
MSHADOW_XINLINE static void Map(int i, IType* out) {
KERNEL_ASSIGN(out[i], kWriteTo, i);
}
};

// Fill full indices NDArray with zeros by updating the aux shape.
template<typename xpu>
void PopulateFullIdxRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
Expand Down
31 changes: 17 additions & 14 deletions tests/nightly/dist_sync_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def check_diff_to_scalar(A, x, rank=None):

rate = 2
shape = (2, 3)
big_shape = (1200, 1200) # bigger than BIGARRAY_BOUND
big_shape = (1200, 1200) # bigger than MXNET_KVSTORE_BIGARRAY_BOUND

kv = mx.kv.create('dist_sync')

Expand Down Expand Up @@ -104,24 +104,27 @@ def check_row_sparse_keys(kv, my_rank, nworker):
def check_row_sparse_keys_with_zeros(kv, my_rank, nworker):
nrepeat = 3
# prepare gradient
v = mx.nd.zeros(shape)
big_v = mx.nd.zeros(big_shape)
v = mx.nd.sparse.zeros('row_sparse', shape)
big_v = mx.nd.sparse.zeros('row_sparse', big_shape)
# push
for i in range(nrepeat):
kv.push('11', v.tostype('row_sparse'))
kv.push('100', big_v.tostype('row_sparse'))

kv.push('11', v)
kv.push('100', big_v)
# pull a subset of rows this worker is interested in
all_row_ids = np.arange(shape[0])
val = mx.nd.ones(shape).tostype('row_sparse')
big_val = mx.nd.ones(big_shape).tostype('row_sparse')
kv.row_sparse_pull('11', out=val, row_ids=mx.nd.array(all_row_ids, dtype='int64'))
big_num_rows = shape[0]
val = mx.nd.sparse.zeros('row_sparse', shape)
big_val = mx.nd.sparse.zeros('row_sparse', big_shape)
kv.row_sparse_pull('11', out=val, row_ids=mx.nd.array(all_row_ids))
big_all_row_ids = np.arange(big_shape[0])
kv.row_sparse_pull('100', out=big_val, row_ids=mx.nd.array(big_all_row_ids, dtype='int64'))
kv.row_sparse_pull('100', out=big_val, row_ids=mx.nd.array(big_all_row_ids))
# verify results
check_diff_to_scalar(val, mx.nd.ones(shape))
check_diff_to_scalar(big_val, mx.nd.ones(big_shape))
check_diff_to_scalar(val, 1)
check_diff_to_scalar(big_val, 1)
# pull empty weights
kv.row_sparse_pull('11', out=val, row_ids=mx.nd.array([]))
kv.row_sparse_pull('100', out=big_val, row_ids=mx.nd.array([]))
check_diff_to_scalar(val, 0)
check_diff_to_scalar(big_val, 0)

def check_big_row_sparse_keys(kv, my_rank, nworker):
mx.random.seed(123)
Expand Down Expand Up @@ -154,7 +157,7 @@ def check_big_row_sparse_keys(kv, my_rank, nworker):
rnd.seed(my_rank)
num_rows = big_shape[0]
row_ids_np = np.random.randint(num_rows, size=num_rows)
row_ids = mx.nd.array(row_ids_np, dtype='int64')
row_ids = mx.nd.array(row_ids_np)
# perform pull
val = mx.nd.zeros(big_shape, stype='row_sparse')
kv.row_sparse_pull('100', out=val, row_ids=row_ids)
Expand Down
4 changes: 4 additions & 0 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def test_sgd():
if dtype != np.float16:
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape[:2],
dtype, w_stype='csr', g_stype='csr')
# test optimizer with a big shape
big_shape = (54686454, 1)
kwarg = {'momentum': 0.9, 'wd': 0.05}
compare_optimizer(opt1(**kwarg), opt2(**kwarg), big_shape, np.float32)

class PySparseSGD(mx.optimizer.Optimizer):
"""python reference implemenation of sgd"""
Expand Down