Skip to content

Commit

Permalink
Change idx type switch for aux data (apache#6860)
Browse files Browse the repository at this point in the history
* Change idx type switch for aux data

* Add mshadow commit
  • Loading branch information
reminisce authored and piiswrong committed Jun 29, 2017
1 parent 88d46ec commit 1e804c1
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 18 deletions.
2 changes: 1 addition & 1 deletion mshadow
4 changes: 2 additions & 2 deletions src/kvstore/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class CommCPU : public Comm {
std::vector<bool> skip(num_in, false);
// the values tensor of the inputs
MSHADOW_TYPE_SWITCH(out->dtype(), DType, {
MSHADOW_INT_TYPE_SWITCH(out->aux_type(kIdx), IType, {
MSHADOW_IDX_TYPE_SWITCH(out->aux_type(kIdx), IType, {
std::vector<Tensor<cpu, 2, DType>> in_vals(num_in);
std::vector<Tensor<cpu, 1, IType>> in_indices(num_in);
// offset to the values tensor of all inputs
Expand Down Expand Up @@ -350,7 +350,7 @@ class CommCPU : public Comm {
<< out->storage_type() << " given)";

MSHADOW_TYPE_SWITCH(out->dtype(), DType, {
MSHADOW_INT_TYPE_SWITCH(out->aux_type(kIdx), IType, {
MSHADOW_IDX_TYPE_SWITCH(out->aux_type(kIdx), IType, {
std::vector<IType> uniq_row_idx;
GetUniqueRspRowIdx(nds, &uniq_row_idx);
out->CheckAndAlloc({mshadow::Shape1(uniq_row_idx.size())});
Expand Down
12 changes: 6 additions & 6 deletions src/operator/nn/cast_storage-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ inline void CastStorageDnsRspImpl(mshadow::Stream<cpu>* s, const TBlob& dns, NDA
CHECK_EQ(rsp->storage_type(), kRowSparseStorage);
CHECK_EQ(dns.shape_, rsp->shape());
MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type
MSHADOW_INT_TYPE_SWITCH(rsp->aux_type(rowsparse::kIdx), RType, { // row idx type
MSHADOW_IDX_TYPE_SWITCH(rsp->aux_type(rowsparse::kIdx), RType, { // row idx type
const index_t num_rows = dns.shape_[0];
const index_t num_cols = dns.shape_[1];
rsp->CheckAndAllocAuxData(rowsparse::kIdx, mshadow::Shape1(num_rows));
Expand Down Expand Up @@ -102,7 +102,7 @@ void CastStorageRspDnsImpl(mshadow::Stream<xpu>* s, const NDArray& rsp, TBlob* d
using namespace mshadow::expr;
CHECK_EQ(rsp.storage_type(), kRowSparseStorage);
MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, {
MSHADOW_INT_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, {
MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, {
// assign zeros
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(s, dns->Size(), dns->dptr<DType>());
if (rsp.storage_initialized()) {
Expand Down Expand Up @@ -186,8 +186,8 @@ inline void CastStorageDnsCsrImpl(mshadow::Stream<cpu>* s, const TBlob& dns, NDA
CHECK_EQ(dns.shape_.ndim(), 2);
CHECK_EQ(dns.shape_, csr->shape());
MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type
MSHADOW_INT_TYPE_SWITCH(csr->aux_type(csr::kIndPtr), IType, { // indptr type
MSHADOW_INT_TYPE_SWITCH(csr->aux_type(csr::kIdx), CType, { // col idx type
MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIndPtr), IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIdx), CType, { // col idx type
const index_t num_rows = dns.shape_[0];
const index_t num_cols = dns.shape_[1];
csr->CheckAndAllocAuxData(csr::kIndPtr, mshadow::Shape1(num_rows+1));
Expand Down Expand Up @@ -248,8 +248,8 @@ void CastStorageCsrDnsImpl(mshadow::Stream<xpu>* s, const NDArray& csr, TBlob* d
CHECK_EQ(dns->shape_.ndim(), 2);
CHECK_EQ(dns->shape_, csr.shape());
MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, { // data type
MSHADOW_INT_TYPE_SWITCH(csr.aux_type(csr::kIndPtr), IType, { // indptr type
MSHADOW_INT_TYPE_SWITCH(csr.aux_type(csr::kIdx), CType, { // col idx type
MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(csr::kIndPtr), IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(csr::kIdx), CType, { // col idx type
const index_t num_rows = dns->shape_[0];
const index_t num_cols = dns->shape_[1];
DType* dns_data = dns->dptr<DType>();
Expand Down
4 changes: 2 additions & 2 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
CHECK_GT(weight.shape_.Size(), 0);

MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
MSHADOW_INT_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, {
MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
auto weight_data = weight.dptr<DType>();
auto grad_idx = grad.aux_data(rowsparse::kIdx).dptr<IType>();
Expand Down Expand Up @@ -364,7 +364,7 @@ inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
CHECK_GT(mom.shape_.Size(), 0);

MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
MSHADOW_INT_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
auto weight_data = weight.dptr<DType>();
auto grad_idx = grad.aux_data(kIdx).dptr<IType>();
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ void SparseRetainOpForwardEx(const nnvm::NodeAttrs& attrs,
using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(output_data.type_flag_, DType, { // output data type
MSHADOW_INT_TYPE_SWITCH(output_idx.type_flag_, RType, { // row index data type
MSHADOW_IDX_TYPE_SWITCH(output_idx.type_flag_, RType, { // row index data type
MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type
Kernel<set_zero, xpu>::Launch(s, output_data.Size(), output_data.dptr<DType>());
Kernel<SparseRetainRspForward, xpu>::Launch(s, idx_data.Size(), output_data.dptr<DType>(),
Expand Down Expand Up @@ -949,7 +949,7 @@ void SparseRetainOpBackwardEx(const nnvm::NodeAttrs& attrs,
using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(out_grad_data.type_flag_, DType, { // output data type
MSHADOW_INT_TYPE_SWITCH(in_grad_idx.type_flag_, RType, { // row index data type
MSHADOW_IDX_TYPE_SWITCH(in_grad_idx.type_flag_, RType, { // row index data type
MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type
MXNET_ASSIGN_REQ_SWITCH(req[sr::kArr], req_type, {
Kernel<SparseRetainRspBackward<req_type>, xpu>::Launch(
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ inline void FillDnsZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
using namespace mxnet_op;
CHECK_EQ(dst->storage_type(), kRowSparseStorage);
MSHADOW_REAL_TYPE_SWITCH(dst->dtype(), DType, {
MSHADOW_INT_TYPE_SWITCH(dst->aux_type(kIdx), IType, {
MSHADOW_IDX_TYPE_SWITCH(dst->aux_type(kIdx), IType, {
auto num_rows = dst->shape()[0];
dst->CheckAndAlloc({Shape1(num_rows)});
auto idx = dst->aux_data(kIdx).FlatTo1D<xpu, IType>(s);
Expand Down
8 changes: 4 additions & 4 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,8 @@ void DotCsrDnsDnsImpl(const OpContext& ctx,
const TBlob data_out = *ret;

MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MSHADOW_INT_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_INT_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
if (std::is_same<xpu, cpu>::value) { // cpu parallelization by row blocks
if (kWriteTo == req) {
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
Expand Down Expand Up @@ -1157,8 +1157,8 @@ void SliceCsrImpl(const SliceParam &param, const OpContext& ctx,
return;
}
// assume idx indptr share the same type
MSHADOW_INT_TYPE_SWITCH(in.aux_type(kIndPtr), RType, {
MSHADOW_INT_TYPE_SWITCH(in.aux_type(kIdx), IType, {
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), RType, {
MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIdx), IType, {
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
auto in_indptr = in.aux_data(kIndPtr).dptr<RType>();
auto out_indptr = out.aux_data(kIndPtr).dptr<RType>();
Expand Down

0 comments on commit 1e804c1

Please sign in to comment.