diff --git a/inference-engine/src/gna_plugin/transformations/convert_matmul_to_pointwise_convolution.cpp b/inference-engine/src/gna_plugin/transformations/convert_matmul_to_pointwise_convolution.cpp index e49d95ac2f2271..f96ed1dab0e12c 100644 --- a/inference-engine/src/gna_plugin/transformations/convert_matmul_to_pointwise_convolution.cpp +++ b/inference-engine/src/gna_plugin/transformations/convert_matmul_to_pointwise_convolution.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include "layers/gna_permute.hpp" #include "backend/gna_limitations.hpp" @@ -62,30 +63,36 @@ static bool Convert(std::shared_ptr matmul_node, ngraph::Shape{1, 1, width, in_channels}); auto reshape_before = std::make_shared(input_node, reshape_const_before, false); reshape_before->set_friendly_name(base_name + "/reshape_in"); + ngraph::copy_runtime_info(input_node, reshape_before); auto transpose_before = std::make_shared(reshape_before, ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4}, GetPermuteOrder(InferenceEngine::Layout::NHWC, InferenceEngine::Layout::NCHW))); transpose_before->set_friendly_name(base_name + "/transpose_in"); + ngraph::copy_runtime_info(matmul_node, transpose_before); auto weights_reshape_const = std::make_shared(ngraph::element::Type_t::i64, ngraph::Shape{4}, ngraph::Shape{out_channels, in_channels, 1, 1}); auto weights_reshaped = std::make_shared(weights_node, weights_reshape_const, false); + ngraph::copy_runtime_info(weights_node, weights_reshaped); std::shared_ptr conv_node = std::make_shared(transpose_before, weights_reshaped, ngraph::Strides{1, 1}, ngraph::CoordinateDiff{0, 0}, ngraph::CoordinateDiff{0, 0}, ngraph::Strides{1, 1}, ngraph::op::PadType::VALID); conv_node->set_friendly_name(base_name + "/conv"); + ngraph::copy_runtime_info(transpose_before, conv_node); std::shared_ptr root_node = matmul_node; if (bias != nullptr) { conv_node = std::make_shared(conv_node, bias); + ngraph::copy_runtime_info(transpose_before, conv_node); root_node = add; } if (fq != nullptr) { conv_node = fq->clone_with_new_inputs({conv_node, fq->input_value(1), fq->input_value(2), fq->input_value(3), fq->input_value(4)}); + ngraph::copy_runtime_info(fq, conv_node); root_node = fq; } @@ -93,6 +100,7 @@ static bool Convert(std::shared_ptr matmul_node, ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{4}, GetPermuteOrder(InferenceEngine::Layout::NCHW, InferenceEngine::Layout::NHWC))); transpose_after->set_friendly_name(base_name + "/transpose_out"); + ngraph::copy_runtime_info(conv_node, transpose_after); auto output_shape = matmul_node->get_output_shape(0); output_shape[output_shape.size() - 1] = out_channels; @@ -102,6 +110,7 @@ static bool Convert(std::shared_ptr matmul_node, output_shape); auto reshape_after = std::make_shared(transpose_after, reshape_const_after, false); reshape_after->set_friendly_name(base_name); + ngraph::copy_runtime_info(transpose_after, reshape_after); ngraph::replace_node(root_node, reshape_after); return true; diff --git a/inference-engine/tests/unit/gna/ngraph/transformations/gna_convert_matmul_to_pointwise_convolution.cpp b/inference-engine/tests/unit/gna/ngraph/transformations/gna_convert_matmul_to_pointwise_convolution.cpp new file mode 100644 index 00000000000000..6439c5214e2f1d --- /dev/null +++ b/inference-engine/tests/unit/gna/ngraph/transformations/gna_convert_matmul_to_pointwise_convolution.cpp @@ -0,0 +1,417 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include + +#include "transformations/convert_matmul_to_pointwise_convolution.hpp" + +#include "common_test_utils/ngraph_test_utils.hpp" +#include +#include +#include +#include + +namespace testing { + +namespace { + +struct Graph { + std::shared_ptr createFunction(); + + std::shared_ptr input_params; + std::shared_ptr output; +}; + +std::shared_ptr Graph::createFunction() { + auto result = std::make_shared(output); + return std::make_shared(ngraph::ResultVector{result}, + ngraph::ParameterVector{input_params}); +} + +// ------------------------------------------------------------------------------------------------------------ + +// TODO: use std::make_unique when C++14 will be available +template +std::unique_ptr createUnique(Args&&... args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +class CreateGraphDecorator { +public: + CreateGraphDecorator(std::unique_ptr prev_builder = nullptr) : prev_builder_(std::move(prev_builder)) {} + virtual ~CreateGraphDecorator() = default; + virtual Graph build() { + Graph graph; + if (prev_builder_) + graph = prev_builder_->build(); + updateGraph(graph); + return graph; + } +protected: + virtual void updateGraph(Graph&) = 0; +private: + CreateGraphDecorator(const CreateGraphDecorator&) = delete; + CreateGraphDecorator& operator=(const CreateGraphDecorator&) = delete; +private: + std::unique_ptr prev_builder_; +}; + +using CreateGraphDecoratorPtr = std::unique_ptr; + +class CreateBaseDecorator : public CreateGraphDecorator { +public: + // always the first decorator => no prev_builder + CreateBaseDecorator(const ngraph::Shape& input_data_shape, + const ngraph::Shape& input_const_shape) : + CreateGraphDecorator(nullptr), + input_data_shape_(input_data_shape), + input_const_shape_(input_const_shape) {} +protected: + Graph build() override; + void updateGraph(Graph&) override {} +private: + const ngraph::Shape input_data_shape_; + const ngraph::Shape input_const_shape_; +}; + +Graph CreateBaseDecorator::build() { + Graph graph; + graph.input_params = std::make_shared(ngraph::element::i64, + input_data_shape_); + graph.output = ngraph::opset7::Constant::create(ngraph::element::i64, input_const_shape_, {1}); + return graph; +} + +class CreateFakeQuantize : public CreateGraphDecorator { +public: + CreateFakeQuantize(CreateGraphDecoratorPtr prev_builder = nullptr) : CreateGraphDecorator(std::move(prev_builder)) {} +protected: + void updateGraph(Graph&) override; +}; + +std::shared_ptr createFakeQuantizeNode(std::shared_ptr parent_node) { + auto input_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1}); + auto input_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20}); + auto output_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0}); + auto output_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10}); + return std::make_shared(parent_node, input_low, + input_high, output_low, + output_high, 11); +} + +void CreateFakeQuantize::updateGraph(Graph& graph) { + graph.output = createFakeQuantizeNode(graph.output); +} + +class CreateMatMul : public CreateGraphDecorator { +public: + CreateMatMul(CreateGraphDecoratorPtr prev_builder = nullptr) : CreateGraphDecorator(std::move(prev_builder)) {} +protected: + void updateGraph(Graph&) override; +}; + +void CreateMatMul::updateGraph(Graph& graph) { + auto matmul_node = std::make_shared(graph.input_params, graph.output); + graph.output = matmul_node; +} + +class CreateAdd : public CreateGraphDecorator { +public: + CreateAdd(CreateGraphDecoratorPtr prev_builder = nullptr) : CreateGraphDecorator(std::move(prev_builder)) {} +protected: + void updateGraph(Graph&) override; +}; + +void CreateAdd::updateGraph(Graph& graph) { + auto bias = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1}); + auto add_node = std::make_shared(graph.output, bias); + graph.output = add_node; +} + +template::type = true> +CreateGraphDecoratorPtr createBuildDecorator(const ngraph::Shape& input_data_shape = ngraph::Shape{16, 8}, + const ngraph::Shape& input_const_shape = ngraph::Shape{8, 8}) { + CreateGraphDecoratorPtr build_decorator = createUnique(input_data_shape, input_const_shape); + return createUnique(std::move(build_decorator)); +} + +template 0), bool>::type = true> +CreateGraphDecoratorPtr createBuildDecorator(const ngraph::Shape& input_data_shape = ngraph::Shape{16, 8}, + const ngraph::Shape& input_const_shape = ngraph::Shape{8, 8}) { + CreateGraphDecoratorPtr build_decorator = createBuildDecorator(input_data_shape, input_const_shape); + return createUnique(std::move(build_decorator)); +} + +template +Graph createTransformedGraph(const ngraph::Shape& input_data_shape = ngraph::Shape{16, 8}, + const ngraph::Shape& input_const_shape = ngraph::Shape{8, 8}) { + CreateGraphDecoratorPtr build_decorator = createBuildDecorator(input_data_shape, input_const_shape); + return build_decorator->build(); +} + +// ------------------------------------------------------------------------------------------------------------ + +Graph createReferenceGraph(bool addConstFakeQuantizeNode, bool insertAddNode, bool addOutFakeQuantizeNode) { + Graph graph; + + graph.input_params = std::make_shared(ngraph::element::i64, + ngraph::Shape{16, 8}); + auto constant_node = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{8, 8}, {1}); + + auto const_reshape_before = std::make_shared(ngraph::element::Type_t::i64, + ngraph::Shape{4}, + ngraph::Shape{1, 1, 16, 8}); + auto reshape_before = std::make_shared(graph.input_params, const_reshape_before, false); + + auto const_transpose_before = ngraph::opset7::Constant::create(ngraph::element::i64, + ngraph::Shape{4}, + ngraph::Shape{0, 3, 1, 2}); + auto transpose_before = std::make_shared(reshape_before, const_transpose_before); + + std::shared_ptr parent_node = constant_node; + if (addConstFakeQuantizeNode) + parent_node = createFakeQuantizeNode(constant_node); + + auto weights_reshape_const = std::make_shared(ngraph::element::Type_t::i64, + ngraph::Shape{4}, ngraph::Shape{8, 8, 1, 1}); + auto weights_reshaped = std::make_shared(parent_node, weights_reshape_const, false); + + auto conv_node = std::make_shared(transpose_before, + weights_reshaped, + ngraph::Strides{1, 1}, + ngraph::CoordinateDiff{0, 0}, + ngraph::CoordinateDiff{0, 0}, + ngraph::Strides{1, 1}, + ngraph::op::PadType::VALID); + + parent_node = conv_node; + + if (insertAddNode) { + auto bias = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1}); + auto add_node = std::make_shared(parent_node, bias); + parent_node = add_node; + } + + if (addOutFakeQuantizeNode) + parent_node = createFakeQuantizeNode(parent_node); + + auto const_transpose_after = ngraph::opset7::Constant::create(ngraph::element::i64, + ngraph::Shape{4}, + ngraph::Shape{0, 2, 3, 1}); + auto transpose_after = std::make_shared(parent_node, const_transpose_after); + + auto const_reshape_after = std::make_shared(ngraph::element::Type_t::i64, + ngraph::Shape{2}, + ngraph::Shape{16, 8}); + graph.output = std::make_shared(transpose_after, const_reshape_after, false); + + return graph; +} + +// ------------------------------------------------------------------------------------------------------- + +class ConvertMatmulToPointWiseConvolutionFixture: public CommonTestUtils::TestsCommon, + public ::testing::WithParamInterface> { +public: + void SetUp() override; +public: + std::shared_ptr function, reference_function; + ngraph::pass::Manager pass_manager; +}; + +void ConvertMatmulToPointWiseConvolutionFixture::SetUp() { + // TODO: use auto & [transformed_graph, reference_graph] = this->GetParam() when C++17 + Graph transformed_graph; + Graph reference_graph; + std::tie(transformed_graph, reference_graph, pass_manager) = this->GetParam(); + + function = transformed_graph.createFunction(); + reference_function = reference_graph.createFunction(); +} + +void execute_test(std::shared_ptr function, std::shared_ptr reference_function, ngraph::pass::Manager& pass_manager) { + pass_manager.run_passes(function); + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(function, reference_function); + ASSERT_TRUE(result.valid); +} + +template +ngraph::pass::Manager createPassManager() { + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + return manager; +} + +TEST_P(ConvertMatmulToPointWiseConvolutionFixture, CompareFunctions) { + execute_test(function, reference_function, pass_manager); +} + +INSTANTIATE_TEST_SUITE_P(ConvertMatmulToPointWiseConvolutionTestSuite, ConvertMatmulToPointWiseConvolutionFixture, + ::testing::Values(std::make_tuple(createTransformedGraph(), + createReferenceGraph(false /* addConstFakeQuantizeNode */, + false /* insertAddNode */, + false /* addOutFakeQuantizeNode */), + createPassManager()), + std::make_tuple(createTransformedGraph(), + createReferenceGraph(true /* addConstFakeQuantizeNode */, + false /* insertAddNode */, + false /* addOutFakeQuantizeNode */), + createPassManager()), + std::make_tuple(createTransformedGraph(), + createReferenceGraph(false /* addConstFakeQuantizeNode */, + true /* insertAddNode */, + false /* addOutFakeQuantizeNode */), + createPassManager()), + std::make_tuple(createTransformedGraph(), + createReferenceGraph(true /* addConstFakeQuantizeNode */, + true /* insertAddNode */, + false /* addOutFakeQuantizeNode */), + createPassManager()), + std::make_tuple(createTransformedGraph(), + createReferenceGraph(false /* addConstFakeQuantizeNode */, + true /* insertAddNode */, + true /* addOutFakeQuantizeNode */), + createPassManager()), + std::make_tuple(createTransformedGraph(), + createReferenceGraph(true /* addConstFakeQuantizeNode */, + true /* insertAddNode */, + true /* addOutFakeQuantizeNode */), + createPassManager()), + std::make_tuple(createTransformedGraph(), + createReferenceGraph(false /* addConstFakeQuantizeNode */, + false /* insertAddNode */, + true /* addOutFakeQuantizeNode */), + createPassManager()), + std::make_tuple(createTransformedGraph(), + createReferenceGraph(true /* addConstFakeQuantizeNode */, + false /* insertAddNode */, + true /* addOutFakeQuantizeNode */), + createPassManager()))); + +// ------------------------------------------------------------------------------------------------------- + +class ITransformedGraphFactory { +public: + virtual ~ITransformedGraphFactory() = default; + virtual Graph createGraph(const ngraph::Shape& input_data_shape, + const ngraph::Shape& input_const_shape) = 0; +}; + +template +class TransformedGraphFactory : public ITransformedGraphFactory { +public: + TransformedGraphFactory() = default; + + Graph createGraph(const ngraph::Shape& input_data_shape, + const ngraph::Shape& input_const_shape) override { + return createTransformedGraph(input_data_shape, input_const_shape); + } +private: + TransformedGraphFactory(const TransformedGraphFactory&) = delete; + TransformedGraphFactory& operator=(const TransformedGraphFactory&) = delete; +}; + +struct FixtureData { + std::shared_ptr graph_factory; + ngraph::pass::Manager pass_manager; + + template + static FixtureData create() { + FixtureData fixture_data; + fixture_data.graph_factory = std::make_shared>(); + fixture_data.pass_manager = createPassManager(); + return fixture_data; + } +}; + +using FixtureInputShapes = std::tuple; + +class ConvertMatmulToPointWiseConvolutionInvalidInputFixture: public CommonTestUtils::TestsCommon, + public ::testing::WithParamInterface> { +public: + void SetUp() override; +public: + std::shared_ptr function; + ngraph::pass::Manager pass_manager; +}; + +void ConvertMatmulToPointWiseConvolutionInvalidInputFixture::SetUp() { + // TODO: use auto & [fixture_data, input_shapes] = this->GetParam() when C++17 + FixtureData fixture_data; + FixtureInputShapes input_shapes; + std::tie(fixture_data, input_shapes) = this->GetParam(); + + ngraph::Shape input_data, input_const; + std::tie(input_data, input_const) = input_shapes; + + function = fixture_data.graph_factory->createGraph(input_data, input_const).createFunction(); + pass_manager = fixture_data.pass_manager; +} + +void execute_test_cloned_function(std::shared_ptr function, + ngraph::pass::Manager& pass_manager) { + std::shared_ptr reference_function = ngraph::clone_function(*function); + pass_manager.run_passes(function); + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(function, reference_function); + ASSERT_TRUE(result.valid); +} + +std::vector transform_types = { + FixtureData::create(), + FixtureData::create(), + FixtureData::create(), + FixtureData::create(), + FixtureData::create(), + FixtureData::create(), + FixtureData::create(), + FixtureData::create() +}; + +std::vector input_shapes = { + std::make_tuple(ngraph::Shape{16, 16, 16}, ngraph::Shape{16, 16, 16}), + std::make_tuple(ngraph::Shape{16, 9}, ngraph::Shape{9, 9}), + std::make_tuple(ngraph::Shape{16, 65533}, ngraph::Shape{65533, 2}), + std::make_tuple(ngraph::Shape{16, 769}, ngraph::Shape{769, 2}) +}; + +TEST_P(ConvertMatmulToPointWiseConvolutionInvalidInputFixture, CompareFunctions) { + execute_test_cloned_function(function, pass_manager); +} + +INSTANTIATE_TEST_SUITE_P(ConvertMatmulToPointWiseConvolutionInvalidInputTestSuite, ConvertMatmulToPointWiseConvolutionInvalidInputFixture, + ::testing::Combine(::testing::ValuesIn(transform_types), + ::testing::ValuesIn(input_shapes))); + +} // namespace + +} // namespace testing