Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
bug fix for copyfromto. sparse sgd test pass on gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed May 15, 2017
1 parent c4c03e2 commit 0ee7f1d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 24 deletions.
12 changes: 7 additions & 5 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ class NDArray {
: static_data(false), delay_alloc(delay_alloc_), storage_type(storage_type_),
aux_types(aux_types_), ctx(ctx_), storage_shape(storage_shape_),
aux_shapes(aux_shapes_) {
shandle.ctx = ctx;
var = Engine::Get()->NewVariable();
// aux_handles always reflect the correct number of aux data
for (size_t i = 0; i < aux_shapes.size(); i++) {
Expand Down Expand Up @@ -795,20 +796,17 @@ inline void CopyFromToRspImpl(const NDArray from, NDArray *to, RunContext ctx) {
template<typename from_xpu, typename to_xpu>
inline void CopyFromToDnsImpl(const NDArray from, NDArray *to, RunContext ctx) {
using namespace mshadow;
using namespace std;
CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type";
TBlob tmp = to->data();
ndarray::Copy<from_xpu, to_xpu>(from.data(), &tmp,
from.ctx(), to->ctx(), ctx);
if (is_same<from_xpu, mshadow::gpu>::value || is_same<to_xpu, mshadow::gpu>::value) {
// Wait GPU kernel to complete
ctx.get_stream<gpu>()->Wait();
}
}

// Make a copy of an NDArray based on storage type
template<typename from_xpu, typename to_xpu>
void CopyFromToImpl(const NDArray from, NDArray *to, RunContext ctx) {
using namespace std;
using namespace mshadow;
// if storage type doesn't match, cast the storage first
auto from_stype = from.storage_type();
auto to_stype = to->storage_type();
Expand Down Expand Up @@ -836,6 +834,10 @@ void CopyFromToImpl(const NDArray from, NDArray *to, RunContext ctx) {
// indices and values instead of the superset.
LOG(FATAL) << "Not implemented yet";
}
if (is_same<from_xpu, mshadow::gpu>::value || is_same<to_xpu, mshadow::gpu>::value) {
// Wait GPU kernel to complete
ctx.get_stream<gpu>()->Wait();
}
}

/*!
Expand Down
22 changes: 3 additions & 19 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,25 +84,6 @@ inline void SGDUpdate(const nnvm::NodeAttrs& attrs,
});
}

// TODO(haibin) duplicated code. remove me
#define NDARRAY_IDX_TYPE_SWITCH(type, DType, ...) \
switch (type) { \
case mshadow::kUint8: \
{ \
typedef uint8_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt32: \
{ \
typedef int32_t DType; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Unknown idx type enum " << type; \
}

/*! \brief kernel for sparse sgd
*/
template<int req>
Expand Down Expand Up @@ -145,6 +126,7 @@ inline void SparseSGDUpdateDnsRspImpl(const SGDParam& param,
auto &out = outputs[0];
CHECK_EQ(weight.storage_type(), kDefaultStorage);
CHECK_EQ(grad.storage_type(), kRowSparseStorage);
if (!grad.storage_initialized()) return;

MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, {
NDARRAY_IDX_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, {
Expand Down Expand Up @@ -291,6 +273,8 @@ inline void SparseSGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
auto &grad = inputs[1];
auto &mom = inputs[2];
auto &out = outputs[0];
if (!grad.storage_initialized()) return;

MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, {
NDARRAY_IDX_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Expand Down
3 changes: 3 additions & 0 deletions src/operator/optimizer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ NNVM_REGISTER_OP(sgd_mom_update)
NNVM_REGISTER_OP(sparse_sgd_update)
.set_attr<FComputeEx>(FCOMP_EX_GPU, SparseSGDUpdateEx<gpu>);

NNVM_REGISTER_OP(sparse_sgd_mom_update)
.set_attr<FComputeEx>(FCOMP_EX_GPU, SparseSGDMomUpdateEx<gpu>);

NNVM_REGISTER_OP(adam_update)
.set_attr<FCompute>("FCompute<gpu>", AdamUpdate<gpu>);

Expand Down

0 comments on commit 0ee7f1d

Please sign in to comment.