Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Optimized gpu dot kernels #6937

Merged
merged 4 commits into from
Jul 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mshadow
1 change: 1 addition & 0 deletions src/io/inst_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <mxnet/base.h>
#include <dmlc/base.h>
#include <mshadow/tensor.h>
#include <mshadow/tensor_blob.h>
#include <vector>
#include <string>

Expand Down
257 changes: 239 additions & 18 deletions src/operator/tensor/dot-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@

namespace mxnet {
namespace op {
using mshadow::cuda::kBaseThreadNum;

/*!
* \brief Kernel of dot(csr, dns1) = dns2
* Parallelization by output matrix elements
* \brief Scalar kernel of dot(csr, dns1) = dns2
* Parallelization by output matrix elements: 1 thread/element
*/
template<int req>
struct DotCsrDnsDns {
struct DotCsrDnsDnsScalarKernel {
/*!
* \brief This function represents performing an inner product between a row of lhs
* and a column of rhs and then assigning the value to out[i].
Expand Down Expand Up @@ -45,11 +46,52 @@ struct DotCsrDnsDns {
};

/*!
* \brief Kernel of dot(csr.T(), dns1) = dns2
* Parallelization by output matrix elements
* \brief Vector kernel of dot(csr, dns1) = dns2
* Parallelization by output matrix elements: 1 warp/element
*/
template<int req>
struct DotCsrTransDnsDns {
struct DotCsrDnsDnsVectorKernel {
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid, DType* out, const DType* data_l, const IType* indptr_l,
const CType* col_idx_l, const DType* data_r,
const int num_cols_r) {
__shared__ volatile DType vals[kBaseThreadNum];

const int warp_id = tid / 32; // global warp id
const int lane = tid & (32-1); // local thread id within warp
const int irow = warp_id / num_cols_r; // lhs row that this warp computes
const int kcol = warp_id % num_cols_r; // rhs column that this warp computes

// Range of nnz elements in this row
const int low = static_cast<int>(indptr_l[irow]);
const int high = static_cast<int>(indptr_l[irow+1]);

// Compute running sum per thread
DType sum = 0;
for (int j = low+lane; j < high; j+=32) {
sum += data_l[j] * data_r[col_idx_l[j]*num_cols_r + kcol];
}
vals[threadIdx.x] = sum;

// Parallel reduction in shared memory
if (lane < 16) {vals[threadIdx.x] += vals[threadIdx.x+16];} __syncwarp();
if (lane < 8) {vals[threadIdx.x] += vals[threadIdx.x+ 8];} __syncwarp();
if (lane < 4) {vals[threadIdx.x] += vals[threadIdx.x+ 4];} __syncwarp();
if (lane < 2) {vals[threadIdx.x] += vals[threadIdx.x+ 2];} __syncwarp();
if (lane < 1) {vals[threadIdx.x] += vals[threadIdx.x+ 1];} __syncwarp();

if (lane == 0) {
KERNEL_ASSIGN(out[irow*num_cols_r+kcol], req, vals[threadIdx.x]);
}
}
};

/*!
* \brief Scalar kernel of dot(csr.T(), dns1) = dns2
* Parallelization by output matrix elements: 1 thread/element
*/
template<int req>
struct DotCsrTransDnsDnsScalarKernel {
/*!
* \brief This function represents performing an inner product between a column of lhs
* and a column of rhs and then assigning the value to out[i].
Expand All @@ -69,6 +111,8 @@ struct DotCsrTransDnsDns {
const int irow = i / num_cols; // col id of the lhs
const int icol = i % num_cols; // col id of the rhs
DType sum = 0;

// Each thread scans each column with binary search to find nnz elements in its row
for (int k = 0; k < num_rows_l; ++k) {
const IType low = indptr_l[k];
const IType high = indptr_l[k+1];
Expand All @@ -93,6 +137,98 @@ struct DotCsrTransDnsDns {
}
};

/*!
* \brief Warp kernel of dot(csr.T(), dns1) = dns2
* Parallelization by columns: 1 warp computes one lhs column for one rhs column
*/
template<int req>
struct DotCsrTransDnsDnsWarpKernel {
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid, DType* out, const DType* data_l, const IType* indptr_l,
const CType* col_idx_l, const DType* data_r,
const int num_cols_r) {
const int warp_id = tid / 32; // global warp id
const int lane = tid & (32-1); // local thread id within warp
const int icol = warp_id / num_cols_r; // lhs column that this warp computes
const int kcol = warp_id % num_cols_r; // rhs column that this warp computes

// Compute range of nnz elements in this column
const int low = static_cast<int>(indptr_l[icol]);
const int high = static_cast<int>(indptr_l[icol+1]);

// Iterate through the nnz elements in this column
for (int j = low+lane; j < high; j+=32) {
const int irow = static_cast<int>(col_idx_l[j]);
const DType val = data_l[j]*data_r[icol*num_cols_r+kcol];
atomicAdd(static_cast<DType *>(&(out[irow*num_cols_r+kcol])), val);
}
}
};

/*!
* \brief Thread block kernel of dot(csr.T(), dns1) = dns2
* Parallelization by columns: 1 thread block computes one lhs column for all rhs columns
*/
template<int req>
struct DotCsrTransDnsDnsThreadBlockKernel {
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid, DType* out, const DType* data_l, const IType* indptr_l,
const CType* col_idx_l, const DType* data_r,
const int num_cols_r) {
const int warps_per_block = blockDim.x / 32; // number of warps in this thread block
const int warp_id = tid / 32; // global warp id
const int lane = tid & (32-1); // local thread id within warp
const int icol = blockIdx.x; // lhs column that this thread block computes
const int kcol = warp_id % warps_per_block; // rhs column where warp starts computing (offset)

// Compute range of nnz elements in this lhs column
const int low = static_cast<int>(indptr_l[icol]);
const int high = static_cast<int>(indptr_l[icol+1]);

// Iterate through the nnz elements in this lhs column
for (int j = low+lane; j < high; j+=32) {
const int irow = static_cast<int>(col_idx_l[j]);
const DType datum_l = data_l[j];
// Iterate over rhs columns that this warp computes
for (int k = kcol; k < num_cols_r; k+=warps_per_block) {
const DType val = datum_l*data_r[icol*num_cols_r+k];
atomicAdd(static_cast<DType *>(&(out[irow*num_cols_r+k])), val);
}
}
}
};

/*!
* \brief Warp block kernel of dot(csr.T(), dns1) = dns2
* Parallelization by columns: 1 warp computes one lhs column for all rhs columns
*/
template<int req>
struct DotCsrTransDnsDnsWarpBlockKernel {
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid, DType* out, const DType* data_l, const IType* indptr_l,
const CType* col_idx_l, const DType* data_r,
const int num_cols_r) {
const int warp_id = tid / 32; // global warp id
const int lane = tid & (32-1); // local thread id within warp
const int icol = warp_id; // lhs column that this warp computes

// Compute range of nnz elements in this column
const int low = static_cast<int>(indptr_l[icol]);
const int high = static_cast<int>(indptr_l[icol+1]);

// Iterate through the nnz elements in lhs column
for (int j = low+lane; j < high; j+=32) {
const int irow = static_cast<int>(col_idx_l[j]);
const DType datum_l = data_l[j];
// Iterate over all rhs columns
for (int k = 0; k < num_cols_r; k++) {
const DType val = datum_l*data_r[icol*num_cols_r+k];
atomicAdd(static_cast<DType *>(&(out[irow*num_cols_r+k])), val);
}
}
}
};

inline void DotCsrDnsDnsImpl(mshadow::Stream<gpu>* s,
const NDArray& lhs,
const TBlob& rhs,
Expand All @@ -109,22 +245,107 @@ inline void DotCsrDnsDnsImpl(mshadow::Stream<gpu>* s,
const TBlob& data_r = rhs;
const TBlob data_out = *ret;

MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
if (kWriteTo == req) {
mxnet_op::Kernel<mxnet_op::set_zero, gpu>::Launch(s, data_out.Size(), data_out.dptr<DType>());
}
int num_threads;
const int threads_per_warp = 32;
const int threads_per_block = kBaseThreadNum;
const int num_rows_l = lhs.shape()[0];
const int num_cols_r = rhs.shape_[1];
if (trans_lhs) {
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrTransDnsDns<ReqType>, gpu>::Launch(s, data_out.Size(),
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), lhs.shape()[0],
data_out.shape_[1]);
});
// Different kernel versions are optimized for different matrix instances
// TODO: switch between kernel versions depending on input
// (1) 'Scalar kernel' (one thread computing one output element )
// (2) 'Warp kernel' (one warp computing one lhs column for one rhs column )
// (3) 'Thread block kernel' (one thread block computing one lhs column for all rhs columns)
// (4) 'Warp block kernel' (one warp computing one lhs column for all rhs columns)
const int kernel_version = 0;
switch (kernel_version) {
case 1:
num_threads = data_out.Size();
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrTransDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_rows_l, num_cols_r);
});
break;
case 2:
num_threads = threads_per_warp * num_rows_l * num_cols_r;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrTransDnsDnsWarpKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
break;
case 3:
num_threads = threads_per_block * num_rows_l;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrTransDnsDnsThreadBlockKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
break;
case 4:
num_threads = threads_per_warp * num_rows_l;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrTransDnsDnsWarpBlockKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
break;
default:
num_threads = threads_per_warp * num_rows_l * num_cols_r;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrTransDnsDnsWarpKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
break;
}
} else {
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrDnsDns<ReqType>, gpu>::Launch(s, data_out.Size(),
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), rhs.shape_[1]);
});
// Different kernel versions are optimized for different matrix instances
// (1) 'Scalar kernel' (one thread computing one output element)
// (2) 'Vector kernel' (one warp computing one output element)
const int kernel_version = 0;
switch (kernel_version) {
case 1:
num_threads = data_out.Size();
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
break;
case 2:
num_threads = threads_per_warp * num_rows_l * num_cols_r;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrDnsDnsVectorKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
break;
default:
if (num_cols_r > 4) {
num_threads = data_out.Size();
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
} else {
num_threads = threads_per_warp * num_rows_l * num_cols_r;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrDnsDnsVectorKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
}
break;
}
}
});
});
Expand Down
3 changes: 2 additions & 1 deletion src/operator/tensor/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
// csr has many zero columns, so the result of dot(csr.T, matrix) should be rsp
if (param.transpose_a && kCSRStorage == (*in_attrs)[0]) {
// dot(csr.T,dns)=rsp not yet implemented on gpu
if (param.transpose_a && kCSRStorage == (*in_attrs)[0] && ctx.dev_type != Context::kGPU) {
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage);
} else {
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage);
Expand Down
1 change: 1 addition & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from test_operator import *
from test_optimizer import *
from test_random import *
from test_sparse_operator import test_sparse_dot
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stefanhenneking
Did you run nosetests --verbose tests/python/gpu/test_operator_gpu.py after you added test_sparse_dot to gpu test suite? It fails on my p2 machine with the most recent version of sparse branch 0d8d6c4
The last test(test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True, 0.05)) fails because it's not implemented yet.

