diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 7ab471009069..2432703291f9 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -30,9 +30,10 @@ #include #include #include -#include "./init_op.h" +#include "./util/tensor_util-inl.h" #include "../mshadow_op.h" #include "../elemwise_op_common.h" +#include "./init_op.h" #include "../mxnet_op.h" #ifdef __CUDACC__ #include "./dot-inl.cuh" @@ -364,19 +365,17 @@ struct DotCsrTransDnsDnsByRowBlocks { /*! * \brief CPU Kernel of dot(csr.T(), dns) = rsp - * Parallelization by row blocks. - * This kernel fills up the row_idx array of the rsp - * with 1 for nonzero rows and 0 for zero rows. - * The matrix will be compacted after this kernel call. + * Parallelization by row blocks which evenly partition the non-zero rows. */ struct DotCsrTransDnsRspByRowBlocks { /*! * \brief * \param i the i-th thread */ - template + template MSHADOW_CINLINE static void Map(int i, DType* out, + nnvm::dim_t* row_flg_sum, RType* row_idx, const DType* data_l, const IType* indptr_l, @@ -384,21 +383,25 @@ struct DotCsrTransDnsRspByRowBlocks { const DType* data_r, const nnvm::dim_t seg_len, const nnvm::dim_t num_rows_l, - const nnvm::dim_t num_rows, + const nnvm::dim_t nnr, const nnvm::dim_t num_cols) { using nnvm::dim_t; const dim_t seg_start = i * seg_len; - if (seg_start >= num_rows) return; + if (seg_start >= nnr) return; const dim_t seg_end = (i + 1) * seg_len; + const dim_t col_start = row_idx[seg_start]; + const dim_t col_end = seg_end >= nnr ? (row_idx[nnr-1] + 1) : row_idx[seg_end]; for (dim_t j = 0; j < num_rows_l; ++j) { if (indptr_l[j] == indptr_l[j+1]) continue; const dim_t offset_r = j * num_cols; for (IType k = indptr_l[j]; k < indptr_l[j+1]; ++k) { const CType col_idx = col_idx_l[k]; - if (col_idx < seg_start || col_idx >= seg_end) continue; - const dim_t offset_out = col_idx * num_cols; - row_idx[col_idx] = 1; + if (col_idx < col_start || col_idx >= col_end) continue; + + const nnvm::dim_t rsp_row = row_flg_sum[col_idx] - 1; + const nnvm::dim_t offset_out = rsp_row * num_cols; const DType val = data_l[k]; + for (dim_t l = 0; l < num_cols; ++l) { out[offset_out+l] += data_r[offset_r+l] * val; } @@ -605,43 +608,51 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx, const TBlob col_idx_l = lhs.aux_data(csr::kIdx); const TBlob& data_r = rhs; - // pre-allocate spaces for ret using the dense dimension size - ret->CheckAndAlloc({mshadow::Shape1(lhs.shape()[1])}); - const TBlob data_out = ret->data(); - const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx); - MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type - MSHADOW_IDX_TYPE_SWITCH(row_idx_out.type_flag_, RType, { // row idx type + MSHADOW_IDX_TYPE_SWITCH(ret->aux_type(rowsparse::kIdx), RType, { // row idx type + const dim_t num_rows = lhs.shape()[1]; + size_t workspace_size = 2 * (num_rows * sizeof(dim_t)); + mshadow::Tensor workspace = + ctx.requested[0].get_space_typed( + mshadow::Shape1(workspace_size), s); + dim_t* row_flg = reinterpret_cast(workspace.dptr_); + dim_t* prefix_sum = row_flg + num_rows; + + Fill(s, TBlob(row_flg, mshadow::Shape1(num_rows), cpu::kDevMask), kWriteTo, 0); + mxnet_op::Kernel::Launch(s, lhs.aux_shape(csr::kIdx)[0], row_flg, + col_idx_l.dptr()); + + prefix_sum[0] = row_flg[0]; + for (nnvm::dim_t i = 1; i < num_rows; i++) { + prefix_sum[i] = prefix_sum[i - 1] + row_flg[i]; + } + dim_t nnr = prefix_sum[num_rows - 1]; + + if (nnr == 0) { + FillZerosRspImpl(s, *ret); + return; + } + + ret->CheckAndAlloc({mshadow::Shape1(nnr)}); + const TBlob& data_out = ret->data(); + const TBlob& row_idx = ret->aux_data(rowsparse::kIdx); + dim_t num_threads = data_out.Size(); mxnet_op::Kernel::Launch(s, num_threads, data_out.dptr()); - RType* row_idx = row_idx_out.dptr(); - num_threads = row_idx_out.Size(); - mxnet_op::Kernel::Launch(s, num_threads, row_idx); - num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); - dim_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + RType* row_idx_out = row_idx.dptr(); + + mxnet_op::Kernel::Launch(s, num_rows, + row_idx_out, prefix_sum, num_rows); + + num_threads = mxnet_op::get_num_threads(nnr); + dim_t seg_len = (nnr + num_threads - 1) / num_threads; if (trans_lhs) { mxnet_op::Kernel::Launch(s, num_threads, - data_out.dptr(), row_idx, data_l.dptr(), - indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), - seg_len, lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); - dim_t nnr = 0; - nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); - if (0 == nnr) { - FillZerosRspImpl(s, *ret); - return; - } - ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); - mshadow::Tensor rsp_data = data_out.FlatTo2D(s); - dim_t idx = 0; - for (index_t i = 0; i < ret->shape()[0]; ++i) { - if (row_idx[i] > 0) { - row_idx[idx] = i; - mshadow::Copy(rsp_data[idx], rsp_data[i], s); - ++idx; - } - } + data_out.dptr(), prefix_sum, row_idx_out, data_l.dptr(), + indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), + seg_len, lhs.shape()[0], nnr, ret->shape()[1]); } else { LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns)=rsp yet."; }