From fcf37f9975e349d503e129fc64e964ddfa3bf28e Mon Sep 17 00:00:00 2001 From: Adnios <2780199647@qq.com> Date: Sat, 17 Apr 2021 17:58:13 +0800 Subject: [PATCH 1/2] add check for group not equal zero --- src/operator/nn/convolution.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 556918a572a8..be33db805dbd 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -99,6 +99,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, // 1d conv CHECK_EQ(dshp.ndim(), 3U) << "Input data should be 3D in batch-num_filter-x"; Shape<3> dshape = ConvertLayout(dshp.get<3>(), param_.layout.value(), kNCW); + CHECK_NE(param_.num_group, 0U) \ + << "num_group must be non-zero"; Shape<3> wshape = Shape3(param_.num_filter / param_.num_group, mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1, param_.kernel[0]); @@ -149,6 +151,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(dshp.ndim(), 4U) \ << "Input data should be 4D in batch-num_filter-y-x"; Shape<4> dshape = ConvertLayout(dshp.get<4>(), param_.layout.value(), kNCHW); + CHECK_NE(param_.num_group, 0U) \ + << "num_group must be non-zero"; Shape<4> wshape = Shape4(param_.num_filter / param_.num_group, mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1, param_.kernel[0], param_.kernel[1]); @@ -208,6 +212,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(dshp.ndim(), 5U) \ << "Input data should be 5D in batch-num_filter-depth-y-x"; Shape<5> dshape = ConvertLayout(dshp.get<5>(), param_.layout.value(), kNCDHW); + CHECK_NE(param_.num_group, 0U) \ + << "num_group must be non-zero"; Shape<5> wshape = Shape5(param_.num_filter / param_.num_group, mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1, param_.kernel[0], param_.kernel[1], param_.kernel[2]); From 2824daa6ada9a02c0ca72852d86a2d883f2d099b Mon Sep 17 00:00:00 2001 From: Adnios <2780199647@qq.com> Date: Mon, 19 Apr 2021 09:37:20 +0800 Subject: [PATCH 2/2] num_group in convolution must be positive --- src/operator/nn/convolution.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index be33db805dbd..cbfadf9b6450 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -99,8 +99,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, // 1d conv CHECK_EQ(dshp.ndim(), 3U) << "Input data should be 3D in batch-num_filter-x"; Shape<3> dshape = ConvertLayout(dshp.get<3>(), param_.layout.value(), kNCW); - CHECK_NE(param_.num_group, 0U) \ - << "num_group must be non-zero"; + CHECK_GT(param_.num_group, 0U) \ + << "Range only supports num_group > 0, received " << param_.num_group; Shape<3> wshape = Shape3(param_.num_filter / param_.num_group, mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1, param_.kernel[0]); @@ -151,8 +151,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(dshp.ndim(), 4U) \ << "Input data should be 4D in batch-num_filter-y-x"; Shape<4> dshape = ConvertLayout(dshp.get<4>(), param_.layout.value(), kNCHW); - CHECK_NE(param_.num_group, 0U) \ - << "num_group must be non-zero"; + CHECK_GT(param_.num_group, 0U) \ + << "Range only supports num_group > 0, received " << param_.num_group; Shape<4> wshape = Shape4(param_.num_filter / param_.num_group, mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1, param_.kernel[0], param_.kernel[1]); @@ -212,8 +212,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(dshp.ndim(), 5U) \ << "Input data should be 5D in batch-num_filter-depth-y-x"; Shape<5> dshape = ConvertLayout(dshp.get<5>(), param_.layout.value(), kNCDHW); - CHECK_NE(param_.num_group, 0U) \ - << "num_group must be non-zero"; + CHECK_GT(param_.num_group, 0U) \ + << "Range only supports num_group > 0, received " << param_.num_group; Shape<5> wshape = Shape5(param_.num_filter / param_.num_group, mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1, param_.kernel[0], param_.kernel[1], param_.kernel[2]);