Skip to content

Commit

Permalink
[sparse] slice for csr on two dimensions, cpu implementation (apache#…
Browse files Browse the repository at this point in the history
…8331)

* slice axis for csr (cpu impl)

* fix indice bug and use kernel launch

* small fix

* misc updates to address comments

* fix type

* csr slice

* unittest

* fix lint

* address comments

* return csr zeros before kernel launch if nnz=0

* fix
  • Loading branch information
ZiyueHuang authored and Olivier committed Nov 9, 2017
1 parent a63a42a commit 81ea103
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 60 deletions.
193 changes: 160 additions & 33 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,7 @@ inline bool SliceForwardInferStorageType(const nnvm::NodeAttrs& attrs,
dispatch_mode, DispatchMode::kFCompute);
}

if (!dispatched && in_stype == kCSRStorage && param.begin.ndim() <= 1 &&
param.end.ndim() <= 1) {
if (!dispatched && in_stype == kCSRStorage) {
dispatched = storage_type_assign(&out_stype, kCSRStorage,
dispatch_mode, dispatch_ex);
}
Expand Down Expand Up @@ -551,48 +550,43 @@ void SliceCsrIndPtrImpl(const int begin, const int end, RunContext ctx,
}

/*
* Slice a CSR NDArray
* Slice a CSR NDArray for first dimension
* Only implemented for CPU
*/
template<typename xpu>
void SliceCsrImpl(const SliceParam &param, const OpContext& ctx,
const NDArray &in, OpReqType req, const NDArray &out) {
void SliceDimOneCsrImpl(const TShape &begin, const TShape &end, const OpContext& ctx,
const NDArray &in, const NDArray &out) {
using namespace mshadow;
using namespace mxnet_op;
using namespace csr;
CHECK((std::is_same<xpu, cpu>::value)) << "Slice for CSR input only implemented for CPU";
if (req == kNullOp) return;
CHECK_NE(req, kAddTo) << "kAddTo for Slice on CSR input is not supported";
CHECK_NE(req, kWriteInplace) << "kWriteInplace for Slice on CSR input is not supported";
const TShape ishape = in.shape();
int begin = *param.begin[0];
if (begin < 0) begin += ishape[0];
int end = *param.end[0];
if (end < 0) end += ishape[0];
int indptr_len = end - begin + 1;
CHECK((std::is_same<xpu, cpu>::value)) << "SliceDimOneCsrImpl is only implemented for CPU";
nnvm::dim_t begin_row = begin[0];
nnvm::dim_t end_row = end[0];
nnvm::dim_t indptr_len = end_row - begin_row + 1;
out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len));
if (!in.storage_initialized()) {
out.set_aux_shape(kIndPtr, Shape1(0));
return;
}
// assume idx indptr share the same type
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), RType, {
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIdx), IType, {
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
auto in_indptr = in.aux_data(kIndPtr).dptr<RType>();
auto out_indptr = out.aux_data(kIndPtr).dptr<RType>();
SliceCsrIndPtrImpl<cpu, RType>(begin, end, ctx.run_ctx, in_indptr, out_indptr);
RType* in_indptr = in.aux_data(kIndPtr).dptr<RType>();
RType* out_indptr = out.aux_data(kIndPtr).dptr<RType>();
SliceCsrIndPtrImpl<cpu, RType>(begin_row, end_row, ctx.run_ctx, in_indptr, out_indptr);

// retrieve nnz (CPU implementation)
int nnz = out_indptr[indptr_len - 1];
// return csr zeros if nnz = 0
if (nnz == 0) {
out.set_aux_shape(kIdx, Shape1(0));
return;
}
// copy indices and values
out.CheckAndAllocAuxData(kIdx, Shape1(nnz));
out.CheckAndAllocData(Shape1(nnz));
auto in_idx = in.aux_data(kIdx).dptr<IType>();
auto out_idx = out.aux_data(kIdx).dptr<IType>();
auto in_data = in.data().dptr<DType>();
auto out_data = out.data().dptr<DType>();
int offset = in_indptr[begin];
IType* in_idx = in.aux_data(kIdx).dptr<IType>();
IType* out_idx = out.aux_data(kIdx).dptr<IType>();
DType* in_data = in.data().dptr<DType>();
DType* out_data = out.data().dptr<DType>();
int offset = in_indptr[begin_row];
// this is also a CPU-only implementation
memcpy(out_idx, in_idx + offset, nnz * sizeof(IType));
memcpy(out_data, in_data + offset, nnz * sizeof(DType));
Expand All @@ -601,18 +595,151 @@ void SliceCsrImpl(const SliceParam &param, const OpContext& ctx,
});
}

