diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh index a65936840662..19cac543bf50 100644 --- a/src/operator/tensor/dot-inl.cuh +++ b/src/operator/tensor/dot-inl.cuh @@ -30,6 +30,8 @@ #include "./util/tensor_util-inl.h" #include "./util/tensor_util-inl.cuh" +typedef unsigned long long AtomicIType; + namespace mxnet { namespace op { @@ -453,17 +455,17 @@ struct CscDataIndicesKernel { const IType* csr_indices, const CType* csr_indptr, DType* csc_data, - unsigned long long* csc_indices, - unsigned long long* csc_indptr, - unsigned long long* workspace, + AtomicIType* csc_indices, + AtomicIType* csc_indptr, + AtomicIType* col_counters, const nnvm::dim_t num_rows, const nnvm::dim_t num_cols) { if (tid < num_rows) { for (CType i = csr_indptr[tid]; i < csr_indptr[tid + 1]; ++i) { // target column - IType target_col = csr_indices[i]; - int target_offset = atomicAdd(&workspace[target_col], 1); - int new_pos = csc_indptr[target_col] + target_offset; + const IType target_col = csr_indices[i]; + const int target_offset = atomicAdd(&col_counters[target_col], 1); + const int new_pos = csc_indptr[target_col] + target_offset; csc_data[new_pos] = csr_data[i]; csc_indices[new_pos] = tid; } @@ -486,7 +488,7 @@ struct CsrTransHistogramKernel { template __device__ __forceinline__ static void Map(int tid, const IType* in_indices, - unsigned long long* out_indptr, + AtomicIType* out_indptr, const nnvm::dim_t nnz) { if (tid < nnz) { atomicAdd(&out_indptr[in_indices[tid]], 1); @@ -1023,54 +1025,60 @@ inline void DotDnsCsrDnsImpl(const OpContext& ctx, return; } - // if dot(dense, csr) = dns, transform to csc first - if (!transpose_b) { - // LOG(FATAL) << "dot(dns, csr) = dns not implemented yet"; - const nnvm::dim_t csr_rows = rhs.shape()[0]; - const nnvm::dim_t csr_cols = rhs.shape()[1]; - const nnvm::dim_t dns_rows = dns.shape_[0]; - const nnvm::dim_t nnz = rhs.storage_shape().Size(); - - MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { - MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { - MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { + MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { // colidx type + const CType out_num_rows = ret->shape()[0]; + const CType out_num_cols = ret->shape()[1]; + // if dot(dense, csr) = dns, transform to csc first + if (!transpose_b) { + const nnvm::dim_t num_csr_rows = rhs.shape()[0]; + const nnvm::dim_t num_csr_cols = rhs.shape()[1]; + const nnvm::dim_t num_dns_rows = dns.shape_[0]; + const nnvm::dim_t nnz = rhs.storage_shape().Size(); + DType* csc_data_ptr = NULL; - unsigned long long* csc_indices_ptr = NULL; - unsigned long long* csc_indptr_ptr = NULL; - unsigned long long* col_counters = NULL; - size_t ull_mem_size = sizeof(unsigned long long); + AtomicIType* csc_indices_ptr = NULL; + AtomicIType* csc_indptr_ptr = NULL; + AtomicIType* col_counters = NULL; + size_t ull_num_bytes = sizeof(AtomicIType); void* temp_storage = NULL; size_t temp_storage_bytes = 0; - CType out_num_rows = ret->shape()[0]; - CType out_num_cols = ret->shape()[1]; + // Get necessary temporary storage amount cub::DeviceScan::ExclusiveSum(NULL, temp_storage_bytes, csc_indices_ptr, csc_indices_ptr, - csr_cols+1, + num_csr_cols + 1, Stream::GetStream(s)); - temp_storage_bytes += (ull_mem_size - (temp_storage_bytes % ull_mem_size)); + // Align to multiple of ull_num_bytes + temp_storage_bytes += (ull_num_bytes - (temp_storage_bytes % ull_num_bytes)); + size_t csc_data_size = nnz*sizeof(DType); + size_t csc_indices_size = nnz*ull_num_bytes; + size_t csc_indptr_size = (num_csr_cols+1)*ull_num_bytes; + size_t col_counters_size = (num_csr_cols+1)*ull_num_bytes; Tensor workspace = ctx.requested[0].get_space_typed( - Shape1(nnz*sizeof(DType) + nnz*ull_mem_size + - 2*(csr_cols + 1)*ull_mem_size + + Shape1(csc_data_size + csc_indices_size + + csc_indptr_size + col_counters_size + temp_storage_bytes), s); - csc_indices_ptr = reinterpret_cast(workspace.dptr_); - csc_indptr_ptr = reinterpret_cast( - workspace.dptr_ + nnz*ull_mem_size); - col_counters = reinterpret_cast( - workspace.dptr_ + nnz*ull_mem_size + (csr_cols+1)*ull_mem_size); - csc_data_ptr = reinterpret_cast(workspace.dptr_ + nnz*ull_mem_size + - 2*(csr_cols+1)*ull_mem_size); - temp_storage = reinterpret_cast(workspace.dptr_ + nnz*sizeof(DType) + - nnz*ull_mem_size + 2*(csr_cols+1)*ull_mem_size); + csc_indices_ptr = reinterpret_cast(workspace.dptr_); + csc_indptr_ptr = reinterpret_cast( + workspace.dptr_ + csc_indices_size); + col_counters = reinterpret_cast( + workspace.dptr_ + csc_indices_size + csc_indptr_size); + csc_data_ptr = reinterpret_cast(workspace.dptr_ + csc_indices_size + + csc_indptr_size + col_counters_size); + temp_storage = reinterpret_cast(workspace.dptr_ + csc_data_size + + csc_indices_size + csc_indptr_size + + col_counters_size); mxnet_op::Kernel::Launch( - s, dns_rows*csr_cols, ret->data().dptr()); + s, num_dns_rows*num_csr_cols, ret->data().dptr()); // Reset values for indptr, ready for histogramming mxnet_op::Kernel::Launch( - s, csr_cols + 1, csc_indptr_ptr); + s, num_csr_cols+1, csc_indptr_ptr); // Histogramming on col id mxnet_op::Kernel::Launch( s, nnz, csr_indices.dptr(), csc_indptr_ptr, nnz); @@ -1078,39 +1086,31 @@ inline void DotDnsCsrDnsImpl(const OpContext& ctx, temp_storage_bytes, csc_indptr_ptr, csc_indptr_ptr, - csr_cols+1, + num_csr_cols + 1, Stream::GetStream(s)); // Reset values for col_counter, ready for the final transform mxnet_op::Kernel::Launch( - s, csr_cols+1, col_counters); + s, num_csr_cols+1, col_counters); // Transform to CSC mxnet_op::Kernel::Launch( - s, csr_rows, csr_data.dptr(), csr_indices.dptr(), + s, num_csr_rows, csr_data.dptr(), csr_indices.dptr(), csr_indptr.dptr(), csc_data_ptr, csc_indices_ptr, - csc_indptr_ptr, col_counters, csr_rows, csr_cols); + csc_indptr_ptr, col_counters, num_csr_rows, num_csr_cols); mxnet_op::Kernel::Launch( s, out_num_rows * out_num_cols, dns.dptr(), csc_data_ptr, csc_indices_ptr, csc_indptr_ptr, ret->data().dptr(), dns.shape_[1], out_num_rows, out_num_cols); - }); - }); - }); - } else { - MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { // data type - MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { // indptr type - MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { // colidx type - CType out_num_rows = ret->shape()[0]; - CType out_num_cols = ret->shape()[1]; + } else { mxnet_op::Kernel::Launch( s, out_num_rows * out_num_cols, dns.dptr(), csr_data.dptr(), csr_indices.dptr(), csr_indptr.dptr(), ret->data().dptr(), dns.shape_[1], out_num_rows, out_num_cols); - }); + } }); }); - } + }); } } // namespace op diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 2270304b90a3..7ca33fe13b2d 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -224,10 +224,11 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, bool rhs_rsp_or_dns = rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage; NDArrayStorageType target_stype; + bool hint_has_value = param.forward_stype_hint.has_value(); if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kDefaultStorage) { // dns, dns -> dns - target_stype = (param.forward_stype_hint.has_value())? + target_stype = hint_has_value ? static_cast(param.forward_stype_hint.value()) : kDefaultStorage; if (target_stype == kDefaultStorage) { @@ -237,7 +238,7 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, } if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose && rhs_rsp_or_dns) { // csr.T, rsp/dns -> rsp - target_stype = (param.forward_stype_hint.has_value())? + target_stype = hint_has_value ? static_cast(param.forward_stype_hint.value()) : kRowSparseStorage; if (target_stype == kRowSparseStorage) { @@ -248,7 +249,7 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, if (!dispatched && lhs_stype == kCSRStorage && rhs_rsp_or_dns && !param.transpose_a && !param.transpose_b) { // csr, rsp/dns -> dns - target_stype = (param.forward_stype_hint.has_value())? + target_stype = hint_has_value ? static_cast(param.forward_stype_hint.value()) : kDefaultStorage; if (target_stype == kDefaultStorage) { @@ -260,7 +261,7 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, !param.transpose_a) { // dns, csr -> csr on CPU if (dev_mask == mshadow::cpu::kDevMask && !param.transpose_b) { - target_stype = (param.forward_stype_hint.has_value())? + target_stype = hint_has_value ? static_cast(param.forward_stype_hint.value()) : kCSRStorage; if (target_stype == kCSRStorage) { @@ -269,7 +270,7 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, } // dns, csr/csr.T -> dns on GPU } else if (dev_mask == mshadow::gpu::kDevMask) { - target_stype = (param.forward_stype_hint.has_value())? + target_stype = hint_has_value ? static_cast(param.forward_stype_hint.value()) : kDefaultStorage; if (target_stype == kDefaultStorage) {