Skip to content

Commit

Permalink
GNA convert matmul to pointwise convolution transformation unit tests (
Browse files Browse the repository at this point in the history
…#6524)

* ConvertMatmulToPointWiseConvolutionTest first test

* add ConvertMatmulToPointWiseConvolutionFqTest

* use general functions to create test subgraphs

* use general funstion to append node; add ConvertMatmulWithBiasToPointWiseConvolutionTest

* add ConvertMatmulWithBiasToPointWiseConvolutionFqTest

* use decorator instead of bool function arguments

* remove unused functions

* cleanup

* add ConvertMatmulWithFqToPointWiseConvolutionTest

* add ConvertMatmulWithFqToPointWiseConvolutionFqTest

* add ConvertMatmulWithFqToPointWiseConvolutionTestNoAddNode

* remove debug

* add ConvertMatmulToPointWiseConvolutionTestInputRank3

* use TEST_P for ConvertMatmulToPointWiseConvolution tests

* use testing::values fixture instead of multiple tests

* cleanup

* use combine tests for invalid inputs

* code style cleanup

* fix unique_ptr build under Windows

* code review fixes: function template params

* code review fixes: remove duplicated test entry

* fix function arguments alignments
  • Loading branch information
evkotov authored Jul 16, 2021
1 parent f48ea5d commit c64b809
Show file tree
Hide file tree
Showing 2 changed files with 426 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>

#include "layers/gna_permute.hpp"
#include "backend/gna_limitations.hpp"
Expand Down Expand Up @@ -62,37 +63,44 @@ static bool Convert(std::shared_ptr<ngraph::Node> matmul_node,
ngraph::Shape{1, 1, width, in_channels});
auto reshape_before = std::make_shared<ngraph::opset7::Reshape>(input_node, reshape_const_before, false);
reshape_before->set_friendly_name(base_name + "/reshape_in");
ngraph::copy_runtime_info(input_node, reshape_before);

auto transpose_before = std::make_shared<ngraph::opset7::Transpose>(reshape_before,
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4},
GetPermuteOrder(InferenceEngine::Layout::NHWC, InferenceEngine::Layout::NCHW)));
transpose_before->set_friendly_name(base_name + "/transpose_in");
ngraph::copy_runtime_info(matmul_node, transpose_before);

auto weights_reshape_const = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
ngraph::Shape{4}, ngraph::Shape{out_channels, in_channels, 1, 1});
auto weights_reshaped = std::make_shared<ngraph::opset7::Reshape>(weights_node, weights_reshape_const, false);
ngraph::copy_runtime_info(weights_node, weights_reshaped);

std::shared_ptr<ngraph::Node> conv_node = std::make_shared<ngraph::opset7::Convolution>(transpose_before, weights_reshaped,
ngraph::Strides{1, 1}, ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0},
ngraph::Strides{1, 1}, ngraph::op::PadType::VALID);
conv_node->set_friendly_name(base_name + "/conv");
ngraph::copy_runtime_info(transpose_before, conv_node);

std::shared_ptr<ngraph::Node> root_node = matmul_node;
if (bias != nullptr) {
conv_node = std::make_shared<ngraph::opset7::Add>(conv_node, bias);
ngraph::copy_runtime_info(transpose_before, conv_node);
root_node = add;
}

if (fq != nullptr) {
conv_node = fq->clone_with_new_inputs({conv_node, fq->input_value(1), fq->input_value(2),
fq->input_value(3), fq->input_value(4)});
ngraph::copy_runtime_info(fq, conv_node);
root_node = fq;
}

auto transpose_after = std::make_shared<ngraph::opset7::Transpose>(conv_node,
ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4},
GetPermuteOrder(InferenceEngine::Layout::NCHW, InferenceEngine::Layout::NHWC)));
transpose_after->set_friendly_name(base_name + "/transpose_out");
ngraph::copy_runtime_info(conv_node, transpose_after);

auto output_shape = matmul_node->get_output_shape(0);
output_shape[output_shape.size() - 1] = out_channels;
Expand All @@ -102,6 +110,7 @@ static bool Convert(std::shared_ptr<ngraph::Node> matmul_node,
output_shape);
auto reshape_after = std::make_shared<ngraph::opset7::Reshape>(transpose_after, reshape_const_after, false);
reshape_after->set_friendly_name(base_name);
ngraph::copy_runtime_info(transpose_after, reshape_after);

ngraph::replace_node(root_node, reshape_after);
return true;
Expand Down
Loading

0 comments on commit c64b809

Please sign in to comment.