Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
num_group in convolution must be positive
Browse files Browse the repository at this point in the history
  • Loading branch information
Adnios committed Apr 19, 2021
1 parent fcf37f9 commit 3db8883
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/operator/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
// 1d conv
CHECK_EQ(dshp.ndim(), 3U) << "Input data should be 3D in batch-num_filter-x";
Shape<3> dshape = ConvertLayout(dshp.get<3>(), param_.layout.value(), kNCW);
CHECK_NE(param_.num_group, 0U) \
<< "num_group must be non-zero";
CHECK_GT(param_.num_group, 0U) \
<< "Range only supports num_group > 0, received "<<param_.num_group;
Shape<3> wshape = Shape3(param_.num_filter / param_.num_group,
mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1,
param_.kernel[0]);
Expand Down Expand Up @@ -151,8 +151,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(dshp.ndim(), 4U) \
<< "Input data should be 4D in batch-num_filter-y-x";
Shape<4> dshape = ConvertLayout(dshp.get<4>(), param_.layout.value(), kNCHW);
CHECK_NE(param_.num_group, 0U) \
<< "num_group must be non-zero";
CHECK_GT(param_.num_group, 0U) \
<< "Range only supports num_group > 0, received "<<param_.num_group;
Shape<4> wshape = Shape4(param_.num_filter / param_.num_group,
mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1,
param_.kernel[0], param_.kernel[1]);
Expand Down Expand Up @@ -212,8 +212,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(dshp.ndim(), 5U) \
<< "Input data should be 5D in batch-num_filter-depth-y-x";
Shape<5> dshape = ConvertLayout(dshp.get<5>(), param_.layout.value(), kNCDHW);
CHECK_NE(param_.num_group, 0U) \
<< "num_group must be non-zero";
CHECK_GT(param_.num_group, 0U) \
<< "Range only supports num_group > 0, received "<<param_.num_group;
Shape<5> wshape = Shape5(param_.num_filter / param_.num_group,
mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1,
param_.kernel[0], param_.kernel[1], param_.kernel[2]);
Expand Down

0 comments on commit 3db8883

Please sign in to comment.