Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-399] Elemwise_mul between dense and csr on CPU & GPU (#10894)
Browse files Browse the repository at this point in the history
* support elemwise_mul between dns and csr

* address reviews and support for backward when ograd is dns
  • Loading branch information
haojin2 authored and eric-haibin-lin committed May 31, 2018
1 parent 5109b00 commit 9feecce
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 35 deletions.
85 changes: 85 additions & 0 deletions src/operator/tensor/elemwise_binary_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,91 @@ void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<xpu> *s,
});
}

/*!
* \brief Kernel for performing elemwise op between dense and csr matrix
* \param i global thread id
* \param req type of request
* \param out output array
* \param dns_data data array of dense input
* \param csr_data data array of csr input
* \param csr_indices indices array of csr input
* \param csr_indptr indptr array of csr input
* \param num_rows number of rows of both inputs
* \param num_cols number of columns of both inputs
*/
template<int req, typename OP, bool reverse>
struct ElemwiseDnsCsrCsrKernel {
template<typename DType, typename IType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* out, DType* dns_data,
const DType* csr_data, const IType* csr_indices,
const CType* csr_indptr, const nnvm::dim_t num_rows,
const nnvm::dim_t num_cols) {
if (i < num_rows) {
for (int j = csr_indptr[i]; j < csr_indptr[i+1]; ++j) {
KERNEL_ASSIGN(out[j], req, reverse ?
OP::Map(dns_data[i * num_cols + csr_indices[j]], csr_data[j]) :
OP::Map(csr_data[j], dns_data[i * num_cols + csr_indices[j]]));
}
}
}
};

/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
template<typename xpu, typename OP>
void ElemwiseBinaryOp::DnsCsrCsrOp(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &dns,
const NDArray &csr,
const OpReqType req,
const NDArray &output,
const bool reverse) {
using namespace mshadow;
using namespace mxnet_op;
using namespace csr;
CHECK_EQ(dns.storage_type(), kDefaultStorage);
CHECK_EQ(csr.storage_type(), kCSRStorage);
CHECK_EQ(req, kWriteTo) << "elemwise(dns, csr) = csr only supports kWriteTo";
if (req == kNullOp) return;
const bool supported_op = std::is_same<OP, mshadow_op::mul>::value;
CHECK(supported_op == true) << "elemwise(dns, csr) = csr only supports mul";
const nnvm::dim_t num_csr_rows = csr.shape()[0];
const nnvm::dim_t num_csr_cols = csr.shape()[1];
const nnvm::dim_t nnz = csr.storage_shape()[0];
Stream<xpu> *s = ctx.get_stream<xpu>();

output.CheckAndAlloc({Shape1(num_csr_rows + 1), Shape1(nnz)});
if (csr.storage_initialized()) {
TBlob csr_data = csr.data();
TBlob csr_indices = csr.aux_data(kIdx);
TBlob csr_indptr = csr.aux_data(kIndPtr);
MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, {
MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, {
MXNET_ASSIGN_REQ_SWITCH(req, Req, {
if (reverse) {
Kernel<ElemwiseDnsCsrCsrKernel<Req, OP, true>, xpu>::Launch(
s, num_csr_rows, output.data().dptr<DType>(), dns.data().dptr<DType>(),
csr_data.dptr<DType>(), csr_indices.dptr<IType>(), csr_indptr.dptr<CType>(),
num_csr_rows, num_csr_cols);
} else {
Kernel<ElemwiseDnsCsrCsrKernel<Req, OP, false>, xpu>::Launch(
s, num_csr_rows, output.data().dptr<DType>(), dns.data().dptr<DType>(),
csr_data.dptr<DType>(), csr_indices.dptr<IType>(), csr_indptr.dptr<CType>(),
num_csr_rows, num_csr_cols);
}
Copy(output.aux_data(kIdx).FlatTo1D<xpu, IType>(),
csr.aux_data(kIdx).FlatTo1D<xpu, IType>(), s);
Copy(output.aux_data(kIndPtr).FlatTo1D<xpu, CType>(),
csr.aux_data(kIndPtr).FlatTo1D<xpu, CType>(), s);
});
});
});
});
} else {
FillZerosCsrImpl(s, output);
}
}