/*!
* \brief slice a CSRNDArray for two dimensions
*/
struct SliceDimTwoCsrAssign {
/*!
* \brief This function slices a CSRNDArray on axis one between begin_col and end_col
* \param i loop index
* \param out_idx output csr ndarray column indices
* \param out_data output csr ndarray data
* \param out_indptr output csr ndarray row index pointer
* \param in_idx input csr ndarray column indices
* \param in_data input csr ndarray data
* \param in_indptr input csr ndarray row index pointer
* \param begin_col begin column indice
* \param end_col end column indice
*/
template<typename IType, typename RType, typename DType>
MSHADOW_XINLINE static void Map(int i,
IType* out_idx, DType* out_data,
const RType* out_indptr,
const IType* in_idx, const DType* in_data,
const RType* in_indptr,
const int begin_col, const int end_col) {
RType ind = out_indptr[i];
for (RType j = in_indptr[i]; j < in_indptr[i+1]; j++) {
// indices of CSRNDArray are in ascending order per row
if (in_idx[j] >= end_col) {
break;
} else if (in_idx[j] >= begin_col) {
out_idx[ind] = in_idx[j] - begin_col;
out_data[ind] = in_data[j];
ind++;
}
}
}
};

/*
* Slice a CSR NDArray for two dimensions
* Only implemented for CPU
*/
template<typename xpu>
void SliceDimTwoCsrImpl(const TShape &begin, const TShape &end, const OpContext& ctx,
const NDArray &in, const NDArray &out) {
using namespace mshadow;
using namespace mxnet_op;
using namespace csr;
CHECK((std::is_same<xpu, cpu>::value)) << "SliceDimTwoCsrImpl is only implemented for CPU";
nnvm::dim_t begin_row = begin[0], end_row = end[0];
nnvm::dim_t begin_col = begin[1], end_col = end[1];
nnvm::dim_t indptr_len = end_row - begin_row + 1;
out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len));
// assume idx indptr share the same type
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), RType, {
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIdx), IType, {
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
RType *in_indptr = in.aux_data(kIndPtr).dptr<RType>();
IType *in_idx = in.aux_data(kIdx).dptr<IType>();
DType *in_data = in.data().dptr<DType>();
// retrieve nnz (CPU implementation)
RType *out_indptr = out.aux_data(kIndPtr).dptr<RType>();
int nnz = 0;
out_indptr[0] = 0;
// loop through indptr array and corresponding indices to count for nnz
for (nnvm::dim_t i = 0; i < indptr_len - 1; i++) {
out_indptr[i+1] = out_indptr[i];
for (RType j = in_indptr[i + begin_row];
j < in_indptr[i + begin_row + 1]; j++) {
// indices of CSRNDArray are in ascending order per row
if (in_idx[j] >= end_col) {
break;
} else if (in_idx[j] >= begin_col) {
out_indptr[i+1]++;
nnz++;
}
}
}
// returns zeros in csr format if nnz = 0
if (nnz == 0) {
out.set_aux_shape(kIdx, Shape1(0));
return;
}
out.CheckAndAllocAuxData(kIdx, Shape1(nnz));
out.CheckAndAllocData(Shape1(nnz));
IType *out_idx = out.aux_data(kIdx).dptr<IType>();
DType *out_data = out.data().dptr<DType>();

Stream<xpu> *s = ctx.get_stream<xpu>();
Kernel<SliceDimTwoCsrAssign, xpu>::Launch(s, indptr_len - 1, out_idx, out_data,
out_indptr, in_idx, in_data,
in_indptr + begin_row,
begin_col, end_col);
});
});
});
}


template<typename xpu>
void SliceCsrImpl(const SliceParam &param, const OpContext& ctx,
const NDArray &in, OpReqType req, const NDArray &out) {
CHECK((std::is_same<xpu, cpu>::value)) << "Slice for CSR input only implemented for CPU";
if (req == kNullOp) return;
CHECK_NE(req, kAddTo) << "kAddTo for Slice on CSR input is not supported";
CHECK_NE(req, kWriteInplace) << "kWriteInplace for Slice on CSR input is not supported";

const TShape ishape = in.shape();
const TShape oshape = out.shape();

uint32_t N = ishape.ndim();
TShape begin(N), end(N);
for (uint32_t i = 0; i < N; ++i) {
int s = 0;
if (param.begin[i]) {
s = *param.begin[i];
if (s < 0) s += ishape[i];
}
begin[i] = s;
end[i] = s + oshape[i];
}
switch (N) {
case 1: {
SliceDimOneCsrImpl<xpu>(begin, end, ctx, in, out);
break;
}
case 2: {
SliceDimTwoCsrImpl<xpu>(begin, end, ctx, in, out);
break;
}
default:
LOG(FATAL) << "CSR is only for 2-D shape";
break;
}
}

