diff --git a/src/operator/tensor/elemwise_binary_op-inl.h b/src/operator/tensor/elemwise_binary_op-inl.h index c74f1f936031..911c369b3e69 100644 --- a/src/operator/tensor/elemwise_binary_op-inl.h +++ b/src/operator/tensor/elemwise_binary_op-inl.h @@ -495,6 +495,91 @@ void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream *s, }); } +/*! + * \brief Kernel for performing elemwise op between dense and csr matrix + * \param i global thread id + * \param req type of request + * \param out output array + * \param dns_data data array of dense input + * \param csr_data data array of csr input + * \param csr_indices indices array of csr input + * \param csr_indptr indptr array of csr input + * \param num_rows number of rows of both inputs + * \param num_cols number of columns of both inputs + */ +template +struct ElemwiseDnsCsrCsrKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out, DType* dns_data, + const DType* csr_data, const IType* csr_indices, + const CType* csr_indptr, const nnvm::dim_t num_rows, + const nnvm::dim_t num_cols) { + if (i < num_rows) { + for (int j = csr_indptr[i]; j < csr_indptr[i+1]; ++j) { + KERNEL_ASSIGN(out[j], req, reverse ? + OP::Map(dns_data[i * num_cols + csr_indices[j]], csr_data[j]) : + OP::Map(csr_data[j], dns_data[i * num_cols + csr_indices[j]])); + } + } + } +}; + +/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */ +template +void ElemwiseBinaryOp::DnsCsrCsrOp(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const NDArray &dns, + const NDArray &csr, + const OpReqType req, + const NDArray &output, + const bool reverse) { + using namespace mshadow; + using namespace mxnet_op; + using namespace csr; + CHECK_EQ(dns.storage_type(), kDefaultStorage); + CHECK_EQ(csr.storage_type(), kCSRStorage); + CHECK_EQ(req, kWriteTo) << "elemwise(dns, csr) = csr only supports kWriteTo"; + if (req == kNullOp) return; + const bool supported_op = std::is_same::value; + CHECK(supported_op == true) << "elemwise(dns, csr) = csr only supports mul"; + const nnvm::dim_t num_csr_rows = csr.shape()[0]; + const nnvm::dim_t num_csr_cols = csr.shape()[1]; + const nnvm::dim_t nnz = csr.storage_shape()[0]; + Stream *s = ctx.get_stream(); + + output.CheckAndAlloc({Shape1(num_csr_rows + 1), Shape1(nnz)}); + if (csr.storage_initialized()) { + TBlob csr_data = csr.data(); + TBlob csr_indices = csr.aux_data(kIdx); + TBlob csr_indptr = csr.aux_data(kIndPtr); + MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { + MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { + MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { + MXNET_ASSIGN_REQ_SWITCH(req, Req, { + if (reverse) { + Kernel, xpu>::Launch( + s, num_csr_rows, output.data().dptr(), dns.data().dptr(), + csr_data.dptr(), csr_indices.dptr(), csr_indptr.dptr(), + num_csr_rows, num_csr_cols); + } else { + Kernel, xpu>::Launch( + s, num_csr_rows, output.data().dptr(), dns.data().dptr(), + csr_data.dptr(), csr_indices.dptr(), csr_indptr.dptr(), + num_csr_rows, num_csr_cols); + } + Copy(output.aux_data(kIdx).FlatTo1D(), + csr.aux_data(kIdx).FlatTo1D(), s); + Copy(output.aux_data(kIndPtr).FlatTo1D(), + csr.aux_data(kIndPtr).FlatTo1D(), s); + }); + }); + }); + }); + } else { + FillZerosCsrImpl(s, output); + } +} + /*! * \brief Kernel for performing elemwise op between dense and rsp tensor * \param i global thread id diff --git a/src/operator/tensor/elemwise_binary_op.cc b/src/operator/tensor/elemwise_binary_op.cc index e8ba2fa72345..9ccbacc2f654 100644 --- a/src/operator/tensor/elemwise_binary_op.cc +++ b/src/operator/tensor/elemwise_binary_op.cc @@ -63,6 +63,11 @@ bool ElemwiseBinaryOp::BackwardUseInStorageType(const nnvm::NodeAttrs& attrs, const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask; const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback : DispatchMode::kFComputeEx; + const int ograd_stype = in_attrs->at(0); + const int lhs_stype = in_attrs->at(1); + const int rhs_stype = in_attrs->at(2); + int& lhs_grad_stype = out_attrs->at(0); + int& rhs_grad_stype = out_attrs->at(1); if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { dispatched = storage_type_assign(out_attrs, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); @@ -74,6 +79,22 @@ bool ElemwiseBinaryOp::BackwardUseInStorageType(const nnvm::NodeAttrs& attrs, dispatch_mode, dispatch_ex); } } + if (!dispatched && ograd_stype == kDefaultStorage && + ((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) || + (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) { + const bool reverse = (lhs_stype == kCSRStorage); + if (reverse && + type_assign(&lhs_grad_stype, kDefaultStorage) && + type_assign(&rhs_grad_stype, kCSRStorage)) { + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); + dispatched = true; + } else if (!reverse && + type_assign(&lhs_grad_stype, kCSRStorage) && + type_assign(&rhs_grad_stype, kDefaultStorage)) { + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); + dispatched = true; + } + } if (!dispatched) { dispatched = dispatch_fallback(out_attrs, dispatch_mode); } diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index a5b73dadd3ac..ad4b3e7cc4a3 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -165,12 +165,11 @@ class ElemwiseBinaryOp : public OpBase { typename xpu, typename LOP, typename ROP, - typename DType, bool in0_ok_dense = false, bool in1_ok_dense = false, bool in2_ok_dense = false, typename BackupCompute> - static inline void BackwardUseInEx_(const nnvm::NodeAttrs &attrs, + static inline void RspRspOpBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -200,6 +199,33 @@ class ElemwiseBinaryOp : public OpBase { } } + template + static inline void DnsCsrCsrOpBackward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const bool supported_ops = std::is_same::value && + std::is_same::value; + CHECK(supported_ops) + << "Only backward for mul is supported (LOP should be right, ROP should be left)"; + const NDArray& out_grad = inputs[0]; + const NDArray& lhs_in = inputs[1]; + const NDArray& rhs_in = inputs[2]; + const NDArray& lhs_grad = outputs[0]; + const NDArray& rhs_grad = outputs[1]; + const bool reverse = (outputs[0].storage_type() == kCSRStorage); + if (reverse) { + DnsCsrCsrOp(attrs, ctx, out_grad, rhs_in, req[0], lhs_grad, false); + Compute(attrs, ctx, {out_grad.data(), lhs_in.data()}, {req[1]}, + {rhs_grad.data()}); + } else { + DnsCsrCsrOp(attrs, ctx, out_grad, lhs_in, req[1], rhs_grad, false); + Compute(attrs, ctx, {out_grad.data(), rhs_in.data()}, {req[0]}, + {lhs_grad.data()}); + } + } + public: /*! \brief Binary op handling for lhr/rhs: RspDns, RspRsp, DnsRsp, or RspRsp->Dns result */ template @@ -232,44 +258,54 @@ class ElemwiseBinaryOp : public OpBase { /*! \brief CSR -op- CSR binary operator for non-canonical NDArray */ template static void CsrCsrOp(mshadow::Stream *s, - const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const NDArray &lhs, - const NDArray &rhs, - OpReqType req, - const NDArray &output); + const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const NDArray &lhs, + const NDArray &rhs, + OpReqType req, + const NDArray &output); /*! \brief CSR -op- CSR binary operator for non-canonical NDArray */ template static void CsrCsrOp(mshadow::Stream *s, - const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const NDArray &lhs, - const NDArray &rhs, - OpReqType req, - const NDArray &output); + const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const NDArray &lhs, + const NDArray &rhs, + OpReqType req, + const NDArray &output); /*! \brief DNS -op- CSR binary operator for non-canonical NDArray */ template static void DnsCsrDnsOp(mshadow::Stream *s, - const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const NDArray &lhs, - const NDArray &rhs, - OpReqType req, - const NDArray &output, - const bool reverse); + const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const NDArray &lhs, + const NDArray &rhs, + OpReqType req, + const NDArray &output, + const bool reverse); + + /*! \brief DNS -op- CSR binary operator for non-canonical NDArray */ + template + static void DnsCsrCsrOp(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const NDArray &lhs, + const NDArray &rhs, + OpReqType req, + const NDArray &output, + const bool reverse); /*! \brief DNS -op- RSP binary operator for non-canonical NDArray */ template static void DnsRspDnsOp(mshadow::Stream *s, - const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const NDArray &lhs, - const NDArray &rhs, - OpReqType req, - const NDArray &output, - const bool reverse); + const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const NDArray &lhs, + const NDArray &rhs, + OpReqType req, + const NDArray &output, + const bool reverse); public: /*! @@ -336,6 +372,14 @@ class ElemwiseBinaryOp : public OpBase { dispatched = storage_type_assign(&out_stype, kRowSparseStorage, dispatch_mode, dispatch_ex); } + if (!dispatched && + ((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) || + (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) { + // csr, dns -> csr + // dns, csr -> csr + dispatched = storage_type_assign(&out_stype, kCSRStorage, + dispatch_mode, DispatchMode::kFComputeEx); + } if (!dispatched) { dispatched = dispatch_fallback(out_attrs, dispatch_mode); } @@ -540,6 +584,14 @@ class ElemwiseBinaryOp : public OpBase { req[0], outputs[0], lhs_may_be_dense, rhs_may_be_dense, false, false); } else if (lhs_stype == kCSRStorage && rhs_stype == kCSRStorage) { ComputeEx(attrs, ctx, inputs, req, outputs); + } else if (((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) || + (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage)) && + out_stype == kCSRStorage) { + const NDArray& dns = (lhs_stype == kDefaultStorage)? inputs[0] : inputs[1]; + const NDArray& csr = (lhs_stype == kCSRStorage)? inputs[0] : inputs[1]; + const bool reverse = (lhs_stype == kCSRStorage); + + DnsCsrCsrOp(attrs, ctx, dns, csr, req[0], outputs[0], reverse); } else { LogUnimplementedOp(attrs, ctx, inputs, req, outputs); } @@ -635,16 +687,21 @@ class ElemwiseBinaryOp : public OpBase { using namespace common; CHECK_EQ(inputs.size(), 3U); CHECK_EQ(outputs.size(), 2U); // lhs input grad, rhs input grad + const auto out_grad_stype = inputs[0].storage_type(); const auto lhs_grad_stype = outputs[0].storage_type(); const auto rhs_grad_stype = outputs[1].storage_type(); if (ContainsOnlyStorage(inputs, kRowSparseStorage) && (lhs_grad_stype == kDefaultStorage || lhs_grad_stype == kRowSparseStorage) && (rhs_grad_stype == kDefaultStorage || rhs_grad_stype == kRowSparseStorage)) { // rsp, rsp, rsp -> [dns, rsp], [dns, rsp] - MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { - BackwardUseInEx_( - attrs, ctx, inputs, req, outputs, BackwardUseIn); - }); + RspRspOpBackward( + attrs, ctx, inputs, req, outputs, BackwardUseIn); + } + if (((lhs_grad_stype == kDefaultStorage && rhs_grad_stype == kCSRStorage) || + (lhs_grad_stype == kCSRStorage && rhs_grad_stype == kDefaultStorage)) && + out_grad_stype == kDefaultStorage) { + // dns, csr, dns -> [csr, dns] / csr, dns, dns -> [dns, csr] + DnsCsrCsrOpBackward(attrs, ctx, inputs, req, outputs); } } }; // class ElemwiseBinaryOp diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu b/src/operator/tensor/elemwise_binary_op_basic.cu index 9c1fd0e14f35..5cdd8947dd49 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cu +++ b/src/operator/tensor/elemwise_binary_op_basic.cu @@ -51,7 +51,9 @@ NNVM_REGISTER_OP(_backward_sub) mshadow_op::negation>); NNVM_REGISTER_OP(elemwise_mul) -.set_attr("FCompute", ElemwiseBinaryOp::ComputeWithHalf2); +.set_attr("FCompute", ElemwiseBinaryOp::ComputeWithHalf2) +.set_attr("FComputeEx", + ElemwiseBinaryOp::ComputeDnsLRValueEx); NNVM_REGISTER_OP(_backward_mul) .set_attr("FCompute", diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 226db70a2acb..b2ff0fecb5a7 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -329,9 +329,19 @@ def elemwise_mul_stype(lstype, rstype): return 'row_sparse' elif lstype == 'row_sparse' and rstype == 'default': return 'row_sparse' + elif lstype == 'default' and rstype == 'csr': + return 'csr' + elif lstype == 'csr' and rstype == 'default': + return 'csr' else: return 'default' + def elemwise_mul_lhs_grad_stype(lstype, rstype): + return elemwise_mul_stype(elemwise_mul_stype(lstype, rstype), rstype) + + def elemwise_mul_rhs_grad_stype(lstype, rstype): + return elemwise_mul_stype(elemwise_mul_stype(lstype, rstype), lstype) + def check_elemwise_binary_ops(lhs_stype, rhs_stype, shape, lhs_grad_stype=None, rhs_grad_stype=None, lhs_density=.5, rhs_density=.5, @@ -378,8 +388,8 @@ def check_elemwise_binary_ops(lhs_stype, rhs_stype, shape, lambda l, r: mx.sym.sparse.elemwise_mul(l, r), lambda l, r: l * r, lambda outg, l, r: (outg * r, outg * l), - elemwise_mul_stype(lhs_stype, rhs_stype), - elemwise_mul_stype(lhs_stype, rhs_stype), + elemwise_mul_lhs_grad_stype(lhs_stype, rhs_stype), + elemwise_mul_rhs_grad_stype(lhs_stype, rhs_stype), expected_result_storage_type=elemwise_mul_stype(lhs_stype, rhs_stype), ograd_density=ograd_density, force_lr_overlap=force_lr_overlap,