Skip to content

Commit

Permalink
Change STORAGE_TYPE_ASSIGN_CHECK to type_assign for fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Aug 3, 2017
1 parent d900ad6 commit 9ce878f
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 16 deletions.
51 changes: 36 additions & 15 deletions src/operator/tensor/sparse_retain-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,14 @@ inline bool SparseRetainForwardInferStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, sr::kArr, kRowSparseStorage);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, sr::kOut, kRowSparseStorage);
if ((*in_attrs)[sr::kArr] == kRowSparseStorage) {
STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, sr::kIdx, kDefaultStorage);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, sr::kOut, kRowSparseStorage);
} else { // fallback
type_assign(&(in_attrs->at(sr::kArr)), kDefaultStorage);
type_assign(&(in_attrs->at(sr::kIdx)), kDefaultStorage);
type_assign(&(out_attrs->at(sr::kOut)), kDefaultStorage);
}
return true;
}

Expand All @@ -68,10 +74,16 @@ inline bool SparseRetainBackwardInferStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 2U);
STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, sr::kOut, kDefaultStorage);
STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, sr::kIdx, kDefaultStorage);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, sr::kArr, kRowSparseStorage);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, sr::kIdx, kDefaultStorage);
if (out_attrs->at(sr::kArr) == kRowSparseStorage) {
STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, sr::kOut, kDefaultStorage);
STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, sr::kIdx, kDefaultStorage);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, sr::kIdx, kDefaultStorage);
} else {
type_assign(&(in_attrs->at(sr::kOut)), kDefaultStorage);
type_assign(&(in_attrs->at(sr::kIdx)), kDefaultStorage);
type_assign(&(out_attrs->at(sr::kArr)), kDefaultStorage);
type_assign(&(out_attrs->at(sr::kIdx)), kDefaultStorage);
}
return true;
}

Expand Down Expand Up @@ -184,13 +196,13 @@ void SparseRetainOpForwardRspImpl(mshadow::Stream<xpu> *s,
const TBlob& idx_data,
const OpReqType req,
NDArray* output_nd) {
if (req == kNullOp) return;
CHECK_EQ(input_nd.storage_type(), kRowSparseStorage)
<< "SparseRetainOpForwardRspImpl operator only takes row sparse NDArray as input";
CHECK_EQ(output_nd->storage_type(), kRowSparseStorage)
<< "SparseRetainOpForwardRspImpl operator only outputs row sparse NDArray";

if (req == kNullOp
|| !input_nd.storage_initialized()
if (!input_nd.storage_initialized()
|| idx_data.Size() == 0U
|| input_nd.shape()[0] == 0) {
FillZerosRspImpl(s, output_nd);
Expand Down Expand Up @@ -270,11 +282,13 @@ void SparseRetainOpBackwardEx(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 2U);
CHECK_EQ(req.size(), 2U);
CHECK_NE(req[sr::kArr], kWriteInplace);
CHECK_EQ(req[sr::kIdx], kNullOp)
CHECK_EQ(req[sr::kIdx], kNullOp);
if (req[sr::kArr] == kNullOp) return;
CHECK_EQ(req[sr::kArr], kWriteTo);

CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 2U)
<< "sparse_retain does not support calculating gradients of indices";

CHECK_EQ(inputs[sr::kOut].storage_type(), kDefaultStorage)
Expand All @@ -284,17 +298,24 @@ void SparseRetainOpBackwardEx(const nnvm::NodeAttrs& attrs,
CHECK_EQ(outputs[sr::kArr].storage_type(), kRowSparseStorage)
<< "sparse_retain backward only outputs row sparse NDArray as grad of input";

const TBlob out_grad_data = inputs[sr::kOut].data();
using namespace mxnet_op;
using namespace mshadow;
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob idx_data = inputs[sr::kIdx].data();
if (idx_data.Size() == 0U) {
NDArray output = outputs[sr::kArr];
FillZerosRspImpl<xpu>(s, &output);
return;
}

const TBlob out_grad_data = inputs[sr::kOut].data();

NDArray in_grad_nd = outputs[sr::kArr];
in_grad_nd.CheckAndAlloc({mshadow::Shape1(idx_data.Size())});
TBlob in_grad_data = in_grad_nd.data();
TBlob in_grad_idx = in_grad_nd.aux_data(rowsparse::kIdx);
const auto row_length = out_grad_data.shape_.ProdShape(1, out_grad_data.shape_.ndim());

using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(out_grad_data.type_flag_, DType, { // output data type
MSHADOW_IDX_TYPE_SWITCH(in_grad_idx.type_flag_, RType, { // row index data type
MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type
Expand Down
3 changes: 2 additions & 1 deletion tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from test_nn import *
#from test_rnn import *
from test_gluon_rnn import *
from test_sparse_operator import test_cast_storage_ex, test_sparse_dot, test_sparse_nd_zeros
from test_sparse_operator import test_cast_storage_ex, test_sparse_dot
from test_sparse_operator import test_sparse_nd_zeros, test_sparse_retain
from test_sparse_ndarray import test_create_csr, test_create_row_sparse

set_default_context(mx.gpu(0))
Expand Down

0 comments on commit 9ce878f

Please sign in to comment.