Skip to content

Commit

Permalink
fix convert_function_to_cnn_network.cpp not using friendly names; fix…
Browse files Browse the repository at this point in the history
… transpose_nchw update last node name
  • Loading branch information
evkotov committed Mar 8, 2023
1 parent dafd919 commit 225e39d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ NodePair SwapNodes(NodePtr first_node, NodePtr second_node) {
return std::make_pair(new_first_node, new_second_node);
}

#if 0
/**
* @brief SwapOutputs has much better performance than SwapNodes and covers the most of the real situations
* but cannot work when the consumers count greater than one
Expand Down Expand Up @@ -85,9 +86,9 @@ NodePair Swap(NodePtr first_node, NodePtr second_node) {
new_nodes = SwapNodes(first_node, second_node);
else
new_nodes = SwapOutputs(first_node, second_node);

return new_nodes;
}
#endif

NodePtr GetPatternNode(const PatternValueMap& pattern_map, const NodeVector& node_labels) {
for (const auto& node_label : node_labels) {
Expand All @@ -112,8 +113,11 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto unary = GetPatternNode(pattern_to_output, NodeVector{unary_op_label, fq_label});

#if 0
const NodePair new_nodes = Swap(transpose, unary);
#else
const NodePair new_nodes = SwapNodes(transpose, unary);
#endif

register_new_node(new_nodes.first);
register_new_node(new_nodes.second);
Expand Down Expand Up @@ -148,9 +152,11 @@ ov::pass::TransposeSinkingUnaryBackwardSingleConsumer::TransposeSinkingUnaryBack
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto unary = GetPatternNode(pattern_to_output, NodeVector{unary_op_label, fq_label});

#if 0
const NodePair new_nodes = Swap(unary, transpose);

#else
const NodePair new_nodes = SwapNodes(unary, transpose);
#endif
register_new_node(new_nodes.first);
register_new_node(new_nodes.second);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
#include "transformations/rt_info/fused_names_attribute.hpp"
#include "transformations/rt_info/primitives_priority_attribute.hpp"
#include "transformations/utils/utils.hpp"
#include "../../src/ops/gna_convolution.hpp"
#include "../../src/ops/gna_max_pool.hpp"

namespace Builder {

Expand Down Expand Up @@ -2069,19 +2071,11 @@ void convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function
const auto isInternalConstLayer = [](const std::shared_ptr<::ngraph::op::Constant>& constLayer,
const std::shared_ptr<::ngraph::Node>& consumerLayer,
bool keep_constants) -> bool {

const auto isGNAConvolution = [](const std::shared_ptr<::ngraph::Node> &node) -> bool {
return (node->get_friendly_name().find("gna_convolution") != std::string::npos);
};
const auto isGNAMaxPool = [](const std::shared_ptr<::ngraph::Node> &node) -> bool {
return (node->get_friendly_name().find("gna_max_pool") != std::string::npos);
};

if (((::ngraph::as_type_ptr<::ngraph::op::ConvolutionIE>(consumerLayer) ||
::ngraph::as_type_ptr<::ngraph::op::FullyConnected>(consumerLayer)) &&
::ngraph::as_type_ptr<::ngraph::op::FullyConnected>(consumerLayer) ||
::ngraph::as_type_ptr<ov::intel_gna::op::GNAConvolution>(consumerLayer) ||
::ngraph::as_type_ptr<ov::intel_gna::op::GNAMaxPool>(consumerLayer)) &&
!keep_constants) ||
isGNAConvolution(consumerLayer) ||
isGNAMaxPool(consumerLayer) ||
::ngraph::as_type_ptr<::ngraph::op::v1::BinaryConvolution>(consumerLayer) ||
::ngraph::as_type_ptr<::ngraph::op::DeconvolutionIE>(consumerLayer) ||
::ngraph::as_type_ptr<::ngraph::op::v1::DeformableConvolution>(consumerLayer) ||
Expand Down
26 changes: 2 additions & 24 deletions src/plugins/intel_gna/src/transformations/transpose_nchw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,8 @@ bool DoTransformation(Node convolution) {

ov::copy_runtime_info(convolution_node, {transpose_before, transpose_const, conv_new, transpose_after});

const bool has_parent_param = HasParentNode<ngraph::opset8::Parameter>(convolution);
const bool has_child_result = HasChildNode<ngraph::opset8::Result>(convolution);
if (has_parent_param != has_child_result) {
if (has_parent_param) {
transpose_before->set_friendly_name(convolution_node->get_friendly_name());
} else {
transpose_after->set_friendly_name(convolution_node->get_friendly_name());
}
} else {
conv_new->set_friendly_name(convolution_node->get_friendly_name());
}
ov::replace_output_update_name(convolution->output(0), transpose_after->output(0));

convolution->output(0).replace(transpose_after->output(0));
return true;
}

Expand Down Expand Up @@ -189,19 +178,8 @@ bool DoTransformation(Node max_pool) {

ov::copy_runtime_info(max_pool_node, {transpose_before, transpose_const, max_pool_new, transpose_after});

const bool has_parent_param = HasParentNode<ngraph::opset8::Parameter>(max_pool);
const bool has_child_result = HasChildNode<ngraph::opset8::Result>(max_pool);
if (has_parent_param != has_child_result) {
if (has_parent_param) {
transpose_before->set_friendly_name(max_pool->get_friendly_name());
} else {
transpose_after->set_friendly_name(max_pool->get_friendly_name());
}
} else {
max_pool_new->set_friendly_name(max_pool->get_friendly_name());
}
ov::replace_output_update_name(max_pool->output(0), transpose_after->output(0));

max_pool->output(0).replace(transpose_after->output(0));
return true;
}

Expand Down

0 comments on commit 225e39d

Please sign in to comment.