Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
address code reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Jin committed Apr 13, 2018
1 parent bf7ed00 commit 52ebdbc
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 59 deletions.
108 changes: 54 additions & 54 deletions src/operator/tensor/dot-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include "./util/tensor_util-inl.h"
#include "./util/tensor_util-inl.cuh"

typedef unsigned long long AtomicIType;

namespace mxnet {
namespace op {

Expand Down Expand Up @@ -453,17 +455,17 @@ struct CscDataIndicesKernel {
const IType* csr_indices,
const CType* csr_indptr,
DType* csc_data,
unsigned long long* csc_indices,
unsigned long long* csc_indptr,
unsigned long long* workspace,
AtomicIType* csc_indices,
AtomicIType* csc_indptr,
AtomicIType* col_counters,
const nnvm::dim_t num_rows,
const nnvm::dim_t num_cols) {
if (tid < num_rows) {
for (CType i = csr_indptr[tid]; i < csr_indptr[tid + 1]; ++i) {
// target column
IType target_col = csr_indices[i];
int target_offset = atomicAdd(&workspace[target_col], 1);
int new_pos = csc_indptr[target_col] + target_offset;
const IType target_col = csr_indices[i];
const int target_offset = atomicAdd(&col_counters[target_col], 1);
const int new_pos = csc_indptr[target_col] + target_offset;
csc_data[new_pos] = csr_data[i];
csc_indices[new_pos] = tid;
}
Expand All @@ -486,7 +488,7 @@ struct CsrTransHistogramKernel {
template<typename IType>
__device__ __forceinline__ static void Map(int tid,
const IType* in_indices,
unsigned long long* out_indptr,
AtomicIType* out_indptr,
const nnvm::dim_t nnz) {
if (tid < nnz) {
atomicAdd(&out_indptr[in_indices[tid]], 1);
Expand Down Expand Up @@ -1023,94 +1025,92 @@ inline void DotDnsCsrDnsImpl(const OpContext& ctx,
return;
}

// if dot(dense, csr) = dns, transform to csc first
if (!transpose_b) {
// LOG(FATAL) << "dot(dns, csr) = dns not implemented yet";
const nnvm::dim_t csr_rows = rhs.shape()[0];
const nnvm::dim_t csr_cols = rhs.shape()[1];
const nnvm::dim_t dns_rows = dns.shape_[0];
const nnvm::dim_t nnz = rhs.storage_shape().Size();

MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, {
MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, {
MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { // colidx type
const CType out_num_rows = ret->shape()[0];
const CType out_num_cols = ret->shape()[1];
// if dot(dense, csr) = dns, transform to csc first
if (!transpose_b) {
const nnvm::dim_t num_csr_rows = rhs.shape()[0];
const nnvm::dim_t num_csr_cols = rhs.shape()[1];
const nnvm::dim_t num_dns_rows = dns.shape_[0];
const nnvm::dim_t nnz = rhs.storage_shape().Size();

DType* csc_data_ptr = NULL;
unsigned long long* csc_indices_ptr = NULL;
unsigned long long* csc_indptr_ptr = NULL;
unsigned long long* col_counters = NULL;
size_t ull_mem_size = sizeof(unsigned long long);
AtomicIType* csc_indices_ptr = NULL;
AtomicIType* csc_indptr_ptr = NULL;
AtomicIType* col_counters = NULL;
size_t ull_num_bytes = sizeof(AtomicIType);
void* temp_storage = NULL;
size_t temp_storage_bytes = 0;
CType out_num_rows = ret->shape()[0];
CType out_num_cols = ret->shape()[1];

// Get necessary temporary storage amount
cub::DeviceScan::ExclusiveSum(NULL,
temp_storage_bytes,
csc_indices_ptr,
csc_indices_ptr,
csr_cols+1,
num_csr_cols + 1,
Stream<gpu>::GetStream(s));
temp_storage_bytes += (ull_mem_size - (temp_storage_bytes % ull_mem_size));
// Align to multiple of ull_num_bytes
temp_storage_bytes += (ull_num_bytes - (temp_storage_bytes % ull_num_bytes));
size_t csc_data_size = nnz*sizeof(DType);
size_t csc_indices_size = nnz*ull_num_bytes;
size_t csc_indptr_size = (num_csr_cols+1)*ull_num_bytes;
size_t col_counters_size = (num_csr_cols+1)*ull_num_bytes;
Tensor<gpu, 1, char> workspace =
ctx.requested[0].get_space_typed<gpu, 1, char>(
Shape1(nnz*sizeof(DType) + nnz*ull_mem_size +
2*(csr_cols + 1)*ull_mem_size +
Shape1(csc_data_size + csc_indices_size +
csc_indptr_size + col_counters_size +
temp_storage_bytes),
s);
csc_indices_ptr = reinterpret_cast<unsigned long long*>(workspace.dptr_);
csc_indptr_ptr = reinterpret_cast<unsigned long long*>(
workspace.dptr_ + nnz*ull_mem_size);
col_counters = reinterpret_cast<unsigned long long*>(
workspace.dptr_ + nnz*ull_mem_size + (csr_cols+1)*ull_mem_size);
csc_data_ptr = reinterpret_cast<DType*>(workspace.dptr_ + nnz*ull_mem_size +
2*(csr_cols+1)*ull_mem_size);
temp_storage = reinterpret_cast<void*>(workspace.dptr_ + nnz*sizeof(DType) +
nnz*ull_mem_size + 2*(csr_cols+1)*ull_mem_size);
csc_indices_ptr = reinterpret_cast<AtomicIType*>(workspace.dptr_);
csc_indptr_ptr = reinterpret_cast<AtomicIType*>(
workspace.dptr_ + csc_indices_size);
col_counters = reinterpret_cast<AtomicIType*>(
workspace.dptr_ + csc_indices_size + csc_indptr_size);
csc_data_ptr = reinterpret_cast<DType*>(workspace.dptr_ + csc_indices_size +
csc_indptr_size + col_counters_size);
temp_storage = reinterpret_cast<void*>(workspace.dptr_ + csc_data_size +
csc_indices_size + csc_indptr_size +
col_counters_size);
mxnet_op::Kernel<mxnet_op::set_zero, gpu>::Launch(
s, dns_rows*csr_cols, ret->data().dptr<DType>());
s, num_dns_rows*num_csr_cols, ret->data().dptr<DType>());
// Reset values for indptr, ready for histogramming
mxnet_op::Kernel<mxnet_op::set_zero, gpu>::Launch(
s, csr_cols + 1, csc_indptr_ptr);
s, num_csr_cols+1, csc_indptr_ptr);
// Histogramming on col id
mxnet_op::Kernel<CsrTransHistogramKernel, gpu>::Launch(
s, nnz, csr_indices.dptr<IType>(), csc_indptr_ptr, nnz);
cub::DeviceScan::ExclusiveSum(temp_storage,
temp_storage_bytes,
csc_indptr_ptr,
csc_indptr_ptr,
csr_cols+1,
num_csr_cols + 1,
Stream<gpu>::GetStream(s));
// Reset values for col_counter, ready for the final transform
mxnet_op::Kernel<mxnet_op::set_zero, gpu>::Launch(
s, csr_cols+1, col_counters);
s, num_csr_cols+1, col_counters);
// Transform to CSC
mxnet_op::Kernel<CscDataIndicesKernel, gpu>::Launch(
s, csr_rows, csr_data.dptr<DType>(), csr_indices.dptr<IType>(),
s, num_csr_rows, csr_data.dptr<DType>(), csr_indices.dptr<IType>(),
csr_indptr.dptr<CType>(), csc_data_ptr, csc_indices_ptr,
csc_indptr_ptr, col_counters, csr_rows, csr_cols);
csc_indptr_ptr, col_counters, num_csr_rows, num_csr_cols);
mxnet_op::Kernel<DotDnsCsrTransDnsKernel, gpu>::Launch(
s, out_num_rows * out_num_cols, dns.dptr<DType>(),
csc_data_ptr, csc_indices_ptr, csc_indptr_ptr,
ret->data().dptr<DType>(), dns.shape_[1],
out_num_rows, out_num_cols);
});
});
});
} else {
MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { // colidx type
CType out_num_rows = ret->shape()[0];
CType out_num_cols = ret->shape()[1];
} else {
mxnet_op::Kernel<DotDnsCsrTransDnsKernel, gpu>::Launch(
s, out_num_rows * out_num_cols, dns.dptr<DType>(),
csr_data.dptr<DType>(), csr_indices.dptr<IType>(),
csr_indptr.dptr<CType>(), ret->data().dptr<DType>(),
dns.shape_[1], out_num_rows, out_num_cols);
});
}
});
});
}
});
}

} // namespace op
Expand Down
11 changes: 6 additions & 5 deletions src/operator/tensor/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,11 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
bool rhs_rsp_or_dns =
rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage;
NDArrayStorageType target_stype;
bool hint_has_value = param.forward_stype_hint.has_value();
if (!dispatched && lhs_stype == kDefaultStorage &&
rhs_stype == kDefaultStorage) {
// dns, dns -> dns
target_stype = (param.forward_stype_hint.has_value())?
target_stype = hint_has_value ?
static_cast<NDArrayStorageType>(param.forward_stype_hint.value()) :
kDefaultStorage;
if (target_stype == kDefaultStorage) {
Expand All @@ -237,7 +238,7 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
}
if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose && rhs_rsp_or_dns) {
// csr.T, rsp/dns -> rsp
target_stype = (param.forward_stype_hint.has_value())?
target_stype = hint_has_value ?
static_cast<NDArrayStorageType>(param.forward_stype_hint.value()) :
kRowSparseStorage;
if (target_stype == kRowSparseStorage) {
Expand All @@ -248,7 +249,7 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
if (!dispatched && lhs_stype == kCSRStorage && rhs_rsp_or_dns &&
!param.transpose_a && !param.transpose_b) {
// csr, rsp/dns -> dns
target_stype = (param.forward_stype_hint.has_value())?
target_stype = hint_has_value ?
static_cast<NDArrayStorageType>(param.forward_stype_hint.value()) :
kDefaultStorage;
if (target_stype == kDefaultStorage) {
Expand All @@ -260,7 +261,7 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
!param.transpose_a) {
// dns, csr -> csr on CPU
if (dev_mask == mshadow::cpu::kDevMask && !param.transpose_b) {
target_stype = (param.forward_stype_hint.has_value())?
target_stype = hint_has_value ?
static_cast<NDArrayStorageType>(param.forward_stype_hint.value()) :
kCSRStorage;
if (target_stype == kCSRStorage) {
Expand All @@ -269,7 +270,7 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
}
// dns, csr/csr.T -> dns on GPU
} else if (dev_mask == mshadow::gpu::kDevMask) {
target_stype = (param.forward_stype_hint.has_value())?
target_stype = hint_has_value ?
static_cast<NDArrayStorageType>(param.forward_stype_hint.value()) :
kDefaultStorage;
if (target_stype == kDefaultStorage) {
Expand Down

0 comments on commit 52ebdbc

Please sign in to comment.