/*!
* \brief Kernel for performing elemwise op between dense and rsp tensor
* \param i global thread id
Expand Down
21 changes: 21 additions & 0 deletions src/operator/tensor/elemwise_binary_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ bool ElemwiseBinaryOp::BackwardUseInStorageType(const nnvm::NodeAttrs& attrs,
const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback :
DispatchMode::kFComputeEx;
const int ograd_stype = in_attrs->at(0);
const int lhs_stype = in_attrs->at(1);
const int rhs_stype = in_attrs->at(2);
int& lhs_grad_stype = out_attrs->at(0);
int& rhs_grad_stype = out_attrs->at(1);
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
dispatched = storage_type_assign(out_attrs, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
Expand All @@ -74,6 +79,22 @@ bool ElemwiseBinaryOp::BackwardUseInStorageType(const nnvm::NodeAttrs& attrs,
dispatch_mode, dispatch_ex);
}
}
if (!dispatched && ograd_stype == kDefaultStorage &&
((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
(lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) {
const bool reverse = (lhs_stype == kCSRStorage);
if (reverse &&
type_assign(&lhs_grad_stype, kDefaultStorage) &&
type_assign(&rhs_grad_stype, kCSRStorage)) {
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
dispatched = true;
} else if (!reverse &&
type_assign(&lhs_grad_stype, kCSRStorage) &&
type_assign(&rhs_grad_stype, kDefaultStorage)) {
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
dispatched = true;
}
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
Expand Down
121 changes: 89 additions & 32 deletions src/operator/tensor/elemwise_binary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,11 @@ class ElemwiseBinaryOp : public OpBase {
typename xpu,
typename LOP,
typename ROP,
typename DType,
bool in0_ok_dense = false,
bool in1_ok_dense = false,
bool in2_ok_dense = false,
typename BackupCompute>
static inline void BackwardUseInEx_(const nnvm::NodeAttrs &attrs,
static inline void RspRspOpBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
Expand Down Expand Up @@ -200,6 +199,33 @@ class ElemwiseBinaryOp : public OpBase {
}
}

template<typename xpu, typename LOP, typename ROP>
static inline void DnsCsrCsrOpBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
const bool supported_ops = std::is_same<mshadow_op::right, LOP>::value &&
std::is_same<mshadow_op::left, ROP>::value;
CHECK(supported_ops)
<< "Only backward for mul is supported (LOP should be right, ROP should be left)";
const NDArray& out_grad = inputs[0];
const NDArray& lhs_in = inputs[1];
const NDArray& rhs_in = inputs[2];
const NDArray& lhs_grad = outputs[0];
const NDArray& rhs_grad = outputs[1];
const bool reverse = (outputs[0].storage_type() == kCSRStorage);
if (reverse) {
DnsCsrCsrOp<xpu, mshadow_op::mul>(attrs, ctx, out_grad, rhs_in, req[0], lhs_grad, false);
Compute<xpu, mshadow_op::mul>(attrs, ctx, {out_grad.data(), lhs_in.data()}, {req[1]},
{rhs_grad.data()});
} else {
DnsCsrCsrOp<xpu, mshadow_op::mul>(attrs, ctx, out_grad, lhs_in, req[1], rhs_grad, false);
Compute<xpu, mshadow_op::mul>(attrs, ctx, {out_grad.data(), rhs_in.data()}, {req[0]},
{lhs_grad.data()});
}
}

public:
/*! \brief Binary op handling for lhr/rhs: RspDns, RspRsp, DnsRsp, or RspRsp->Dns result */
template<typename OP>
Expand Down Expand Up @@ -232,44 +258,54 @@ class ElemwiseBinaryOp : public OpBase {
/*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
template<typename OP>
static void CsrCsrOp(mshadow::Stream<cpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output);
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output);

/*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
template<typename OP>
static void CsrCsrOp(mshadow::Stream<gpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output);
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output);

/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
template<typename xpu, typename OP>
static void DnsCsrDnsOp(mshadow::Stream<xpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output,
const bool reverse);
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output,
const bool reverse);

/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
template<typename xpu, typename OP>
static void DnsCsrCsrOp(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output,
const bool reverse);

/*! \brief DNS -op- RSP binary operator for non-canonical NDArray */
template<typename xpu, typename OP>
static void DnsRspDnsOp(mshadow::Stream<xpu> *s,
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output,
const bool reverse);
const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &lhs,
const NDArray &rhs,
OpReqType req,
const NDArray &output,
const bool reverse);

