Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[sparse] slice for csr on two dimensions, cpu implementation #8331

Merged
merged 16 commits into from
Nov 8, 2017
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,140 @@ void SliceAxis(const nnvm::NodeAttrs& attrs,
}
}

inline bool SliceAxisForwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 1);
const SliceAxisParam& param = nnvm::get<SliceAxisParam>(attrs.parsed);
const auto& in_stype = in_attrs->at(0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for & if in_stype is not changed.

auto& out_stype = out_attrs->at(0);
bool dispatched = false;
const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback :
DispatchMode::kFComputeEx;
if (!dispatched && in_stype == kDefaultStorage) {
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
if (!dispatched && in_stype == kCSRStorage && param.axis <= 1) {
dispatched = storage_type_assign(&out_stype, kCSRStorage,
dispatch_mode, dispatch_ex);
}
if (!dispatched) {
dispatch_fallback(out_attrs, dispatch_mode);
}
if (*dispatch_mode == DispatchMode::kFComputeFallback) {
LogStorageFallback(attrs, dev_mask, in_attrs, out_attrs);
}

return true;
}

struct SliceAxisOneCsrAssign {
template<typename IType, typename RType, typename DType>
MSHADOW_XINLINE static void Map(int i, IType* out_idx, DType* out_data,
const RType* out_indptr, const IType* in_idx,
const DType* in_data, const RType* in_indptr,
int begin, int end) {
RType ind = out_indptr[i];
for (int j=in_indptr[i]; j < in_indptr[i+1]; j++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: j = in_indptr[i]

if (in_idx[j] >= begin && in_idx[j] < end) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we put a break if in_idx[j] >= end since indices are sorted per row?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also possible to binary search the lower_bound for begin, although not necessarily faster..

out_idx[ind] = in_idx[j] - static_cast<IType>(begin);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's perform static_cast<IType>(begin) once and cache the result

out_data[ind] = in_data[j];
ind++;
}
}
}
};

template<typename xpu>
void SliceAxisOneCsrImpl(const SliceAxisParam &param, const OpContext& ctx,
const NDArray &in, OpReqType req, const NDArray &out) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indentation

using namespace mshadow;
using namespace mxnet_op;
using namespace csr;
CHECK((std::is_same<xpu, cpu>::value)) << "SliceAxis for CSR input only implemented for CPU";
if (req == kNullOp) return;
CHECK_NE(req, kAddTo) << "kAddTo for SliceAxis on CSR input is not supported";
CHECK_NE(req, kWriteInplace) << "kWriteInplace for SliceAxis on CSR input is not supported";
int axis, begin, end;
GetSliceAxisParams(param, in.shape(), &axis, &begin, &end);
int indptr_len = in.shape()[0] + 1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use nnvm::dim_t (int64_t) instead because shape[i] is 64 bits

out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len));
// assume idx indptr share the same type
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), RType, {
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIdx), IType, {
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
RType *in_indptr = in.aux_data(kIndPtr).dptr<RType>();
IType *in_idx = in.aux_data(kIdx).dptr<IType>();
DType *in_data = in.data().dptr<DType>();
// retrieve nnz (CPU implementation)
RType *out_indptr = out.aux_data(kIndPtr).dptr<RType>();
int nnz = 0;
out_indptr[0] = 0;
for (int i=0; i < indptr_len - 1; i++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also use nnvm::dim_t for i and j

out_indptr[i+1] = out_indptr[i];
for (int j=in_indptr[i]; j < in_indptr[i+1]; j++) {
if (in_idx[j] >= begin && in_idx[j] < end) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

continue if in_idx[j] >= end instead of scanning the rest, since indices are sorted per row?

out_indptr[i+1]++;
nnz++;
}
}
}
out.CheckAndAllocAuxData(kIdx, Shape1(nnz));
out.CheckAndAllocData(Shape1(nnz));
IType *out_idx = out.aux_data(kIdx).dptr<IType>();
DType *out_data = out.data().dptr<DType>();

Stream<xpu> *s = ctx.get_stream<xpu>();
Kernel<SliceAxisOneCsrAssign, xpu>::Launch(s, indptr_len-1, out_idx, out_data,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work when nnz = 0? Is that tested?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. If nnz=0, kernel launch will return immediately. Test for slice_axis(zeros, ...) is added.

out_indptr, in_idx, in_data, in_indptr,
begin, end);
});
});
});
}

