diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index 8133dc656402..4875b92a5adb 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -13,7 +13,7 @@ #include #include "mxnet/ndarray.h" #include "../ndarray/ndarray_function.h" -#include "../operator/tensor/indexing_op.h" +#include "../operator/tensor/sparse_retain-inl.h" namespace mxnet { namespace kvstore { /** diff --git a/src/operator/tensor/sparse_retain-inl.h b/src/operator/tensor/sparse_retain-inl.h index afa0776292ff..8a6e7a5046fc 100644 --- a/src/operator/tensor/sparse_retain-inl.h +++ b/src/operator/tensor/sparse_retain-inl.h @@ -179,45 +179,33 @@ struct SparseRetainRspRowBlockKernel { }; template -void SparseRetainOpForwardEx(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 1U); - CHECK_EQ(req.size(), 1U); - CHECK_EQ(req[sr::kOut], kWriteTo) << "sparse_retain only supports req=\'write\'"; - - CHECK_EQ(inputs[sr::kArr].storage_type(), kRowSparseStorage) - << "sparse_retain operator only takes row sparse NDArray as input"; - CHECK_EQ(inputs[sr::kIdx].storage_type(), kDefaultStorage) - << "sparse_retain operator only takes default NDArray as its index array"; - CHECK_EQ(outputs[sr::kOut].storage_type(), kRowSparseStorage) - << "sparse_retain operator only outputs row sparse NDArray"; +void SparseRetainOpForwardRspImpl(mshadow::Stream *s, + const NDArray& input_nd, + const TBlob& idx_data, + const OpReqType req, + NDArray* output_nd) { + CHECK_EQ(input_nd.storage_type(), kRowSparseStorage) + << "SparseRetainOpForwardRspImpl operator only takes row sparse NDArray as input"; + CHECK_EQ(output_nd->storage_type(), kRowSparseStorage) + << "SparseRetainOpForwardRspImpl operator only outputs row sparse NDArray"; - const NDArray& input_nd = inputs[sr::kArr]; - const TBlob idx_data = inputs[sr::kIdx].data(); - - if (req[sr::kOut] == kNullOp + if (req == kNullOp || !input_nd.storage_initialized() || idx_data.Size() == 0U || input_nd.shape()[0] == 0) { - FillComputeZerosEx(attrs, ctx, {}, req, outputs); + FillZerosRspImpl(s, output_nd); return; } const TBlob input_data = input_nd.data(); const TBlob input_idx = input_nd.aux_data(rowsparse::kIdx); - NDArray output_nd = outputs[sr::kOut]; - output_nd.CheckAndAlloc({mshadow::Shape1(idx_data.Size())}); - TBlob output_data = output_nd.data(); - TBlob output_idx = output_nd.aux_data(rowsparse::kIdx); + output_nd->CheckAndAlloc({mshadow::Shape1(idx_data.Size())}); + TBlob output_data = output_nd->data(); + TBlob output_idx = output_nd->aux_data(rowsparse::kIdx); const auto row_length = input_data.shape_.ProdShape(1, input_data.shape_.ndim()); using namespace mxnet_op; - Stream *s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(output_data.type_flag_, DType, { // output data type Kernel::Launch(s, output_data.Size(), output_data.dptr()); MSHADOW_IDX_TYPE_SWITCH(output_idx.type_flag_, RType, { // row index data type @@ -239,6 +227,27 @@ void SparseRetainOpForwardEx(const nnvm::NodeAttrs& attrs, }); } +template +void SparseRetainOpForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(req[sr::kOut], kWriteTo) << "sparse_retain only supports req=\'write\'"; + CHECK_EQ(inputs[sr::kIdx].storage_type(), kDefaultStorage) + << "sparse_retain operator only takes default NDArray as its index array"; + if (inputs[sr::kArr].storage_type() == kRowSparseStorage) { + NDArray output_nd = outputs[sr::kOut]; + SparseRetainOpForwardRspImpl(ctx.get_stream(), inputs[sr::kArr], + inputs[sr::kIdx].data(), req[sr::kOut], &output_nd); + } else { + LOG(FATAL) << "sparse_retain op only supports row-sparse ndarrays as input"; + } +} + template struct SparseRetainRspGradKernel { template diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 577868b5cebf..bfcdc5c404a1 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -189,7 +189,8 @@ def check_sparse_retain(shape, density): data = mx.symbol.Variable('data') idx = mx.symbol.Variable('indices') sym = mx.sym.sparse_retain(data=data, indices=idx) - check_numeric_gradient(sym, [rsp, indices], grad_nodes=['data'], grad_stype_dict={'data': 'row_sparse'}) + check_numeric_gradient(sym, [rsp, indices], grad_nodes=['data'], + grad_stype_dict={'data': 'row_sparse'}) shape = rand_shape_2d() shape_3d = rand_shape_3d()