Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
changed variable types from index_t to dim_t
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanhenneking committed Jul 31, 2017
1 parent 6cdf419 commit bcc3d64
Showing 1 changed file with 123 additions and 95 deletions.
218 changes: 123 additions & 95 deletions src/operator/tensor/cast_storage-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <mxnet/base.h>
#include <mxnet/operator.h>
#include <nnvm/tuple.h>

#include <cub/cub.cuh>

Expand All @@ -31,11 +32,12 @@ struct MarkRspRowIdxThreadKernel {
__device__ __forceinline__ static void Map(int tid,
RType* row_flg,
const DType* dns,
const index_t num_rows,
const index_t row_length) {
const nnvm::dim_t num_rows,
const nnvm::dim_t row_length) {
using nnvm::dim_t;
if (tid < num_rows) {
index_t j = 0;
index_t offset = tid * row_length;
dim_t j = 0;
dim_t offset = tid * row_length;
for (; j < row_length; ++j) {
if (dns[offset+j] != 0) {
break;
Expand All @@ -59,27 +61,28 @@ struct MarkRspRowIdxWarpKernel {
__device__ __forceinline__ static void Map(int tid,
RType* row_flg,
const DType* dns,
const index_t num_rows,
const index_t row_length) {
typedef cub::WarpReduce<index_t> WarpReduce;
const index_t warps_per_block = mshadow::cuda::kBaseThreadNum / 32;
const nnvm::dim_t num_rows,
const nnvm::dim_t row_length) {
using nnvm::dim_t;
typedef cub::WarpReduce<dim_t> WarpReduce;
const dim_t warps_per_block = mshadow::cuda::kBaseThreadNum / 32;
__shared__ typename WarpReduce::TempStorage temp_storage[warps_per_block];

const index_t warp_id = tid / 32; // global warp id
const index_t warp_lane = threadIdx.x / 32; // local warp id within thread block
const index_t lane = tid & (32-1); // local thread id within warp
const dim_t warp_id = tid / 32; // global warp id
const dim_t warp_lane = threadIdx.x / 32; // local warp id within thread block
const dim_t lane = tid & (32-1); // local thread id within warp

if (warp_id < num_rows) {
index_t flg = 0;
index_t offset = warp_id * row_length;
for (index_t j = lane; j < row_length; j+=32) {
dim_t flg = 0;
dim_t offset = warp_id * row_length;
for (dim_t j = lane; j < row_length; j+=32) {
if (dns[offset+j] != 0) {
// avoid break: causes slower performance on sparse tensors (<20% density),
// due to thread divergence
flg++;
}
}
index_t aggr = WarpReduce(temp_storage[warp_lane]).Sum(flg);
dim_t aggr = WarpReduce(temp_storage[warp_lane]).Sum(flg);
if (lane == 0) {
if (aggr > 0) {
row_flg[warp_id] = 1; // mark as one for non-zero row
Expand All @@ -100,22 +103,23 @@ struct MarkRspRowIdxBlockKernel {
__device__ __forceinline__ static void Map(int tid,
RType* row_flg,
const DType* dns,
const index_t num_rows,
const index_t row_length) {
const nnvm::dim_t num_rows,
const nnvm::dim_t row_length) {
using nnvm::dim_t;
using mshadow::cuda::kBaseThreadNum;
typedef cub::BlockReduce<index_t, kBaseThreadNum> BlockReduce;
typedef cub::BlockReduce<dim_t, kBaseThreadNum> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
if (blockIdx.x < num_rows) {
index_t flg = 0;
index_t offset = blockIdx.x * row_length;
for (index_t j = threadIdx.x; j < row_length; j+=kBaseThreadNum) {
dim_t flg = 0;
dim_t offset = blockIdx.x * row_length;
for (dim_t j = threadIdx.x; j < row_length; j+=kBaseThreadNum) {
if (dns[offset+j] != 0) {
// avoid break: causes slower performance on sparse tensors (<20% density),
// due to thread divergence
flg++;
}
}
index_t aggr = BlockReduce(temp_storage).Sum(flg);
dim_t aggr = BlockReduce(temp_storage).Sum(flg);
if (threadIdx.x == 0) {
if (aggr > 0) {
row_flg[blockIdx.x] = 1; // mark as one for non-zero row
Expand Down Expand Up @@ -143,9 +147,9 @@ struct FillRspRowIdxKernel {
__device__ __forceinline__ static void Map(int tid,
RType* row_idx,
const RType* row_flg_sum,
const index_t num_rows) {
const nnvm::dim_t num_rows) {
if (tid < num_rows) {
index_t prev = (tid == 0)? 0 : row_flg_sum[tid-1];
nnvm::dim_t prev = (tid == 0)? 0 : row_flg_sum[tid-1];
if (row_flg_sum[tid] > prev) {
row_idx[prev] = tid;
}
Expand All @@ -172,12 +176,13 @@ struct FillRspValsKernel {
DType* rsp_val,
const RType* row_idx,
const DType* dns,
const index_t nnr,
const index_t row_length) {
const nnvm::dim_t nnr,
const nnvm::dim_t row_length) {
using nnvm::dim_t;
if (tid < nnr*row_length) {
const index_t row_id = tid / row_length;
const index_t row_el = tid % row_length;
const index_t dns_idx = row_idx[row_id] * row_length + row_el;
const dim_t row_id = tid / row_length;
const dim_t row_el = tid % row_length;
const dim_t dns_idx = row_idx[row_id] * row_length + row_el;
rsp_val[tid] = dns[dns_idx];
}
}
Expand All @@ -195,15 +200,16 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx,
CHECK_EQ(dns.shape_, rsp->shape());
using mshadow::Shape1;
using mxnet_op::Kernel;
using nnvm::dim_t;
mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(rsp->aux_type(rowsparse::kIdx), RType, { // row idx type
const index_t num_rows = dns.shape_[0];
const index_t row_length = dns.shape_.ProdShape(1, dns.shape_.ndim());
const index_t threads_per_warp = mxnet_op::cuda_get_device_prop().warpSize;
const index_t threads_per_block = mshadow::cuda::kBaseThreadNum;
const index_t min_num_warps = 512;
index_t num_threads;
const dim_t num_rows = dns.shape_[0];
const dim_t row_length = dns.shape_.ProdShape(1, dns.shape_.ndim());
const dim_t threads_per_warp = mxnet_op::cuda_get_device_prop().warpSize;
const dim_t threads_per_block = mshadow::cuda::kBaseThreadNum;
const dim_t min_num_warps = 512;
dim_t num_threads;
// TODO: remove kernel dependency on warpSize=32
if (threads_per_warp != 32) {
LOG(FATAL) << "CastStorageDnsRspImpl GPU kernels expect warpSize=32";
Expand Down Expand Up @@ -276,7 +282,7 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx,
CUDA_CALL(cudaMemcpy(&nnr, &row_flg[num_rows-1], sizeof(RType), cudaMemcpyDeviceToHost));

// Allocate rsp tensor row index array and fill
rsp->CheckAndAllocAuxData(rowsparse::kIdx, Shape1(static_cast<index_t>(nnr)));
rsp->CheckAndAllocAuxData(rowsparse::kIdx, Shape1(static_cast<dim_t>(nnr)));
if (0 == nnr) return;
RType* row_idx = rsp->aux_data(rowsparse::kIdx).dptr<RType>();
num_threads = num_rows;
Expand Down Expand Up @@ -309,15 +315,18 @@ struct FillCsrIndPtrThreadKernel {
*/
template<typename DType, typename IType>
__device__ __forceinline__ static void Map(int tid,
IType* indptr, const DType* dns,
const index_t num_rows, const index_t num_cols) {
IType* indptr,
const DType* dns,
const nnvm::dim_t num_rows,
const nnvm::dim_t num_cols) {
using nnvm::dim_t;
if (tid == 0) {
indptr[tid] = 0;
}
if (tid < num_rows) {
index_t nnz = 0;
const index_t offset = tid * num_cols;
for (index_t j = 0; j < num_cols; ++j) {
dim_t nnz = 0;
const dim_t offset = tid * num_cols;
for (dim_t j = 0; j < num_cols; ++j) {
if (dns[offset+j] != 0) {
nnz++;
}
Expand All @@ -344,13 +353,17 @@ struct FillCsrColIdxAndValsThreadKernel {
*/
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid,
DType* val, CType* col_idx,
const IType* indptr, const DType* dns,
const index_t num_rows, const index_t num_cols) {
DType* val,
CType* col_idx,
const IType* indptr,
const DType* dns,
const nnvm::dim_t num_rows,
const nnvm::dim_t num_cols) {
using nnvm::dim_t;
if (tid < num_rows) {
const index_t offset = tid * num_cols;
index_t k = indptr[tid];
for (index_t j = 0; j < num_cols; ++j) {
const dim_t offset = tid * num_cols;
dim_t k = indptr[tid];
for (dim_t j = 0; j < num_cols; ++j) {
if (dns[offset+j] != 0) {
val[k] = dns[offset+j];
col_idx[k] = j;
Expand All @@ -368,27 +381,30 @@ struct FillCsrColIdxAndValsThreadKernel {
struct FillCsrIndPtrWarpKernel {
template<typename DType, typename IType>
__device__ __forceinline__ static void Map(int tid,
IType* indptr, const DType* dns,
const index_t num_rows, const index_t num_cols) {
typedef cub::WarpReduce<index_t> WarpReduce;
const index_t warps_per_block = mshadow::cuda::kBaseThreadNum / 32;
IType* indptr,
const DType* dns,
const nnvm::dim_t num_rows,
const nnvm::dim_t num_cols) {
using nnvm::dim_t;
typedef cub::WarpReduce<dim_t> WarpReduce;
const dim_t warps_per_block = mshadow::cuda::kBaseThreadNum / 32;
__shared__ typename WarpReduce::TempStorage temp_storage[warps_per_block];

if (tid == 0) {
indptr[tid] = 0;
}
const index_t warp_id = tid / 32; // global warp id
const index_t warp_lane = threadIdx.x / 32; // local warp id within thread block
const index_t lane = tid & (32-1); // local thread id within warp
const dim_t warp_id = tid / 32; // global warp id
const dim_t warp_lane = threadIdx.x / 32; // local warp id within thread block
const dim_t lane = tid & (32-1); // local thread id within warp
if (warp_id < num_rows) {
index_t lane_nnz = 0;
const index_t offset = warp_id * num_cols;
for (index_t j = lane; j < num_cols; j+=32) {
dim_t lane_nnz = 0;
const dim_t offset = warp_id * num_cols;
for (dim_t j = lane; j < num_cols; j+=32) {
if (dns[offset+j] != 0) {
lane_nnz++;
}
}
index_t aggr = WarpReduce(temp_storage[warp_lane]).Sum(lane_nnz);
dim_t aggr = WarpReduce(temp_storage[warp_lane]).Sum(lane_nnz);
if (lane == 0) {
indptr[warp_id+1] = aggr;
}
Expand All @@ -403,22 +419,26 @@ struct FillCsrIndPtrWarpKernel {
struct FillCsrColIdxAndValsWarpKernel {
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid,
DType* val, CType* col_idx,
const IType* indptr, const DType* dns,
const index_t num_rows, const index_t num_cols) {
typedef cub::WarpScan<index_t> WarpScan;
const index_t warps_per_block = mshadow::cuda::kBaseThreadNum / 32;
DType* val,
CType* col_idx,
const IType* indptr,
const DType* dns,
const nnvm::dim_t num_rows,
const nnvm::dim_t num_cols) {
using nnvm::dim_t;
typedef cub::WarpScan<dim_t> WarpScan;
const dim_t warps_per_block = mshadow::cuda::kBaseThreadNum / 32;
__shared__ typename WarpScan::TempStorage temp_storage[warps_per_block];
__shared__ volatile index_t warp_nnz[warps_per_block];
__shared__ volatile dim_t warp_nnz[warps_per_block];

const index_t warp_id = tid / 32; // global warp id
const index_t warp_lane = threadIdx.x / 32; // local warp id within thread block
const index_t lane = tid & (32-1); // local thread id within warp
const dim_t warp_id = tid / 32; // global warp id
const dim_t warp_lane = threadIdx.x / 32; // local warp id within thread block
const dim_t lane = tid & (32-1); // local thread id within warp
if (warp_id < num_rows) {
const index_t offset = warp_id * num_cols;
index_t k = indptr[warp_id];
index_t nnz;
for (index_t j = lane; j < num_cols+lane; j+=32) {
const dim_t offset = warp_id * num_cols;
dim_t k = indptr[warp_id];
dim_t nnz;
for (dim_t j = lane; j < num_cols+lane; j+=32) {
nnz = 0;
if (j < num_cols) {
if (dns[offset+j] != 0) {
Expand Down Expand Up @@ -453,24 +473,27 @@ struct FillCsrColIdxAndValsWarpKernel {
struct FillCsrIndPtrBlockKernel {
template<typename DType, typename IType>
__device__ __forceinline__ static void Map(int tid,
IType* indptr, const DType* dns,
const index_t num_rows, const index_t num_cols) {
IType* indptr,
const DType* dns,
const nnvm::dim_t num_rows,
const nnvm::dim_t num_cols) {
using mshadow::cuda::kBaseThreadNum;
typedef cub::BlockReduce<index_t, kBaseThreadNum> BlockReduce;
using nnvm::dim_t;
typedef cub::BlockReduce<dim_t, kBaseThreadNum> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

if (tid == 0) {
indptr[tid] = 0;
}
if (blockIdx.x < num_rows) {
index_t lane_nnz = 0;
const index_t offset = blockIdx.x * num_cols;
for (index_t j = threadIdx.x; j < num_cols; j+=kBaseThreadNum) {
dim_t lane_nnz = 0;
const dim_t offset = blockIdx.x * num_cols;
for (dim_t j = threadIdx.x; j < num_cols; j+=kBaseThreadNum) {
if (dns[offset+j] != 0) {
lane_nnz++;
}
}
index_t aggr = BlockReduce(temp_storage).Sum(lane_nnz);
dim_t aggr = BlockReduce(temp_storage).Sum(lane_nnz);
if (threadIdx.x == 0) {
indptr[blockIdx.x+1] = aggr;
}
Expand All @@ -485,19 +508,23 @@ struct FillCsrIndPtrBlockKernel {
struct FillCsrColIdxAndValsBlockKernel {
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid,
DType* val, CType* col_idx,
const IType* indptr, const DType* dns,
const index_t num_rows, const index_t num_cols) {
DType* val,
CType* col_idx,
const IType* indptr,
const DType* dns,
const nnvm::dim_t num_rows,
const nnvm::dim_t num_cols) {
using mshadow::cuda::kBaseThreadNum;
typedef cub::BlockScan<index_t, kBaseThreadNum> BlockScan;
using nnvm::dim_t;
typedef cub::BlockScan<dim_t, kBaseThreadNum> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
__shared__ volatile index_t block_nnz;
__shared__ volatile dim_t block_nnz;

if (blockIdx.x < num_rows) {
const index_t offset = blockIdx.x * num_cols;
index_t k = indptr[blockIdx.x];
index_t nnz;
for (index_t j = threadIdx.x; j < num_cols+threadIdx.x; j+=kBaseThreadNum) {
const dim_t offset = blockIdx.x * num_cols;
dim_t k = indptr[blockIdx.x];
dim_t nnz;
for (dim_t j = threadIdx.x; j < num_cols+threadIdx.x; j+=kBaseThreadNum) {
nnz = 0;
if (j < num_cols) {
if (dns[offset+j] != 0) {
Expand Down Expand Up @@ -538,16 +565,17 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx,
CHECK_EQ(dns.shape_, csr->shape());
using mshadow::Shape1;
using mxnet_op::Kernel;
using nnvm::dim_t;
mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIndPtr), IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIdx), CType, { // col_idx type
const index_t num_rows = dns.shape_[0];
const index_t num_cols = dns.shape_[1];
const index_t threads_per_warp = mxnet_op::cuda_get_device_prop().warpSize;
const index_t threads_per_block = mshadow::cuda::kBaseThreadNum;
const index_t min_num_warps = 512;
index_t num_threads;
const dim_t num_rows = dns.shape_[0];
const dim_t num_cols = dns.shape_[1];
const dim_t threads_per_warp = mxnet_op::cuda_get_device_prop().warpSize;
const dim_t threads_per_block = mshadow::cuda::kBaseThreadNum;
const dim_t min_num_warps = 512;
dim_t num_threads;
// TODO: remove kernel dependency on warpSize=32
if (threads_per_warp != 32) {
LOG(FATAL) << "CastStorageDnsCsrImpl GPU kernels expect warpSize=32";
Expand Down Expand Up @@ -622,8 +650,8 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx,
CUDA_CALL(cudaMemcpy(&nnz, &(indptr[num_rows]), sizeof(IType), cudaMemcpyDeviceToHost));

// Allocate column index array and data array of the csr matrix
csr->CheckAndAllocAuxData(csr::kIdx, Shape1(static_cast<index_t>(nnz)));
csr->CheckAndAllocData(Shape1(static_cast<index_t>(nnz)));
csr->CheckAndAllocAuxData(csr::kIdx, Shape1(static_cast<dim_t>(nnz)));
csr->CheckAndAllocData(Shape1(static_cast<dim_t>(nnz)));

// Compute and fill column index array and data array of the csr matrix
switch (kernel_version) {
Expand Down

0 comments on commit bcc3d64

Please sign in to comment.