diff --git a/inference-engine/src/gna_plugin/backend/gna_limitations.hpp b/inference-engine/src/gna_plugin/backend/gna_limitations.hpp index 59dd0478cfa900..114ee45e6fd882 100644 --- a/inference-engine/src/gna_plugin/backend/gna_limitations.hpp +++ b/inference-engine/src/gna_plugin/backend/gna_limitations.hpp @@ -23,6 +23,15 @@ constexpr uint32_t noOfInputsLowPrecDivisor = 16; constexpr uint32_t affineMaxBatchSize = 8; +inline bool IsTransposeSupported(const std::vector& shape) { + auto shape_no_1 = shape; + shape_no_1.erase(std::remove(shape_no_1.begin(), shape_no_1.end(), 1), shape_no_1.end()); + if (shape_no_1.size() != 2) return false; + size_t min, max; + std::tie(min, max) = std::minmax(shape_no_1[0], shape_no_1[1]); + return min <= 8 && max % 8 == 0; +} + namespace Cnn2D { struct RangeLimit { uint32_t min; diff --git a/inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.cpp b/inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.cpp index 6b0cad24ec2e31..bf424ac7aa9e4d 100644 --- a/inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.cpp +++ b/inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.cpp @@ -12,6 +12,7 @@ #include #include "gna_plugin_log.hpp" +#include "backend/gna_limitations.hpp" using namespace GNAPluginNS; @@ -58,15 +59,6 @@ static void InsertTranspose(std::shared_ptr prev_node, const std:: } } -static bool IsTransposeSupported(const ngraph::Shape& shape) { - auto shape_no_1 = shape; - shape_no_1.erase(std::remove(shape_no_1.begin(), shape_no_1.end(), 1), shape_no_1.end()); - if (shape_no_1.size() != 2) return false; - size_t min, max; - std::tie(min, max) = std::minmax(shape_no_1[0], shape_no_1[1]); - return min <= 8 && max % 8 == 0; -} - HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() { auto reshape = ngraph::pattern::wrap_type({ngraph::pattern::any_input(), ngraph::pattern::any_input()}, VerifyReshape()); @@ -84,7 +76,7 @@ HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() { ReplaceTransposeWithReshape(transpose_it->second.get_node_shared_ptr()); } else { auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr(); - if (!IsTransposeSupported(reshape_node->get_output_shape(0))) return false; + if (!GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) return false; auto matmul_it = pattern_map.find(matmul1); auto matmul_out = matmul_it != std::end(pattern_map) ? matmul_it->second : pattern_map.at(matmul2); InsertTranspose(reshape_node, matmul_out.get_node_shared_ptr()->get_friendly_name()); @@ -113,7 +105,7 @@ HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() { ReplaceTransposeWithReshape(transpose_it->second.get_node_shared_ptr()); } else { auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr(); - if (!IsTransposeSupported(reshape_node->get_input_shape(0))) return false; + if (!GNALimitations::IsTransposeSupported(reshape_node->get_input_shape(0))) return false; auto matmul_node = pattern_map.at(matmul).get_node_shared_ptr(); InsertTranspose(matmul_node, matmul_node->get_friendly_name()); }