From cffba0406905e7c2d3d97a6dc299c97ac086733b Mon Sep 17 00:00:00 2001 From: eshoguli Date: Tue, 21 Mar 2023 01:33:48 +0000 Subject: [PATCH] SnippetsFunctionBase usage --- .../tests/src/pass/precision_propagation.cpp | 45 +---- .../precision_propagation_function.hpp | 180 +++++++++++------- 2 files changed, 116 insertions(+), 109 deletions(-) diff --git a/src/common/snippets/tests/src/pass/precision_propagation.cpp b/src/common/snippets/tests/src/pass/precision_propagation.cpp index 336d7943535ebb..3c7da4d06aa165 100644 --- a/src/common/snippets/tests/src/pass/precision_propagation.cpp +++ b/src/common/snippets/tests/src/pass/precision_propagation.cpp @@ -17,45 +17,6 @@ namespace snippets { namespace { -/** - * @class DummyAdd - * @brief DummyAdd operation has custom validate_and_infer_types method implementation. - */ -class DummyAdd : public ngraph::opset1::Add { -public: - OPENVINO_OP("DummyAdd", "test::snippets"); - - DummyAdd(const Output& arg0, - const Output& arg1, - const ngraph::op::AutoBroadcastSpec& auto_broadcast = - ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY)) - : ngraph::opset1::Add(arg0, arg1, auto_broadcast) { - constructor_validate_and_infer_types(); - } - - DummyAdd(const ngraph::opset1::Add& add) - : Add(add.get_input_source_output(0), add.get_input_source_output(1), add.get_autob()) { - constructor_validate_and_infer_types(); - } - - DummyAdd() = default; - - void validate_and_infer_types() override { - const auto input_type1 = get_input_element_type(0); - const auto input_type2 = get_input_element_type(1); - - const element::Type output_type = (input_type1 == element::i8) || (input_type2 == element::i8) ? - element::i32 : - get_input_element_type(0); - - set_output_type(0, output_type, get_input_partial_shape(0)); - } - - std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override { - return std::make_shared(new_args.at(0), new_args.at(1), this->get_autob()); - } -}; - class DummyPrecisionPropagationTargetMachine : public DummyTargetMachine { public: DummyPrecisionPropagationTargetMachine( @@ -114,7 +75,7 @@ TEST_P(PrecisionPropagationTest, CompareFunctions) { const auto test_values = std::get<1>(param); const auto input_shapes = std::vector({ shapes.first, shapes.second }); - EmpyInfrastructureSupportStubForPrecisionPropagationFunction infrastructureSupportStub( + PrecisionPropagationAddFunction function_stub( input_shapes, test_values.input_types[0], test_values.input_types[1], @@ -130,7 +91,7 @@ TEST_P(PrecisionPropagationTest, CompareFunctions) { test_values.expected.convertion_before_op2_2, test_values.expected.convertion_after_op2 }); - function = infrastructureSupportStub.getOriginal(); + function = function_stub.getOriginal(); const auto target_machine = std::make_shared( test_values.actual.op1_supported_precisions, @@ -138,7 +99,7 @@ TEST_P(PrecisionPropagationTest, CompareFunctions) { manager.register_pass(target_machine); - function_ref = infrastructureSupportStub.getReference(); + function_ref = function_stub.getReference(); } namespace PrecisionPropagationTestInstantiation { diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/precision_propagation_function.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/precision_propagation_function.hpp index 7b6f45569cb455..48bd2e024ad1ab 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/precision_propagation_function.hpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/precision_propagation_function.hpp @@ -13,9 +13,94 @@ namespace ov { namespace test { namespace snippets { -class PrecisionPropagationFunction { +/** + * @class DummyAdd + * @brief DummyAdd operation has custom validate_and_infer_types method implementation. + */ +class DummyAdd : public ngraph::opset1::Add { public: - template + OPENVINO_OP("DummyAdd", "test::snippets"); + + DummyAdd(const Output& arg0, + const Output& arg1, + const ngraph::op::AutoBroadcastSpec& auto_broadcast = + ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY)) + : ngraph::opset1::Add(arg0, arg1, auto_broadcast) { + constructor_validate_and_infer_types(); + } + + DummyAdd(const ngraph::opset1::Add& add) + : Add(add.get_input_source_output(0), add.get_input_source_output(1), add.get_autob()) { + constructor_validate_and_infer_types(); + } + + DummyAdd() = default; + + void validate_and_infer_types() override { + const auto input_type1 = get_input_element_type(0); + const auto input_type2 = get_input_element_type(1); + + const element::Type output_type = (input_type1 == element::i8) || (input_type2 == element::i8) ? + element::i32 : + get_input_element_type(0); + + set_output_type(0, output_type, get_input_partial_shape(0)); + } + + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override { + return std::make_shared(new_args.at(0), new_args.at(1), this->get_autob()); + } +}; + +class PrecisionPropagationAddFunctionParams { +public: + class Actual { + public: + std::pair convertion_before_op1; + element::Type convertion_before_op2_1; + std::pair convertion_before_op2_2; + }; + + class Expected { + public: + std::pair convertion_before_op1; + element::Type convertion_before_op2_1; + std::pair convertion_before_op2_2; + element::Type convertion_after_op2; + }; +}; + +/** + * @class PrecisionPropagationAddFunction + * @brief PrecisionPropagationAddFunction instance returns reference and original functions. + * + * Input arguments are used to create function in getOriginal or getReference methods only. + * Dont use getLowered method, it is not implemented and throw std::runtime_error exception. + * Note, ov::element::Type_t precision base type input argument is not used. + */ +class PrecisionPropagationAddFunction : public SnippetsFunctionBase { +public: + explicit PrecisionPropagationAddFunction( + const std::vector input_shapes, + const ngraph::element::Type precision1, + const ngraph::element::Type precision2, + const ngraph::element::Type constant_precision, + PrecisionPropagationAddFunctionParams::Actual actual, + PrecisionPropagationAddFunctionParams::Expected expected) : + SnippetsFunctionBase(input_shapes), + precision1(precision1), + precision2(precision2), + constant_precision(constant_precision), + actual(actual), + expected(expected) { + OPENVINO_ASSERT(input_shapes.size() == 2ull, "input_shapes size has to be equal to 2"); + } + + /* + * Don't call this method explicity. You should create the instance of PrecisionPropagationAddFunction before. + * After the method will be called implicitly in getOriginal or getReference methods. + * Note, please, getLowered method is not implemented. + */ static std::shared_ptr get( const ngraph::element::Type precision1, const ngraph::PartialShape& inputShape1, @@ -26,10 +111,29 @@ class PrecisionPropagationFunction { const element::Type convertion_before_op2_1 = element::undefined, const std::pair& convertion_before_op2_2 = std::pair(), const element::Type convertion_after_op2 = {}) { + const auto create_convert = [](std::shared_ptr parent, const element::Type convertion_type) -> std::shared_ptr { + return convertion_type == element::undefined + ? std::dynamic_pointer_cast(parent) + : std::make_shared(parent, convertion_type); + }; + + const auto make_branch = [&create_convert]( + const ngraph::element::Type precision, + const ngraph::PartialShape& inputShape, + const size_t index, + const element::Type convertion_type) -> std::pair, std::shared_ptr> { + const auto parameter = std::make_shared(precision, inputShape); + parameter->set_friendly_name("parameter" + std::to_string(index)); + + std::shared_ptr parent = create_convert(parameter, convertion_type); + + return { parameter, parent }; + }; + const auto branch1 = make_branch(precision1, inputShape1, 1, convertion_before_op1.first); const auto branch2 = make_branch(precision2, inputShape2, 2, convertion_before_op1.second); - std::shared_ptr parent = std::make_shared(branch1.second, branch2.second); + std::shared_ptr parent = std::make_shared(branch1.second, branch2.second); parent->set_friendly_name("add"); parent = create_convert(parent, convertion_before_op2_1); @@ -57,73 +161,15 @@ class PrecisionPropagationFunction { result_out_tensor.set_names({ "result_tensor" }); result->set_friendly_name("result"); - const ngraph::ResultVector results{result}; - const ngraph::ParameterVector parameters{branch1.first, branch2.first}; + const ngraph::ResultVector results{ result }; + const ngraph::ParameterVector parameters{ branch1.first, branch2.first }; const auto model = std::make_shared(results, parameters, "SnippetsPrecisionPropagation"); return model; } -private: - static std::shared_ptr create_convert(std::shared_ptr parent, const element::Type convertion_type) { - return convertion_type == element::undefined - ? std::dynamic_pointer_cast(parent) - : std::make_shared(parent, convertion_type); - } - - static std::pair, std::shared_ptr> make_branch( - const ngraph::element::Type precision, - const ngraph::PartialShape& inputShape, - const size_t index, - const element::Type convertion_type) { - const auto parameter = std::make_shared(precision, inputShape); - parameter->set_friendly_name("parameter" + std::to_string(index)); - - std::shared_ptr parent = create_convert(parameter, convertion_type); - - return {parameter, parent}; - } -}; - -class PrecisionPropagationFunctionParams { -public: - class Actual { - public: - std::pair convertion_before_op1; - element::Type convertion_before_op2_1; - std::pair convertion_before_op2_2; - }; - - class Expected { - public: - std::pair convertion_before_op1; - element::Type convertion_before_op2_1; - std::pair convertion_before_op2_2; - element::Type convertion_after_op2; - }; -}; - -template -class EmpyInfrastructureSupportStubForPrecisionPropagationFunction : public SnippetsFunctionBase { -public: - explicit EmpyInfrastructureSupportStubForPrecisionPropagationFunction( - const std::vector input_shapes, - const ngraph::element::Type precision1, - const ngraph::element::Type precision2, - const ngraph::element::Type constant_precision, - PrecisionPropagationFunctionParams::Actual actual, - PrecisionPropagationFunctionParams::Expected expected) : - SnippetsFunctionBase(input_shapes), - precision1(precision1), - precision2(precision2), - constant_precision(constant_precision), - actual(actual), - expected(expected) { - OPENVINO_ASSERT(input_shapes.size() == 2ull, "input_shapes size has to be equal to 2"); - } - protected: std::shared_ptr initOriginal() const override { - return PrecisionPropagationFunction::get( + return get( precision1, input_shapes[0], precision2, @@ -135,7 +181,7 @@ class EmpyInfrastructureSupportStubForPrecisionPropagationFunction : public Snip } std::shared_ptr initReference() const override { - return PrecisionPropagationFunction::get( + return get( precision1, input_shapes[0], precision2, @@ -150,8 +196,8 @@ class EmpyInfrastructureSupportStubForPrecisionPropagationFunction : public Snip const ngraph::element::Type precision1; const ngraph::element::Type precision2; const ngraph::element::Type constant_precision; - const PrecisionPropagationFunctionParams::Actual actual; - const PrecisionPropagationFunctionParams::Expected expected; + const PrecisionPropagationAddFunctionParams::Actual actual; + const PrecisionPropagationAddFunctionParams::Expected expected; }; } // namespace snippets