diff --git a/mshadow b/mshadow index 06407689699e..5a11d7544841 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit 06407689699efc043db1ba5a8131abc2c53c4cda +Subproject commit 5a11d7544841b55a8ac1a65081759dc2289c335d diff --git a/src/io/inst_vector.h b/src/io/inst_vector.h index d82bd48e2fa1..0a665bd6811d 100644 --- a/src/io/inst_vector.h +++ b/src/io/inst_vector.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh index 513fde306bab..c8572ba5e0cb 100644 --- a/src/operator/tensor/dot-inl.cuh +++ b/src/operator/tensor/dot-inl.cuh @@ -11,13 +11,14 @@ namespace mxnet { namespace op { +using mshadow::cuda::kBaseThreadNum; /*! - * \brief Kernel of dot(csr, dns1) = dns2 - * Parallelization by output matrix elements + * \brief Scalar kernel of dot(csr, dns1) = dns2 + * Parallelization by output matrix elements: 1 thread/element */ template -struct DotCsrDnsDns { +struct DotCsrDnsDnsScalarKernel { /*! * \brief This function represents performing an inner product between a row of lhs * and a column of rhs and then assigning the value to out[i]. @@ -45,11 +46,52 @@ struct DotCsrDnsDns { }; /*! - * \brief Kernel of dot(csr.T(), dns1) = dns2 - * Parallelization by output matrix elements + * \brief Vector kernel of dot(csr, dns1) = dns2 + * Parallelization by output matrix elements: 1 warp/element */ template -struct DotCsrTransDnsDns { +struct DotCsrDnsDnsVectorKernel { + template + __device__ __forceinline__ static void Map(int tid, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const int num_cols_r) { + __shared__ volatile DType vals[kBaseThreadNum]; + + const int warp_id = tid / 32; // global warp id + const int lane = tid & (32-1); // local thread id within warp + const int irow = warp_id / num_cols_r; // lhs row that this warp computes + const int kcol = warp_id % num_cols_r; // rhs column that this warp computes + + // Range of nnz elements in this row + const int low = static_cast(indptr_l[irow]); + const int high = static_cast(indptr_l[irow+1]); + + // Compute running sum per thread + DType sum = 0; + for (int j = low+lane; j < high; j+=32) { + sum += data_l[j] * data_r[col_idx_l[j]*num_cols_r + kcol]; + } + vals[threadIdx.x] = sum; + + // Parallel reduction in shared memory + if (lane < 16) {vals[threadIdx.x] += vals[threadIdx.x+16];} __syncwarp(); + if (lane < 8) {vals[threadIdx.x] += vals[threadIdx.x+ 8];} __syncwarp(); + if (lane < 4) {vals[threadIdx.x] += vals[threadIdx.x+ 4];} __syncwarp(); + if (lane < 2) {vals[threadIdx.x] += vals[threadIdx.x+ 2];} __syncwarp(); + if (lane < 1) {vals[threadIdx.x] += vals[threadIdx.x+ 1];} __syncwarp(); + + if (lane == 0) { + KERNEL_ASSIGN(out[irow*num_cols_r+kcol], req, vals[threadIdx.x]); + } + } +}; + +/*! + * \brief Scalar kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by output matrix elements: 1 thread/element + */ +template +struct DotCsrTransDnsDnsScalarKernel { /*! * \brief This function represents performing an inner product between a column of lhs * and a column of rhs and then assigning the value to out[i]. @@ -69,6 +111,8 @@ struct DotCsrTransDnsDns { const int irow = i / num_cols; // col id of the lhs const int icol = i % num_cols; // col id of the rhs DType sum = 0; + + // Each thread scans each column with binary search to find nnz elements in its row for (int k = 0; k < num_rows_l; ++k) { const IType low = indptr_l[k]; const IType high = indptr_l[k+1]; @@ -93,6 +137,98 @@ struct DotCsrTransDnsDns { } }; +/*! + * \brief Warp kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by columns: 1 warp computes one lhs column for one rhs column + */ +template +struct DotCsrTransDnsDnsWarpKernel { + template + __device__ __forceinline__ static void Map(int tid, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const int num_cols_r) { + const int warp_id = tid / 32; // global warp id + const int lane = tid & (32-1); // local thread id within warp + const int icol = warp_id / num_cols_r; // lhs column that this warp computes + const int kcol = warp_id % num_cols_r; // rhs column that this warp computes + + // Compute range of nnz elements in this column + const int low = static_cast(indptr_l[icol]); + const int high = static_cast(indptr_l[icol+1]); + + // Iterate through the nnz elements in this column + for (int j = low+lane; j < high; j+=32) { + const int irow = static_cast(col_idx_l[j]); + const DType val = data_l[j]*data_r[icol*num_cols_r+kcol]; + atomicAdd(static_cast(&(out[irow*num_cols_r+kcol])), val); + } + } +}; + +/*! + * \brief Thread block kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by columns: 1 thread block computes one lhs column for all rhs columns + */ +template +struct DotCsrTransDnsDnsThreadBlockKernel { + template + __device__ __forceinline__ static void Map(int tid, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const int num_cols_r) { + const int warps_per_block = blockDim.x / 32; // number of warps in this thread block + const int warp_id = tid / 32; // global warp id + const int lane = tid & (32-1); // local thread id within warp + const int icol = blockIdx.x; // lhs column that this thread block computes + const int kcol = warp_id % warps_per_block; // rhs column where warp starts computing (offset) + + // Compute range of nnz elements in this lhs column + const int low = static_cast(indptr_l[icol]); + const int high = static_cast(indptr_l[icol+1]); + + // Iterate through the nnz elements in this lhs column + for (int j = low+lane; j < high; j+=32) { + const int irow = static_cast(col_idx_l[j]); + const DType datum_l = data_l[j]; + // Iterate over rhs columns that this warp computes + for (int k = kcol; k < num_cols_r; k+=warps_per_block) { + const DType val = datum_l*data_r[icol*num_cols_r+k]; + atomicAdd(static_cast(&(out[irow*num_cols_r+k])), val); + } + } + } +}; + +/*! + * \brief Warp block kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by columns: 1 warp computes one lhs column for all rhs columns + */ +template +struct DotCsrTransDnsDnsWarpBlockKernel { + template + __device__ __forceinline__ static void Map(int tid, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const int num_cols_r) { + const int warp_id = tid / 32; // global warp id + const int lane = tid & (32-1); // local thread id within warp + const int icol = warp_id; // lhs column that this warp computes + + // Compute range of nnz elements in this column + const int low = static_cast(indptr_l[icol]); + const int high = static_cast(indptr_l[icol+1]); + + // Iterate through the nnz elements in lhs column + for (int j = low+lane; j < high; j+=32) { + const int irow = static_cast(col_idx_l[j]); + const DType datum_l = data_l[j]; + // Iterate over all rhs columns + for (int k = 0; k < num_cols_r; k++) { + const DType val = datum_l*data_r[icol*num_cols_r+k]; + atomicAdd(static_cast(&(out[irow*num_cols_r+k])), val); + } + } + } +}; + inline void DotCsrDnsDnsImpl(mshadow::Stream* s, const NDArray& lhs, const TBlob& rhs, @@ -109,22 +245,107 @@ inline void DotCsrDnsDnsImpl(mshadow::Stream* s, const TBlob& data_r = rhs; const TBlob data_out = *ret; - MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + if (kWriteTo == req) { + mxnet_op::Kernel::Launch(s, data_out.Size(), data_out.dptr()); + } + int num_threads; + const int threads_per_warp = 32; + const int threads_per_block = kBaseThreadNum; + const int num_rows_l = lhs.shape()[0]; + const int num_cols_r = rhs.shape_[1]; if (trans_lhs) { - MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { - mxnet_op::Kernel, gpu>::Launch(s, data_out.Size(), - data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), lhs.shape()[0], - data_out.shape_[1]); - }); + // Different kernel versions are optimized for different matrix instances + // TODO: switch between kernel versions depending on input + // (1) 'Scalar kernel' (one thread computing one output element ) + // (2) 'Warp kernel' (one warp computing one lhs column for one rhs column ) + // (3) 'Thread block kernel' (one thread block computing one lhs column for all rhs columns) + // (4) 'Warp block kernel' (one warp computing one lhs column for all rhs columns) + const int kernel_version = 0; + switch (kernel_version) { + case 1: + num_threads = data_out.Size(); + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_rows_l, num_cols_r); + }); + break; + case 2: + num_threads = threads_per_warp * num_rows_l * num_cols_r; + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_cols_r); + }); + break; + case 3: + num_threads = threads_per_block * num_rows_l; + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_cols_r); + }); + break; + case 4: + num_threads = threads_per_warp * num_rows_l; + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_cols_r); + }); + break; + default: + num_threads = threads_per_warp * num_rows_l * num_cols_r; + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_cols_r); + }); + break; + } } else { - MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { - mxnet_op::Kernel, gpu>::Launch(s, data_out.Size(), - data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), rhs.shape_[1]); - }); + // Different kernel versions are optimized for different matrix instances + // (1) 'Scalar kernel' (one thread computing one output element) + // (2) 'Vector kernel' (one warp computing one output element) + const int kernel_version = 0; + switch (kernel_version) { + case 1: + num_threads = data_out.Size(); + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_cols_r); + }); + break; + case 2: + num_threads = threads_per_warp * num_rows_l * num_cols_r; + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_cols_r); + }); + break; + default: + if (num_cols_r > 4) { + num_threads = data_out.Size(); + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_cols_r); + }); + } else { + num_threads = threads_per_warp * num_rows_l * num_cols_r; + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_cols_r); + }); + } + break; + } } }); }); diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 33cc095c0cee..7440128dce09 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -187,7 +187,8 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); const DotParam& param = nnvm::get(attrs.parsed); // csr has many zero columns, so the result of dot(csr.T, matrix) should be rsp - if (param.transpose_a && kCSRStorage == (*in_attrs)[0]) { + // dot(csr.T,dns)=rsp not yet implemented on gpu + if (param.transpose_a && kCSRStorage == (*in_attrs)[0] && ctx.dev_type != Context::kGPU) { STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage); } else { STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage); diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 48e44133216b..79806901f522 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -5,6 +5,7 @@ from test_operator import * from test_optimizer import * from test_random import * +from test_sparse_operator import test_sparse_dot import mxnet as mx import numpy as np from mxnet.test_utils import check_consistency, set_default_context diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 1fc64a7149ea..f2d61d225153 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -103,12 +103,12 @@ def test_dns_to_csr(dns_in): def test_sparse_dot(): def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1): - lhs_dns = rand_ndarray(lhs_shape, 'default') - lhs_nd = mx.nd.cast_storage(lhs_dns, storage_type='csr') + lhs_nd = rand_ndarray(lhs_shape, 'csr', 1) + lhs_dns = lhs_nd.todense() rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=density) rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.todense() out = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs) - if trans_lhs: + if trans_lhs and default_context().device_type is 'cpu': assert out.storage_type == 'row_sparse' else: assert out.storage_type == 'default' @@ -131,6 +131,8 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1): rtol=1e-3, atol=1e-4) lhs_shape = rand_shape_2d(50, 200) + test_dot_csr(lhs_shape, (lhs_shape[1], 1), 'default', False) + test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True) test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'default', False) test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'default', True) test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False)