From 5876fc990060d99c2588c1d20acff24b43d0029d Mon Sep 17 00:00:00 2001 From: Yida Wang Date: Fri, 22 Feb 2019 10:05:23 -0800 Subject: [PATCH] [RELAY][PASS]use attribute registration style in the mac count pass (#2645) --- src/relay/pass/mac_count.cc | 172 +++++++++++----------- tests/python/relay/test_pass_mac_count.py | 1 - 2 files changed, 82 insertions(+), 91 deletions(-) diff --git a/src/relay/pass/mac_count.cc b/src/relay/pass/mac_count.cc index 5709d7d0ea31..500312117c5b 100644 --- a/src/relay/pass/mac_count.cc +++ b/src/relay/pass/mac_count.cc @@ -16,19 +16,88 @@ namespace tvm { namespace relay { -namespace { +namespace mac_count { -bool IsConv2DNode(const ExprNode* node) { - const auto* call_node = dynamic_cast(node); - return call_node != nullptr && call_node->attrs.as(); +inline int64_t GetCartesianProd(Array arr) { + int64_t ret = 1; + for (size_t i = 0; i < arr.size(); i++) { + const auto* intImm = arr[i].as(); + ret *= static_cast(intImm->value); + } + return ret; +} + +/* + * \brief Preparation function for MAC count. + * \param call_node The call node. + * \return The number of MACs. + */ +using FMacCount = runtime::TypedPackedFunc< + int64_t(const Call& call_node)>; + +//---------------------------------------------- +// Per operator defs for MAC count +//---------------------------------------------- + +int64_t ConvMacCount(const Call& call_node) { + if (!call_node->checked_type_.defined()) { + LOG(WARNING) << "The infer type pass should be called before the mac count pass"; + return 0; + } + Array args = call_node->args; + CHECK(args.size() == 2) + << "The number of input arguments of a CONV 2D node should be 2."; + const auto* conv_2d_attr = call_node->attrs.as(); + const auto* data_type = args[0]->checked_type().as(); + Array data_shape = data_type->shape; + std::string data_layout = conv_2d_attr->data_layout; + int32_t C_ind = Layout(data_layout).Indexof('C'); + int32_t c_ind = Layout(data_layout).Indexof('c'); + CHECK(C_ind != -1) + << "There is no input channel dimension."; + int64_t input_channel = static_cast(data_shape[C_ind].as()->value); + if (c_ind != -1) + input_channel *= static_cast(data_shape[c_ind].as()->value); + Array kernel_size = conv_2d_attr->kernel_size; + CHECK(kernel_size.size() == 2) + << "The dimension of the kernel size in Conv 2D should be 2."; + const auto* expr = call_node->checked_type().as(); + Array output_tensor = expr->shape; + CHECK(output_tensor.size() == 4 || output_tensor.size() == 5) + << "The dimension of the output tensor in Conv 2D should be 4 or 5."; + int64_t count = input_channel * GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); + return count; } -bool IsDenseNode(const ExprNode* node) { - const auto* call_node = dynamic_cast(node); - return call_node != nullptr && call_node->attrs.as(); +int64_t DenseMacCount(const Call& call_node) { + if (!call_node->checked_type_.defined()) { + LOG(WARNING) << "The infer type pass should be called before the mac count pass"; + return 0; + } + Array args = call_node->args; + CHECK(args.size() == 2) + << "The number of input arguments of a Dense node should be 2."; + const auto* data_type = args[0]->checked_type().as(); + const auto* weight_type = args[1]->checked_type().as(); + Array data_shape = data_type->shape; + Array weight_shape = weight_type->shape; + CHECK(data_shape.size() == 2 && weight_shape.size() == 2) + << "The dimension of an input tensor to Dense node should be 2."; + int64_t d1 = static_cast(data_shape[0].as()->value); + int64_t d2 = static_cast(data_shape[1].as()->value); + int64_t d3 = static_cast(weight_shape[0].as()->value); + int64_t d4 = static_cast(weight_shape[1].as()->value); + CHECK(d2 == d4) + << "The dimensions of input arguments do not match."; + int64_t count = d1 * d2 * d3; + return count; } -} // namespace +RELAY_REGISTER_OP("nn.conv2d") +.set_attr("FMacCount", ConvMacCount); + +RELAY_REGISTER_OP("nn.dense") +.set_attr("FMacCount", DenseMacCount); class MacCounter : private ExprVisitor { public: @@ -44,91 +113,13 @@ class MacCounter : private ExprVisitor { private: void VisitExpr_(const CallNode* call_node) final { - if (IsConv2DNode(call_node)) { - count_ += ComputeConv2DMacs(call_node); - } else if (IsDenseNode(call_node)) { - count_ += ComputeDenseMacs(call_node); - } + static const auto& fprep = + Op::GetAttr("FMacCount"); + auto f = fprep.get(call_node->op, nullptr); + if (f != nullptr) count_ += f(GetRef(call_node)); ExprVisitor::VisitExpr_(call_node); } - /* - * \brief Get the number of MACs of a CONV 2D node. - * \param call_node The CONV 2D call node. - * \return The number of MACs. - */ - int64_t ComputeConv2DMacs(const CallNode* call_node) { - CHECK(IsConv2DNode(call_node)) - << "The input call node must be a CONV 2D node."; - if (!call_node->checked_type_.defined()) { - LOG(WARNING) << "The infer type pass should be called before the mac count pass"; - return 0; - } - Array args = call_node->args; - CHECK(args.size() == 2) - << "The number of input arguments of a CONV 2D node should be 2."; - const auto* conv_2d_attr = call_node->attrs.as(); - const auto* data_type = args[0]->checked_type().as(); - Array data_shape = data_type->shape; - std::string data_layout = conv_2d_attr->data_layout; - int32_t C_ind = Layout(data_layout).Indexof('C'); - int32_t c_ind = Layout(data_layout).Indexof('c'); - CHECK(C_ind != -1) - << "There is no input channel dimension."; - int64_t input_channel = static_cast(data_shape[C_ind].as()->value); - if (c_ind != -1) - input_channel *= static_cast(data_shape[c_ind].as()->value); - Array kernel_size = conv_2d_attr->kernel_size; - CHECK(kernel_size.size() == 2) - << "The dimension of the kernel size in Conv 2D should be 2."; - const auto* expr = call_node->checked_type().as(); - Array output_tensor = expr->shape; - CHECK(output_tensor.size() == 4 || output_tensor.size() == 5) - << "The dimension of the output tensor in Conv 2D should be 4 or 5."; - int64_t count = input_channel * GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); - return count; - } - - /* - * \brief Get the number of MACs of a Dense node. - * \param call_node The Dense call node. - * \return The number of MACs. - */ - int64_t ComputeDenseMacs(const CallNode* call_node) { - CHECK(IsDenseNode(call_node)) - << "The input call node must be a Dense node."; - if (!call_node->checked_type_.defined()) { - LOG(WARNING) << "The infer type pass should be called before the mac count pass"; - return 0; - } - Array args = call_node->args; - CHECK(args.size() == 2) - << "The number of input arguments of a Dense node should be 2."; - const auto* data_type = args[0]->checked_type().as(); - const auto* weight_type = args[1]->checked_type().as(); - Array data_shape = data_type->shape; - Array weight_shape = weight_type->shape; - CHECK(data_shape.size() == 2 && weight_shape.size() == 2) - << "The dimension of an input tensor to Dense node should be 2."; - int64_t d1 = static_cast(data_shape[0].as()->value); - int64_t d2 = static_cast(data_shape[1].as()->value); - int64_t d3 = static_cast(weight_shape[0].as()->value); - int64_t d4 = static_cast(weight_shape[1].as()->value); - CHECK(d2 == d4) - << "The dimensions of input arguments do not match."; - int64_t count = d1 * d2 * d3; - return count; - } - - int64_t GetCartesianProd(Array arr) { - int64_t ret = 1; - for (size_t i = 0; i < arr.size(); i++) { - const auto* intImm = arr[i].as(); - ret *= static_cast(intImm->value); - } - return ret; - } - int64_t count_; }; @@ -141,5 +132,6 @@ TVM_REGISTER_API("relay._ir_pass.GetTotalMacNumber") *ret = GetTotalMacNumber(args[0]); }); +} // namespace mac_count } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_pass_mac_count.py b/tests/python/relay/test_pass_mac_count.py index 56a0f5490cac..0c0144e246d3 100644 --- a/tests/python/relay/test_pass_mac_count.py +++ b/tests/python/relay/test_pass_mac_count.py @@ -1,7 +1,6 @@ """Unit tests for MAC counter.""" import tvm from tvm import relay -import sys def test_gemm(): n = 512