diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index e9aee83ba284..cf61b4302a59 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -1201,24 +1201,6 @@ struct NDArrayFunctionReg #define MXNET_REGISTER_NDARRAY_FUN(name) \ DMLC_REGISTRY_REGISTER(::mxnet::NDArrayFunctionReg, NDArrayFunctionReg, name) -#define NDARRAY_IDX_TYPE_SWITCH(type, DType, ...) \ - switch (type) { \ - case mshadow::kUint8: \ - { \ - typedef uint8_t DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kInt32: \ - { \ - typedef int32_t DType; \ - {__VA_ARGS__} \ - } \ - break; \ - default: \ - LOG(FATAL) << "Unknown idx type enum " << type; \ - } - } // namespace mxnet namespace dmlc { diff --git a/mshadow b/mshadow index 1d633cbf0bf5..fd5edc4e9a93 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit 1d633cbf0bf5b13196cfbd95e6666df3c5f4a8a7 +Subproject commit fd5edc4e9a9301943025820d6b2d99feeb172806 diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 05e5e3badf3f..358511aa0592 100755 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -129,7 +129,7 @@ inline void SparseSGDUpdateDnsRspImpl(const SGDParam& param, if (!grad.storage_initialized()) return; MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, { - NDARRAY_IDX_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, { + MSHADOW_INT_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, { MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { auto weight_data = weight.data().FlatTo2D(s); auto grad_idx = grad.aux_data(rowsparse::kIdx).FlatTo1D(s); @@ -276,7 +276,7 @@ inline void SparseSGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param, if (!grad.storage_initialized()) return; MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, { - NDARRAY_IDX_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, { + MSHADOW_INT_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, { MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { auto weight_data = weight.data().FlatTo2D(s); auto grad_idx = grad.aux_data(rowsparse::kIdx).FlatTo1D(s); diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 3d024b365ac0..0f3f587782d4 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -19,10 +19,10 @@ namespace mxnet { namespace op { template void UnaryLaunch(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; using namespace mxnet_op; Stream *s = ctx.get_stream(); @@ -238,6 +238,52 @@ struct FillRspRowIdx { } }; +/*! + * \brief Kernel for marking row_idx of a RSP matrix per row + */ +struct MarkRspRowIdx { + // i represents the row index of the matrix data + template + MSHADOW_XINLINE static void Map(int i, RType* row_idx, const DType* data, + const int invalid_rid, const int num_cols) { + int j = 0; + int offset = i * num_cols; + for (; j < num_cols; ++j) { + if (data[offset+j] != 0) { + break; + } + } + if (num_cols == j) { + row_idx[i] = invalid_rid; // mark zero row as invalid + } else { + row_idx[i] = i; + } + } +}; + +struct CopyDnsToRsp{ + // i represents the row index of the matrix data + template + MSHADOW_XINLINE static void Map(int i, RType* row_idx, DType* rsp_data, + const DType* dns_data, const int num_rows, const int num_cols) { + int j = 0; + int offset = i * num_cols; + for (; j < num_cols; ++j) { + if (dns_data[offset+j] != 0) { + break; + } + } + if (num_cols == j) { + row_idx[i] = num_rows; + } else { + row_idx[i] = i; + for (j = 0; j < num_cols; ++j) { + rsp_data[offset+j] = dns_data[offset+j]; + } + } + } +}; + /*! * \brief * Given a DNS storage type tensor, create a RSP type sparse tensor @@ -257,39 +303,14 @@ void CastStorageDnsRspImpl(mshadow::Stream *s, const TBlob& dns, NDArray* r CHECK_EQ(rsp->storage_type(), kRowSparseStorage); CHECK_EQ(dns.shape_, rsp->shape()); - rsp->CheckAndAllocAuxData(rowsparse::kIdx, mshadow::Shape1(dns.shape_[0])); MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type - NDARRAY_IDX_TYPE_SWITCH(rsp->aux_type(rowsparse::kIdx), RType, { // row idx type - RType* row_idx = rsp->aux_data(rowsparse::kIdx).dptr(); + MSHADOW_INT_TYPE_SWITCH(rsp->aux_type(rowsparse::kIdx), RType, { // row idx type const index_t num_rows = dns.shape_[0]; const index_t num_cols = dns.shape_[1]; - // Fill input_data.shape_[0] into row_idx array - mxnet_op::Kernel::Launch(s, num_rows, row_idx, dns.dptr(), - num_rows, num_cols); - - // single thread scanning row_idx array to find out number of non-zero rows - index_t nnr = 0; // number of non-zero rows - for (index_t i = 0; i < num_rows; ++i) { - if (row_idx[i] < static_cast(num_rows)) ++nnr; - } - if (0 == nnr) { - rsp->SetAuxShape(rowsparse::kIdx, TShape(mshadow::Shape1(0))); - return; // zero matrix - } - rsp->CheckAndAllocData(mshadow::Shape2(nnr, num_cols)); - // TODO(junwu): single thread for compressing row_idx and copying data - // from dns to rsp, might be a bottleneck. - auto in_tensor = dns.FlatTo2D(s); - auto out_tensor = rsp->data().FlatTo2D(s); - int last_nnr_id = -1; // last non-zero row id - for (index_t i = 0; i < num_rows; ++i) { - if (row_idx[i] < static_cast(num_rows)) { // non-zero row found - row_idx[++last_nnr_id] = row_idx[i]; - mshadow::Copy(out_tensor[last_nnr_id], in_tensor[i], s); - } - } - // update effective size (not capacity) of the row_idx of rsp - rsp->SetAuxShape(rowsparse::kIdx, mshadow::Shape1(last_nnr_id+1)); + rsp->CheckAndAlloc({TShape({num_rows})}); + RType* row_idx = rsp->aux_data(rowsparse::kIdx).dptr(); + mxnet_op::Kernel::Launch(s, num_rows, row_idx, rsp->data().dptr(), + dns.dptr(), num_rows, num_cols); }); }); } @@ -310,7 +331,6 @@ struct CastStorageRspDnsKernel { } }; - /*! * \brief This function assumes that the meomry for dns has been allocated already * since the shape is known at binding stage. @@ -321,7 +341,7 @@ void CastStorageRspDnsImpl(mshadow::Stream *s, const NDArray& rsp, TBlob* d using namespace mshadow::expr; CHECK_EQ(rsp.storage_type(), kRowSparseStorage); MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, { - NDARRAY_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, { + MSHADOW_INT_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, { // assign zeros mxnet_op::Kernel::Launch(s, dns->Size(), dns->dptr()); if (rsp.storage_initialized()) { @@ -416,8 +436,8 @@ void CastStorageDnsCsrImpl(mshadow::Stream *s, const TBlob& dns, NDArray* c CHECK_EQ(dns.shape_, csr->shape()); MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type - NDARRAY_IDX_TYPE_SWITCH(csr->aux_type(csr::kIndPtr), IType, { // indptr type - NDARRAY_IDX_TYPE_SWITCH(csr->aux_type(csr::kIdx), CType, { // col idx type + MSHADOW_INT_TYPE_SWITCH(csr->aux_type(csr::kIndPtr), IType, { // indptr type + MSHADOW_INT_TYPE_SWITCH(csr->aux_type(csr::kIdx), CType, { // col idx type const index_t num_rows = dns.shape_[0]; const index_t num_cols = dns.shape_[1]; csr->CheckAndAllocAuxData(csr::kIndPtr, mshadow::Shape1(num_rows+1)); @@ -487,8 +507,8 @@ void CastStorageCsrDnsImpl(mshadow::Stream *s, const NDArray& csr, TBlob* d CHECK_EQ(dns->shape_, csr.shape()); MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, { // data type - NDARRAY_IDX_TYPE_SWITCH(csr.aux_type(csr::kIndPtr), IType, { // indptr type - NDARRAY_IDX_TYPE_SWITCH(csr.aux_type(csr::kIdx), CType, { // col idx type + MSHADOW_INT_TYPE_SWITCH(csr.aux_type(csr::kIndPtr), IType, { // indptr type + MSHADOW_INT_TYPE_SWITCH(csr.aux_type(csr::kIdx), CType, { // col idx type const index_t num_rows = dns->shape_[0]; const index_t num_cols = dns->shape_[1]; DType* dns_data = dns->dptr(); @@ -520,8 +540,8 @@ inline bool CastStorageInferStorageType(const nnvm::NodeAttrs& attrs, template void CastStorageComputeImpl(mshadow::Stream *s, - const NDArray& input, - const NDArray& output) { + const NDArray& input, + const NDArray& output) { using namespace mshadow; using namespace mshadow::expr; const auto src_stype = input.storage_type(); @@ -542,6 +562,23 @@ void CastStorageComputeImpl(mshadow::Stream *s, LOG(FATAL) << "Not implemented"; } } + +template +void CastStorageToDefault(mshadow::Stream *s, + const NDArray& input, + TBlob* ret) { + using namespace mshadow; + using namespace mshadow::expr; + const auto src_stype = input.storage_type(); + if (src_stype == kRowSparseStorage) { + CastStorageRspDnsImpl(s, input, ret); + } else if (src_stype == kCSRStorage) { + CastStorageCsrDnsImpl(s, input, ret); + } else { + LOG(FATAL) << "Not implemented"; + } +} + template void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index c8b2c95c8a15..12523e237cf2 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -387,7 +387,7 @@ void SparseEmbeddingOpBackwardDnsDnsRsp(const nnvm::NodeAttrs& attrs, unsigned int num_rows = output.shape()[0]; output.CheckAndAlloc({mshadow::Shape1(num_rows)}); MSHADOW_TYPE_SWITCH(output.dtype(), DType, { - NDARRAY_IDX_TYPE_SWITCH(idx.dtype(), IType, { + MSHADOW_INT_TYPE_SWITCH(idx.dtype(), IType, { MXNET_ASSIGN_REQ_SWITCH(req[1], req_type, { // input embedding indice, each idx in [0, input_dim) auto idx_data = idx.data().FlatTo1D(s); diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 51e4869e94f8..8b304afc202a 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -15,7 +15,6 @@ #include "../elemwise_op_common.h" #include "../mxnet_op.h" #include "broadcast_reduce_op.h" -#include "./elemwise_unary_op.h" #if MXNET_USE_CUDA #include @@ -324,6 +323,7 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs, struct DotParam : public dmlc::Parameter { bool transpose_a; bool transpose_b; + int out_stype; // output storage type DMLC_DECLARE_PARAMETER(DotParam) { DMLC_DECLARE_FIELD(transpose_a) .describe("If true then transpose the first input before dot.") @@ -483,13 +483,7 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 2U); CHECK_EQ(out_attrs->size(), 1U); - // dot(csr, dns) = rsp is a requirement from users - if (kCSRStorage == in_attrs->at(0) && kDefaultStorage == in_attrs->at(1)) { - // dot(csr, dns) = rsp - out_attrs->at(0) = kRowSparseStorage; - } else { // fallback - return ElemwiseStorageType<2, 1>(attrs, in_attrs, out_attrs); - } + out_attrs->at(0) = kDefaultStorage; return true; } @@ -498,14 +492,8 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), 2U); - if (kDefaultStorage == in_attrs->at(0) - && kCSRStorage == in_attrs->at(1) - && kDefaultStorage == in_attrs->at(2)) { - out_attrs->at(0) = kDefaultStorage; - out_attrs->at(1) = kRowSparseStorage; - } else { // fallback - return ElemwiseStorageType<3, 2>(attrs, in_attrs, out_attrs); - } + out_attrs->at(0) = kDefaultStorage; + out_attrs->at(1) = kDefaultStorage; return true; } @@ -515,14 +503,14 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, * is determined by trans_csr and trans_dns, respectively. * For now we only implemented the case when trans_dns = false. */ -template +template struct DotCsrDnsDns; /*! * \brief Kernel of dot(csr, dns1) = dns2 */ -template<> -struct DotCsrDnsDns { +template +struct DotCsrDnsDns { /*! * \brief This function represents performing an inner product between a row of lhs * and a column of rhs and then assigning the value to out[i]. @@ -540,18 +528,20 @@ struct DotCsrDnsDns { const int num_cols) { const int irow = i / num_cols; // row id of the lhs const int icol = i % num_cols; // col id of the rhs + DType sum = 0; for (IType j = indptr_l[irow]; j < indptr_l[irow+1]; ++j) { const CType cur_col = col_idx_l[j]; // corresponding row id of the rhs - out[i] += data_l[j] * data_r[cur_col*num_cols+icol]; + sum += data_l[j] * data_r[cur_col*num_cols+icol]; } + KERNEL_ASSIGN(out[i], req, sum); } }; /*! * \brief Kernel of dot(csr.T(), dns1) = dns2 */ -template<> -struct DotCsrDnsDns { +template +struct DotCsrDnsDns { /*! * \brief This function represents performing an inner product between a column of lhs * and a column of rhs and then assigning the value to out[i]. @@ -570,6 +560,7 @@ struct DotCsrDnsDns { const int num_cols) { const int irow = i / num_cols; // col id of the lhs const int icol = i % num_cols; // col id of the rhs + DType sum = 0; for (int k = 0; k < num_rows_l; ++k) { const IType low = indptr_l[k]; const IType high = indptr_l[k+1]; @@ -587,9 +578,10 @@ struct DotCsrDnsDns { } } if (j >= 0) { - out[i] += data_l[j] * data_r[k*num_cols+icol]; + sum += data_l[j] * data_r[k*num_cols+icol]; } } + KERNEL_ASSIGN(out[i], req, sum); } }; @@ -612,199 +604,20 @@ void DotCsrDnsDnsImpl(const OpContext& ctx, const TBlob data_r = rhs.data(); const TBlob data_out = ret->data(); - MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type - NDARRAY_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type - NDARRAY_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type - if (kWriteTo == req) { - mxnet_op::Kernel::Launch( - s, data_out.Size(), data_out.dptr()); - } - if (trans_lhs) { + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_INT_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_INT_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type if (!lhs.storage_initialized()) return; - mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), - data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), lhs.shape()[0], - rhs.shape()[1]); - } else { - if (!lhs.storage_initialized()) return; - mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), - data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), rhs.shape()[1]); - } - }); - }); - }); -} - -/*! - * \brief Tempalte declaration of dot(csr, dns) = rsp. - * Whether csr and dns are transposed before dot operation - * is determined by trans_csr and trans_dns, respectively. - * For now we only implemented the case when trans_dns = false. - */ -template -struct DotCsrDnsRsp; - -/*! - * \brief Kernel of dot(csr, dns) = rsp - */ -template<> -struct DotCsrDnsRsp { - /*! - * \brief This function represents performing an inner product between a row of lhs - * and a column of rhs and then assigning the value to out[i]. - * \param i i-th element in out 1D view - * \param out output matrix's non-zero rows - * \param row_idx output matrix row_idx in RSP format - * \param data_l csr values of lhs - * \param indptr_l csr indptr of lhs - * \param col_idx_l csr col_idx of lhs - * \param data_r dense data of rhs - * \param num_cols number of columns of output - */ - template - MSHADOW_XINLINE static void Map(int i, DType* out, const RType* row_idx, const DType* data_l, - const IType* indptr_l, const CType* col_idx_l, - const DType* data_r, const int num_cols) { - const int irow = row_idx[i/num_cols]; // row id of the lhs - const int icol = i % num_cols; // col id of the rhs - for (IType j = indptr_l[irow]; j < indptr_l[irow+1]; ++j) { - const CType cur_col = col_idx_l[j]; // corresponding row id of the rhs - out[i] += data_l[j] * data_r[cur_col*num_cols+icol]; - } - } -}; - -/*! - * \brief Kernel of dot(csr.T(), dns) = rsp - */ -template<> -struct DotCsrDnsRsp { - /*! - * \brief This function represents performing an inner product between a column of lhs - * and a column of rhs and then assigning the value to out[i]. - * \param i i-th element in out 1D view - * \param out output row sparse matrix in 1-D view - * \param row_idx aux_data of out - * \param data_l csr values of lhs - * \param indptr_l csr indptr of lhs - * \param col_idx_l csr col_idx of lhs - * \param data_r dense data of rhs - * \param num_rows_l number of rows of lhs - * \param num_cols number of columns of outputs - */ - template - MSHADOW_XINLINE static void Map(int i, DType* out, const RType* row_idx, - const DType* data_l, const IType* indptr_l, - const CType* col_idx_l, const DType* data_r, - const int num_rows_l, const int num_cols) { - const int irow = row_idx[i/num_cols]; // col id of the lhs - const int icol = i % num_cols; // col id of the rhs - for (int k = 0; k < num_rows_l; ++k) { - const IType low = indptr_l[k]; - const IType high = indptr_l[k+1]; - if (low == high || irow < col_idx_l[low] || irow > col_idx_l[high-1]) continue; - int j = -1, l = low, r = high - 1; - while (l <= r) { - int m = l + (r - l) / 2; - if (col_idx_l[m] == irow) { - j = m; break; - } - if (col_idx_l[m] < irow) { - l = m + 1; - } else { - r = m - 1; - } - } - if (j >= 0) { - out[i] += data_l[j] * data_r[k*num_cols+icol]; - } - } - } -}; - -/*! - * \brief Implementation of dot(csr, dns) = rsp - * dot(csr.T(), dns) = rsp - * Must be called by DotForwardEx - */ -template -void DotCsrDnsRspImpl(const OpContext& ctx, - const NDArray& lhs, - const NDArray& rhs, - const OpReqType req, - const bool trans_lhs, - NDArray* ret) { - if (kNullOp == req) return; - - CHECK_EQ(lhs.storage_type(), kCSRStorage) << "lhs must be csr type in DotCsrDnsRspImpl"; - CHECK_EQ(rhs.storage_type(), kDefaultStorage) << "rhs must be dns type in DotCsrDnsRspImpl"; - CHECK_EQ(ret->storage_type(), kRowSparseStorage) << "ret must be rsp type in DotCsrDnsRspImpl"; - - mshadow::Stream *s = ctx.get_stream(); - const TBlob data_l = lhs.data(); - const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); - const TBlob col_idx_l = lhs.aux_data(csr::kIdx); - const TBlob data_r = rhs.data(); - - MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type - NDARRAY_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type - NDARRAY_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type - NDARRAY_IDX_TYPE_SWITCH(ret->aux_type(rowsparse::kIdx), RType, { // row idx type if (trans_lhs) { - // TODO(junwu): When performing dot(csr.T(), dns) to get a rsp matrix, - // we first allocate a dns tblob as a temporary placeholder for the output. - // Then we cast the dns to a rsp matrix. We take this approach - // instead of generating a rsp output with the actual storage size - // because it's difficult to calculate the number of non-zero columns - // of the csr for allocating the memory of the output rsp. - // We will revisit this approach in the future to see if there are - // better ways. - - // get temporary space as an intermediate result of dot(csr.T(), dns). - // requested[0] is temp space resource - mshadow::Tensor out_tmp = - ctx.requested[0].get_space_typed( - mshadow::Shape2(ret->shape()[0], ret->shape()[1]), s); - if (kWriteTo == req) { - mxnet_op::Kernel::Launch( - s, out_tmp.shape_.Size(), out_tmp.dptr_); - } - if (lhs.storage_initialized()) { - // generate a temporary dns output - mxnet_op::Kernel, xpu>::Launch( - s, out_tmp.shape_.Size(), out_tmp.dptr_, data_l.dptr(), - indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), - lhs.shape()[0], out_tmp.shape_[1]); - } - // cast dns to rsp - CastStorageDnsRspImpl(s, TBlob(out_tmp), ret); + mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), lhs.shape()[0], + rhs.shape()[1]); } else { - // TODO(junwu): check whether the following code is a bottleneck - // allocate output NDArray (single thread) - index_t nnr = 0; // number of non-zero rows in csr - const IType* indptr = indptr_l.dptr(); - for (int i = 0; i < static_cast(indptr_l.Size())-1; ++i) { - if (indptr[i] < indptr[i+1]) ++nnr; - } - ret->CheckAndAlloc({mshadow::Shape1(nnr)}); - // fill in row_idx_out (single thread) - const TBlob data_out = ret->data(); - const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx); - RType* row_idx = row_idx_out.dptr(); - for (int i = 0, k = 0; i < static_cast(indptr_l.Size())-1; ++i) { - if (indptr[i] < indptr[i+1]) { - row_idx[k++] = i; - } - } - if (kWriteTo == req) { - mxnet_op::Kernel::Launch( - s, data_out.Size(), data_out.dptr()); - } - mxnet_op::Kernel, xpu>::Launch( - s, data_out.Size(), data_out.dptr(), row_idx_out.dptr(), - data_l.dptr(), indptr_l.dptr(), col_idx_l.dptr(), - data_r.dptr(), data_out.shape_[1]); + mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), rhs.shape()[1]); } }); }); @@ -812,57 +625,6 @@ void DotCsrDnsRspImpl(const OpContext& ctx, }); } -template -void DotForwardEx(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 1U); - CHECK_EQ(req.size(), 1U); - const DotParam& param = nnvm::get(attrs.parsed); - CHECK(!param.transpose_b) << "tranposing rhs of the op dot is not supported"; - - NDArray ret = outputs[0]; // get rid of the const qualifier - if (inputs[0].storage_type() == kCSRStorage - && inputs[1].storage_type() == kDefaultStorage - && outputs[0].storage_type() == kDefaultStorage) { - DotCsrDnsDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); - } else if (inputs[0].storage_type() == kCSRStorage - && inputs[1].storage_type() == kDefaultStorage - && outputs[0].storage_type() == kRowSparseStorage) { - DotCsrDnsRspImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); - } else { - // TODO(junwu): add fallback mechanism - LOG(FATAL) << "Not supported"; - } -} - -/*! - * \brief Backward of - * 1. out = dot(csr, dns) - * grad(csr) = dot(grad(out), dns.T()) - * grad(dns) = dot(csr.T(), grad(out)) - * - * 2. out = dot(csr.T(), dns) - * grad(csr) = dot(dns, grad(out).T()) - * grad(dns) = dot(csr, grad(out)) - * - * Assume the gradient of the op's output is a dense matrix. - * This function must be called by DotBackwardEx. - */ -template -void DotBackwardCsrDnsRsp(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - const DotParam& param = nnvm::get(attrs.parsed); - NDArray ret = outputs[1]; - DotCsrDnsRspImpl(ctx, inputs[1], inputs[0], req[1], !param.transpose_a, &ret); -} - template void DotBackwardCsrDnsDns(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -874,40 +636,6 @@ void DotBackwardCsrDnsDns(const nnvm::NodeAttrs& attrs, DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0], req[1], !param.transpose_a, &ret); } -template -void DotBackwardEx(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - CHECK_EQ(inputs.size(), 3U); - CHECK_EQ(outputs.size(), 2U); - CHECK_EQ(req.size(), 2U); - CHECK_EQ(kNullOp, req[0]) - << "sparse dot does not support computing the gradient of the csr/lhs"; - CHECK_NE(req[1], kWriteInplace) << "DotBackwardCsrDnsRsp does not support WriteInplace"; - - // TODO(junwu): check whether this CHECK is reasonable - const DotParam& param = nnvm::get(attrs.parsed); - CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)"; - if (inputs[0].storage_type() == kDefaultStorage // ograd dns format - // dns, csr, dns => *, rsp - && inputs[1].storage_type() == kCSRStorage // csr input lhs of the op - && inputs[2].storage_type() == kDefaultStorage // dns input rhs of the op - && outputs[1].storage_type() == kRowSparseStorage) { // grad(rhs) rsp format - DotBackwardCsrDnsRsp(attrs, ctx, inputs, req, outputs); - } else if (inputs[0].storage_type() == kDefaultStorage // ograd dns format - // dns, csr, dns => *, dns - && inputs[1].storage_type() == kCSRStorage // csr input lhs of the op - && inputs[2].storage_type() == kDefaultStorage // dns input rhs of the op - && outputs[1].storage_type() == kDefaultStorage) { // grad(rhs) dns format - DotBackwardCsrDnsDns(attrs, ctx, inputs, req, outputs); - } else { - // TODO(junwu): add fallback mechanism - LOG(FATAL) << "Not supported"; - } -} - inline bool DotShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { @@ -951,6 +679,57 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs, return true; } +template +void DotForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const DotParam& param = nnvm::get(attrs.parsed); + CHECK(!param.transpose_b) << "tranposing rhs of the op dot is not supported"; + + NDArray ret = outputs[0]; // get rid of the const qualifier + if (inputs[0].storage_type() == kCSRStorage + && inputs[1].storage_type() == kDefaultStorage + && outputs[0].storage_type() == kDefaultStorage) { + DotCsrDnsDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + } else { // TODO(junwu): add fallback + LOG(FATAL) << "Not supported dot operation for lhs.storage_type = " + << inputs[0].storage_type() << ", rhs.storage_type = " << inputs[1].storage_type() + << ", out.storage_type = " << outputs[0].storage_type(); + } +} + +template +void DotBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + CHECK_EQ(kNullOp, req[0]) + << "sparse dot does not support computing the gradient of the csr/lhs"; + CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace"; + + // TODO(junwu): check whether this CHECK is reasonable + const DotParam& param = nnvm::get(attrs.parsed); + CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)"; + if (inputs[0].storage_type() == kDefaultStorage // ograd dns format + // dns, csr, dns => *, dns + && inputs[1].storage_type() == kCSRStorage // csr input lhs of the op + && inputs[2].storage_type() == kDefaultStorage // dns input rhs of the op + && outputs[1].storage_type() == kDefaultStorage) { // grad(rhs) dns format + DotBackwardCsrDnsDns(attrs, ctx, inputs, req, outputs); + } else { + LOG(FATAL) << "Not supported dot backward for sparse input(s) with sparse gradients"; + } +} + template void BatchDotForward_(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -1252,7 +1031,7 @@ void SliceCsrImpl(const SliceParam ¶m, const OpContext& ctx, CHECK_EQ(in.aux_type(kIndPtr), in.aux_type(kIdx)) << "The type for indptr and indices are different. This is not implemented yet."; // assume idx indptr share the same type - NDARRAY_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), IType, { + MSHADOW_INT_TYPE_SWITCH(in.aux_type(kIndPtr), IType, { MSHADOW_TYPE_SWITCH(in.dtype(), DType, { auto in_indptr = in.aux_data(kIndPtr).dptr(); auto out_indptr = out.aux_data(kIndPtr).dptr(); diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 8a34d4d591a0..ea6f835a60aa 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -5,9 +5,10 @@ from numpy.testing import assert_allclose from mxnet.test_utils import * + def check_elemwise_add_ex(lhs_stype, rhs_stype, shape, lhs_grad_stype=None, rhs_grad_stype=None): - lhs = mx.symbol.Variable('lhs', storage_type = lhs_stype) - rhs = mx.symbol.Variable('rhs', storage_type = rhs_stype) + lhs = mx.symbol.Variable('lhs', storage_type=lhs_stype) + rhs = mx.symbol.Variable('rhs', storage_type=rhs_stype) if lhs_grad_stype is not None: lhs._set_attr(grad_stype_hint=str(lhs_grad_stype)) if rhs_grad_stype is not None: @@ -20,26 +21,28 @@ def check_elemwise_add_ex(lhs_stype, rhs_stype, shape, lhs_grad_stype=None, rhs_ out_np = lhs_np + rhs_np test = mx.symbol.elemwise_add(lhs, rhs) - location = {'lhs':lhs_nd, 'rhs':rhs_nd} + location = {'lhs': lhs_nd, 'rhs': rhs_nd} check_symbolic_forward(test, location, [out_np]) check_numeric_gradient(test, location) check_symbolic_backward(test, location, [out_np], [out_np, out_np]) + def test_elemwise_add_ex(): - shape = (rnd.randint(1, 10),rnd.randint(1, 10)) + shape = (rnd.randint(1, 10), rnd.randint(1, 10)) check_elemwise_add_ex('default_storage', 'default_storage', shape) check_elemwise_add_ex('default_storage', 'row_sparse', shape) check_elemwise_add_ex('row_sparse', 'default_storage', shape) check_elemwise_add_ex('row_sparse', 'row_sparse', shape, - lhs_grad_stype='row_sparse', rhs_grad_stype='row_sparse') + lhs_grad_stype='row_sparse', rhs_grad_stype='row_sparse') + # TODO(haibin) randomize this test def test_elemwise_add_ex_multiple_stages(): # prep data shape = (4, 2) - ds_np = np.array([[1,2],[3,4],[5,6],[7,8]]) - sp_np1 = np.array([[5,10],[0,0],[0,0],[0,0]]) - sp_np2 = np.array([[0,0],[5,10],[0,0],[0,0]]) + ds_np = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + sp_np1 = np.array([[5, 10], [0, 0], [0, 0], [0, 0]]) + sp_np2 = np.array([[0, 0], [5, 10], [0, 0], [0, 0]]) val1 = mx.nd.array([[5, 10]]); val2 = mx.nd.array([[5, 10]]); @@ -53,20 +56,21 @@ def test_elemwise_add_ex_multiple_stages(): sp_data1 = mx.symbol.Variable('sp_data1', storage_type='row_sparse') sp_data2 = mx.symbol.Variable('sp_data2', storage_type='row_sparse') ds_data = mx.symbol.Variable('ds_data') - plus = mx.symbol.elemwise_add(sp_data1, sp_data2, name='plus') + plus = mx.symbol.elemwise_add(sp_data1, sp_data2, name='plus') # sparse + dense = dense - test = mx.symbol.elemwise_add(plus, ds_data) - check_symbolic_forward(test, {'sp_data1':sp_nd1, 'sp_data2':sp_nd2, - 'ds_data':ds_nd}, [sp_np1 + sp_np2 + ds_np]) + test = mx.symbol.elemwise_add(plus, ds_data) + check_symbolic_forward(test, {'sp_data1': sp_nd1, 'sp_data2': sp_nd2, + 'ds_data': ds_nd}, [sp_np1 + sp_np2 + ds_np]) - arr_grads = [mx.nd.zeros(shape) for i in range(3)] - exec_test = test.bind(default_context(), args={'sp_data1':sp_nd1, 'sp_data2':sp_nd2, - 'ds_data':ds_nd}, args_grad=arr_grads) + arr_grads = [mx.nd.zeros(shape) for i in xrange(3)] + exec_test = test.bind(default_context(), args={'sp_data1': sp_nd1, 'sp_data2': sp_nd2, + 'ds_data': ds_nd}, args_grad=arr_grads) exec_test.forward(is_train=True) assert_almost_equal(exec_test.outputs[0].asnumpy(), sp_np1 + sp_np2 + ds_np) - exec_test.backward(out_grads = exec_test.outputs) + exec_test.backward(out_grads=exec_test.outputs) assert_almost_equal(arr_grads[0].asnumpy(), arr_grads[1].asnumpy()) + # TODO(haibin) also add test for backward pass def test_cast_storage_ex(): def test_rsp_to_dns(shape): @@ -91,17 +95,18 @@ def test_csr_to_dns(shape): assert_almost_equal(mx_dns.asnumpy(), np_dns) def test_dns_to_csr(dns_in): - dns_in= np.array(dns_in) + dns_in = np.array(dns_in) csr_out = mx.nd.cast_storage(mx.nd.array(dns_in, dtype=default_dtype()), storage_type='csr') ret = mx.nd.cast_storage(csr_out, storage_type='default_storage') assert same(ret.asnumpy(), dns_in) - shape = (rnd.randint(1, 10),rnd.randint(1, 10)) + shape = (rnd.randint(1, 10), rnd.randint(1, 10)) test_rsp_to_dns(shape) test_dns_to_rsp(shape) test_csr_to_dns((4, 4)) test_dns_to_csr([[0, 1, 0], [0, 2, 0], [3, 0, 0], [0, 0, 4], [5, 6, 0], [0, 0, 7]]) + # TODO(junwu): The backward of the operator dot cannot be tested for now # since the backend function CopyFromTo does not support taking two arguments # of the different storage types. Will add backward test after removing this @@ -109,41 +114,39 @@ def test_dns_to_csr(dns_in): # the same impl function of dot(csr, dns) = rsp and it has been tested # in the forward test cases as the following. def test_sparse_dot(): - def test_dot_csr_dns_rsp(csr_shape, dns_shape, dns_grad_stype, trans_csr): + def test_dot_csr_dns(csr_shape, dns_shape, trans_csr): dns1 = rand_ndarray(csr_shape, 'default_storage') dns2 = rand_ndarray(dns_shape, 'default_storage') csr = mx.nd.cast_storage(dns1, storage_type='csr') - rsp_out = mx.nd.dot(csr, dns2, transpose_a=trans_csr) - rsp_expected = mx.nd.dot(dns1, dns2, transpose_a=trans_csr) - out_np = rsp_expected.asnumpy() + out = mx.nd.dot(csr, dns2, transpose_a=trans_csr) + assert out.storage_type == 'default_storage' + out_expected = mx.nd.dot(dns1, dns2, transpose_a=trans_csr) + out_np = out_expected.asnumpy() backward_trans = not trans_csr - rhs_backward_grad = mx.nd.dot(dns1, rsp_expected, transpose_a=backward_trans).asnumpy() - # TODO(junwu): may need to compare rsp_out and rsp_expected in rsp format - # instead of converting them to the dense format - assert same(rsp_out.asnumpy(), out_np) + rhs_backward_grad = mx.nd.dot(dns1, out_expected, transpose_a=backward_trans).asnumpy() + assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5) # test symbolic forward lhs = mx.symbol.Variable('lhs', storage_type='csr') rhs = mx.symbol.Variable('rhs', storage_type='default_storage') - rhs._set_attr(grad_stype_hint=str(dns_grad_stype)) # TODO(haibin) since backward op is not fully implemented, here we add a dense zero ndarray # so that the output gradient is dense. zeros = mx.symbol.Variable('zero', storage_type='default_storage') sym_dot = mx.symbol.dot(lhs, rhs, transpose_a=trans_csr) test = mx.symbol.elemwise_add(sym_dot, zeros) - location = {'lhs':csr, 'rhs':dns2, 'zero':mx.nd.zeros(rsp_expected.shape)} - expected = {'rhs':rhs_backward_grad, 'zero':out_np} + location = {'lhs': csr, 'rhs': dns2, 'zero': mx.nd.zeros(out_expected.shape)} + expected = {'rhs': rhs_backward_grad, 'zero': out_np} # dot(lhs, rhs) + zeros - check_symbolic_forward(test, location, [rsp_expected.asnumpy()]) + check_symbolic_forward(test, location, [out_expected.asnumpy()], rtol=1e-3, atol=1e-4) check_symbolic_backward(test, location, [out_np], expected, - grad_req={'lhs': 'null', 'rhs': 'write', 'zero' : 'write'}) + grad_req={'lhs': 'null', 'rhs': 'write', 'zero': 'write'}, + rtol=1e-3, atol=1e-4) + + lhs_shape = (rnd.randint(1, 10), rnd.randint(1, 10)) + test_dot_csr_dns(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), False) + test_dot_csr_dns(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), True) - lhs_shape = (rnd.randint(1, 10),rnd.randint(1, 10)) - test_dot_csr_dns_rsp(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False) - test_dot_csr_dns_rsp(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True) - test_dot_csr_dns_rsp(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'default_storage', False) - test_dot_csr_dns_rsp(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'default_storage', True) def test_sparse_embedding(): in_dim = 10 @@ -187,6 +190,7 @@ def check_csr_slice(shape, sliced_input): check_csr_slice(shape, True) check_csr_slice(shape, False) + if __name__ == '__main__': test_elemwise_add_ex() test_elemwise_add_ex_multiple_stages()