diff --git a/src/relay/pass/mac_count.cc b/src/relay/pass/mac_count.cc index c9ee4eec0337..3d77fabe6fe9 100644 --- a/src/relay/pass/mac_count.cc +++ b/src/relay/pass/mac_count.cc @@ -30,7 +30,9 @@ #include #include #include +#include #include +#include "pattern_util.h" namespace tvm { namespace relay { @@ -65,7 +67,7 @@ int64_t ConvMacCount(const Call& call_node) { } Array args = call_node->args; CHECK(args.size() == 2) - << "The number of input arguments of a CONV 2D node should be 2."; + << "The number of input arguments of a CONV 2D node should be 2."; const auto* conv_2d_attr = call_node->attrs.as(); const auto* data_type = args[0]->checked_type().as(); Array data_shape = data_type->shape; @@ -73,18 +75,21 @@ int64_t ConvMacCount(const Call& call_node) { int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C')); int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c')); CHECK(C_ind != -1) - << "There is no input channel dimension."; + << "There is no input channel dimension."; int64_t input_channel = static_cast(data_shape[C_ind].as()->value); if (c_ind != -1) input_channel *= static_cast(data_shape[c_ind].as()->value); Array kernel_size = conv_2d_attr->kernel_size; CHECK(kernel_size.size() == 2) - << "The dimension of the kernel size in Conv 2D should be 2."; + << "The dimension of the kernel in Conv 2D should be 2."; const auto* expr = call_node->checked_type().as(); Array output_tensor = expr->shape; CHECK(output_tensor.size() == 4 || output_tensor.size() == 5) - << "The dimension of the output tensor in Conv 2D should be 4 or 5."; - int64_t count = input_channel * GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); + << "The dimension of the output tensor in Conv 2D should be 4 or 5."; + int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); + CHECK_EQ(input_channel % conv_2d_attr->groups, 0) + << "The number of input channels is not divisble by groups."; + count *= input_channel/conv_2d_attr->groups; return count; } diff --git a/tests/python/relay/test_pass_mac_count.py b/tests/python/relay/test_pass_mac_count.py index 5a975fd41364..98ba1ad6325d 100644 --- a/tests/python/relay/test_pass_mac_count.py +++ b/tests/python/relay/test_pass_mac_count.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Unit tests for MAC counter.""" +import numpy as np import tvm from tvm import relay @@ -99,7 +100,35 @@ def test_simple_network(): expect_count = 231411712 assert compute_count == expect_count +def test_depthwise_conv2d(): + batch_size = 1 + dshape = (batch_size, 64, 56, 56) + weight_conv = relay.var("weight_depthwiseconv", shape=(64, 1, 3, 3)) + data1 = relay.var("data1", shape=dshape) + data2 = relay.var("data2", shape=dshape) + depthwise_conv2d_1 = relay.nn.conv2d( + data1, + weight_conv, + kernel_size=(3, 3), + padding=(1, 1), + groups=64) + depthwise_conv2d_2 = relay.nn.conv2d( + data2, + weight_conv, + kernel_size=(3, 3), + padding=(1, 1), + groups=64) + add = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) + func = relay.Function([data1, data2, weight_conv], + relay.Tuple(tvm.convert([depthwise_conv2d_1, + depthwise_conv2d_2, + add]))) + func = relay.ir_pass.infer_type(func) + compute_count = relay.ir_pass.get_total_mac_number(func) + assert compute_count == 2 * np.prod(dshape) * 3*3 + if __name__ == "__main__": test_conv() test_gemm() test_simple_network() + test_depthwise_conv2d()