template<typename xpu>
void SliceAxisZeroCsrImpl(const SliceAxisParam &param, const OpContext& ctx,
const NDArray &in, OpReqType req, const NDArray &out) {
int axis, begin, end;
GetSliceAxisParams(param, in.shape(), &axis, &begin, &end);
SliceParam slice_param;
slice_param.begin[0] = begin;
slice_param.end[0] = end;
SliceCsrImpl<xpu>(slice_param, ctx, in, req, out);
}

template<typename xpu>
void SliceAxisEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indentation

const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1);
CHECK_EQ(outputs.size(), 1);

const SliceAxisParam& param = nnvm::get<SliceAxisParam>(attrs.parsed);
auto in_stype = inputs[0].storage_type();
CHECK_NE(in_stype, kDefaultStorage)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can remove this check and print operator_info(ctx, ..) in line 1060

<< "SliceAxisEx is not expected to execute for input with default storage type";
if (in_stype == kCSRStorage) {
if (param.axis == 0) {
SliceAxisZeroCsrImpl<xpu>(param, ctx, inputs[0], req[0], outputs[0]);
} else if (param.axis == 1) {
SliceAxisOneCsrImpl<xpu>(param, ctx, inputs[0], req[0], outputs[0]);
} else {
LOG(FATAL) << "CSRNDArray is only for 2-D shape";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it fail with negative axis? I think GetSliceAxisParams already handles negative axis for you

}
} else {
LOG(FATAL) << "SliceAxisEx not implemented for storage type" << in_stype;
}
}

// Backward pass of broadcast over the given axis
template<typename xpu>
void SliceAxisGrad_(const nnvm::NodeAttrs& attrs,
Expand Down
2 changes: 2 additions & 0 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ Examples::
.set_attr_parser(ParamParser<SliceAxisParam>)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add MXNET_ADD_SPARSE_OP_ALIAS to this op. We should also update the description for output storage type (see sparse ops like add_n)

.set_attr<nnvm::FInferShape>("FInferShape", SliceAxisShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferStorageType>("FInferStorageType", SliceAxisForwardInferStorageType)
.set_attr<FComputeEx>("FComputeEx<cpu>", SliceAxisEx<cpu>)
.set_attr<FCompute>("FCompute<cpu>", SliceAxis<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_slice_axis"})
.add_argument("data", "NDArray-or-Symbol", "Source input")
Expand Down
20 changes: 19 additions & 1 deletion tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,27 @@ def check_slice_nd_csr_fallback(shape):
result_dense = mx.nd.slice(mx.nd.array(A2), begin=(start, shape[1] - 1), end=(end + 1, shape[1]))
assert same(result_dense.asnumpy(), result.asnumpy())

shape = (rnd.randint(2, 10), rnd.randint(1, 10))
def check_sparse_nd_csr_slice_axis(shape):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also add some test cases for negative axis

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

stype = 'csr'
A, _ = rand_sparse_ndarray(shape, stype)
A2 = A.asnumpy()
start = rnd.randint(0, shape[0] - 1)
end = rnd.randint(start + 1, shape[0])
assert same(mx.nd.slice_axis(A, begin=start, end=end, axis=0).asnumpy(),
A2[start:end])
assert same(mx.nd.slice_axis(A, begin=start-shape[0], end=end, axis=0).asnumpy(),
A2[start:end])
start = rnd.randint(0, shape[1] - 1)
end = rnd.randint(start + 1, shape[1])
assert same(mx.nd.slice_axis(A, begin=start, end=end, axis=1).asnumpy(),
A2[:, start:end])
assert same(mx.nd.slice_axis(A, begin=start-shape[1], end=end, axis=1).asnumpy(),
A2[:, start:end])

shape = (rnd.randint(2, 10), rnd.randint(2, 10))
check_sparse_nd_csr_slice(shape)
check_slice_nd_csr_fallback(shape)
check_sparse_nd_csr_slice_axis(shape)


def test_sparse_nd_equal():
Expand Down