Skip to content

Commit

Permalink
[BUGFIX] Add check to make sure num_group is non-zero (apache#20186)
Browse files Browse the repository at this point in the history
* add check for group not equal zero

* num_group in convolution must be positive
  • Loading branch information
Adnios authored and chinakook committed Aug 1, 2021
1 parent 04ffc6b commit 9be4117
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/operator/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
// 1d conv
CHECK_EQ(dshp.ndim(), 3U) << "Input data should be 3D in batch-num_filter-x";
Shape<3> dshape = ConvertLayout(dshp.get<3>(), param_.layout.value(), kNCW);
CHECK_GT(param_.num_group, 0U) \
<< "Range only supports num_group > 0, received " << param_.num_group;
Shape<3> wshape = Shape3(param_.num_filter / param_.num_group,
mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1,
param_.kernel[0]);
Expand Down Expand Up @@ -152,6 +154,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(dshp.ndim(), 4U) \
<< "Input data should be 4D in batch-num_filter-y-x";
Shape<4> dshape = ConvertLayout(dshp.get<4>(), param_.layout.value(), kNCHW);
CHECK_GT(param_.num_group, 0U) \
<< "Range only supports num_group > 0, received " << param_.num_group;
Shape<4> wshape = Shape4(param_.num_filter / param_.num_group,
mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1,
param_.kernel[0], param_.kernel[1]);
Expand Down Expand Up @@ -211,6 +215,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(dshp.ndim(), 5U) \
<< "Input data should be 5D in batch-num_filter-depth-y-x";
Shape<5> dshape = ConvertLayout(dshp.get<5>(), param_.layout.value(), kNCDHW);
CHECK_GT(param_.num_group, 0U) \
<< "Range only supports num_group > 0, received " << param_.num_group;
Shape<5> wshape = Shape5(param_.num_filter / param_.num_group,
mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1,
param_.kernel[0], param_.kernel[1], param_.kernel[2]);
Expand Down

0 comments on commit 9be4117

Please sign in to comment.