Skip to content

Commit

Permalink
Move IsTransposeSupported function to GNA limitations file
Browse files Browse the repository at this point in the history
  • Loading branch information
elilobanova committed Jun 29, 2021
1 parent b7e2062 commit 39f1fb3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
9 changes: 9 additions & 0 deletions inference-engine/src/gna_plugin/backend/gna_limitations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ constexpr uint32_t noOfInputsLowPrecDivisor = 16;

constexpr uint32_t affineMaxBatchSize = 8;

inline bool IsTransposeSupported(const std::vector<size_t>& shape) {
auto shape_no_1 = shape;
shape_no_1.erase(std::remove(shape_no_1.begin(), shape_no_1.end(), 1), shape_no_1.end());
if (shape_no_1.size() != 2) return false;
size_t min, max;
std::tie(min, max) = std::minmax(shape_no_1[0], shape_no_1[1]);
return min <= 8 && max % 8 == 0;
}

namespace Cnn2D {
struct RangeLimit {
uint32_t min;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <ngraph/rt_info.hpp>

#include "gna_plugin_log.hpp"
#include "backend/gna_limitations.hpp"

using namespace GNAPluginNS;

Expand Down Expand Up @@ -58,15 +59,6 @@ static void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::
}
}

static bool IsTransposeSupported(const ngraph::Shape& shape) {
auto shape_no_1 = shape;
shape_no_1.erase(std::remove(shape_no_1.begin(), shape_no_1.end(), 1), shape_no_1.end());
if (shape_no_1.size() != 2) return false;
size_t min, max;
std::tie(min, max) = std::minmax(shape_no_1[0], shape_no_1[1]);
return min <= 8 && max % 8 == 0;
}

HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() {
auto reshape = ngraph::pattern::wrap_type<ngraph::opset7::Reshape>({ngraph::pattern::any_input(),
ngraph::pattern::any_input()}, VerifyReshape());
Expand All @@ -84,7 +76,7 @@ HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() {
ReplaceTransposeWithReshape(transpose_it->second.get_node_shared_ptr());
} else {
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
if (!IsTransposeSupported(reshape_node->get_output_shape(0))) return false;
if (!GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) return false;
auto matmul_it = pattern_map.find(matmul1);
auto matmul_out = matmul_it != std::end(pattern_map) ? matmul_it->second : pattern_map.at(matmul2);
InsertTranspose(reshape_node, matmul_out.get_node_shared_ptr()->get_friendly_name());
Expand Down Expand Up @@ -113,7 +105,7 @@ HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
ReplaceTransposeWithReshape(transpose_it->second.get_node_shared_ptr());
} else {
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
if (!IsTransposeSupported(reshape_node->get_input_shape(0))) return false;
if (!GNALimitations::IsTransposeSupported(reshape_node->get_input_shape(0))) return false;
auto matmul_node = pattern_map.at(matmul).get_node_shared_ptr();
InsertTranspose(matmul_node, matmul_node->get_friendly_name());
}
Expand Down

0 comments on commit 39f1fb3

Please sign in to comment.