From 843b407f9554b5af0441bc52d130dd02d669b854 Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Tue, 15 Oct 2019 21:18:43 +0000 Subject: [PATCH] adding large tensor support for dropout operator --- src/operator/nn/dropout-inl.h | 12 ++++++------ src/operator/tensor/elemwise_binary_broadcast_op.h | 5 +++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index eda9051fd0a2..6387dff96eb7 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -182,10 +182,10 @@ class DropoutOp { * \param input_data Input data to perform the dropout on * \param pkeep Dropout rate (keep when the generated random number is less than this value) */ - MSHADOW_XINLINE static void Map(int id, + MSHADOW_XINLINE static void Map(index_t id, RandGenerator gen, - const int N, - const int step, + const index_t N, + const index_t step, DType *dropout_out, DType *mask_out, const DType *input_data, @@ -199,10 +199,10 @@ class DropoutOp { }; struct BernoulliKernel { /*! \brief Bernoulli kernel for generating mask */ - MSHADOW_XINLINE static void Map(int id, + MSHADOW_XINLINE static void Map(index_t id, RandGenerator gen, - const int N, - const int step, + const index_t N, + const index_t step, DType *mask_out, const real_t pkeep) { RNG_KERNEL_LOOP(xpu, DType, id, gen, N, step, { diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index 6a612e6f1cd5..3d3bcfacbd05 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -151,9 +151,10 @@ inline int BinaryBroadcastShapeCompact(const mxnet::TShape& lshape, const mxnet: *new_oshape = mxnet::TShape(odim, 1); int bl = oshape.ndim() - lshape.ndim(); int br = oshape.ndim() - rshape.ndim(); - int j = 0, lprod = 1, rprod = 1, oprod = 1; + int j = 0; + index_t lprod = 1, rprod = 1, oprod = 1; for (int i = 0; i < oshape.ndim(); ++i) { - int l = 1, r = 1, o = oshape[i]; + index_t l = 1, r = 1, o = oshape[i]; if (i >= bl) l = lshape[i-bl]; if (i >= br) r = rshape[i-br]; if ((lprod != rprod || l != r) &&