public:
/*!
Expand Down Expand Up @@ -336,6 +372,14 @@ class ElemwiseBinaryOp : public OpBase {
dispatched = storage_type_assign(&out_stype, kRowSparseStorage,
dispatch_mode, dispatch_ex);
}
if (!dispatched &&
((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
(lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) {
// csr, dns -> csr
// dns, csr -> csr
dispatched = storage_type_assign(&out_stype, kCSRStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
Expand Down Expand Up @@ -540,6 +584,14 @@ class ElemwiseBinaryOp : public OpBase {
req[0], outputs[0], lhs_may_be_dense, rhs_may_be_dense, false, false);
} else if (lhs_stype == kCSRStorage && rhs_stype == kCSRStorage) {
ComputeEx<xpu, OP>(attrs, ctx, inputs, req, outputs);
} else if (((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
(lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage)) &&
out_stype == kCSRStorage) {
const NDArray& dns = (lhs_stype == kDefaultStorage)? inputs[0] : inputs[1];
const NDArray& csr = (lhs_stype == kCSRStorage)? inputs[0] : inputs[1];
const bool reverse = (lhs_stype == kCSRStorage);

DnsCsrCsrOp<xpu, OP>(attrs, ctx, dns, csr, req[0], outputs[0], reverse);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
Expand Down Expand Up @@ -635,16 +687,21 @@ class ElemwiseBinaryOp : public OpBase {
using namespace common;
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U); // lhs input grad, rhs input grad
const auto out_grad_stype = inputs[0].storage_type();
const auto lhs_grad_stype = outputs[0].storage_type();
const auto rhs_grad_stype = outputs[1].storage_type();
if (ContainsOnlyStorage(inputs, kRowSparseStorage) &&
(lhs_grad_stype == kDefaultStorage || lhs_grad_stype == kRowSparseStorage) &&
(rhs_grad_stype == kDefaultStorage || rhs_grad_stype == kRowSparseStorage)) {
// rsp, rsp, rsp -> [dns, rsp], [dns, rsp]
MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
BackwardUseInEx_<xpu, LOP, ROP, DType, in0_ok_dense, in1_ok_dense, in2_ok_dense>(
attrs, ctx, inputs, req, outputs, BackwardUseIn<xpu, LOP, ROP>);
});
RspRspOpBackward<xpu, LOP, ROP, in0_ok_dense, in1_ok_dense, in2_ok_dense>(
attrs, ctx, inputs, req, outputs, BackwardUseIn<xpu, LOP, ROP>);
}
if (((lhs_grad_stype == kDefaultStorage && rhs_grad_stype == kCSRStorage) ||
(lhs_grad_stype == kCSRStorage && rhs_grad_stype == kDefaultStorage)) &&
out_grad_stype == kDefaultStorage) {
// dns, csr, dns -> [csr, dns] / csr, dns, dns -> [dns, csr]
DnsCsrCsrOpBackward<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
}
}
}; // class ElemwiseBinaryOp
Expand Down
4 changes: 3 additions & 1 deletion src/operator/tensor/elemwise_binary_op_basic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ NNVM_REGISTER_OP(_backward_sub)
mshadow_op::negation>);

NNVM_REGISTER_OP(elemwise_mul)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::mul>);
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::mul>)
.set_attr<FComputeEx>("FComputeEx<gpu>",
ElemwiseBinaryOp::ComputeDnsLRValueEx<gpu, op::mshadow_op::mul, true, true>);

NNVM_REGISTER_OP(_backward_mul)
.set_attr<FCompute>("FCompute<gpu>",
Expand Down
14 changes: 12 additions & 2 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,19 @@ def elemwise_mul_stype(lstype, rstype):
return 'row_sparse'
elif lstype == 'row_sparse' and rstype == 'default':
return 'row_sparse'
elif lstype == 'default' and rstype == 'csr':
return 'csr'
elif lstype == 'csr' and rstype == 'default':
return 'csr'
else:
return 'default'

def elemwise_mul_lhs_grad_stype(lstype, rstype):
return elemwise_mul_stype(elemwise_mul_stype(lstype, rstype), rstype)

def elemwise_mul_rhs_grad_stype(lstype, rstype):
return elemwise_mul_stype(elemwise_mul_stype(lstype, rstype), lstype)

def check_elemwise_binary_ops(lhs_stype, rhs_stype, shape,
lhs_grad_stype=None, rhs_grad_stype=None,
lhs_density=.5, rhs_density=.5,
Expand Down Expand Up @@ -378,8 +388,8 @@ def check_elemwise_binary_ops(lhs_stype, rhs_stype, shape,
lambda l, r: mx.sym.sparse.elemwise_mul(l, r),
lambda l, r: l * r,
lambda outg, l, r: (outg * r, outg * l),
elemwise_mul_stype(lhs_stype, rhs_stype),
elemwise_mul_stype(lhs_stype, rhs_stype),
elemwise_mul_lhs_grad_stype(lhs_stype, rhs_stype),
elemwise_mul_rhs_grad_stype(lhs_stype, rhs_stype),
expected_result_storage_type=elemwise_mul_stype(lhs_stype, rhs_stype),
ograd_density=ograd_density,
force_lr_overlap=force_lr_overlap,
Expand Down

0 comments on commit 9feecce

Please sign in to comment.