-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on CPU #11113
Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on CPU #11113
Conversation
c481dff
to
99b4bf8
Compare
@haojin2 @eric-haibin-lin for the review :) |
src/operator/tensor/dot-inl.h
Outdated
dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, | ||
DispatchMode::kFComputeEx); | ||
// dns, csr/csr.T -> dns on CPU | ||
} else if (target_stype == kDefaultStorage) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: extra space at "== kDefaultStorage"
src/operator/tensor/dot-inl.h
Outdated
@@ -327,7 +331,8 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, | |||
dispatched = true; | |||
} | |||
} | |||
if (!dispatched && dev_mask == mshadow::gpu::kDevMask && !param.transpose_a && | |||
// if (!dispatched && dev_mask == mshadow::gpu::kDevMask && !param.transpose_a && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove unused code.
src/operator/tensor/dot-inl.h
Outdated
, data_r.dptr<DType>(), indptr_r.dptr<IType>() | ||
, col_idx_r.dptr<CType>(), seg_len | ||
, dns.shape_[0], dns.shape_[1] | ||
, rhs.shape()[0], rhs.shape()[1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For line breaks please follow google c++ style guide, a sample may be found here: https://gist.github.com/davidzchen/9187878#file-sample-google-c-L110-L114
src/operator/tensor/dot-inl.h
Outdated
*/ | ||
inline void DotDnsCsrDnsImpl(const OpContext& ctx, const cpu& cpu_dev, | ||
const TBlob& dns, const NDArray& rhs, | ||
const OpReqType req, NDArray* ret, | ||
const bool transpose_b) { | ||
LOG(FATAL) << "dot(dense, csr) = dense is not implemented on CPU"; | ||
if (kNullOp == req) return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: personally I would prefer "req == kNullOp".
src/operator/tensor/dot-inl.h
Outdated
MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type | ||
MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // col idx type | ||
dim_t num_threads; | ||
if (kWriteTo == req) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Personally I would prefer "req == kWriteTo"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx, I have modified those according to your suggestion.
5b78686
to
7fad590
Compare
7fad590
to
dd989d6
Compare
@@ -264,11 +264,15 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, | |||
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && | |||
!param.transpose_a) { | |||
target_stype = hint_has_value ? target_stype : kCSRStorage; | |||
// dns, csr -> csr on CPU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also update doc in https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/dot.cc#L63-L64
src/operator/tensor/dot-inl.h
Outdated
struct DotDnsCsrTransDnsByRowBlocks { | ||
/*! | ||
* \brief | ||
* \param i the i-th thread |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please complete documentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx, it's done.
@haojin2 @eric-haibin-lin I have solved the problems mentioned in your comments. Can you accept this PR ? |
LGTM, will wait for @eric-haibin-lin to take a final look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! a few comments
src/operator/tensor/dot-inl.h
Outdated
const TBlob& data_l = dns; | ||
const TBlob data_out = ret->data(); | ||
|
||
MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp16 will fail with this branch. I think there's a MSHADOW_REAL_TYPE_SWITCH
src/operator/tensor/dot-inl.h
Outdated
MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type | ||
MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // col idx type | ||
dim_t num_threads; | ||
if (req == kWriteTo) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
req == writeto || req == writeinplace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx, it's Done ! May you review it again :)
Thx, it's Done ! May you review it again :) @eric-haibin-lin |
…che#11113) * implement dot(dns, csr/csr.T)=dns on cpu * complete documentaion related to dot(dns, csr/csr.T)=dns on cpu * support fp16 by replacing MSHADOW_SGL_DBL_TYPE_SWITCH with MSHADOW_REAL_TYPE_SWITCH
…che#11113) * implement dot(dns, csr/csr.T)=dns on cpu * complete documentaion related to dot(dns, csr/csr.T)=dns on cpu * support fp16 by replacing MSHADOW_SGL_DBL_TYPE_SWITCH with MSHADOW_REAL_TYPE_SWITCH
Description
Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on CPU
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments
We implemented dot(dns,csr/csr.T)=dns on cpu inspired by the implementation of dot(dns,csr/csr.T)=dns on gpu by @haojin2 . It gains 190 times speed up when density of csr is 0.01%. The details as follows.
The benchmark script is here: https://github.com/XiaotaoChen/incubator-mxnet/blob/mytest/example/sparse/cxt-test/test_dot.py . And also can test this with haojin's script by replacing mx.gpu() with mx.cpu() : #10371
@pengzhao-intel @TaoLv