Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix windows compile error
Browse files Browse the repository at this point in the history
  • Loading branch information
JiangZhaoh committed May 26, 2020
1 parent 4a789af commit 0577671
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 49 deletions.
43 changes: 7 additions & 36 deletions src/operator/tensor/index_add-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,35 +60,6 @@ inline bool IndexModifyOpType(const nnvm::NodeAttrs& attrs,
return (*out_attrs)[0] != -1;
}

MSHADOW_XINLINE void index_unravel(const size_t idx, const int ndim,
const size_t* shape, size_t* ret) {
#pragma unroll
for (int i = ndim-1, j = idx; i >= 0; --i) {
auto tmp = j / shape[i];
ret[i] = j - tmp*shape[i];
j = tmp;
}
}

MSHADOW_XINLINE size_t index_dot(const int ndim, const size_t* coord, const size_t* stride) {
size_t ret = 0;
#pragma unroll
for (int i = 0; i < ndim; ++i) {
ret += coord[i] * stride[i];
}
return ret;
}

MSHADOW_XINLINE void vec_calc_stride(const int ndim, const size_t* shape,
size_t* stride) {
size_t cumprod = 1;
#pragma unroll
for (int i = ndim - 1; i >= 0; --i) {
stride[i] = (shape[i] > 1) ? cumprod : 0;
cumprod *= shape[i];
}
}

