Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor fixes sparse ops #160

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 36 additions & 178 deletions src/operator/tensor/cast_storage-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,162 +25,20 @@
#ifndef MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_CUH_
#define MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_CUH_

#include <cub/cub.cuh>
#include <mxnet/base.h>
#include <mxnet/operator.h>
#include <nnvm/tuple.h>

#include <cub/cub.cuh>
#include "./util/tensor_util-inl.cuh"

namespace mxnet {
namespace op {

/*!
* \brief Thread kernel for marking non-zero rows of a tensor.
* Parallelized by tensor rows: 1 thread/row
*/
struct MarkRspRowIdxThreadKernel {
/*!
* \brief
* \param tid global thread id
* \param row_flg row flag array to mark non-zero rows
* \param dns dense matrix data
* \param num_rows number of rows (size of first dimension of tensor)
* \param row_length number of elements per row
*/
template<typename DType, typename RType>
__device__ __forceinline__ static void Map(int tid,
RType* row_flg,
const DType* dns,
const nnvm::dim_t num_rows,
const nnvm::dim_t row_length) {
using nnvm::dim_t;
if (tid < num_rows) {
dim_t j = 0;
dim_t offset = tid * row_length;
for (; j < row_length; ++j) {
if (dns[offset+j] != 0) {
break;
}
}
if (j < row_length) {
row_flg[tid] = 1; // mark as one for non-zero row
} else {
row_flg[tid] = 0; // mark as zero for zero row
}
}
}
};

/*!
* \brief Warp kernel for marking non-zero rows of a tensor.
* Parallelized by tensor rows: 1 warp/row
*/
struct MarkRspRowIdxWarpKernel {
template<typename DType, typename RType>
__device__ __forceinline__ static void Map(int tid,
RType* row_flg,
const DType* dns,
const nnvm::dim_t num_rows,
const nnvm::dim_t row_length) {
using nnvm::dim_t;
typedef cub::WarpReduce<dim_t> WarpReduce;
const dim_t warps_per_block = mshadow::cuda::kBaseThreadNum / 32;
__shared__ typename WarpReduce::TempStorage temp_storage[warps_per_block];

const dim_t warp_id = tid / 32; // global warp id
const dim_t warp_lane = threadIdx.x / 32; // local warp id within thread block
const dim_t lane = tid & (32-1); // local thread id within warp

if (warp_id < num_rows) {
dim_t flg = 0;
dim_t offset = warp_id * row_length;
for (dim_t j = lane; j < row_length; j+=32) {
if (dns[offset+j] != 0) {
// avoid break: causes slower performance on sparse tensors (<20% density),
// due to thread divergence
flg++;
}
}
dim_t aggr = WarpReduce(temp_storage[warp_lane]).Sum(flg);
if (lane == 0) {
if (aggr > 0) {
row_flg[warp_id] = 1; // mark as one for non-zero row
} else {
row_flg[warp_id] = 0; // mark as zero for zero row
}
}
}
}
};

/*!
* \brief Block kernel for marking non-zero rows of a tensor.
* Parallelized by tensor rows: 1 threadBlock/row
*/
struct MarkRspRowIdxBlockKernel {
template<typename DType, typename RType>
__device__ __forceinline__ static void Map(int tid,
RType* row_flg,
const DType* dns,
const nnvm::dim_t num_rows,
const nnvm::dim_t row_length) {
using nnvm::dim_t;
using mshadow::cuda::kBaseThreadNum;
typedef cub::BlockReduce<dim_t, kBaseThreadNum> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
if (blockIdx.x < num_rows) {
dim_t flg = 0;
dim_t offset = blockIdx.x * row_length;
for (dim_t j = threadIdx.x; j < row_length; j+=kBaseThreadNum) {
if (dns[offset+j] != 0) {
// avoid break: causes slower performance on sparse tensors (<20% density),
// due to thread divergence
flg++;
}
}
dim_t aggr = BlockReduce(temp_storage).Sum(flg);
if (threadIdx.x == 0) {
if (aggr > 0) {
row_flg[blockIdx.x] = 1; // mark as one for non-zero row
} else {
row_flg[blockIdx.x] = 0; // mark as zero for zero row
}
}
}
}
};

/*!
* \brief Kernel for filling the row index array of the rsp tensor.
* Parallelized by tensor rows: 1 thread/row
*/
struct FillRspRowIdxKernel {
/*!
* \brief
* \param tid global thread id
* \param row_idx row index array to store indices of non-zero rows
* \param row_flg_sum inclusive prefix sum array over marked row flag array
* \param num_rows number of rows (size of first dimension of tensor)
*/
template<typename RType>
__device__ __forceinline__ static void Map(int tid,
RType* row_idx,
const RType* row_flg_sum,
const nnvm::dim_t num_rows) {
if (tid < num_rows) {
nnvm::dim_t prev = (tid == 0)? 0 : row_flg_sum[tid-1];
if (row_flg_sum[tid] > prev) {
row_idx[prev] = tid;
}
}
}
};

/*!
* \brief Kernel for filling the value array of the rsp tensor.
* \brief GPU Kernel for filling the value array of the rsp tensor.
* Parallelized by rsp tensor elements: 1 thread/element
*/
struct FillRspValsKernel {
struct CastDnsRspValsKernel {
/*!
* \brief
* \param tid global thread id
Expand Down Expand Up @@ -243,7 +101,7 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx,
LOG(FATAL) << "CastStorageDnsRspImpl GPU kernels expect warpSize=32";
}
// Determine temporary device storage requirements
RType* row_flg = NULL;
dim_t* row_flg = NULL;
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceScan::InclusiveSum(d_temp_storage,
Expand All @@ -254,10 +112,10 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx,
mshadow::Stream<gpu>::GetStream(s));

// Allocate temp storage for marking non-zero rows and for cub's prefix sum
auto workspace = AllocateTempDataForCast<gpu, 1, char>(ctx, Shape1(num_rows*sizeof(RType)
auto workspace = AllocateTempDataForCast<gpu, 1, char>(ctx, Shape1(num_rows*sizeof(dim_t)
+ temp_storage_bytes));
row_flg = reinterpret_cast<RType*>(workspace.dptr_);
d_temp_storage = workspace.dptr_ + num_rows*sizeof(RType);
row_flg = reinterpret_cast<dim_t*>(workspace.dptr_);
d_temp_storage = workspace.dptr_ + num_rows*sizeof(dim_t);

// Mark non-zero rows as 'one' in row_flg
// Different kernel versions are optimized for different matrix instances
Expand All @@ -268,31 +126,31 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx,
switch (kernel_version) {
case 1:
num_threads = num_rows;
Kernel<MarkRspRowIdxThreadKernel, gpu>::Launch(s, num_threads,
Kernel<MarkRspRowThreadKernel, gpu>::Launch(s, num_threads,
row_flg, dns.dptr<DType>(), num_rows, row_length);
break;
case 2:
num_threads = num_rows * threads_per_warp;
Kernel<MarkRspRowIdxWarpKernel, gpu>::Launch(s, num_threads,
Kernel<MarkRspRowWarpKernel, gpu>::Launch(s, num_threads,
row_flg, dns.dptr<DType>(), num_rows, row_length);
break;
case 3:
num_threads = num_rows * threads_per_block;
Kernel<MarkRspRowIdxBlockKernel, gpu>::Launch(s, num_threads,
Kernel<MarkRspRowBlockKernel, gpu>::Launch(s, num_threads,
row_flg, dns.dptr<DType>(), num_rows, row_length);
break;
default:
if (row_length < threads_per_warp) {
num_threads = num_rows;
Kernel<MarkRspRowIdxThreadKernel, gpu>::Launch(s, num_threads,
Kernel<MarkRspRowThreadKernel, gpu>::Launch(s, num_threads,
row_flg, dns.dptr<DType>(), num_rows, row_length);
} else if (row_length < threads_per_block || num_rows > min_num_warps) {
num_threads = num_rows * threads_per_warp;
Kernel<MarkRspRowIdxWarpKernel, gpu>::Launch(s, num_threads,
Kernel<MarkRspRowWarpKernel, gpu>::Launch(s, num_threads,
row_flg, dns.dptr<DType>(), num_rows, row_length);
} else {
num_threads = num_rows * threads_per_block;
Kernel<MarkRspRowIdxBlockKernel, gpu>::Launch(s, num_threads,
Kernel<MarkRspRowBlockKernel, gpu>::Launch(s, num_threads,
row_flg, dns.dptr<DType>(), num_rows, row_length);
}
break;
Expand All @@ -306,11 +164,11 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx,
mshadow::Stream<gpu>::GetStream(s));

// Get total number of non-zero rows from device
RType nnr = 0;
CUDA_CALL(cudaMemcpy(&nnr, &row_flg[num_rows-1], sizeof(RType), cudaMemcpyDeviceToHost));
dim_t nnr = 0;
CUDA_CALL(cudaMemcpy(&nnr, &row_flg[num_rows-1], sizeof(dim_t), cudaMemcpyDeviceToHost));

// Allocate rsp tensor row index array and fill
rsp->CheckAndAllocAuxData(rowsparse::kIdx, Shape1(static_cast<dim_t>(nnr)));
rsp->CheckAndAllocAuxData(rowsparse::kIdx, Shape1(nnr));
if (0 == nnr) return;
RType* row_idx = rsp->aux_data(rowsparse::kIdx).dptr<RType>();
num_threads = num_rows;
Expand All @@ -322,7 +180,7 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx,
storage_shape[0] = nnr;
rsp->CheckAndAllocData(storage_shape);
num_threads = nnr * row_length;
Kernel<FillRspValsKernel, gpu>::Launch(s, num_threads,
Kernel<CastDnsRspValsKernel, gpu>::Launch(s, num_threads,
rsp->data().dptr<DType>(), row_idx, dns.dptr<DType>(), nnr, row_length);
});
});
Expand All @@ -332,7 +190,7 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx,
* \brief Thread kernel for initializing the indptr in a csr matrix.
* Parallelized by matrix rows: 1 thread/row
*/
struct FillCsrIndPtrThreadKernel {
struct CastDnsCsrIndPtrThreadKernel {
/*!
* \brief
* \param tid global thread id
Expand Down Expand Up @@ -368,7 +226,7 @@ struct FillCsrIndPtrThreadKernel {
* \brief Thread kernel for initializing the col_idx and value array of the csr matrix.
* Parallelized by matrix rows: 1 thread/row
*/
struct FillCsrColIdxAndValsThreadKernel {
struct CastDnsCsrColIdxAndValsThreadKernel {
/*!
* \brief
* \param tid global thread id
Expand Down Expand Up @@ -406,7 +264,7 @@ struct FillCsrColIdxAndValsThreadKernel {
* \brief Warp kernel for initializing the indptr in a csr matrix.
* Parallelized by matrix rows: 1 warp/row
*/
struct FillCsrIndPtrWarpKernel {
struct CastDnsCsrIndPtrWarpKernel {
template<typename DType, typename IType>
__device__ __forceinline__ static void Map(int tid,
IType* indptr,
Expand Down Expand Up @@ -444,7 +302,7 @@ struct FillCsrIndPtrWarpKernel {
* \brief Warp kernel for initializing the col_idx and value array of the csr matrix.
* Parallelized by matrix rows: 1 warp/row
*/
struct FillCsrColIdxAndValsWarpKernel {
struct CastDnsCsrColIdxAndValsWarpKernel {
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid,
DType* val,
Expand Down Expand Up @@ -498,7 +356,7 @@ struct FillCsrColIdxAndValsWarpKernel {
* \brief Block kernel for initializing the indptr in a csr matrix.
* Parallelized by matrix rows: 1 threadBlock/row
*/
struct FillCsrIndPtrBlockKernel {
struct CastDnsCsrIndPtrBlockKernel {
template<typename DType, typename IType>
__device__ __forceinline__ static void Map(int tid,
IType* indptr,
Expand Down Expand Up @@ -533,7 +391,7 @@ struct FillCsrIndPtrBlockKernel {
* \brief Block kernel for initializing the col_idx and value array of the csr matrix.
* Parallelized by matrix rows: 1 threadBlock/row
*/
struct FillCsrColIdxAndValsBlockKernel {
struct CastDnsCsrColIdxAndValsBlockKernel {
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid,
DType* val,
Expand Down Expand Up @@ -620,31 +478,31 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx,
switch (kernel_version) {
case 1:
num_threads = num_rows;
Kernel<FillCsrIndPtrThreadKernel, gpu>::Launch(s, num_threads,
Kernel<CastDnsCsrIndPtrThreadKernel, gpu>::Launch(s, num_threads,
indptr, dns_data, num_rows, num_cols);
break;
case 2:
num_threads = num_rows * threads_per_warp;
Kernel<FillCsrIndPtrWarpKernel, gpu>::Launch(s, num_threads,
Kernel<CastDnsCsrIndPtrWarpKernel, gpu>::Launch(s, num_threads,
indptr, dns_data, num_rows, num_cols);
break;
case 3:
num_threads = num_rows * threads_per_block;
Kernel<FillCsrIndPtrBlockKernel, gpu>::Launch(s, num_threads,
Kernel<CastDnsCsrIndPtrBlockKernel, gpu>::Launch(s, num_threads,
indptr, dns_data, num_rows, num_cols);
break;
default:
if (num_cols < threads_per_warp) {
num_threads = num_rows;
Kernel<FillCsrIndPtrThreadKernel, gpu>::Launch(s, num_threads,
Kernel<CastDnsCsrIndPtrThreadKernel, gpu>::Launch(s, num_threads,
indptr, dns_data, num_rows, num_cols);
} else if (num_cols < threads_per_block || num_rows > min_num_warps) {
num_threads = num_rows * threads_per_warp;
Kernel<FillCsrIndPtrWarpKernel, gpu>::Launch(s, num_threads,
Kernel<CastDnsCsrIndPtrWarpKernel, gpu>::Launch(s, num_threads,
indptr, dns_data, num_rows, num_cols);
} else {
num_threads = num_rows * threads_per_block;
Kernel<FillCsrIndPtrBlockKernel, gpu>::Launch(s, num_threads,
Kernel<CastDnsCsrIndPtrBlockKernel, gpu>::Launch(s, num_threads,
indptr, dns_data, num_rows, num_cols);
}
break;
Expand Down Expand Up @@ -685,36 +543,36 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx,
switch (kernel_version) {
case 1:
num_threads = num_rows;
Kernel<FillCsrColIdxAndValsThreadKernel, gpu>::Launch(s, num_threads,
Kernel<CastDnsCsrColIdxAndValsThreadKernel, gpu>::Launch(s, num_threads,
csr->data().dptr<DType>(), csr->aux_data(csr::kIdx).dptr<CType>(),
indptr, dns_data, num_rows, num_cols);
break;
case 2:
num_threads = num_rows * threads_per_warp;
Kernel<FillCsrColIdxAndValsWarpKernel, gpu>::Launch(s, num_threads,
Kernel<CastDnsCsrColIdxAndValsWarpKernel, gpu>::Launch(s, num_threads,
csr->data().dptr<DType>(), csr->aux_data(csr::kIdx).dptr<CType>(),
indptr, dns_data, num_rows, num_cols);
break;
case 3:
num_threads = num_rows * threads_per_block;
Kernel<FillCsrColIdxAndValsBlockKernel, gpu>::Launch(s, num_threads,
Kernel<CastDnsCsrColIdxAndValsBlockKernel, gpu>::Launch(s, num_threads,
csr->data().dptr<DType>(), csr->aux_data(csr::kIdx).dptr<CType>(),
indptr, dns_data, num_rows, num_cols);
break;
default:
if (num_cols < threads_per_warp) {
num_threads = num_rows;
Kernel<FillCsrColIdxAndValsThreadKernel, gpu>::Launch(s, num_threads,
Kernel<CastDnsCsrColIdxAndValsThreadKernel, gpu>::Launch(s, num_threads,
csr->data().dptr<DType>(), csr->aux_data(csr::kIdx).dptr<CType>(),
indptr, dns_data, num_rows, num_cols);
} else if (num_cols < threads_per_block || num_rows > min_num_warps) {
num_threads = num_rows * threads_per_warp;
Kernel<FillCsrColIdxAndValsWarpKernel, gpu>::Launch(s, num_threads,
Kernel<CastDnsCsrColIdxAndValsWarpKernel, gpu>::Launch(s, num_threads,
csr->data().dptr<DType>(), csr->aux_data(csr::kIdx).dptr<CType>(),
indptr, dns_data, num_rows, num_cols);
} else {
num_threads = num_rows * threads_per_block;
Kernel<FillCsrColIdxAndValsBlockKernel, gpu>::Launch(s, num_threads,
Kernel<CastDnsCsrColIdxAndValsBlockKernel, gpu>::Launch(s, num_threads,
csr->data().dptr<DType>(), csr->aux_data(csr::kIdx).dptr<CType>(),
indptr, dns_data, num_rows, num_cols);
}
Expand Down
Loading