-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[sparse] slice for csr on two dimensions, cpu implementation #8331
Changes from 4 commits
3733381
b680fac
b1ba370
ec07bf4
688c212
1fb751c
72b6d65
50b4eb4
2c3849b
086ba4a
d7c99ec
9b16b69
5f50690
3e1445c
0988a3b
a63e6e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -927,6 +927,140 @@ void SliceAxis(const nnvm::NodeAttrs& attrs, | |
} | ||
} | ||
|
||
inline bool SliceAxisForwardInferStorageType(const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int>* in_attrs, | ||
std::vector<int>* out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 1); | ||
CHECK_EQ(out_attrs->size(), 1); | ||
const SliceAxisParam& param = nnvm::get<SliceAxisParam>(attrs.parsed); | ||
const auto& in_stype = in_attrs->at(0); | ||
auto& out_stype = out_attrs->at(0); | ||
bool dispatched = false; | ||
const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask; | ||
const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback : | ||
DispatchMode::kFComputeEx; | ||
if (!dispatched && in_stype == kDefaultStorage) { | ||
dispatched = storage_type_assign(&out_stype, kDefaultStorage, | ||
dispatch_mode, DispatchMode::kFCompute); | ||
} | ||
if (!dispatched && in_stype == kCSRStorage && param.axis <= 1) { | ||
dispatched = storage_type_assign(&out_stype, kCSRStorage, | ||
dispatch_mode, dispatch_ex); | ||
} | ||
if (!dispatched) { | ||
dispatch_fallback(out_attrs, dispatch_mode); | ||
} | ||
if (*dispatch_mode == DispatchMode::kFComputeFallback) { | ||
LogStorageFallback(attrs, dev_mask, in_attrs, out_attrs); | ||
} | ||
|
||
return true; | ||
} | ||
|
||
struct SliceAxisOneCsrAssign { | ||
template<typename IType, typename RType, typename DType> | ||
MSHADOW_XINLINE static void Map(int i, IType* out_idx, DType* out_data, | ||
const RType* out_indptr, const IType* in_idx, | ||
const DType* in_data, const RType* in_indptr, | ||
int begin, int end) { | ||
RType ind = out_indptr[i]; | ||
for (int j=in_indptr[i]; j < in_indptr[i+1]; j++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: j = in_indptr[i] |
||
if (in_idx[j] >= begin && in_idx[j] < end) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we put a break if in_idx[j] >= end since indices are sorted per row? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's also possible to binary search the lower_bound for begin, although not necessarily faster.. |
||
out_idx[ind] = in_idx[j] - static_cast<IType>(begin); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's perform |
||
out_data[ind] = in_data[j]; | ||
ind++; | ||
} | ||
} | ||
} | ||
}; | ||
|
||
template<typename xpu> | ||
void SliceAxisOneCsrImpl(const SliceAxisParam ¶m, const OpContext& ctx, | ||
const NDArray &in, OpReqType req, const NDArray &out) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: indentation |
||
using namespace mshadow; | ||
using namespace mxnet_op; | ||
using namespace csr; | ||
CHECK((std::is_same<xpu, cpu>::value)) << "SliceAxis for CSR input only implemented for CPU"; | ||
if (req == kNullOp) return; | ||
CHECK_NE(req, kAddTo) << "kAddTo for SliceAxis on CSR input is not supported"; | ||
CHECK_NE(req, kWriteInplace) << "kWriteInplace for SliceAxis on CSR input is not supported"; | ||
int axis, begin, end; | ||
GetSliceAxisParams(param, in.shape(), &axis, &begin, &end); | ||
int indptr_len = in.shape()[0] + 1; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len)); | ||
// assume idx indptr share the same type | ||
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), RType, { | ||
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIdx), IType, { | ||
MSHADOW_TYPE_SWITCH(in.dtype(), DType, { | ||
RType *in_indptr = in.aux_data(kIndPtr).dptr<RType>(); | ||
IType *in_idx = in.aux_data(kIdx).dptr<IType>(); | ||
DType *in_data = in.data().dptr<DType>(); | ||
// retrieve nnz (CPU implementation) | ||
RType *out_indptr = out.aux_data(kIndPtr).dptr<RType>(); | ||
int nnz = 0; | ||
out_indptr[0] = 0; | ||
for (int i=0; i < indptr_len - 1; i++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also use nnvm::dim_t for i and j |
||
out_indptr[i+1] = out_indptr[i]; | ||
for (int j=in_indptr[i]; j < in_indptr[i+1]; j++) { | ||
if (in_idx[j] >= begin && in_idx[j] < end) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. continue if in_idx[j] >= end instead of scanning the rest, since indices are sorted per row? |
||
out_indptr[i+1]++; | ||
nnz++; | ||
} | ||
} | ||
} | ||
out.CheckAndAllocAuxData(kIdx, Shape1(nnz)); | ||
out.CheckAndAllocData(Shape1(nnz)); | ||
IType *out_idx = out.aux_data(kIdx).dptr<IType>(); | ||
DType *out_data = out.data().dptr<DType>(); | ||
|
||
Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
Kernel<SliceAxisOneCsrAssign, xpu>::Launch(s, indptr_len-1, out_idx, out_data, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it work when nnz = 0? Is that tested? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. If |
||
out_indptr, in_idx, in_data, in_indptr, | ||
begin, end); | ||
}); | ||
}); | ||
}); | ||
} | ||
|
||
template<typename xpu> | ||
void SliceAxisZeroCsrImpl(const SliceAxisParam ¶m, const OpContext& ctx, | ||
const NDArray &in, OpReqType req, const NDArray &out) { | ||
int axis, begin, end; | ||
GetSliceAxisParams(param, in.shape(), &axis, &begin, &end); | ||
SliceParam slice_param; | ||
slice_param.begin[0] = begin; | ||
slice_param.end[0] = end; | ||
SliceCsrImpl<xpu>(slice_param, ctx, in, req, out); | ||
} | ||
|
||
template<typename xpu> | ||
void SliceAxisEx(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: indentation |
||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<NDArray>& outputs) { | ||
CHECK_EQ(inputs.size(), 1); | ||
CHECK_EQ(outputs.size(), 1); | ||
|
||
const SliceAxisParam& param = nnvm::get<SliceAxisParam>(attrs.parsed); | ||
auto in_stype = inputs[0].storage_type(); | ||
CHECK_NE(in_stype, kDefaultStorage) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can remove this check and print |
||
<< "SliceAxisEx is not expected to execute for input with default storage type"; | ||
if (in_stype == kCSRStorage) { | ||
if (param.axis == 0) { | ||
SliceAxisZeroCsrImpl<xpu>(param, ctx, inputs[0], req[0], outputs[0]); | ||
} else if (param.axis == 1) { | ||
SliceAxisOneCsrImpl<xpu>(param, ctx, inputs[0], req[0], outputs[0]); | ||
} else { | ||
LOG(FATAL) << "CSRNDArray is only for 2-D shape"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it fail with negative axis? I think GetSliceAxisParams already handles negative axis for you |
||
} | ||
} else { | ||
LOG(FATAL) << "SliceAxisEx not implemented for storage type" << in_stype; | ||
} | ||
} | ||
|
||
// Backward pass of broadcast over the given axis | ||
template<typename xpu> | ||
void SliceAxisGrad_(const nnvm::NodeAttrs& attrs, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -363,6 +363,8 @@ Examples:: | |
.set_attr_parser(ParamParser<SliceAxisParam>) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should add |
||
.set_attr<nnvm::FInferShape>("FInferShape", SliceAxisShape) | ||
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) | ||
.set_attr<FInferStorageType>("FInferStorageType", SliceAxisForwardInferStorageType) | ||
.set_attr<FComputeEx>("FComputeEx<cpu>", SliceAxisEx<cpu>) | ||
.set_attr<FCompute>("FCompute<cpu>", SliceAxis<cpu>) | ||
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_slice_axis"}) | ||
.add_argument("data", "NDArray-or-Symbol", "Source input") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,9 +127,27 @@ def check_slice_nd_csr_fallback(shape): | |
result_dense = mx.nd.slice(mx.nd.array(A2), begin=(start, shape[1] - 1), end=(end + 1, shape[1])) | ||
assert same(result_dense.asnumpy(), result.asnumpy()) | ||
|
||
shape = (rnd.randint(2, 10), rnd.randint(1, 10)) | ||
def check_sparse_nd_csr_slice_axis(shape): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's also add some test cases for negative axis There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added. |
||
stype = 'csr' | ||
A, _ = rand_sparse_ndarray(shape, stype) | ||
A2 = A.asnumpy() | ||
start = rnd.randint(0, shape[0] - 1) | ||
end = rnd.randint(start + 1, shape[0]) | ||
assert same(mx.nd.slice_axis(A, begin=start, end=end, axis=0).asnumpy(), | ||
A2[start:end]) | ||
assert same(mx.nd.slice_axis(A, begin=start-shape[0], end=end, axis=0).asnumpy(), | ||
A2[start:end]) | ||
start = rnd.randint(0, shape[1] - 1) | ||
end = rnd.randint(start + 1, shape[1]) | ||
assert same(mx.nd.slice_axis(A, begin=start, end=end, axis=1).asnumpy(), | ||
A2[:, start:end]) | ||
assert same(mx.nd.slice_axis(A, begin=start-shape[1], end=end, axis=1).asnumpy(), | ||
A2[:, start:end]) | ||
|
||
shape = (rnd.randint(2, 10), rnd.randint(2, 10)) | ||
check_sparse_nd_csr_slice(shape) | ||
check_slice_nd_csr_fallback(shape) | ||
check_sparse_nd_csr_slice_axis(shape) | ||
|
||
|
||
def test_sparse_nd_equal(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for & if in_stype is not changed.