template<typename xpu, typename DType>
void IndexAddForwardCalc(mshadow::Stream<xpu> *s,
const int ind_num, DType* out,
Expand All @@ -98,7 +69,7 @@ void IndexAddForwardCalc(mshadow::Stream<xpu> *s,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_shape,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& a_shape,
const size_t a_tail_size,
const int a_tail_size,
const int ind_ndim, const int* ind,
const int a_ndim);

Expand Down Expand Up @@ -156,7 +127,7 @@ void IndexAddOpForward(const nnvm::NodeAttrs& attrs,
<< " in axis " << i;
}
}
size_t a_tail_size = a.shape_.ProdShape(ind_ndim, a_ndim);
int a_tail_size = static_cast<int>(a.shape_.ProdShape(ind_ndim, a_ndim));
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>a_shape, val_shape;
for (int i = MXNET_SPECIAL_MAX_NDIM - 1, j = a_ndim - 1; i >= 0; --i, --j) {
a_shape[i] = (j >= 0) ? a.shape_[j] : 1;
Expand Down Expand Up @@ -197,14 +168,14 @@ struct IndexAddBackwardAKernel {
MSHADOW_XINLINE static void Map(size_t i, DType* grad_a,
const DType* ograd,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& stride,
const size_t tail_size, const int ind_num, const int ind_ndim,
const int tail_size, const int ind_num, const int ind_ndim,
const int32_t* ind_vec, const int req, const int out_ndim) {
size_t id = 0;
int seg = MXNET_SPECIAL_MAX_NDIM - out_ndim;
for (int dim = 0; dim < ind_ndim; ++dim) {
id += stride[seg + dim] * ind_vec[dim * ind_num + i];
}
for (size_t _i = 0; _i < tail_size; ++_i) {
for (int _i = 0; _i < tail_size; ++_i) {
KERNEL_ASSIGN(grad_a[id + _i], req, ograd[id + _i]);
}
}
Expand All @@ -214,7 +185,7 @@ template<typename xpu, typename DType>
void IndexAddOpBackwardACalc(mshadow::Stream<xpu> *s,
DType* grad_a, const DType* ograd,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& stride,
const size_t tail_size, const int ind_num,
const int tail_size, const int ind_num,
const int ind_ndim, const int32_t* ind_vec,
const int req, const int out_ndim);

Expand All @@ -225,7 +196,7 @@ void IndexAddOpBackwardValCalc(mshadow::Stream<xpu> *s,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& ograd_pre_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_shape,
const size_t tail_size, const int ind_num,
const int tail_size, const int ind_num,
const int ind_ndim, const int32_t* ind_vec,
const int out_ndim);

Expand Down Expand Up @@ -259,7 +230,7 @@ void IndexAddOpBackward(const nnvm::NodeAttrs& attrs,
int ind_num = ind.shape_[1];
// broadcast 'ind'
int ndim = ograd.shape_.ndim();
size_t tail_size = ograd.shape_.ProdShape(ind_ndim, ndim);
int tail_size = static_cast<int>(ograd.shape_.ProdShape(ind_ndim, ndim));
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>ograd_shape, val_shape;
for (int i = MXNET_SPECIAL_MAX_NDIM - 1, j = ndim - 1; i >= 0; --i, --j) {
ograd_shape[i] = (j >= 0) ? ograd.shape_[j] : 1;
Expand Down
8 changes: 4 additions & 4 deletions src/operator/tensor/index_add_backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ template<typename xpu, typename DType>
void IndexAddOpBackwardACalc(mshadow::Stream<xpu> *s,
DType* grad_a, const DType* ograd,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& stride,
const size_t tail_size, const int ind_num,
const int tail_size, const int ind_num,
const int ind_ndim, const int32_t* ind_vec,
const int req, const int out_ndim) {
using namespace mxnet_op;
Expand All @@ -48,7 +48,7 @@ struct IndexAddBackwardValCPUKernel {
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& ograd_pre_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_shape,
const size_t ograd_tail_size, const int ind_num,
const int ograd_tail_size, const int ind_num,
const int ind_ndim, const int32_t* ind_vec,
const int out_ndim) {
size_t id = 0;
Expand All @@ -58,7 +58,7 @@ struct IndexAddBackwardValCPUKernel {
}
id *= ograd_tail_size;
#pragma omp parallel for
for (size_t _i = 0; _i < ograd_tail_size; ++_i) {
for (int _i = 0; _i < ograd_tail_size; ++_i) {
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_id =
mxnet_op::unravel(_i, ograd_tail_shape);
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_id;
Expand All @@ -82,7 +82,7 @@ void IndexAddOpBackwardValCalc(mshadow::Stream<xpu> *s,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& ograd_pre_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_shape,
const size_t tail_size, const int ind_num,
const int tail_size, const int ind_num,
const int ind_ndim, const int32_t* ind_vec,
const int out_ndim) {
using namespace mxnet_op;
Expand Down
8 changes: 4 additions & 4 deletions src/operator/tensor/index_add_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ template<typename xpu, typename DType>
void IndexAddOpBackwardACalc(mshadow::Stream<xpu> *s,
DType* grad_a, const DType* ograd,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& stride,
const size_t tail_size, const int ind_num,
const int tail_size, const int ind_num,
const int ind_ndim, const int32_t* ind_vec,
const int req, const int out_ndim) {
using namespace mxnet_op;
Expand All @@ -52,7 +52,7 @@ struct IndexAddBackwardValGPUKernel {
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& ograd_pre_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_shape,
const size_t ograd_tail_size, const int ind_num,
const int ograd_tail_size, const int ind_num,
const int ind_ndim, const int32_t* ind_vec,
const int out_ndim) {
size_t id = 0;
Expand All @@ -61,7 +61,7 @@ struct IndexAddBackwardValGPUKernel {
id += ograd_pre_stride[seg + dim] * ind_vec[dim * ind_num + i];
}
id *= ograd_tail_size;
for (size_t _i = 0; _i < ograd_tail_size; ++_i) {
for (int _i = 0; _i < ograd_tail_size; ++_i) {
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_id =
mxnet_op::unravel(_i, ograd_tail_shape);
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_id;
Expand All @@ -82,7 +82,7 @@ void IndexAddOpBackwardValCalc(mshadow::Stream<xpu> *s,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& ograd_pre_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_shape,
const size_t tail_size, const int ind_num,
const int tail_size, const int ind_num,
const int ind_ndim, const int32_t* ind_vec,
const int out_ndim) {
using namespace mxnet_op;
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/index_add_forward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct IndexAddForwardCPUKernel {
}
id *= a_tail_size;
#pragma omp parallel for
for (size_t _i = 0; _i < a_tail_size; ++_i) {
for (int _i = 0; _i < a_tail_size; ++_i) {
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_tail_id = mxnet_op::unravel(_i, a_tail_shape);
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_id;
for (int _j = seg; _j < seg + a_ndim; ++_j) {
Expand All @@ -77,7 +77,7 @@ void IndexAddForwardCalc(mshadow::Stream<xpu> *s,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_shape,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& a_shape,
const size_t a_tail_size,
const int a_tail_size,
const int ind_ndim, const int* ind,
const int a_ndim) {
using namespace mxnet_op;
Expand Down
6 changes: 3 additions & 3 deletions src/operator/tensor/index_add_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct IndexAddForwardGPUKernel {
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& a_pre_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_shape,
const size_t a_tail_size, const int ind_num,
const int a_tail_size, const int ind_num,
const int ind_ndim, const int* ind,
const int a_ndim) {
size_t id = 0;
Expand All @@ -47,7 +47,7 @@ struct IndexAddForwardGPUKernel {
id += a_pre_stride[seg + dim] * ind[dim * ind_num + i];
}
id *= a_tail_size;
for (size_t _i = 0; _i < a_tail_size; ++_i) {
for (int _i = 0; _i < a_tail_size; ++_i) {
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_tail_id = mxnet_op::unravel(_i, a_tail_shape);
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_id;
for (int _j = seg; _j < seg + a_ndim; ++_j) {
Expand All @@ -69,7 +69,7 @@ void IndexAddForwardCalc(mshadow::Stream<xpu> *s,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_stride,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& val_shape,
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>& a_shape,
const size_t a_tail_size,
const int a_tail_size,
const int ind_ndim, const int* ind,
const int a_ndim) {
using namespace mxnet_op;
Expand Down

0 comments on commit 0577671

Please sign in to comment.