From 66b21b5bcdd4622d021ab09db575006076887cb3 Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Mon, 16 Mar 2020 16:02:05 -0700 Subject: [PATCH] fixing batch_norm and layer_norm for large tensors (#17805) Co-authored-by: Rohit Kumar Srivastava --- src/operator/nn/batch_norm.cc | 2 +- src/operator/nn/layer_norm.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 97acced29d6e..df0357369fed 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -330,7 +330,7 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs, : param.axis); CHECK_LT(channelAxis, dshape.ndim()) << "Channel axis out of range: " << param.axis; - const int channelCount = dshape[channelAxis]; + const index_t channelCount = dshape[channelAxis]; if (!mxnet::ndim_is_known(dshape)) { return false; diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index d385b93e9cff..e3d641af4015 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -47,7 +47,7 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs, CHECK(axis >= 0 && axis < dshape.ndim()) << "Channel axis out of range: axis=" << param.axis; - const int channelCount = dshape[axis]; + const index_t channelCount = dshape[axis]; if (!mxnet::ndim_is_known(dshape)) { return false;