Skip to content

Commit

Permalink
Fix compile
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Aug 2, 2017
1 parent bc3ae9b commit d931310
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/kvstore/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <thread>
#include "mxnet/ndarray.h"
#include "../ndarray/ndarray_function.h"
#include "../operator/tensor/indexing_op.h"
#include "../operator/tensor/sparse_retain-inl.h"
namespace mxnet {
namespace kvstore {
/**
Expand Down
61 changes: 35 additions & 26 deletions src/operator/tensor/sparse_retain-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,45 +179,33 @@ struct SparseRetainRspRowBlockKernel {
};

template<typename xpu>
void SparseRetainOpForwardEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK_EQ(req[sr::kOut], kWriteTo) << "sparse_retain only supports req=\'write\'";

CHECK_EQ(inputs[sr::kArr].storage_type(), kRowSparseStorage)
<< "sparse_retain operator only takes row sparse NDArray as input";
CHECK_EQ(inputs[sr::kIdx].storage_type(), kDefaultStorage)
<< "sparse_retain operator only takes default NDArray as its index array";
CHECK_EQ(outputs[sr::kOut].storage_type(), kRowSparseStorage)
<< "sparse_retain operator only outputs row sparse NDArray";
void SparseRetainOpForwardRspImpl(mshadow::Stream<xpu> *s,
const NDArray& input_nd,
const TBlob& idx_data,
const OpReqType req,
NDArray* output_nd) {
CHECK_EQ(input_nd.storage_type(), kRowSparseStorage)
<< "SparseRetainOpForwardRspImpl operator only takes row sparse NDArray as input";
CHECK_EQ(output_nd->storage_type(), kRowSparseStorage)
<< "SparseRetainOpForwardRspImpl operator only outputs row sparse NDArray";

const NDArray& input_nd = inputs[sr::kArr];
const TBlob idx_data = inputs[sr::kIdx].data();

if (req[sr::kOut] == kNullOp
if (req == kNullOp
|| !input_nd.storage_initialized()
|| idx_data.Size() == 0U
|| input_nd.shape()[0] == 0) {
FillComputeZerosEx<xpu>(attrs, ctx, {}, req, outputs);
FillZerosRspImpl(s, output_nd);
return;
}

const TBlob input_data = input_nd.data();
const TBlob input_idx = input_nd.aux_data(rowsparse::kIdx);

NDArray output_nd = outputs[sr::kOut];
output_nd.CheckAndAlloc({mshadow::Shape1(idx_data.Size())});
TBlob output_data = output_nd.data();
TBlob output_idx = output_nd.aux_data(rowsparse::kIdx);
output_nd->CheckAndAlloc({mshadow::Shape1(idx_data.Size())});
TBlob output_data = output_nd->data();
TBlob output_idx = output_nd->aux_data(rowsparse::kIdx);
const auto row_length = input_data.shape_.ProdShape(1, input_data.shape_.ndim());

using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(output_data.type_flag_, DType, { // output data type
Kernel<set_zero, xpu>::Launch(s, output_data.Size(), output_data.dptr<DType>());
MSHADOW_IDX_TYPE_SWITCH(output_idx.type_flag_, RType, { // row index data type
Expand All @@ -239,6 +227,27 @@ void SparseRetainOpForwardEx(const nnvm::NodeAttrs& attrs,
});
}

template<typename xpu>
void SparseRetainOpForwardEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK_EQ(req[sr::kOut], kWriteTo) << "sparse_retain only supports req=\'write\'";
CHECK_EQ(inputs[sr::kIdx].storage_type(), kDefaultStorage)
<< "sparse_retain operator only takes default NDArray as its index array";
if (inputs[sr::kArr].storage_type() == kRowSparseStorage) {
NDArray output_nd = outputs[sr::kOut];
SparseRetainOpForwardRspImpl<xpu>(ctx.get_stream<xpu>(), inputs[sr::kArr],
inputs[sr::kIdx].data(), req[sr::kOut], &output_nd);
} else {
LOG(FATAL) << "sparse_retain op only supports row-sparse ndarrays as input";
}
}

template<int req>
struct SparseRetainRspGradKernel {
template<typename DType, typename RType, typename IType>
Expand Down
3 changes: 2 additions & 1 deletion tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def check_sparse_retain(shape, density):
data = mx.symbol.Variable('data')
idx = mx.symbol.Variable('indices')
sym = mx.sym.sparse_retain(data=data, indices=idx)
check_numeric_gradient(sym, [rsp, indices], grad_nodes=['data'], grad_stype_dict={'data': 'row_sparse'})
check_numeric_gradient(sym, [rsp, indices], grad_nodes=['data'],
grad_stype_dict={'data': 'row_sparse'})

shape = rand_shape_2d()
shape_3d = rand_shape_3d()
Expand Down

0 comments on commit d931310

Please sign in to comment.