From 113122bfe839aa5661a198d6974b256797569405 Mon Sep 17 00:00:00 2001 From: Henneking Date: Thu, 10 Aug 2017 11:45:38 -0700 Subject: [PATCH 1/4] change CPU kernel inline directives, data types, and function doc --- src/operator/tensor/cast_storage-inl.h | 172 +++++++++++++++---------- 1 file changed, 102 insertions(+), 70 deletions(-) diff --git a/src/operator/tensor/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h index 7d4f39ff9c58..acb30a9eff2b 100644 --- a/src/operator/tensor/cast_storage-inl.h +++ b/src/operator/tensor/cast_storage-inl.h @@ -43,10 +43,13 @@ namespace op { struct MarkRspRowIdx { // i represents the row index of the tensor data template - MSHADOW_XINLINE static void Map(int i, RType* row_idx, const DType* data, - const index_t row_length) { - index_t j = 0; - index_t offset = i * row_length; + MSHADOW_CINLINE static void Map(int i, + RType* row_idx, + const DType* data, + const nnvm::dim_t row_length) { + using nnvm::dim_t; + dim_t j = 0; + dim_t offset = i * row_length; for (; j < row_length; ++j) { if (data[offset+j] != 0) { break; @@ -69,20 +72,22 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx, NDArray* rsp) { using namespace rowsparse; using namespace mshadow; + using nnvm::dim_t; CHECK(rsp != nullptr); CHECK_EQ(rsp->storage_type(), kRowSparseStorage); CHECK_EQ(dns.shape_, rsp->shape()); mshadow::Stream* s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(rsp->aux_type(kIdx), RType, { // row idx type - const index_t num_rows = dns.shape_[0]; - const index_t row_length = dns.shape_.ProdShape(1, dns.shape_.ndim()); + const dim_t num_rows = dns.shape_[0]; + const dim_t row_length = dns.shape_.ProdShape(1, dns.shape_.ndim()); rsp->CheckAndAllocAuxData(kIdx, Shape1(num_rows)); TBlob row_idx_blob = rsp->aux_data(kIdx); RType* row_idx = row_idx_blob.dptr(); - mxnet_op::Kernel::Launch(s, num_rows, row_idx, - dns.dptr(), row_length); - index_t nnr = 0; + dim_t num_threads = num_rows; + mxnet_op::Kernel::Launch(s, num_threads, + row_idx, dns.dptr(), row_length); + dim_t nnr = 0; nnr = common::ParallelAccumulate(row_idx, num_rows, nnr); rsp->set_aux_shape(kIdx, Shape1(nnr)); if (0 == nnr) return; @@ -91,8 +96,8 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx, rsp->CheckAndAllocData(storage_shape); auto dns_data = dns.get_with_shape(Shape2(num_rows, row_length), s); auto rsp_data = rsp->data().get_with_shape(Shape2(nnr, row_length), s); - size_t idx = 0; - for (index_t i = 0; i < num_rows; ++i) { + dim_t idx = 0; + for (dim_t i = 0; i < num_rows; ++i) { if (row_idx[i] > 0) { row_idx[idx] = i; Copy(rsp_data[idx], dns_data[i], s); @@ -106,12 +111,16 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx, // TODO(haibin) Use memcopy instead will be much faster than assigning each individual element struct CastStorageRspDnsKernel { template - MSHADOW_XINLINE static void Map(int i, const index_t row_length, const IType* idx, - const DType *data, DType* dns) { - auto rid = idx[i]; - auto dns_offset = rid * row_length; - auto rsp_offset = i * row_length; - for (size_t col = 0; col < row_length; col++) { + MSHADOW_XINLINE static void Map(int i, + const nnvm::dim_t row_length, + const IType* idx, + const DType *data, + DType* dns) { + using nnvm::dim_t; + IType rid = idx[i]; + dim_t dns_offset = rid * row_length; + dim_t rsp_offset = i * row_length; + for (dim_t col = 0; col < row_length; col++) { dns[dns_offset + col] = data[rsp_offset + col]; } } @@ -122,9 +131,12 @@ struct CastStorageRspDnsKernel { * since the shape is known at binding stage. */ template -void CastStorageRspDnsImpl(const OpContext& ctx, const NDArray& rsp, TBlob* dns) { +void CastStorageRspDnsImpl(const OpContext& ctx, + const NDArray& rsp, + TBlob* dns) { mshadow::Stream* s = ctx.get_stream(); CHECK_EQ(rsp.storage_type(), kRowSparseStorage); + using nnvm::dim_t; MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, { MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, { // assign zeros @@ -134,11 +146,12 @@ void CastStorageRspDnsImpl(const OpContext& ctx, const NDArray& rsp, TBlob* dns) auto in_idx = rsp.aux_data(rowsparse::kIdx).FlatTo1D(s).dptr_; auto in_data = rsp.data().dptr(); auto out_data = dns->dptr(); - auto num_rows = rsp.aux_shape(rowsparse::kIdx).Size(); - const auto shape = rsp.shape(); - auto row_length = shape.ProdShape(1, shape.ndim()); - mxnet_op::Kernel::Launch(s, num_rows, row_length, in_idx, - in_data, out_data); + auto shape = rsp.shape(); + const dim_t num_rows = rsp.aux_shape(rowsparse::kIdx).Size(); + const dim_t row_length = shape.ProdShape(1, shape.ndim()); + const dim_t num_threads = num_rows; + mxnet_op::Kernel::Launch(s, num_threads, + row_length, in_idx, in_data, out_data); } }); }); @@ -150,18 +163,22 @@ void CastStorageRspDnsImpl(const OpContext& ctx, const NDArray& rsp, TBlob* dns) struct FillCsrIndPtr { /*! * \brief - * \param i the i-th row of the dns tensor - * \param indptr indptr of the csr tensor - * \param dns the dns tensor - * \param num_rows - * \param num_cols + * \param i the i-th row of the dns tensor + * \param indptr the indptr of the csr tensor + * \param dns the dns tensor + * \param num_rows number of rows of the dns tensor + * \param num_cols number of columns of the dns tensor */ template - MSHADOW_XINLINE static void Map(int i, IType* indptr, const DType* dns, - const int num_rows, const int num_cols) { + MSHADOW_CINLINE static void Map(int i, + IType* indptr, + const DType* dns, + const nnvm::dim_t num_rows, + const nnvm::dim_t num_cols) { + using nnvm::dim_t; indptr[i+1] = 0; - const int offset = i * num_cols; - for (int j = 0; j < num_cols; ++j) { + const dim_t offset = i * num_cols; + for (dim_t j = 0; j < num_cols; ++j) { if (dns[offset+j] != 0) { ++indptr[i+1]; } @@ -175,21 +192,26 @@ struct FillCsrIndPtr { struct FillCsrColIdxAndVals { /*! * \brief - * \param i the i-th row of the dns tensor - * \param val value array of the csr - * \param col_idx column idx array of the csr - * \param indptr indptr array of the csr - * \param dns the dns tensor - * \param num_rows number of rows of the dns - * \param num_cols number of columns of the dns + * \param i the i-th row of the dns tensor + * \param val value array of the csr tensor + * \param col_idx column idx array of the csr tensor + * \param indptr indptr array of the csr tensor + * \param dns dns tensor + * \param num_rows number of rows of the dns tensor + * \param num_cols number of columns of the dns tensor */ template - MSHADOW_XINLINE static void Map(int i, DType* val, CType* col_idx, - const IType* indptr, const DType* dns, - const index_t num_rows, const index_t num_cols) { - const index_t offset = i * num_cols; + MSHADOW_CINLINE static void Map(int i, + DType* val, + CType* col_idx, + const IType* indptr, + const DType* dns, + const nnvm::dim_t num_rows, + const nnvm::dim_t num_cols) { + using nnvm::dim_t; + const dim_t offset = i * num_cols; IType k = indptr[i]; - for (index_t j = 0; j < num_cols; ++j) { + for (dim_t j = 0; j < num_cols; ++j) { if (dns[offset+j] != 0) { val[k] = dns[offset+j]; col_idx[k] = j; @@ -210,29 +232,31 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx, CHECK_EQ(csr->storage_type(), kCSRStorage); CHECK_EQ(dns.shape_.ndim(), 2); CHECK_EQ(dns.shape_, csr->shape()); + using mshadow::Shape1; + using nnvm::dim_t; mshadow::Stream* s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIndPtr), IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIdx), CType, { // col idx type - const index_t num_rows = dns.shape_[0]; - const index_t num_cols = dns.shape_[1]; + const dim_t num_rows = dns.shape_[0]; + const dim_t num_cols = dns.shape_[1]; csr->CheckAndAllocAuxData(csr::kIndPtr, mshadow::Shape1(num_rows+1)); IType* indptr = csr->aux_data(csr::kIndPtr).dptr(); DType* dns_data = dns.dptr(); - mxnet_op::Kernel::Launch(s, num_rows, indptr, - dns_data, num_rows, num_cols); + dim_t num_threads = num_rows; + mxnet_op::Kernel::Launch(s, num_threads, + indptr, dns_data, num_rows, num_cols); // single thread to accumulate indptr // indptr[num_rows] indicates the number of non-zero elements indptr[0] = 0; - for (index_t i = 0; i < num_rows; ++i) { + for (dim_t i = 0; i < num_rows; ++i) { indptr[i+1] += indptr[i]; } // allocate column idx array and value array - csr->CheckAndAllocAuxData(csr::kIdx, - mshadow::Shape1(static_cast(indptr[num_rows]))); - csr->CheckAndAllocData(mshadow::Shape1(static_cast(indptr[num_rows]))); + csr->CheckAndAllocAuxData(csr::kIdx, Shape1(static_cast(indptr[num_rows]))); + csr->CheckAndAllocData(Shape1(static_cast(indptr[num_rows]))); // fill col_idx and value arrays of the csr - mxnet_op::Kernel::Launch(s, num_rows, + mxnet_op::Kernel::Launch(s, num_threads, csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), indptr, dns_data, num_rows, num_cols); }); @@ -246,19 +270,22 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx, struct CopyCsrDataToDns { /*! * \brief - * \param i the i-th row of the dns tensor - * \param dns_data data blob of the dns tensor - * \param col_idx column idx array of the csr - * \param indptr indptr array of the csr - * \param csr_data data blob of the csr tensor - * \param num_cols number of columns of the dns + * \param i the i-th row of the dns tensor + * \param dns_data data blob of the dns tensor + * \param col_idx column idx array of the csr tensor + * \param indptr indptr array of the csr tensor + * \param csr_data data blob of the csr tensor + * \param num_cols number of columns of the dns tensor */ template - MSHADOW_XINLINE static void Map(int i, DType* dns_data, const CType* col_idx, - const IType* indptr, const DType* csr_data, - const int num_cols) { - const int offset = i * num_cols; - for (auto j = indptr[i]; j < indptr[i+1]; ++j) { + MSHADOW_XINLINE static void Map(int i, + DType* dns_data, + const CType* col_idx, + const IType* indptr, + const DType* csr_data, + const nnvm::dim_t num_cols) { + const nnvm::dim_t offset = i * num_cols; + for (IType j = indptr[i]; j < indptr[i+1]; ++j) { dns_data[offset+col_idx[j]] = csr_data[j]; } } @@ -268,25 +295,30 @@ struct CopyCsrDataToDns { * \brief Casts a csr matrix to dns format. */ template -void CastStorageCsrDnsImpl(const OpContext& ctx, const NDArray& csr, TBlob* dns) { +void CastStorageCsrDnsImpl(const OpContext& ctx, + const NDArray& csr, + TBlob* dns) { CHECK(dns != nullptr); CHECK_EQ(csr.storage_type(), kCSRStorage); CHECK_EQ(dns->shape_.ndim(), 2); CHECK_EQ(dns->shape_, csr.shape()); + using nnvm::dim_t; mshadow::Stream* s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(csr::kIndPtr), IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(csr::kIdx), CType, { // col idx type - const index_t num_rows = dns->shape_[0]; - const index_t num_cols = dns->shape_[1]; + const dim_t num_rows = dns->shape_[0]; + const dim_t num_cols = dns->shape_[1]; DType* dns_data = dns->dptr(); - mxnet_op::Kernel::Launch(s, dns->shape_.Size(), dns_data); + dim_t num_threads = dns->shape_.Size(); + mxnet_op::Kernel::Launch(s, num_threads, dns_data); if (!csr.storage_initialized()) return; const IType* indptr = csr.aux_data(csr::kIndPtr).dptr(); const CType* col_idx = csr.aux_data(csr::kIdx).dptr(); const DType* csr_data = csr.data().dptr(); - mxnet_op::Kernel::Launch(s, num_rows, dns_data, - col_idx, indptr, csr_data, num_cols); + num_threads = num_rows; + mxnet_op::Kernel::Launch(s, num_threads, + dns_data, col_idx, indptr, csr_data, num_cols); }); }); }); From b2cc87d1fadab98aca12b6d1f5cbecefde9d08bf Mon Sep 17 00:00:00 2001 From: Henneking Date: Thu, 10 Aug 2017 11:50:29 -0700 Subject: [PATCH 2/4] update dot dtype switch to use 32 and 64bit floating point only --- src/operator/tensor/dot-inl.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 79b87da77011..907ea26d323a 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -69,7 +69,7 @@ void DotForward_(const nnvm::NodeAttrs& attrs, << "Binary function only support input/output with the same type"; CHECK(outputs[0].type_flag_ == kFloat32 || outputs[0].type_flag_ == kFloat64) << "dot only supports float32 and float64"; - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, DType, { if (inputs[0].ndim() == 1 && inputs[1].ndim() == 1) { CHECK_NE(req[0], kAddTo) << "AddTo not yet suported"; Tensor out = outputs[0].get(s); @@ -127,7 +127,7 @@ void DotBackward_(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); CHECK_NE(req[0], kWriteInplace); CHECK_NE(req[1], kWriteInplace); - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, DType, { if (inputs[1].ndim() == 1 && inputs[2].ndim() == 1) { Tensor mout_grad = inputs[0].get(s); Tensor mlhs_data = inputs[1].get(s); @@ -492,7 +492,7 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx, const TBlob& data_r = rhs; const TBlob data_out = *ret; - MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type dim_t num_threads; @@ -549,7 +549,7 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx, const TBlob data_out = ret->data(); const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx); - MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type MSHADOW_IDX_TYPE_SWITCH(row_idx_out.type_flag_, RType, { // row idx type @@ -609,7 +609,7 @@ inline void DotCsrRspDnsImpl(const OpContext& ctx, mshadow::Stream* s = ctx.get_stream(); if (!lhs.storage_initialized() || !rhs.storage_initialized()) { if (kWriteTo == req) { - MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { // data type + MSHADOW_SGL_DBL_TYPE_SWITCH(ret->type_flag_, DType, { // data type mxnet_op::Kernel::Launch( s, ret->Size(), ret->dptr()); }); @@ -624,7 +624,7 @@ inline void DotCsrRspDnsImpl(const OpContext& ctx, const TBlob data_r = rhs.data(); const TBlob row_idx_r = rhs.aux_data(rowsparse::kIdx); - MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, { // row idx type @@ -691,7 +691,7 @@ inline void DotCsrRspRspImpl(const OpContext& ctx, const TBlob data_out = ret->data(); const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx); - MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, { // row idx type @@ -861,7 +861,7 @@ void BatchDotForward_(const nnvm::NodeAttrs& attrs, << "Binary function only support input/output with the same type"; CHECK(outputs[0].type_flag_ == kFloat32 || outputs[0].type_flag_ == kFloat64) << "dot only supports float32 and float64"; - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, DType, { mshadow::Tensor out = outputs[0].get(s); mshadow::Tensor mlhs = inputs[0].get(s); mshadow::Tensor mrhs = inputs[1].get(s); @@ -903,7 +903,7 @@ void BatchDotBackward_(const nnvm::NodeAttrs& attrs, CHECK_NE(req[0], kWriteInplace); CHECK(outputs[0].type_flag_ == kFloat32 || outputs[0].type_flag_ == kFloat64) << "dot only supports float32 and float64"; - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, DType, { mshadow::Tensor mout_grad = inputs[0].get(s); mshadow::Tensor mlhs_data = inputs[1].get(s); mshadow::Tensor mrhs_data = inputs[2].get(s); From e06b755f4e78d183beb2b8a3bd6e3e23fdc74060 Mon Sep 17 00:00:00 2001 From: Henneking Date: Thu, 10 Aug 2017 12:55:58 -0700 Subject: [PATCH 3/4] use type_assign instead of STORAGE_TYPE_ASSIGN_CHECK --- src/operator/tensor/dot-inl.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 907ea26d323a..7b7d82b01b91 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -205,11 +205,11 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); const DotParam& param = nnvm::get(attrs.parsed); // csr has many zero columns, so the result of dot(csr.T, matrix) should be rsp - // TODO(stefan/haibin): don't enforce kRowSparseStorage if out_attrs has already been set + // TODO(stefan/haibin/jun): check type_assign return value if (param.transpose_a && kCSRStorage == (*in_attrs)[0]) { - STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage); + type_assign(&((*out_attrs)[0]), kRowSparseStorage); } else { - STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage); + type_assign(&((*out_attrs)[0]), kDefaultStorage); } return true; } @@ -221,11 +221,11 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), 2U); const DotParam& param = nnvm::get(attrs.parsed); - STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage); + type_assign(&((*out_attrs)[0]), kDefaultStorage); if (!param.transpose_a && kCSRStorage == (*in_attrs)[1]) { - STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kRowSparseStorage); + type_assign(&((*out_attrs)[1]), kRowSparseStorage); } else { - STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kDefaultStorage); + type_assign(&((*out_attrs)[1]), kDefaultStorage); } return true; } From 427389e5ed22b6ec5952c7b19f308bb72db386d6 Mon Sep 17 00:00:00 2001 From: Henneking Date: Thu, 10 Aug 2017 14:46:43 -0700 Subject: [PATCH 4/4] added tensor_util-inl.cuh file for common tensor operator GPU kernels --- src/operator/tensor/cast_storage-inl.cuh | 214 +++-------------- src/operator/tensor/dot-inl.cuh | 92 +------ src/operator/tensor/util/tensor_util-inl.cuh | 240 +++++++++++++++++++ 3 files changed, 279 insertions(+), 267 deletions(-) create mode 100644 src/operator/tensor/util/tensor_util-inl.cuh diff --git a/src/operator/tensor/cast_storage-inl.cuh b/src/operator/tensor/cast_storage-inl.cuh index 1a47b68f7aac..afef53e979ea 100644 --- a/src/operator/tensor/cast_storage-inl.cuh +++ b/src/operator/tensor/cast_storage-inl.cuh @@ -25,162 +25,20 @@ #ifndef MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_CUH_ #define MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_CUH_ +#include #include #include #include - -#include +#include "./util/tensor_util-inl.cuh" namespace mxnet { namespace op { /*! - * \brief Thread kernel for marking non-zero rows of a tensor. - * Parallelized by tensor rows: 1 thread/row - */ -struct MarkRspRowIdxThreadKernel { - /*! - * \brief - * \param tid global thread id - * \param row_flg row flag array to mark non-zero rows - * \param dns dense matrix data - * \param num_rows number of rows (size of first dimension of tensor) - * \param row_length number of elements per row - */ - template - __device__ __forceinline__ static void Map(int tid, - RType* row_flg, - const DType* dns, - const nnvm::dim_t num_rows, - const nnvm::dim_t row_length) { - using nnvm::dim_t; - if (tid < num_rows) { - dim_t j = 0; - dim_t offset = tid * row_length; - for (; j < row_length; ++j) { - if (dns[offset+j] != 0) { - break; - } - } - if (j < row_length) { - row_flg[tid] = 1; // mark as one for non-zero row - } else { - row_flg[tid] = 0; // mark as zero for zero row - } - } - } -}; - -/*! - * \brief Warp kernel for marking non-zero rows of a tensor. - * Parallelized by tensor rows: 1 warp/row - */ -struct MarkRspRowIdxWarpKernel { - template - __device__ __forceinline__ static void Map(int tid, - RType* row_flg, - const DType* dns, - const nnvm::dim_t num_rows, - const nnvm::dim_t row_length) { - using nnvm::dim_t; - typedef cub::WarpReduce WarpReduce; - const dim_t warps_per_block = mshadow::cuda::kBaseThreadNum / 32; - __shared__ typename WarpReduce::TempStorage temp_storage[warps_per_block]; - - const dim_t warp_id = tid / 32; // global warp id - const dim_t warp_lane = threadIdx.x / 32; // local warp id within thread block - const dim_t lane = tid & (32-1); // local thread id within warp - - if (warp_id < num_rows) { - dim_t flg = 0; - dim_t offset = warp_id * row_length; - for (dim_t j = lane; j < row_length; j+=32) { - if (dns[offset+j] != 0) { - // avoid break: causes slower performance on sparse tensors (<20% density), - // due to thread divergence - flg++; - } - } - dim_t aggr = WarpReduce(temp_storage[warp_lane]).Sum(flg); - if (lane == 0) { - if (aggr > 0) { - row_flg[warp_id] = 1; // mark as one for non-zero row - } else { - row_flg[warp_id] = 0; // mark as zero for zero row - } - } - } - } -}; - -/*! - * \brief Block kernel for marking non-zero rows of a tensor. - * Parallelized by tensor rows: 1 threadBlock/row - */ -struct MarkRspRowIdxBlockKernel { - template - __device__ __forceinline__ static void Map(int tid, - RType* row_flg, - const DType* dns, - const nnvm::dim_t num_rows, - const nnvm::dim_t row_length) { - using nnvm::dim_t; - using mshadow::cuda::kBaseThreadNum; - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - if (blockIdx.x < num_rows) { - dim_t flg = 0; - dim_t offset = blockIdx.x * row_length; - for (dim_t j = threadIdx.x; j < row_length; j+=kBaseThreadNum) { - if (dns[offset+j] != 0) { - // avoid break: causes slower performance on sparse tensors (<20% density), - // due to thread divergence - flg++; - } - } - dim_t aggr = BlockReduce(temp_storage).Sum(flg); - if (threadIdx.x == 0) { - if (aggr > 0) { - row_flg[blockIdx.x] = 1; // mark as one for non-zero row - } else { - row_flg[blockIdx.x] = 0; // mark as zero for zero row - } - } - } - } -}; - -/*! - * \brief Kernel for filling the row index array of the rsp tensor. - * Parallelized by tensor rows: 1 thread/row - */ -struct FillRspRowIdxKernel { - /*! - * \brief - * \param tid global thread id - * \param row_idx row index array to store indices of non-zero rows - * \param row_flg_sum inclusive prefix sum array over marked row flag array - * \param num_rows number of rows (size of first dimension of tensor) - */ - template - __device__ __forceinline__ static void Map(int tid, - RType* row_idx, - const RType* row_flg_sum, - const nnvm::dim_t num_rows) { - if (tid < num_rows) { - nnvm::dim_t prev = (tid == 0)? 0 : row_flg_sum[tid-1]; - if (row_flg_sum[tid] > prev) { - row_idx[prev] = tid; - } - } - } -}; - -/*! - * \brief Kernel for filling the value array of the rsp tensor. + * \brief GPU Kernel for filling the value array of the rsp tensor. * Parallelized by rsp tensor elements: 1 thread/element */ -struct FillRspValsKernel { +struct CastDnsRspValsKernel { /*! * \brief * \param tid global thread id @@ -243,7 +101,7 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx, LOG(FATAL) << "CastStorageDnsRspImpl GPU kernels expect warpSize=32"; } // Determine temporary device storage requirements - RType* row_flg = NULL; + dim_t* row_flg = NULL; void* d_temp_storage = NULL; size_t temp_storage_bytes = 0; cub::DeviceScan::InclusiveSum(d_temp_storage, @@ -254,10 +112,10 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx, mshadow::Stream::GetStream(s)); // Allocate temp storage for marking non-zero rows and for cub's prefix sum - auto workspace = AllocateTempDataForCast(ctx, Shape1(num_rows*sizeof(RType) + auto workspace = AllocateTempDataForCast(ctx, Shape1(num_rows*sizeof(dim_t) + temp_storage_bytes)); - row_flg = reinterpret_cast(workspace.dptr_); - d_temp_storage = workspace.dptr_ + num_rows*sizeof(RType); + row_flg = reinterpret_cast(workspace.dptr_); + d_temp_storage = workspace.dptr_ + num_rows*sizeof(dim_t); // Mark non-zero rows as 'one' in row_flg // Different kernel versions are optimized for different matrix instances @@ -268,31 +126,31 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx, switch (kernel_version) { case 1: num_threads = num_rows; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, row_flg, dns.dptr(), num_rows, row_length); break; case 2: num_threads = num_rows * threads_per_warp; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, row_flg, dns.dptr(), num_rows, row_length); break; case 3: num_threads = num_rows * threads_per_block; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, row_flg, dns.dptr(), num_rows, row_length); break; default: if (row_length < threads_per_warp) { num_threads = num_rows; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, row_flg, dns.dptr(), num_rows, row_length); } else if (row_length < threads_per_block || num_rows > min_num_warps) { num_threads = num_rows * threads_per_warp; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, row_flg, dns.dptr(), num_rows, row_length); } else { num_threads = num_rows * threads_per_block; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, row_flg, dns.dptr(), num_rows, row_length); } break; @@ -306,11 +164,11 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx, mshadow::Stream::GetStream(s)); // Get total number of non-zero rows from device - RType nnr = 0; - CUDA_CALL(cudaMemcpy(&nnr, &row_flg[num_rows-1], sizeof(RType), cudaMemcpyDeviceToHost)); + dim_t nnr = 0; + CUDA_CALL(cudaMemcpy(&nnr, &row_flg[num_rows-1], sizeof(dim_t), cudaMemcpyDeviceToHost)); // Allocate rsp tensor row index array and fill - rsp->CheckAndAllocAuxData(rowsparse::kIdx, Shape1(static_cast(nnr))); + rsp->CheckAndAllocAuxData(rowsparse::kIdx, Shape1(nnr)); if (0 == nnr) return; RType* row_idx = rsp->aux_data(rowsparse::kIdx).dptr(); num_threads = num_rows; @@ -322,7 +180,7 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx, storage_shape[0] = nnr; rsp->CheckAndAllocData(storage_shape); num_threads = nnr * row_length; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, rsp->data().dptr(), row_idx, dns.dptr(), nnr, row_length); }); }); @@ -332,7 +190,7 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx, * \brief Thread kernel for initializing the indptr in a csr matrix. * Parallelized by matrix rows: 1 thread/row */ -struct FillCsrIndPtrThreadKernel { +struct CastDnsCsrIndPtrThreadKernel { /*! * \brief * \param tid global thread id @@ -368,7 +226,7 @@ struct FillCsrIndPtrThreadKernel { * \brief Thread kernel for initializing the col_idx and value array of the csr matrix. * Parallelized by matrix rows: 1 thread/row */ -struct FillCsrColIdxAndValsThreadKernel { +struct CastDnsCsrColIdxAndValsThreadKernel { /*! * \brief * \param tid global thread id @@ -406,7 +264,7 @@ struct FillCsrColIdxAndValsThreadKernel { * \brief Warp kernel for initializing the indptr in a csr matrix. * Parallelized by matrix rows: 1 warp/row */ -struct FillCsrIndPtrWarpKernel { +struct CastDnsCsrIndPtrWarpKernel { template __device__ __forceinline__ static void Map(int tid, IType* indptr, @@ -444,7 +302,7 @@ struct FillCsrIndPtrWarpKernel { * \brief Warp kernel for initializing the col_idx and value array of the csr matrix. * Parallelized by matrix rows: 1 warp/row */ -struct FillCsrColIdxAndValsWarpKernel { +struct CastDnsCsrColIdxAndValsWarpKernel { template __device__ __forceinline__ static void Map(int tid, DType* val, @@ -498,7 +356,7 @@ struct FillCsrColIdxAndValsWarpKernel { * \brief Block kernel for initializing the indptr in a csr matrix. * Parallelized by matrix rows: 1 threadBlock/row */ -struct FillCsrIndPtrBlockKernel { +struct CastDnsCsrIndPtrBlockKernel { template __device__ __forceinline__ static void Map(int tid, IType* indptr, @@ -533,7 +391,7 @@ struct FillCsrIndPtrBlockKernel { * \brief Block kernel for initializing the col_idx and value array of the csr matrix. * Parallelized by matrix rows: 1 threadBlock/row */ -struct FillCsrColIdxAndValsBlockKernel { +struct CastDnsCsrColIdxAndValsBlockKernel { template __device__ __forceinline__ static void Map(int tid, DType* val, @@ -620,31 +478,31 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx, switch (kernel_version) { case 1: num_threads = num_rows; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, indptr, dns_data, num_rows, num_cols); break; case 2: num_threads = num_rows * threads_per_warp; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, indptr, dns_data, num_rows, num_cols); break; case 3: num_threads = num_rows * threads_per_block; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, indptr, dns_data, num_rows, num_cols); break; default: if (num_cols < threads_per_warp) { num_threads = num_rows; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, indptr, dns_data, num_rows, num_cols); } else if (num_cols < threads_per_block || num_rows > min_num_warps) { num_threads = num_rows * threads_per_warp; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, indptr, dns_data, num_rows, num_cols); } else { num_threads = num_rows * threads_per_block; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, indptr, dns_data, num_rows, num_cols); } break; @@ -685,36 +543,36 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx, switch (kernel_version) { case 1: num_threads = num_rows; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), indptr, dns_data, num_rows, num_cols); break; case 2: num_threads = num_rows * threads_per_warp; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), indptr, dns_data, num_rows, num_cols); break; case 3: num_threads = num_rows * threads_per_block; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), indptr, dns_data, num_rows, num_cols); break; default: if (num_cols < threads_per_warp) { num_threads = num_rows; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), indptr, dns_data, num_rows, num_cols); } else if (num_cols < threads_per_block || num_rows > min_num_warps) { num_threads = num_rows * threads_per_warp; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), indptr, dns_data, num_rows, num_cols); } else { num_threads = num_rows * threads_per_block; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), indptr, dns_data, num_rows, num_cols); } diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh index f8cd4faf8632..41c3faaf419f 100644 --- a/src/operator/tensor/dot-inl.cuh +++ b/src/operator/tensor/dot-inl.cuh @@ -27,97 +27,11 @@ #include #include - -#include +#include "./util/tensor_util-inl.cuh" namespace mxnet { namespace op { -/*! - * \brief GPU auxiliary kernel to flag non-zero rows of an rsp matrix with indices. - * Parallelized by matrix rows: 1 thread/row - */ -struct SetRspRowFlgKernel { - /*! - * \brief - * \param tid global thread id - * \param row_flg array to flag storage indices of non-zero rows - * \param row_idx rsp matrix row index array storing indices of non-zero rows - * \param nnr rsp matrix number of non-zero rows (storage shape) - */ - template - __device__ __forceinline__ static void Map(int tid, - RType* row_flg, - const RType* row_idx, - const nnvm::dim_t nnr) { - if (tid < nnr) { - row_flg[row_idx[tid]] = tid+1; - } - } -}; - -/*! - * \brief GPU auxiliary kernel for marking non-zero columns of a csr matrix. - * Parallelized by matrix rows: 1 warp/row - */ -struct MarkCsrZeroColsWarpKernel { - /*! - * \brief - * \param tid global thread id - * \param col_idx csr matrix column indices - * \param indptr csr matrix row index pointer - * \param num_rows csr matrix number of rows - * \param num_cols csr matrix number of columns - */ - template - __device__ __forceinline__ static void Map(int tid, - nnvm::dim_t* flg, - const CType* col_idx, - const IType* indptr, - const nnvm::dim_t num_rows, - const nnvm::dim_t num_cols) { - typedef unsigned long long int uint64_cu; - static_assert(sizeof(uint64_cu) == sizeof(nnvm::dim_t), "unexpected sizeof dim_t"); - - const nnvm::dim_t warp_id = tid / 32; // global warp id - const nnvm::dim_t lane = tid & (32-1); // local thread id within warp - - if (warp_id < num_rows) { - uint64_cu zero = 0; - uint64_cu one = 1; - for (IType j = indptr[warp_id]+lane; j < indptr[warp_id+1]; j+=32) { - atomicCAS(reinterpret_cast(flg+col_idx[j]), zero, one); - } - } - } -}; - -/*! - * \brief GPU auxiliary kernel for filling the row index array of an rsp matrix. - * Parallelized by matrix rows: 1 thread/row - */ -struct FillRspRowIdxKernel { - /*! - * \brief - * \param tid global thread id - * \param row_idx row index array to store indices of non-zero rows - * \param row_flg_sum inclusive prefix sum array over 0/1 marked row flag array - * \param num_rows rsp matrix number of rows (shape) - */ - template - __device__ __forceinline__ static void Map(int tid, - RType* row_idx, - const nnvm::dim_t* row_flg_sum, - const nnvm::dim_t num_rows) { - if (tid < num_rows) { - nnvm::dim_t prev = (tid == 0)? 0 : row_flg_sum[tid-1]; - if (row_flg_sum[tid] > prev) { - row_idx[prev] = static_cast(tid); - } - } - } -}; - /*! * \brief GPU scalar kernel of dot(csr, dns1) = dns2 * Parallelization by output matrix elements: 1 thread/element @@ -721,7 +635,7 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx, num_threads = num_cols_l; Kernel::Launch(s, num_threads, row_flg_out); num_threads = num_rows_l * threads_per_warp; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, row_flg_out, col_idx_l.dptr(), indptr_l.dptr(), num_rows_l, num_cols_l); cub::DeviceScan::InclusiveSum(d_temp_storage, @@ -840,7 +754,7 @@ inline void DotCsrRspRspImpl(const OpContext& ctx, num_threads = num_cols_l; Kernel::Launch(s, num_threads, row_flg_out); num_threads = num_rows_l * threads_per_warp; - Kernel::Launch(s, num_threads, + Kernel::Launch(s, num_threads, row_flg_out, col_idx_l.dptr(), indptr_l.dptr(), num_rows_l, num_cols_l); cub::DeviceScan::InclusiveSum(d_temp_storage, diff --git a/src/operator/tensor/util/tensor_util-inl.cuh b/src/operator/tensor/util/tensor_util-inl.cuh new file mode 100644 index 000000000000..cf268e7ae9fc --- /dev/null +++ b/src/operator/tensor/util/tensor_util-inl.cuh @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file tensor_util-inl.cuh + * \brief commonly utilized tensor operator GPU kernels + */ +#ifndef MXNET_OPERATOR_TENSOR_UTIL_TENSOR_UTIL_INL_CUH_ +#define MXNET_OPERATOR_TENSOR_UTIL_TENSOR_UTIL_INL_CUH_ + +#include +#include +#include + +namespace mxnet { +namespace op { + +/*! + * \brief Thread kernel for marking non-zero rows of a tensor. + * Parallelized by tensor rows: 1 thread/row + */ +struct MarkRspRowThreadKernel { + /*! + * \brief + * \param tid global thread id + * \param row_flg row flag array to mark non-zero rows + * \param dns dense matrix data + * \param num_rows number of rows (size of first dimension of tensor) + * \param row_length number of elements per row + */ + template + __device__ __forceinline__ static void Map(int tid, + nnvm::dim_t* row_flg, + const DType* dns, + const nnvm::dim_t num_rows, + const nnvm::dim_t row_length) { + using nnvm::dim_t; + if (tid < num_rows) { + dim_t j = 0; + dim_t offset = tid * row_length; + for (; j < row_length; ++j) { + if (dns[offset+j] != 0) { + break; + } + } + if (j < row_length) { + row_flg[tid] = 1; // mark as one for non-zero row + } else { + row_flg[tid] = 0; // mark as zero for zero row + } + } + } +}; + +/*! + * \brief Warp kernel for marking non-zero rows of a tensor. + * Parallelized by tensor rows: 1 warp/row + */ +struct MarkRspRowWarpKernel { + template + __device__ __forceinline__ static void Map(int tid, + nnvm::dim_t* row_flg, + const DType* dns, + const nnvm::dim_t num_rows, + const nnvm::dim_t row_length) { + using nnvm::dim_t; + typedef cub::WarpReduce WarpReduce; + const dim_t warps_per_block = mshadow::cuda::kBaseThreadNum / 32; + __shared__ typename WarpReduce::TempStorage temp_storage[warps_per_block]; + + const dim_t warp_id = tid / 32; // global warp id + const dim_t warp_lane = threadIdx.x / 32; // local warp id within thread block + const dim_t lane = tid & (32-1); // local thread id within warp + + if (warp_id < num_rows) { + dim_t flg = 0; + dim_t offset = warp_id * row_length; + for (dim_t j = lane; j < row_length; j+=32) { + if (dns[offset+j] != 0) { + // avoid break: causes slower performance on sparse tensors (<20% density), + // due to thread divergence + flg++; + } + } + dim_t aggr = WarpReduce(temp_storage[warp_lane]).Sum(flg); + if (lane == 0) { + if (aggr > 0) { + row_flg[warp_id] = 1; // mark as one for non-zero row + } else { + row_flg[warp_id] = 0; // mark as zero for zero row + } + } + } + } +}; + +/*! + * \brief Block kernel for marking non-zero rows of a tensor. + * Parallelized by tensor rows: 1 threadBlock/row + */ +struct MarkRspRowBlockKernel { + template + __device__ __forceinline__ static void Map(int tid, + nnvm::dim_t* row_flg, + const DType* dns, + const nnvm::dim_t num_rows, + const nnvm::dim_t row_length) { + using nnvm::dim_t; + using mshadow::cuda::kBaseThreadNum; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + if (blockIdx.x < num_rows) { + dim_t flg = 0; + dim_t offset = blockIdx.x * row_length; + for (dim_t j = threadIdx.x; j < row_length; j+=kBaseThreadNum) { + if (dns[offset+j] != 0) { + // avoid break: causes slower performance on sparse tensors (<20% density), + // due to thread divergence + flg++; + } + } + dim_t aggr = BlockReduce(temp_storage).Sum(flg); + if (threadIdx.x == 0) { + if (aggr > 0) { + row_flg[blockIdx.x] = 1; // mark as one for non-zero row + } else { + row_flg[blockIdx.x] = 0; // mark as zero for zero row + } + } + } + } +}; + +/*! + * \brief GPU kernel to flag non-zero rows of an rsp tensor with indices. + * Parallelized by matrix rows: 1 thread/row + */ +struct SetRspRowFlgKernel { + /*! + * \brief + * \param tid global thread id + * \param row_flg array to flag storage indices of non-zero rows + * \param row_idx rsp matrix row index array storing indices of non-zero rows + * \param nnr rsp matrix number of non-zero rows (storage shape) + */ + template + __device__ __forceinline__ static void Map(int tid, + RType* row_flg, + const RType* row_idx, + const nnvm::dim_t nnr) { + if (tid < nnr) { + row_flg[row_idx[tid]] = tid+1; + } + } +}; + +/*! + * \brief GPU kernel for filling the row index array of an rsp tensor. + * Parallelized by tensor rows: 1 thread/row + */ +struct FillRspRowIdxKernel { + /*! + * \brief + * \param tid global thread id + * \param row_idx row index array to store indices of non-zero rows + * \param row_flg_sum inclusive prefix sum array over 0/1 marked row flag array + * \param num_rows rsp tensor number of rows (shape) + */ + template + __device__ __forceinline__ static void Map(int tid, + RType* row_idx, + const nnvm::dim_t* row_flg_sum, + const nnvm::dim_t num_rows) { + if (tid < num_rows) { + nnvm::dim_t prev = (tid == 0)? 0 : row_flg_sum[tid-1]; + if (row_flg_sum[tid] > prev) { + row_idx[prev] = static_cast(tid); + } + } + } +}; + +/*! + * \brief GPU kernel for marking non-zero columns of a csr matrix. + * Parallelized by matrix rows: 1 warp/row + */ +struct MarkCsrColWarpKernel { + /*! + * \brief + * \param tid global thread id + * \param flg flg array to mark non-zero columns + * \param col_idx csr matrix column indices + * \param indptr csr matrix row index pointer + * \param num_rows csr matrix number of rows + * \param num_cols csr matrix number of columns + */ + template + __device__ __forceinline__ static void Map(int tid, + nnvm::dim_t* flg, + const CType* col_idx, + const IType* indptr, + const nnvm::dim_t num_rows, + const nnvm::dim_t num_cols) { + typedef unsigned long long int uint64_cu; + static_assert(sizeof(uint64_cu) == sizeof(nnvm::dim_t), "unexpected sizeof dim_t"); + + const nnvm::dim_t warp_id = tid / 32; // global warp id + const nnvm::dim_t lane = tid & (32-1); // local thread id within warp + + if (warp_id < num_rows) { + uint64_cu zero = 0; + uint64_cu one = 1; + for (IType j = indptr[warp_id]+lane; j < indptr[warp_id+1]; j+=32) { + atomicCAS(reinterpret_cast(flg+col_idx[j]), zero, one); + } + } + } +}; + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_TENSOR_UTIL_TENSOR_UTIL_INL_CUH_