-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[sparse] slice for csr on two dimensions, cpu implementation #8331
Changes from 11 commits
3733381
b680fac
b1ba370
ec07bf4
688c212
1fb751c
72b6d65
50b4eb4
2c3849b
086ba4a
d7c99ec
9b16b69
5f50690
3e1445c
0988a3b
a63e6e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -443,8 +443,7 @@ inline bool SliceForwardInferStorageType(const nnvm::NodeAttrs& attrs, | |
dispatch_mode, DispatchMode::kFCompute); | ||
} | ||
|
||
if (!dispatched && in_stype == kCSRStorage && param.begin.ndim() <= 1 && | ||
param.end.ndim() <= 1) { | ||
if (!dispatched && in_stype == kCSRStorage) { | ||
dispatched = storage_type_assign(&out_stype, kCSRStorage, | ||
dispatch_mode, dispatch_ex); | ||
} | ||
|
@@ -551,37 +550,27 @@ void SliceCsrIndPtrImpl(const int begin, const int end, RunContext ctx, | |
} | ||
|
||
/* | ||
* Slice a CSR NDArray | ||
* Slice a CSR NDArray for first dimension | ||
* Only implemented for CPU | ||
*/ | ||
template<typename xpu> | ||
void SliceCsrImpl(const SliceParam ¶m, const OpContext& ctx, | ||
const NDArray &in, OpReqType req, const NDArray &out) { | ||
void SliceDimOneCsrImpl(const TShape &begin, const TShape &end, const OpContext& ctx, | ||
const NDArray &in, const NDArray &out) { | ||
using namespace mshadow; | ||
using namespace mxnet_op; | ||
using namespace csr; | ||
CHECK((std::is_same<xpu, cpu>::value)) << "Slice for CSR input only implemented for CPU"; | ||
if (req == kNullOp) return; | ||
CHECK_NE(req, kAddTo) << "kAddTo for Slice on CSR input is not supported"; | ||
CHECK_NE(req, kWriteInplace) << "kWriteInplace for Slice on CSR input is not supported"; | ||
const TShape ishape = in.shape(); | ||
int begin = *param.begin[0]; | ||
if (begin < 0) begin += ishape[0]; | ||
int end = *param.end[0]; | ||
if (end < 0) end += ishape[0]; | ||
int indptr_len = end - begin + 1; | ||
CHECK((std::is_same<xpu, cpu>::value)) << "SliceDimOneCsrImpl is only implemented for CPU"; | ||
nnvm::dim_t begin_row = begin[0]; | ||
nnvm::dim_t end_row = end[0]; | ||
nnvm::dim_t indptr_len = end_row - begin_row + 1; | ||
out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len)); | ||
if (!in.storage_initialized()) { | ||
out.set_aux_shape(kIndPtr, Shape1(0)); | ||
return; | ||
} | ||
// assume idx indptr share the same type | ||
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), RType, { | ||
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIdx), IType, { | ||
MSHADOW_TYPE_SWITCH(in.dtype(), DType, { | ||
auto in_indptr = in.aux_data(kIndPtr).dptr<RType>(); | ||
auto out_indptr = out.aux_data(kIndPtr).dptr<RType>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we avoid auto for in_indptr and out_indptr |
||
SliceCsrIndPtrImpl<cpu, RType>(begin, end, ctx.run_ctx, in_indptr, out_indptr); | ||
SliceCsrIndPtrImpl<cpu, RType>(begin_row, end_row, ctx.run_ctx, in_indptr, out_indptr); | ||
|
||
// retrieve nnz (CPU implementation) | ||
int nnz = out_indptr[indptr_len - 1]; | ||
|
@@ -592,7 +581,7 @@ void SliceCsrImpl(const SliceParam ¶m, const OpContext& ctx, | |
auto out_idx = out.aux_data(kIdx).dptr<IType>(); | ||
auto in_data = in.data().dptr<DType>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we avoid auto here and use IType and DType |
||
auto out_data = out.data().dptr<DType>(); | ||
int offset = in_indptr[begin]; | ||
int offset = in_indptr[begin_row]; | ||
// this is also a CPU-only implementation | ||
memcpy(out_idx, in_idx + offset, nnz * sizeof(IType)); | ||
memcpy(out_data, in_data + offset, nnz * sizeof(DType)); | ||
|
@@ -601,18 +590,127 @@ void SliceCsrImpl(const SliceParam ¶m, const OpContext& ctx, | |
}); | ||
} | ||
|
||
// slice a CSRNDArray for two dimensions | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add more documentation for the kernels like this one https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/dot-inl.cuh#L40 |
||
struct SliceDimTwoCsrAssign { | ||
template<typename IType, typename RType, typename DType> | ||
MSHADOW_XINLINE static void Map(int i, IType* out_idx, DType* out_data, | ||
const RType* out_indptr, const IType* in_idx, | ||
const DType* in_data, const RType* in_indptr, | ||
const int begin, const int end) { | ||
RType ind = out_indptr[i]; | ||
for (RType j = in_indptr[i]; j < in_indptr[i+1]; j++) { | ||
if (in_idx[j] >= end) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason to not just do if (in_idx[j] >= begin_col && in_idx < end_col) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indices of csr ndarray is in ascending order per row. So if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was suggesting this change for readability. Also, you would be doing the checks for all in_idx[j] < begin_col which will be avoided with the change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I used this condition, in_idx[j] >= begin_col && in_idx < end_col, at the first time. But according to @eric-haibin-lin 's comments, this logic should be changed to a if/else logic which can jump out of the loop since indices are sorted per row. |
||
break; | ||
} else if (in_idx[j] >= begin) { | ||
out_idx[ind] = in_idx[j] - begin; | ||
out_data[ind] = in_data[j]; | ||
ind++; | ||
} | ||
} | ||
} | ||
}; | ||
|
||
/* | ||
* Slice a CSR NDArray for two dimensions | ||
* Only implemented for CPU | ||
*/ | ||
template<typename xpu> | ||
void SliceDimTwoCsrImpl(const TShape &begin, const TShape &end, const OpContext& ctx, | ||
const NDArray &in, const NDArray &out) { | ||
using namespace mshadow; | ||
using namespace mxnet_op; | ||
using namespace csr; | ||
CHECK((std::is_same<xpu, cpu>::value)) << "SliceDimTwoCsrImpl is only implemented for CPU"; | ||
nnvm::dim_t begin_row = begin[0], end_row = end[0]; | ||
nnvm::dim_t begin_col = begin[1], end_col = end[1]; | ||
nnvm::dim_t indptr_len = end_row - begin_row + 1; | ||
out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len)); | ||
// assume idx indptr share the same type | ||
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), RType, { | ||
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIdx), IType, { | ||
MSHADOW_TYPE_SWITCH(in.dtype(), DType, { | ||
RType *in_indptr = in.aux_data(kIndPtr).dptr<RType>(); | ||
IType *in_idx = in.aux_data(kIdx).dptr<IType>(); | ||
DType *in_data = in.data().dptr<DType>(); | ||
// retrieve nnz (CPU implementation) | ||
RType *out_indptr = out.aux_data(kIndPtr).dptr<RType>(); | ||
int nnz = 0; | ||
out_indptr[0] = 0; | ||
for (nnvm::dim_t i = 0; i < indptr_len - 1; i++) { | ||
out_indptr[i+1] = out_indptr[i]; | ||
for (RType j = in_indptr[i + begin_row]; | ||
j < in_indptr[i + begin_row + 1]; j++) { | ||
if (in_idx[j] >= end_col) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add one line comment for the if, else if logic here. Also why not just if (in_idx[j] >= begin_col && in_idx < end_col) ? |
||
break; | ||
} else if (in_idx[j] >= begin_col) { | ||
out_indptr[i+1]++; | ||
nnz++; | ||
} | ||
} | ||
} | ||
out.CheckAndAllocAuxData(kIdx, Shape1(nnz)); | ||
out.CheckAndAllocData(Shape1(nnz)); | ||
IType *out_idx = out.aux_data(kIdx).dptr<IType>(); | ||
DType *out_data = out.data().dptr<DType>(); | ||
|
||
Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
Kernel<SliceDimTwoCsrAssign, xpu>::Launch(s, indptr_len - 1, out_idx, out_data, | ||
out_indptr, in_idx, in_data, | ||
in_indptr + begin_row, | ||
begin_col, end_col); | ||
}); | ||
}); | ||
}); | ||
} | ||
|
||
|
||
template<typename xpu> | ||
void SliceCsrImpl(const SliceParam ¶m, const OpContext& ctx, | ||
const NDArray &in, OpReqType req, const NDArray &out) { | ||
CHECK((std::is_same<xpu, cpu>::value)) << "Slice for CSR input only implemented for CPU"; | ||
if (req == kNullOp) return; | ||
CHECK_NE(req, kAddTo) << "kAddTo for Slice on CSR input is not supported"; | ||
CHECK_NE(req, kWriteInplace) << "kWriteInplace for Slice on CSR input is not supported"; | ||
|
||
const TShape ishape = in.shape(); | ||
const TShape oshape = out.shape(); | ||
|
||
uint32_t N = ishape.ndim(); | ||
TShape begin(N), end(N); | ||
for (uint32_t i = 0; i < N; ++i) { | ||
int s = 0; | ||
if (param.begin[i]) { | ||
s = *param.begin[i]; | ||
if (s < 0) s += ishape[i]; | ||
} | ||
begin[i] = s; | ||
end[i] = s + oshape[i]; | ||
} | ||
switch (N) { | ||
case 1: { | ||
SliceDimOneCsrImpl<xpu>(begin, end, ctx, in, out); | ||
break; | ||
} | ||
case 2: { | ||
SliceDimTwoCsrImpl<xpu>(begin, end, ctx, in, out); | ||
break; | ||
} | ||
default: | ||
LOG(FATAL) << "CSR is only for 2-D shape"; | ||
break; | ||
} | ||
} | ||
|
||
template<typename xpu> | ||
void SliceEx(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<NDArray>& outputs) { | ||
const OpContext& ctx, | ||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<NDArray>& outputs) { | ||
CHECK_EQ(inputs.size(), 1); | ||
CHECK_EQ(outputs.size(), 1); | ||
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed); | ||
auto in_stype = inputs[0].storage_type(); | ||
CHECK_NE(in_stype, kDefaultStorage) | ||
<< "SliceEx is not expected to execute for input with default storage type"; | ||
if (in_stype == kCSRStorage) { | ||
SliceCsrImpl<xpu>(param, ctx, inputs[0], req[0], outputs[0]); | ||
} else { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -264,7 +264,7 @@ The resulting array's *k*-th dimension contains elements | |
from the *k*-th dimension of the input array with the open range ``[b_k, e_k)``. | ||
|
||
For an input array of non-default storage type(e.g. `csr` or `row_sparse`), it only supports | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think row_sparse is not supported for slice. Let's remove this sentence in the doc.
|
||
slicing on the first dimension. | ||
slicing on the two dimensions for `csr`. | ||
|
||
Example:: | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens here if input is a CSR Array with all zeroes ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your comments. If input is zeros, kernel launch will return immediately. Unittest for zeros input case is added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that still true on GPU, when we add GPU support? This PR is dealing with some bugs for zero inputs for dot operator #8470
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For CSRNDArray,
storage_initialized()
returnaux_shape(0).Size() != 0
, I think it is always true for a valid CSRNDArray except for rank-0 array.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to returning csr zeros immediately if nnz=0.