template<typename xpu>
void SliceEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1);
CHECK_EQ(outputs.size(), 1);
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
auto in_stype = inputs[0].storage_type();
CHECK_NE(in_stype, kDefaultStorage)
<< "SliceEx is not expected to execute for input with default storage type";
if (in_stype == kCSRStorage) {
SliceCsrImpl<xpu>(param, ctx, inputs[0], req[0], outputs[0]);
} else {
Expand Down
6 changes: 4 additions & 2 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,10 @@ and ``end=(e_1, e_2, ... e_n)`` indices will result in an array with the shape
The resulting array's *k*-th dimension contains elements
from the *k*-th dimension of the input array with the open range ``[b_k, e_k)``.
For an input array of non-default storage type(e.g. `csr` or `row_sparse`), it only supports
slicing on the first dimension.
The storage type of ``slice`` output depends on storage types of inputs
- slice(csr) = csr
- otherwise, ``slice`` generates output with default storage
Example::
Expand Down
49 changes: 24 additions & 25 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,32 +105,31 @@ def check_sparse_nd_setitem(stype, shape, dst):


def test_sparse_nd_slice():
def check_sparse_nd_csr_slice(shape):
stype = 'csr'
A, _ = rand_sparse_ndarray(shape, stype)
A2 = A.asnumpy()
start = rnd.randint(0, shape[0] - 1)
end = rnd.randint(start + 1, shape[0])
assert same(A[start:end].asnumpy(), A2[start:end])
assert same(A[start - shape[0]:end].asnumpy(), A2[start:end])
assert same(A[start:].asnumpy(), A2[start:])
assert same(A[:end].asnumpy(), A2[:end])
ind = rnd.randint(-shape[0], shape[0] - 1)
assert same(A[ind].asnumpy(), A2[ind][np.newaxis, :])
shape = (rnd.randint(2, 10), rnd.randint(2, 10))
stype = 'csr'
A, _ = rand_sparse_ndarray(shape, stype)
A2 = A.asnumpy()
start = rnd.randint(0, shape[0] - 1)
end = rnd.randint(start + 1, shape[0])
assert same(A[start:end].asnumpy(), A2[start:end])
assert same(A[start - shape[0]:end].asnumpy(), A2[start:end])
assert same(A[start:].asnumpy(), A2[start:])
assert same(A[:end].asnumpy(), A2[:end])
ind = rnd.randint(-shape[0], shape[0] - 1)
assert same(A[ind].asnumpy(), A2[ind][np.newaxis, :])

start_col = rnd.randint(0, shape[1] - 1)
end_col = rnd.randint(start_col + 1, shape[1])
result = mx.nd.slice(A, begin=(start, start_col), end=(end, end_col))
result_dense = mx.nd.slice(mx.nd.array(A2), begin=(start, start_col), end=(end, end_col))
assert same(result_dense.asnumpy(), result.asnumpy())

def check_slice_nd_csr_fallback(shape):
stype = 'csr'
A, _ = rand_sparse_ndarray(shape, stype)
A2 = A.asnumpy()
start = rnd.randint(0, shape[0] - 1)
end = rnd.randint(start + 1, shape[0])
result = mx.nd.sparse.slice(A, begin=(start, shape[1] - 1), end=(end + 1, shape[1]))
result_dense = mx.nd.slice(mx.nd.array(A2), begin=(start, shape[1] - 1), end=(end + 1, shape[1]))
assert same(result_dense.asnumpy(), result.asnumpy())

shape = (rnd.randint(2, 10), rnd.randint(1, 10))
check_sparse_nd_csr_slice(shape)
check_slice_nd_csr_fallback(shape)
A = mx.nd.sparse.zeros('csr', shape)
A2 = A.asnumpy()
assert same(A[start:end].asnumpy(), A2[start:end])
result = mx.nd.slice(A, begin=(start, start_col), end=(end, end_col))
result_dense = mx.nd.slice(mx.nd.array(A2), begin=(start, start_col), end=(end, end_col))
assert same(result_dense.asnumpy(), result.asnumpy())


def test_sparse_nd_equal():
Expand Down

0 comments on commit 81ea103

Please sign in to comment.