terminate called after throwing an instance of 'dmlc::Error'
  what():  [17:40:43] src/engine/./threaded_engine.h:329: [17:40:43] src/operator/tensor/./dot-inl.h:580: DotCsrRspDnsImpl has not implemented dot(csr.T, rsp) = dns yet

Stack trace returned 10 entries:
[bt] (0) /home/ubuntu/haibin-mxnet/python/mxnet/../../lib/libmxnet.so(_ZN4dmlc15LogMessageFatalD1Ev+0x3f) [0x7f133a7cde69]
[bt] (1) /home/ubuntu/haibin-mxnet/python/mxnet/../../lib/libmxnet.so(_ZN5mxnet2op16DotCsrRspDnsImplIN7mshadow3gpuEEEvPNS2_6StreamIT_EERKNS_7NDArrayESA_NS_9OpReqTypeEbPNS_5TBlobE+0$
6ff) [0x7f133bdf1bed]
[bt] (2) /home/ubuntu/haibin-mxnet/python/mxnet/../../lib/libmxnet.so(_ZN5mxnet2op12DotForwardExIN7mshadow3gpuEEEvRKN4nnvm9NodeAttrsERKNS_9OpContextERKSt6vectorINS_7NDArrayESaISC_E$
RKSB_INS_9OpReqTypeESaISH_EESG_+0x562) [0x7f133bdec4ff]
[bt] (3) /home/ubuntu/haibin-mxnet/python/mxnet/../../lib/libmxnet.so(_ZNSt17_Function_handlerIFvRKN4nnvm9NodeAttrsERKN5mxnet9OpContextERKSt6vectorINS4_7NDArrayESaIS9_EERKS8_INS4_9$
pReqTypeESaISE_EESD_EPSJ_E9_M_invokeERKSt9_Any_dataS3_S7_SD_SI_SD_+0x91) [0x7f133aba34a4]
[bt] (4) /home/ubuntu/haibin-mxnet/python/mxnet/../../lib/libmxnet.so(_ZNKSt8functionIFvRKN4nnvm9NodeAttrsERKN5mxnet9OpContextERKSt6vectorINS4_7NDArrayESaIS9_EERKS8_INS4_9OpReqType$
SaISE_EESD_EEclES3_S7_SD_SI_SD_+0xa6) [0x7f133b41effe]
[bt] (5) /home/ubuntu/haibin-mxnet/python/mxnet/../../lib/libmxnet.so(_ZN5mxnet4exec18FComputeExExecutor3RunENS_10RunContextEb+0x6a) [0x7f133b460ea2]
[bt] (6) /home/ubuntu/haibin-mxnet/python/mxnet/../../lib/libmxnet.so(+0x2436816) [0x7f133b479816]
[bt] (7) /home/ubuntu/haibin-mxnet/python/mxnet/../../lib/libmxnet.so(+0x2437e3a) [0x7f133b47ae3a]
[bt] (8) /home/ubuntu/haibin-mxnet/python/mxnet/../../lib/libmxnet.so(_ZNKSt8functionIFvN5mxnet10RunContextENS0_6engine18CallbackOnCompleteEEEclES1_S3_+0x6e) [0x7f133b446d38]
[bt] (9) /home/ubuntu/haibin-mxnet/python/mxnet/../../lib/libmxnet.so(_ZN5mxnet6engine14ThreadedEngine15ExecuteOprBlockENS_10RunContextEPNS0_8OprBlockE+0x1e9) [0x7f133b4527d5]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason it did pass the tests, I noticed when I implemented the cast_storage op and fixed it here #7081.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see

import mxnet as mx
import numpy as np
from mxnet.test_utils import check_consistency, set_default_context
Expand Down
8 changes: 5 additions & 3 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,12 @@ def test_dns_to_csr(dns_in):

def test_sparse_dot():
def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1):
lhs_dns = rand_ndarray(lhs_shape, 'default')
lhs_nd = mx.nd.cast_storage(lhs_dns, storage_type='csr')
lhs_nd = rand_ndarray(lhs_shape, 'csr', 1)
lhs_dns = lhs_nd.todense()
rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=density)
rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.todense()
out = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs)
if trans_lhs:
if trans_lhs and default_context().device_type is 'cpu':
assert out.storage_type == 'row_sparse'
else:
assert out.storage_type == 'default'
Expand All @@ -131,6 +131,8 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1):
rtol=1e-3, atol=1e-4)

lhs_shape = rand_shape_2d(50, 200)
test_dot_csr(lhs_shape, (lhs_shape[1], 1), 'default', False)
test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True)
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'default', False)
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'default', True)
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False)
Expand Down