From 087b10ff00f5f75b75acf842cc2a2e376212a6de Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Thu, 23 Mar 2023 09:16:04 +0000 Subject: [PATCH] Snippets: precision propagation (#14996) --- src/bindings/python/tests/__init__.py | 1 - .../python/tests/test_onnx/test_backend.py | 5 - .../python/tests_compatibility/__init__.py | 1 - .../test_onnx/test_backend.py | 5 - .../snippets/include/snippets/generator.hpp | 15 +- .../snippets/include/snippets/op/subgraph.hpp | 16 +- .../snippets/pass/align_element_type.hpp | 46 --- .../snippets/pass/fq_decomposition.hpp | 5 +- .../snippets/pass/propagate_precision.hpp | 48 +++ src/common/snippets/src/op/subgraph.cpp | 80 +++-- .../snippets/src/pass/align_element_type.cpp | 99 ------ .../snippets/src/pass/collapse_subgraph.cpp | 6 +- .../snippets/src/pass/fq_decomposition.cpp | 12 - .../snippets/src/pass/propagate_precision.cpp | 293 +++++++++++++++++ .../snippets/tests/include/lowering_utils.hpp | 6 +- .../include/pass/precision_propagation.hpp | 54 ++++ .../snippets/tests/src/lowering_utils.cpp | 14 +- .../tests/src/pass/precision_propagation.cpp | 294 ++++++++++++++++++ .../precision_propagation_convert_test.cpp | 153 +++++++++ .../precision_propagation_get_precisions.cpp | 45 +++ src/core/src/pass/visualize_tree.cpp | 4 +- .../intel_cpu/src/emitters/cpu_generator.cpp | 10 +- .../src/emitters/jit_dnnl_emitters.cpp | 4 + .../src/emitters/jit_dnnl_emitters.hpp | 2 + .../src/emitters/jit_eltwise_emitters.cpp | 204 +++++++++--- .../src/emitters/jit_eltwise_emitters.hpp | 66 ++-- .../intel_cpu/src/emitters/jit_emitter.cpp | 6 +- .../intel_cpu/src/emitters/jit_emitter.hpp | 8 +- .../src/emitters/jit_snippets_emitters.cpp | 15 +- .../src/emitters/jit_snippets_emitters.hpp | 9 + src/plugins/intel_cpu/src/nodes/eltwise.cpp | 61 ++-- src/plugins/intel_cpu/src/nodes/subgraph.cpp | 29 +- .../remove_converts.cpp | 38 +++ .../remove_converts.hpp | 27 ++ .../snippets/check_broadcast.cpp | 81 +++++ .../precision_propagation_convertion.cpp | 37 +++ .../ngraph_transformations/mul_add_to_fma.cpp | 2 +- .../include/snippets/check_broadcast.hpp | 38 +++ .../precision_propagation_convertion.hpp | 33 ++ .../fuse_fake_quantize_transformation.cpp | 2 +- .../shared/src/snippets/check_broadcast.cpp | 89 ++++++ .../plugin/shared/src/snippets/convert.cpp | 4 +- .../precision_propagation_convertion.cpp | 48 +++ ...cision_propagation_convertion_function.hpp | 49 +++ .../precision_propagation_function.hpp | 131 ++++++++ .../include/snippets_helpers.hpp | 1 + ...cision_propagation_convertion_function.cpp | 92 ++++++ .../src/precision_propagation_function.cpp | 105 +++++++ 48 files changed, 2066 insertions(+), 327 deletions(-) delete mode 100644 src/common/snippets/include/snippets/pass/align_element_type.hpp create mode 100644 src/common/snippets/include/snippets/pass/propagate_precision.hpp delete mode 100644 src/common/snippets/src/pass/align_element_type.cpp create mode 100644 src/common/snippets/src/pass/propagate_precision.cpp create mode 100644 src/common/snippets/tests/include/pass/precision_propagation.hpp create mode 100644 src/common/snippets/tests/src/pass/precision_propagation.cpp create mode 100644 src/common/snippets/tests/src/pass/precision_propagation_convert_test.cpp create mode 100644 src/common/snippets/tests/src/pass/precision_propagation_get_precisions.cpp create mode 100644 src/plugins/intel_cpu/src/snippets_transformations/remove_converts.cpp create mode 100644 src/plugins/intel_cpu/src/snippets_transformations/remove_converts.hpp create mode 100644 src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/check_broadcast.cpp create mode 100644 src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/precision_propagation_convertion.cpp create mode 100644 src/tests/functional/plugin/shared/include/snippets/check_broadcast.hpp create mode 100644 src/tests/functional/plugin/shared/include/snippets/precision_propagation_convertion.hpp create mode 100644 src/tests/functional/plugin/shared/src/snippets/check_broadcast.cpp create mode 100644 src/tests/functional/plugin/shared/src/snippets/precision_propagation_convertion.cpp create mode 100644 src/tests/ngraph_helpers/snippets_ngraph_functions/include/precision_propagation_convertion_function.hpp create mode 100644 src/tests/ngraph_helpers/snippets_ngraph_functions/include/precision_propagation_function.hpp create mode 100644 src/tests/ngraph_helpers/snippets_ngraph_functions/src/precision_propagation_convertion_function.cpp create mode 100644 src/tests/ngraph_helpers/snippets_ngraph_functions/src/precision_propagation_function.cpp diff --git a/src/bindings/python/tests/__init__.py b/src/bindings/python/tests/__init__.py index 06d8dfb043480f..a426ce8424ec71 100644 --- a/src/bindings/python/tests/__init__.py +++ b/src/bindings/python/tests/__init__.py @@ -117,7 +117,6 @@ def xfail_test(reason="Mark the test as expected to fail", strict=True): xfail_issue_63033 = xfail_test(reason="BatchNormalization: Training mode is not supported") xfail_issue_63036 = xfail_test(reason="Changes in ConvTranspose padding") -xfail_issue_63039 = xfail_test(reason="Result mismatches with UINT8 operations") xfail_issue_63043 = xfail_test(reason="Recurrent node expects constants as W, R, B inputs.") skip_rng_tests = pytest.mark.skip(reason="Tests use random number generator with no seed.") diff --git a/src/bindings/python/tests/test_onnx/test_backend.py b/src/bindings/python/tests/test_onnx/test_backend.py index c681f376348142..dc30a9bda3806b 100644 --- a/src/bindings/python/tests/test_onnx/test_backend.py +++ b/src/bindings/python/tests/test_onnx/test_backend.py @@ -37,7 +37,6 @@ xfail_issue_58033, xfail_issue_63033, xfail_issue_63036, - xfail_issue_63039, xfail_issue_63043, xfail_issue_63137, xfail_issue_63138, @@ -278,10 +277,6 @@ def expect_fail(test_case_path, xfail): # type: (str) -> None "OnnxBackendNodeModelTest.test_batchnorm_example_training_mode_cpu", ), (xfail_issue_63036, "OnnxBackendNodeModelTest.test_convtranspose_autopad_same_cpu"), - ( - xfail_issue_63039, - "OnnxBackendNodeModelTest.test_div_uint8_cpu", - ), ( xfail_issue_63043, "OnnxBackendNodeModelTest.test_gru_batchwise_cpu", diff --git a/src/bindings/python/tests_compatibility/__init__.py b/src/bindings/python/tests_compatibility/__init__.py index 7b5d7217cd8ed1..24d2050a3a9d77 100644 --- a/src/bindings/python/tests_compatibility/__init__.py +++ b/src/bindings/python/tests_compatibility/__init__.py @@ -122,7 +122,6 @@ def xfail_test(reason="Mark the test as expected to fail", strict=True): xfail_issue_63033 = xfail_test(reason="BatchNormalization: Training mode is not supported") xfail_issue_63036 = xfail_test(reason="Changes in ConvTranspose padding") -xfail_issue_63039 = xfail_test(reason="Result mismatches with UINT8 operations") xfail_issue_63043 = xfail_test(reason="Recurrent node expects constants as W, R, B inputs.") skip_rng_tests = pytest.mark.skip(reason="Tests use random number generator with no seed.") diff --git a/src/bindings/python/tests_compatibility/test_onnx/test_backend.py b/src/bindings/python/tests_compatibility/test_onnx/test_backend.py index 89b7afcb47e4af..53ec35731cbc5f 100644 --- a/src/bindings/python/tests_compatibility/test_onnx/test_backend.py +++ b/src/bindings/python/tests_compatibility/test_onnx/test_backend.py @@ -37,7 +37,6 @@ xfail_issue_58033, xfail_issue_63033, xfail_issue_63036, - xfail_issue_63039, xfail_issue_63043, xfail_issue_63137, xfail_issue_63138, @@ -282,10 +281,6 @@ def expect_fail(test_case_path, xfail): # type: (str) -> None "OnnxBackendNodeModelTest.test_batchnorm_example_training_mode_cpu", ), (xfail_issue_63036, "OnnxBackendNodeModelTest.test_convtranspose_autopad_same_cpu"), - ( - xfail_issue_63039, - "OnnxBackendNodeModelTest.test_div_uint8_cpu", - ), ( xfail_issue_63043, "OnnxBackendNodeModelTest.test_gru_batchwise_cpu", diff --git a/src/common/snippets/include/snippets/generator.hpp b/src/common/snippets/include/snippets/generator.hpp index ab3156a108e3e1..939b4f4d43c33d 100644 --- a/src/common/snippets/include/snippets/generator.hpp +++ b/src/common/snippets/include/snippets/generator.hpp @@ -16,6 +16,8 @@ namespace snippets { auto getRegisters(std::shared_ptr& n) -> ngraph::snippets::RegInfo; +typedef std::pair(const std::shared_ptr&)>, + std::function>(const std::shared_ptr&)>> jitters_value; /** * @interface TargetMachine * @brief Base class Target machine representation. Target derives from this class to provide generator information about supported emitters @@ -51,7 +53,16 @@ class TargetMachine { if (jitter == jitters.end()) { throw ngraph_error(std::string("Target code emitter is not available for ") + type.name + " operation."); } - return jitter->second; + return jitter->second.first; + } + + std::function>(const std::shared_ptr&)> + get_supported_precisions(const ngraph::DiscreteTypeInfo type) const { + auto jitter = jitters.find(type); + if (jitter == jitters.end()) { + throw ngraph_error(std::string("Target code emitter is not available for ") + type.name + " operation."); + } + return jitter->second.second; } /** @@ -64,7 +75,7 @@ class TargetMachine { virtual ~TargetMachine() = default; protected: - std::map(std::shared_ptr)>> jitters; + std::map jitters; }; /** diff --git a/src/common/snippets/include/snippets/op/subgraph.hpp b/src/common/snippets/include/snippets/op/subgraph.hpp index ec55f076301c64..46e6633f61b8aa 100644 --- a/src/common/snippets/include/snippets/op/subgraph.hpp +++ b/src/common/snippets/include/snippets/op/subgraph.hpp @@ -101,11 +101,17 @@ class Subgraph : public ov::op::util::SubGraphOp { bool is_quantized() const { return config.m_is_quantized; } bool has_type_relaxed_ops() const { return config.m_has_type_relaxed_ops; } bool has_domain_sensitive_ops() const { return config.m_has_domain_sensitive_ops; } - - snippets::Schedule generate(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes, ngraph::pass::Manager& opt, + snippets::Schedule generate(const BlockedShapeVector& output_shapes, + const BlockedShapeVector& input_shapes, + ngraph::pass::Manager& pre_dialect, + ngraph::pass::Manager& post_dialect, + ngraph::pass::Manager& post_precision, const void* compile_params = nullptr); snippets::Schedule generate(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes, const void* compile_params = nullptr); - snippets::Schedule generate(ngraph::pass::Manager &opt, const void* compile_params = nullptr); + snippets::Schedule generate(ngraph::pass::Manager& pre_dialect, + ngraph::pass::Manager& post_dialect, + ngraph::pass::Manager& post_precision, + const void* compile_params = nullptr); snippets::Schedule generate(const void* compile_params = nullptr); ov::PartialShape canonicalize(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes); std::vector reshape_body(const std::vector& input_shapes); @@ -132,6 +138,8 @@ class Subgraph : public ov::op::util::SubGraphOp { // This check returns True if Constant op which is input of this op should be inside Subgraph body static auto constant_input_should_be_inside_body(const std::shared_ptr& node) -> bool; + static bool check_broadcast(const std::shared_ptr& node) noexcept; + private: void align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes); void convert_to_snippet_dialect(); @@ -164,8 +172,6 @@ class Subgraph : public ov::op::util::SubGraphOp { public: // True if Subgraph contains FakeQuantize -> FQ decomposition should be called bool m_is_quantized = false; - // True if we should align element types indise body - bool m_is_needed_to_align_precision = false; // True if Subgraph contains TypeRelaxed nodes -> for several streams in tp mode we should copy body using mutexes // because TypeRelaxed::copy_with_new_inputs() isn't save-thread method bool m_has_type_relaxed_ops = false; diff --git a/src/common/snippets/include/snippets/pass/align_element_type.hpp b/src/common/snippets/include/snippets/pass/align_element_type.hpp deleted file mode 100644 index 0b1f831091c4cc..00000000000000 --- a/src/common/snippets/include/snippets/pass/align_element_type.hpp +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include -#include - -namespace ngraph { -namespace snippets { -namespace pass { - -/** - * @interface AlignElementType - * @brief Wrap sequence of operations which doesn't support execution on original element type by ConvertSaturation - * and reset element type for type relaxed nodes inside body to align element type between nodes. - * Example 1: - * - After FQ decomposition there may be Convert[U8/I8]. If after the Convert there are other operations - * that don't support U8/I8, new ConvertSaturation[exec_type] will be inserted after the FQ decomposition - * to execute these operations on supported element type - * Example 2: - * - Input[I8] -> Unsupported I8 op -> Movement op -> Output[I8]. There will be inserted two ConvertSaturation: - * * ConvertSatiration[exec_type] before op which is unsupported I8 - * * ConvertSaturation[I8] before Movement op to return original low precision. - * Note: We cannot just remove original Convert[I8/U8] in Example 1 because we should cover two things: - * * allow execution of operations on supported element type for them - * * keep computations mathematically equivalent to the original function - * Thus, for these cases we should have the following pipeline: FP32 -> Convert[I8/U8] -> Convert[FP32] -> FP32 - * Note: We shouldn't call validate_and_infer_type() after Convert insertions to avoid element type conflicts on inputs of ops - * @ingroup snippets - */ -class AlignElementType: public ngraph::pass::FunctionPass { -public: - OPENVINO_RTTI("AlignElementType", "0"); - AlignElementType(const ov::element::Type exec_type = ov::element::f32); - bool run_on_model(const std::shared_ptr& m) override; - - static bool opNeedsAlignElementType(const std::shared_ptr& n, const ov::element::Type exec_type = ov::element::f32); -private: - ov::element::Type exec_type; -}; - -} // namespace pass -} // namespace snippets -} // namespace ngraph diff --git a/src/common/snippets/include/snippets/pass/fq_decomposition.hpp b/src/common/snippets/include/snippets/pass/fq_decomposition.hpp index 284640d8c18122..cfb9ff41955867 100644 --- a/src/common/snippets/include/snippets/pass/fq_decomposition.hpp +++ b/src/common/snippets/include/snippets/pass/fq_decomposition.hpp @@ -29,7 +29,7 @@ namespace pass { * * Expand brackets: * round(x * (levels-1) / (ih - il) - il * (levels-1) / (ih - il)) * (oh - ol) / (levels-1) + ol - * + * * Marking: * - isc := (levels-1) / (ih - il) * - ish := -il * isc @@ -37,7 +37,7 @@ namespace pass { * - osh := ol * Final expression: * round(x * isc + ish) * osc + osh - * + * * Some optimizations (example for scalars): * 1. If output element type of FQ is U8 and il = 0, ish = 0, osc = 1, osh = 0, there is enough expression: x * isc * 2. If output element type of FQ is I8 and ish ~= 128, osc = 1, osh ~= -128, il * isc ~= -128, ih * isc ~= 127 there is enough expression: x * isc @@ -54,7 +54,6 @@ class FakeQuantizeDecomposition : public ngraph::pass::MatcherPass { public: FakeQuantizeDecomposition(); - static bool isAllScalarConstant(const std::shared_ptr& node); static bool getScalesAndShifts(const std::shared_ptr& fq_node, std::vector& cl, std::vector& ch, diff --git a/src/common/snippets/include/snippets/pass/propagate_precision.hpp b/src/common/snippets/include/snippets/pass/propagate_precision.hpp new file mode 100644 index 00000000000000..d0920766f632fd --- /dev/null +++ b/src/common/snippets/include/snippets/pass/propagate_precision.hpp @@ -0,0 +1,48 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include "snippets/generator.hpp" + +namespace ngraph { +namespace snippets { +namespace pass { + +/** + * @class PropagatePrecision + * @ingroup snippets + * @brief PropagatePrecision transformation propagate precision from parameters to results. + */ +class PropagatePrecision: public ngraph::pass::FunctionPass { +public: + OPENVINO_RTTI("PropagatePrecision", "0"); + PropagatePrecision(const std::shared_ptr& target_machine); + bool run_on_model(const std::shared_ptr& m) override; + + static std::vector get_precisions( + const std::vector& input_precisions, + const std::set>& supported_precisions) noexcept; + + // if can_be_removed returns true then actual convertion (actual_before => actual_after) + // can be replaced to required (actual_before => required_after) + static bool can_be_removed( + const element::Type& actual_before, + const element::Type& actual_after, + const element::Type& required_after) noexcept; + + // if can_be_fused returns true then actual convertion can be replaced to required + static bool can_be_fused( + const element::Type& actual, + const element::Type& required) noexcept; + +private: + const std::shared_ptr target_machine; +}; + +} // namespace pass +} // namespace snippets +} // namespace ngraph diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index 07f13ae8defb57..20b6edb17b9d14 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -11,6 +11,7 @@ #include "snippets/pass/insert_movebroadcast.hpp" #include "snippets/pass/broadcast_to_movebroadcast.hpp" #include "snippets/pass/load_movebroadcast_to_broadcastload.hpp" +#include "snippets/pass/propagate_precision.hpp" #include "snippets/pass/assign_registers.hpp" #include "snippets/pass/convert_constants.hpp" #include "snippets/pass/convert_power_to_powerstatic.hpp" @@ -18,7 +19,6 @@ #include "snippets/pass/insert_loops.hpp" #include "snippets/pass/transpose_decomposition.hpp" #include "snippets/pass/transform_convert.hpp" -#include "snippets/pass/align_element_type.hpp" #include "snippets/pass/matmul_to_brgemm.hpp" #include "snippets/pass/fuse_transpose_brgemm.hpp" #include "snippets/pass/softmax_decomposition.hpp" @@ -62,10 +62,6 @@ void snippets::op::Subgraph::init_config() { ov::is_type(op); config.m_has_type_relaxed_ops = config.m_has_type_relaxed_ops || std::dynamic_pointer_cast(op); - config.m_is_needed_to_align_precision = config.m_is_needed_to_align_precision || - is_quantized() || - has_type_relaxed_ops() || - snippets::pass::AlignElementType::opNeedsAlignElementType(op, execution_element_type); config.m_has_domain_sensitive_ops = config.m_has_domain_sensitive_ops || ov::is_type(op) || ov::is_type(op) || @@ -359,6 +355,14 @@ ov::PartialShape snippets::op::Subgraph::canonicalize(const BlockedShapeVector& return master_shape; } +bool snippets::op::Subgraph::check_broadcast(const std::shared_ptr& node) noexcept { + const auto elementwise = std::dynamic_pointer_cast(node); + return + (elementwise == nullptr) || + (elementwise->get_input_partial_shape(0).size() == elementwise->get_input_partial_shape(1).size()) || + (elementwise->get_autob().m_type != ov::op::AutoBroadcastType::PDPD); +} + void snippets::op::Subgraph::align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes) { // We should insert Convert before Results to set original output element type if needed @@ -369,35 +373,34 @@ void snippets::op::Subgraph::align_element_types(const BlockedShapeVector& outpu const auto convert = std::make_shared( body_results[i]->get_input_node_shared_ptr(0), needed_out_type); body_results[i]->set_argument(0, convert); + body_results[i]->validate_and_infer_types(); } } // We should change existing element type to original for Parameters if needed - const auto& body_parameters = body_ptr()->get_parameters(); + const auto& parameters = body_ptr()->get_parameters(); for (size_t i = 0; i < inputShapes.size(); ++i) { const auto needed_in_type = std::get<2>(inputShapes[i]); - if (body_parameters[i]->get_element_type() != needed_in_type) { - body_parameters[i]->set_element_type(needed_in_type); - config.m_is_needed_to_align_precision = true; - } - } + const auto& parameter = parameters[i]; + if (parameter->get_element_type() != needed_in_type) { + const auto parameter_output = parameter->output(0); + const auto convert = std::make_shared( + parameter_output, + parameter_output.get_element_type()); + ngraph::copy_runtime_info(parameter, convert); + + for (const auto input : parameter_output.get_target_inputs()) { + const auto& input_node = input.get_node(); + if (input_node == convert.get()) { + continue; + } + input_node->set_argument(input.get_index(), convert->output(0)); + } - // We should align element type inside body using the corresponding pass: - // - Insert Convert before operations that doesn't support original element type for execution - // - Insert reverse Convert before operations that support original element type - // but have inputs that doesn't support it (because before them will be inserted Convert with exec_type - first point) - // - Then we should use ConstantFolding pass to convert element type of Scalars before inference. - // - Eliminate redundant Converts which can be inserted in AlignElementType() pass - ngraph::pass::Manager manager; - if (config.m_is_needed_to_align_precision) { - manager.register_pass(execution_element_type); - manager.register_pass(); - // TODO [100041] : In some cases AlignElementType pass can insert extra Convert because - // the pass doesn't know real precisions in real time. - // We call EliminateConverts pass to remove them - manager.register_pass(); + parameter->set_element_type(needed_in_type); + parameter->validate_and_infer_types(); + } } - manager.run_passes(body_ptr()); } void snippets::op::Subgraph::initialize_buffer_scratchpad_size() { @@ -602,24 +605,39 @@ snippets::Schedule snippets::op::Subgraph::generate(const BlockedShapeVector& ou snippets::Schedule snippets::op::Subgraph::generate(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes, - ngraph::pass::Manager& opt, + ngraph::pass::Manager& pre_dialect, + ngraph::pass::Manager& post_dialect, + ngraph::pass::Manager& post_precision, const void* compile_params) { canonicalize(output_shapes, input_shapes); - return generate(opt, compile_params); + return generate(pre_dialect, post_dialect, post_precision, compile_params); } snippets::Schedule snippets::op::Subgraph::generate(const void* compile_params) { auto mngr = ngraph::pass::Manager(); - return generate(mngr, compile_params); + return generate(mngr, mngr, mngr, compile_params); } -snippets::Schedule snippets::op::Subgraph::generate(ngraph::pass::Manager& opt, const void* compile_params) { +snippets::Schedule snippets::op::Subgraph::generate( + ngraph::pass::Manager& pre_dialect, + ngraph::pass::Manager& post_dialect, + ngraph::pass::Manager& post_precision, + const void* compile_params) { INTERNAL_OP_SCOPE(Subgraph); OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::op::generate") NGRAPH_CHECK(m_generator != nullptr, "generate is called while generator is not set"); + pre_dialect.run_passes(body_ptr()); convert_to_snippet_dialect(); - opt.run_passes(body_ptr()); + post_dialect.run_passes(body_ptr()); + + ngraph::pass::Manager precision_manager; + precision_manager.register_pass(m_generator->get_target_machine()); + precision_manager.register_pass(); + precision_manager.register_pass(); + precision_manager.run_passes(body_ptr()); + + post_precision.run_passes(body_ptr()); // After all passes, when all optimizations are completed and all MemoryAccess ops are inserted, // we can calculate common buffer scratchpad size and propagate offset from Buffer to the corresponding MemoryAccess ops diff --git a/src/common/snippets/src/pass/align_element_type.cpp b/src/common/snippets/src/pass/align_element_type.cpp deleted file mode 100644 index abd50a9e44605c..00000000000000 --- a/src/common/snippets/src/pass/align_element_type.cpp +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include - -#include "snippets/snippets_isa.hpp" -#include "snippets/op/convert_saturation.hpp" -#include "snippets/pass/align_element_type.hpp" -#include "snippets/utils.hpp" -#include "ov_ops/type_relaxed.hpp" -#include "ngraph/op/util/op_types.hpp" - -#include - -namespace { - -inline auto is_in_op(const std::shared_ptr& n) -> bool { - return ov::is_type(n) - || ov::is_type(n); -} - -// At the moment Subgraph supports only Eltwise, Select, Convert, Broadcast and FQ (which is decomposed into Eltwises and Convert) with -// Softmax (which is decomposed into Eltwises as well) -// And only Eltwise and Select ops supports execution only in "exec_type". So we can check op type from the opposite -// NOTE: This check is only for executable which isn't Parameter/Constant/Result -inline auto op_supports_only_exec_type(const std::shared_ptr& n) -> bool { - return !is_in_op(n) && - !ov::is_type(n) && - !ov::is_type(n) && - !ov::is_type(n) && - !ov::is_type(n) && - !ov::is_type(n); -} - -} // namespace - -ngraph::snippets::pass::AlignElementType::AlignElementType(const ov::element::Type exec_type) : exec_type(exec_type) { } - -bool ngraph::snippets::pass::AlignElementType::run_on_model(const std::shared_ptr &m) { - RUN_ON_FUNCTION_SCOPE(AlignElementType); - - auto insertConvert = [](const std::shared_ptr& op, const size_t idx, const ov::element::Type& element_type) -> void { - auto convert = std::make_shared(op->input(idx).get_source_output(), element_type); - ngraph::copy_runtime_info(op->get_input_node_shared_ptr(idx), convert); - op->set_argument(idx, convert); - }; - - // NOTE: We don't call validate_and_infer_types() to avoid precision conflicts on inputs - bool rewritten = false; - auto ops = m->get_ordered_ops(); - for (auto& op : ops) { - if (is_in_op(op)) { - continue; - } - - if (op_supports_only_exec_type(op)) { - for (size_t i = 0; i < op->inputs().size(); i++) { - auto shared_input = op->get_input_node_shared_ptr(i); - auto existing_convert = ov::as_type_ptr(shared_input); - // We should insert Convert before Ops, which supports only exec element type, only when: - // - Input is Convert with unsupported destination type - // - Input is Op which support any element type - // We couldn't unite these conditions and just check that element type isn't supported exec type - // because we don't call validate_and_infer_types() so we don't know new precisions after setting of original - // input and output element types - if ((existing_convert && existing_convert->get_destination_type() != exec_type) || - (!op_supports_only_exec_type(shared_input))) { - insertConvert(op, i, exec_type); - rewritten |= true; - } - } - if (auto tr_node = std::dynamic_pointer_cast(op)) { - tr_node->set_overridden_output_type(exec_type, 0); - rewritten |= true; - } - } else { // branch for Movement ops, MatMul ops in the future and for the Convert, Result - for (size_t i = 0; i < op->inputs().size(); i++) { - auto shared_input = op->get_input_node_shared_ptr(i); - // it's original element type because we don't use validate_and_infer_type() anywhere - const auto original_eltype = op->input(i).get_element_type(); - // If before op there is another op that doesn't support execution on original element type, we know that - // before this op will be inserted reverse Convert to support execution on supported element type (first branch of condition). - // So we should return original element type for operations that can support low precision - if (op_supports_only_exec_type(shared_input) && original_eltype != exec_type) { - insertConvert(op, i, original_eltype); - rewritten |= true; - } - } - } - } - - return rewritten; -} - -bool ngraph::snippets::pass::AlignElementType::opNeedsAlignElementType(const std::shared_ptr& op, const ov::element::Type exec_type) { - // At the moment Snippets support only Eltwise/Convert/FQ/Select/Softmax/Broadcast which one output so we can just call get_element_type() - return op_supports_only_exec_type(op) && op->get_element_type() != exec_type; -} diff --git a/src/common/snippets/src/pass/collapse_subgraph.cpp b/src/common/snippets/src/pass/collapse_subgraph.cpp index cd3eb887481031..3325881834fd88 100644 --- a/src/common/snippets/src/pass/collapse_subgraph.cpp +++ b/src/common/snippets/src/pass/collapse_subgraph.cpp @@ -212,7 +212,11 @@ const std::set ngraph::snippets::pass::TokenizeSnippets:: { ngraph::element::f32, ngraph::element::bf16, ngraph::element::i8, ngraph::element::u8 }; bool TokenizeSnippets::AppropriateForSubgraph(const std::shared_ptr &node) { - return is_supported_op(node) && has_supported_in_out(node) && node->get_control_dependencies().empty(); + return + is_supported_op(node) && + has_supported_in_out(node) && + node->get_control_dependencies().empty() && + snippets::op::Subgraph::check_broadcast(node); } TokenizeSnippets::TokenizeSnippets() { diff --git a/src/common/snippets/src/pass/fq_decomposition.cpp b/src/common/snippets/src/pass/fq_decomposition.cpp index 5c2cfd6b0f82c3..9688e0a0e22940 100644 --- a/src/common/snippets/src/pass/fq_decomposition.cpp +++ b/src/common/snippets/src/pass/fq_decomposition.cpp @@ -36,11 +36,6 @@ bool isValidRangesInputs(const std::shared_ptr& fq }); } -bool is_scalar_constant(const std::shared_ptr& source_output_node) { - return ngraph::is_type(source_output_node) && - ngraph::shape_size(source_output_node->get_shape()) == 1; -} - } // namespace ngraph::snippets::pass::FakeQuantizeDecomposition::FakeQuantizeDecomposition() { @@ -182,13 +177,6 @@ ngraph::snippets::pass::FakeQuantizeDecomposition::FakeQuantizeDecomposition() { register_matcher(m, callback); } -bool ngraph::snippets::pass::FakeQuantizeDecomposition::isAllScalarConstant(const std::shared_ptr& node) { - return is_scalar_constant(node->get_input_node_shared_ptr(1)) && - is_scalar_constant(node->get_input_node_shared_ptr(2)) && - is_scalar_constant(node->get_input_node_shared_ptr(3)) && - is_scalar_constant(node->get_input_node_shared_ptr(4)); -} - bool ngraph::snippets::pass::FakeQuantizeDecomposition::getScalesAndShifts( const std::shared_ptr& fq_node, std::vector& cl, diff --git a/src/common/snippets/src/pass/propagate_precision.cpp b/src/common/snippets/src/pass/propagate_precision.cpp new file mode 100644 index 00000000000000..19be34b4e97648 --- /dev/null +++ b/src/common/snippets/src/pass/propagate_precision.cpp @@ -0,0 +1,293 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/pass/propagate_precision.hpp" + +#include +#include +#include "ov_ops/type_relaxed.hpp" +#include "snippets/itt.hpp" +#include "ngraph/rt_info.hpp" + +using namespace ngraph; + +ngraph::snippets::pass::PropagatePrecision::PropagatePrecision( + const std::shared_ptr& target_machine) : target_machine(target_machine) { +} + +bool ngraph::snippets::pass::PropagatePrecision::run_on_model(const std::shared_ptr& f) { + RUN_ON_MODEL_SCOPE(PropagatePrecision); + OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::op::PropagatePrecision") + + std::unordered_map, element::Type> result_types; + auto results = f->get_results(); + for (auto& result : results) { + result_types.emplace(result, result->get_input_element_type(0)); + } + + bool was_updated = true; + for (const auto& op : f->get_ordered_ops()) { + auto type_info = op->get_type_info(); + OPENVINO_ASSERT( + target_machine->has(type_info), + "operation '" + std::string(type_info.version_id) + "::" + std::string(type_info.name) + "' was not found in target machine"); + + auto exec = target_machine->get_supported_precisions(type_info); + const auto supported_precisions = exec(op); + if (supported_precisions.empty()) { + continue; + } + + // There are two operation types which break precision propagation: + // 1) Existing convertion operations. Solution: remove convertion + // operation before general algo + // 2) Type relaxed based operations. Will be resolved by snippet opset. + + auto input_precisions_were_changed = false; + + for (const auto& input : op->inputs()) { + const auto convert = ngraph::as_type(input.get_source_output().get_node()); + if (convert == nullptr) { + continue; + } + + const auto precision_before = convert->get_input_element_type(0); + const auto precision_after = convert->get_output_element_type(0); + if (can_be_removed(precision_before, precision_after, precision_before)) { + op->set_argument(input.get_index(), convert->input(0).get_source_output()); + input_precisions_were_changed = true; + } + } + + std::vector input_precisions; + for (const auto& input : op->inputs()) { + const auto input_precision = input.get_source_output().get_element_type(); + input_precisions.push_back(input_precision); + } + + assert(std::all_of( + supported_precisions.begin(), + supported_precisions.end(), + [&input_precisions](const std::vector& precisions) { + return precisions.size() == input_precisions.size(); + }) && "input precisions count is not equal for supported precisions"); + + // update input precisions + // if possible then convert precisions to supported + if (!supported_precisions.empty() && + std::all_of( + supported_precisions.begin(), + supported_precisions.end(), + [&input_precisions](const std::vector& precisions) { + return precisions != input_precisions; + })) { + auto precisions = get_precisions(input_precisions, + supported_precisions); + OPENVINO_ASSERT( + !precisions.empty(), + "there are no supported precisions for operation '" + std::string(type_info.version_id) + "::" + std::string(type_info.name) + "'"); + + auto find_convert = []( + const ngraph::Output parent_output, + const ngraph::element::Type convert_type) -> snippets::op::ConvertSaturation* { + for (const auto& input : parent_output.get_target_inputs()) { + const auto child = ngraph::as_type(input.get_node()); + if ((child != nullptr) && (child->get_output_element_type(0) == convert_type)) { + return child; + } + } + return nullptr; + }; + + for (size_t i = 0; i < op->get_input_size(); ++i) { + const auto& op_input = op->input(i); + const auto& required_after = precisions[i]; + auto parent_output = op_input.get_source_output(); + const auto actual_before = parent_output.get_element_type(); + if (actual_before != required_after) { + was_updated = true; + input_precisions_were_changed = true; + auto existing_convert = ngraph::as_type( + parent_output.get_node()); + + if (existing_convert == nullptr) { + existing_convert = find_convert(parent_output, required_after); + if (existing_convert != nullptr) { + // reuse existing convert + op->set_argument(op_input.get_index(), existing_convert->shared_from_this()); + continue; + } + } + + if (existing_convert == nullptr) { + // create new Convert + auto convert = std::make_shared( + parent_output, + required_after); + ngraph::copy_runtime_info(parent_output.get_node_shared_ptr(), convert); + op->set_argument(op_input.get_index(), convert); + continue; + } + + const auto actual_before = existing_convert->get_input_element_type(0); + const auto actual_after = existing_convert->get_output_element_type(0); + + if (can_be_removed(actual_before, actual_after, required_after)) { + // remove existing convert + existing_convert->output(0).replace(parent_output); + continue; + } + + if (can_be_fused(actual_after, required_after)) { + // fuse existing convert + auto convert = std::make_shared( + existing_convert->get_input_node_shared_ptr(0), + required_after); + ngraph::copy_runtime_info(parent_output.get_node_shared_ptr(), convert); + op->set_argument(op_input.get_index(), convert); + continue; + } + + // create new convert + auto convert = std::make_shared( + existing_convert->output(0), + required_after); + ngraph::copy_runtime_info(existing_convert->output(0).get_node()->shared_from_this(), convert); + op->set_argument(op_input.get_index(), convert); + } + } + } + + auto type_relaxed_node = std::dynamic_pointer_cast(op); + if (input_precisions_were_changed || (type_relaxed_node != nullptr)) { + // update output precision + std::vector op_output_types; + for (auto& output : op->outputs()) { + op_output_types.push_back(output.get_element_type()); + } + + if (type_relaxed_node != nullptr) { + // TODO: user story 104284 + // to keep previous functionality + // unary and binary element-wise operations are supported + // will be replaced to snippets opset later + const auto op_element_type = op->get_input_element_type(0); + if (type_relaxed_node->get_overridden_output_type(0) != op_element_type) { + was_updated = true; + OPENVINO_ASSERT(op->get_output_size() == 1ull, "operation with several output is not supported"); + + type_relaxed_node->set_overridden_output_type(op_element_type, 0); + op->validate_and_infer_types(); + } + } else { + op->validate_and_infer_types(); + } + + for (size_t i = 0; i < op->get_output_size(); ++i) { + auto output = op->output(i); + + if (output.get_element_type() != op_output_types[i]) { + was_updated = true; + auto convert = std::make_shared( + output, + op_output_types[i]); + ngraph::copy_runtime_info(output.get_node_shared_ptr(), convert); + + for (auto& input : output.get_target_inputs()) { + auto child = input.get_node(); + if (child == convert.get()) { + continue; + } + + input.replace_source_output(convert->output(0)); + + + if (ngraph::is_type(input.get_node())) { + input.get_tensor_ptr()->add_names(output.get_tensor_ptr()->get_names()); + + const std::string original_name = op->get_friendly_name(); + op->set_friendly_name(original_name + "_original"); + convert->set_friendly_name(original_name); + } + } + output.get_tensor_ptr()->set_names({}); + } + } + } + } + + for (auto it = result_types.begin(); it != result_types.end(); ++it) { + const auto result = it->first; + const auto actual_type = result->get_input_element_type(0); + const auto expected_type = it->second; + if (actual_type != it->second) { + was_updated = true; + auto convert = std::make_shared( + result->get_input_node_shared_ptr(0), + expected_type); + ngraph::copy_runtime_info(result->get_input_node_shared_ptr(0), convert); + result->set_argument(0, convert); + } + } + + return was_updated; +} + +bool ngraph::snippets::pass::PropagatePrecision::can_be_removed( + const element::Type& actual_before, + const element::Type& actual_after, + const element::Type& required_after) noexcept { + if (actual_before != required_after) { + return false; + } + + return can_be_fused(actual_after, actual_before); +} + +bool ngraph::snippets::pass::PropagatePrecision::can_be_fused( + const element::Type& actual, + const element::Type& required) noexcept { + // custom conditions: between int & float precisions + if (((actual == element::bf16) || (actual == element::f16) || (actual == element::f32)) && + ((required == element::u8) || (required == element::i8))) { + return true; + } + + if ((actual == element::f32) && ((required == element::u16) || (required == element::i16))) { + return true; + } + + // general conditions: any new added precision will support + return + (actual.is_real() == required.is_real()) && + (actual.bitwidth() >= required.bitwidth()); +} + +std::vector ngraph::snippets::pass::PropagatePrecision::get_precisions( + const std::vector& input_precisions, + const std::set>& supported_precisions_pack) noexcept { + bool was_found = false; + for (const auto& supported_precisions : supported_precisions_pack) { + for (size_t i = 0; i < supported_precisions.size(); ++i) { + const auto& supported_precision = supported_precisions[i]; + const auto& input_precision = input_precisions[i]; + if ((supported_precision.is_real() != input_precision.is_real()) || + (input_precision.bitwidth() > supported_precision.bitwidth())) { + was_found = false; + break; + } + + was_found = true; + } + if (was_found) { + return supported_precisions; + } + } + + if (!supported_precisions_pack.empty()) { + return *supported_precisions_pack.begin(); + } + + return {}; +} diff --git a/src/common/snippets/tests/include/lowering_utils.hpp b/src/common/snippets/tests/include/lowering_utils.hpp index be2e0f2e756044..b0b1bafb245308 100644 --- a/src/common/snippets/tests/include/lowering_utils.hpp +++ b/src/common/snippets/tests/include/lowering_utils.hpp @@ -16,7 +16,7 @@ using BlockedShapeVector = ngraph::snippets::op::Subgraph::BlockedShapeVector; class DummyEmitter : public ngraph::snippets::Emitter { public: // Here I pass Add to Emitter, but could be any other op, since it's ignored anyway. - DummyEmitter() : ngraph::snippets::Emitter(std::make_shared()) {} + DummyEmitter(const std::vector& custom_opset = {}) : ngraph::snippets::Emitter(std::make_shared()) {} void emit_code(const std::vector&, const std::vector&, const std::vector&, @@ -49,7 +49,9 @@ class LoweringTests : public TransformationTestsF { static std::shared_ptr getSubgraph(const std::shared_ptr& f); static std::shared_ptr getLoweredSubgraph(const std::shared_ptr& f, const ov::PartialShape& master_shape, - ov::pass::Manager target_optimizations = {}, + ov::pass::Manager pre_dialect = {}, + ov::pass::Manager post_dialect = {}, + ov::pass::Manager post_precision = {}, const std::shared_ptr generator = nullptr); static std::shared_ptr getTokenizedSubgraph(const std::shared_ptr& f); ov::PartialShape master_shape{}; diff --git a/src/common/snippets/tests/include/pass/precision_propagation.hpp b/src/common/snippets/tests/include/pass/precision_propagation.hpp new file mode 100644 index 00000000000000..a60b9161ab4fc4 --- /dev/null +++ b/src/common/snippets/tests/include/pass/precision_propagation.hpp @@ -0,0 +1,54 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "lowering_utils.hpp" +#include "snippets_helpers.hpp" + +namespace ov { +namespace test { +namespace snippets { + +class PrecisionPropagationParamsValues { +public: + class Actual { + public: + std::pair convertion_before_op1; + element::Type convertion_before_op2_1; + std::pair convertion_before_op2_2; + std::set> op1_supported_precisions; + std::set> op2_supported_precisions; + }; + + class Expected { + public: + std::pair convertion_before_op1; + element::Type convertion_before_op2_1; + std::pair convertion_before_op2_2; + element::Type convertion_after_op2; + }; + + std::vector input_types; + Actual actual; + Expected expected; +}; + +typedef std::tuple< + std::pair, // input shapes + PrecisionPropagationParamsValues +> PrecisionPropagationParams; + +class PrecisionPropagationTest : public TransformationTestsF, + public testing::WithParamInterface { +public: + static std::string getTestCaseName(testing::TestParamInfo obj); + +protected: + std::shared_ptr snippets_function; +}; + +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/common/snippets/tests/src/lowering_utils.cpp b/src/common/snippets/tests/src/lowering_utils.cpp index a536a0317eae12..55480e95dae510 100644 --- a/src/common/snippets/tests/src/lowering_utils.cpp +++ b/src/common/snippets/tests/src/lowering_utils.cpp @@ -11,10 +11,12 @@ namespace ov { namespace test { namespace snippets { -DummyTargetMachine::DummyTargetMachine(const std::vector& custom_opset) { - auto dummy_functor = [](const std::shared_ptr& n) { - return std::make_shared(); +DummyTargetMachine::DummyTargetMachine(const std::vector&custom_opset) { + auto dummy_functor = ngraph::snippets::jitters_value { + [](const std::shared_ptr& n) { return std::make_shared(); }, + [](const std::shared_ptr& n) { return std::set>{};} }; + jitters[op::v0::Parameter::get_type_info_static()] = dummy_functor; jitters[op::v0::Constant::get_type_info_static()] = dummy_functor; jitters[op::v0::Result::get_type_info_static()] = dummy_functor; @@ -97,7 +99,9 @@ std::shared_ptr LoweringTests::getSubgraph(const std::shared_ptr LoweringTests::getLoweredSubgraph(const std::shared_ptr &f, const ov::PartialShape& master_shape, - ov::pass::Manager target_optimizations, + ov::pass::Manager pre_dialect, + ov::pass::Manager post_dialect, + ov::pass::Manager post_precision, const std::shared_ptr generator) { auto subgraph = getTokenizedSubgraph(f); subgraph->set_generator(generator == nullptr ? std::make_shared() : generator); @@ -119,7 +123,7 @@ std::shared_ptr LoweringTests::getLoweredSubgrap } body_rt_info["PluginShapesOverride"] = new_shapes; subgraph->set_tile_rank(2); - subgraph->generate(target_optimizations); + subgraph->generate(pre_dialect, post_precision, post_precision); return subgraph; } diff --git a/src/common/snippets/tests/src/pass/precision_propagation.cpp b/src/common/snippets/tests/src/pass/precision_propagation.cpp new file mode 100644 index 00000000000000..3c7da4d06aa165 --- /dev/null +++ b/src/common/snippets/tests/src/pass/precision_propagation.cpp @@ -0,0 +1,294 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "pass/precision_propagation.hpp" + +#include +#include "ngraph/pass/validate.hpp" +#include "snippets/pass/propagate_precision.hpp" +#include "snippets/op/convert_saturation.hpp" +#include "common_test_utils/common_utils.hpp" +#include "precision_propagation_function.hpp" + +namespace ov { +namespace test { +namespace snippets { + +namespace { + +class DummyPrecisionPropagationTargetMachine : public DummyTargetMachine { +public: + DummyPrecisionPropagationTargetMachine( + const std::set>& op1_supported_precisions, + const std::set>& op2_supported_precisions) + : DummyTargetMachine() { + jitters[DummyAdd::get_type_info_static()] = ngraph::snippets::jitters_value { + [](const std::shared_ptr& n) { return std::make_shared(); }, + [op1_supported_precisions](const std::shared_ptr& n) { return op1_supported_precisions; }}; + jitters[op::v1::Maximum::get_type_info_static()] = ngraph::snippets::jitters_value{ + [](const std::shared_ptr& n) { return std::make_shared(); }, + [op2_supported_precisions](const std::shared_ptr&n) { return op2_supported_precisions; }}; + + auto default_jitter = ngraph::snippets::jitters_value{ + [](const std::shared_ptr& n) { return std::make_shared(); }, + [](const std::shared_ptr& n) { return std::set>{};} }; + jitters[ngraph::snippets::op::ConvertSaturation::get_type_info_static()] = default_jitter; + } +}; + +} // namespace + +std::string PrecisionPropagationTest::getTestCaseName(testing::TestParamInfo obj) { + std::pair shapes; + PrecisionPropagationParamsValues test_values; + std::tie(shapes, test_values) = obj.param; + + auto to_string = [](const std::set>& precisions_pack) noexcept { + std::ostringstream result; + result << "{"; + for (const auto& precisions : precisions_pack) { + result << CommonTestUtils::vec2str(precisions) << "_"; + } + result << "}"; + return result.str(); + }; + + std::ostringstream result; + result << "IN0_" << shapes.first << "_" << test_values.input_types[0] << "_" + << "IN1_" << shapes.second << "_" << test_values.input_types[1] << "_" + << "IN2_" << test_values.input_types[2] + << to_string(test_values.actual.op1_supported_precisions) << "_" + << to_string(test_values.actual.op2_supported_precisions) << "_" + << test_values.expected.convertion_before_op1.first << "_" << test_values.expected.convertion_before_op1.second << "_" + << test_values.expected.convertion_before_op2_1 << "_" + << test_values.expected.convertion_before_op2_2.first << "_" << test_values.expected.convertion_before_op2_2.second << "_" + << test_values.expected.convertion_after_op2 << "_"; + return result.str(); +} + +TEST_P(PrecisionPropagationTest, CompareFunctions) { + disable_rt_info_check(); + + const auto param = GetParam(); + const auto shapes = std::get<0>(param); + const auto test_values = std::get<1>(param); + + const auto input_shapes = std::vector({ shapes.first, shapes.second }); + PrecisionPropagationAddFunction function_stub( + input_shapes, + test_values.input_types[0], + test_values.input_types[1], + test_values.input_types[2], + { + test_values.actual.convertion_before_op1, + test_values.actual.convertion_before_op2_1, + test_values.actual.convertion_before_op2_2 + }, + { + test_values.expected.convertion_before_op1, + test_values.expected.convertion_before_op2_1, + test_values.expected.convertion_before_op2_2, + test_values.expected.convertion_after_op2 + }); + function = function_stub.getOriginal(); + + const auto target_machine = std::make_shared( + test_values.actual.op1_supported_precisions, + test_values.actual.op2_supported_precisions); + + manager.register_pass(target_machine); + + function_ref = function_stub.getReference(); +} + +namespace PrecisionPropagationTestInstantiation { +// clang-format off + +std::vector> shapes { + {{1, 3, 16, 16}, {1, 3, 16, 16}} +}; + +std::vector test_cases { + { + {element::f32, element::f32, element::f32}, + { + {}, + {}, + {}, + {{element::f32, element::f32}}, + {{element::f32, element::f32}} + }, + {} + }, + // in: Parameter I8 => Op1 I32 => Convert I8 => Op1 I8 => Result + // out: Parameter I8 => Add I32 => Convert I8 => Convert FP32 => Op1 FP32 => Result + { + {element::i8, element::i8, element::i8}, + { + {}, + {}, + {}, + {{element::i8, element::i8}}, + {{element::f32, element::f32}} + }, + { + {}, + element::i8, + {element::f32, element::f32}, + {element::i8} + } + }, + { + {element::i8, element::i8, element::i8}, + { + {}, + {}, + {}, + {{element::i8, element::i8}}, + {{element::i8, element::i8}} + }, + { + {}, + {}, + {element::i8, element::undefined}, + {} + } + }, + { + {element::i8, element::i8, element::i8}, + { + {}, + {}, + {}, + {{element::i8, element::i8}}, + {{element::i32, element::i32}} + }, + { + {}, + {element::i8}, + {element::i32, element::i32}, + {element::i8} + } + }, + { + {element::bf16, element::bf16, element::f32}, + { + {element::f32, element::f32}, + {}, + {}, + { + {element::f32, element::f32}, + {element::i8, element::i8} + }, + { + {element::f32, element::f32}, + {element::i32, element::i32} + } + }, + { + {element::f32, element::f32}, + {}, + {}, + {} + } + }, + // propagate precision via operation #1 + { + {element::bf16, element::bf16, element::f32}, + { + {element::f32, element::f32}, + {}, + {}, + { + {element::f32, element::f32}, + {element::bf16, element::bf16} + }, + { + {element::f32, element::f32} + } + }, + { + {}, + {}, + {element::f32, element::undefined}, + {} + } + }, + // propagate precision via operation #1 + { + {element::bf16, element::bf16, element::bf16}, + { + {element::f32, element::f32}, + {}, + {element::undefined, element::f32}, + { + {element::f32, element::f32}, + {element::bf16, element::bf16} + }, + { + {element::f32, element::f32} + } + }, + { + {}, + {}, + {element::f32, element::f32}, + {} + } + }, + // propagate precision via both operations + { + {element::bf16, element::bf16, element::bf16}, + { + {element::f32, element::f32}, + {}, + {element::undefined, element::f32}, + { + {element::f32, element::f32}, + {element::bf16, element::bf16} + }, + { + {element::f32, element::f32}, + {element::bf16, element::bf16} + } + }, + { + {}, + {}, + {}, + {element::f32} + } + }, + { + {element::bf16, element::bf16, element::bf16}, + { + {}, + {}, + {}, + {{element::f32, element::f32}}, + {{element::f32, element::f32}} + }, + { + {{element::f32}, {element::f32}}, + {element::bf16}, + {{element::f32}, {element::f32}}, + {element::bf16} + } + }, +}; + +INSTANTIATE_TEST_SUITE_P( + smoke_Snippets_PrecisionPropagationTest, + PrecisionPropagationTest, + ::testing::Combine( + ::testing::ValuesIn(shapes), + ::testing::ValuesIn(test_cases)), + PrecisionPropagationTest::getTestCaseName); + +// clang-format on +} // namespace PrecisionPropagationTestInstantiation + +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/common/snippets/tests/src/pass/precision_propagation_convert_test.cpp b/src/common/snippets/tests/src/pass/precision_propagation_convert_test.cpp new file mode 100644 index 00000000000000..cc6c113cc3f671 --- /dev/null +++ b/src/common/snippets/tests/src/pass/precision_propagation_convert_test.cpp @@ -0,0 +1,153 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include "snippets/pass/propagate_precision.hpp" + +namespace ov { +namespace test { +namespace snippets { + +class PrecisionPropagationConvertTest : public testing::Test {}; + +TEST_F(PrecisionPropagationConvertTest, smoke_Snippets_PrecisionPropagation_can_be_fused) { + const std::set> precisions_set = { + {element::u64, element::u64}, + {element::u64, element::u32}, + {element::u64, element::u16}, + {element::u64, element::u8}, + {element::u32, element::u32}, + {element::u32, element::u16}, + {element::u32, element::u8}, + {element::u16, element::u16}, + {element::u16, element::u8}, + {element::u8, element::u8}, + + {element::i64, element::i64}, + {element::i64, element::i32}, + {element::i64, element::i16}, + {element::i64, element::i8}, + {element::i32, element::i32}, + {element::i32, element::i16}, + {element::i32, element::i8}, + {element::i16, element::i16}, + {element::i16, element::i8}, + {element::i8, element::i8}, + + {element::f64, element::f64}, + {element::f64, element::f32}, + {element::f64, element::f16}, + {element::f32, element::f32}, + {element::f32, element::f16}, + {element::f16, element::f16}, + + {element::f32, element::bf16}, + {element::bf16, element::bf16}, + {element::f32, element::i8}, + {element::f16, element::i8}, + {element::bf16, element::i8}, + {element::f32, element::u8}, + {element::f16, element::u8}, + {element::bf16, element::u8} + }; + + for (const auto& precisions : precisions_set) { + ASSERT_TRUE(ngraph::snippets::pass::PropagatePrecision::can_be_fused( + precisions.first, + precisions.second)) << precisions.second << " can replace " << precisions.first; + + if (precisions.first == precisions.second) { + continue; + } + + ASSERT_FALSE(ngraph::snippets::pass::PropagatePrecision::can_be_fused( + precisions.second, + precisions.first)) << precisions.second << " can not replace " << precisions.first; + } +} + +TEST_F(PrecisionPropagationConvertTest, smoke_Snippets_PrecisionPropagation_can_not_be_fused) { + const std::set> precisions_set = { + {element::i64, element::f32}, + {element::i64, element::f16}, + {element::i64, element::bf16}, + + {element::i32, element::f32}, + {element::i32, element::f16}, + {element::i32, element::bf16}, + + {element::i16, element::f16}, + {element::i16, element::bf16}, + + {element::u64, element::f32}, + {element::u64, element::f16}, + {element::u64, element::bf16}, + + {element::u32, element::f32}, + {element::u32, element::f16}, + {element::u32, element::bf16}, + + {element::u16, element::f16}, + {element::u16, element::bf16} + }; + + for (const auto& precisions : precisions_set) { + ASSERT_FALSE(ngraph::snippets::pass::PropagatePrecision::can_be_fused( + precisions.first, + precisions.second)) << precisions.second << " can not replace " << precisions.first; + } +} + +TEST_F(PrecisionPropagationConvertTest, smoke_Snippets_PrecisionPropagation_can_be_removed) { + const std::set> precisions_set = { + {element::u64, element::u64, element::u64}, + {element::u32, element::u64, element::u32}, + {element::u16, element::u64, element::u16}, + {element::u8, element::u64, element::u8}, + {element::u32, element::u32, element::u32}, + {element::u16, element::u32, element::u16}, + {element::u8, element::u32, element::u8}, + {element::u16, element::u16, element::u16}, + {element::u8, element::u16, element::u8}, + {element::u8, element::u8, element::u8}, + + {element::i64, element::i64, element::i64}, + {element::i32, element::i64, element::i32}, + {element::i16, element::i64, element::i16}, + {element::i8, element::i64, element::i8}, + {element::i32, element::i32, element::i32}, + {element::i16, element::i32, element::i16}, + {element::i8, element::i32, element::i8}, + {element::i16, element::i16, element::i16}, + {element::i8, element::i16, element::i8}, + {element::i8, element::i8, element::i8}, + + {element::f64, element::f64, element::f64}, + {element::f32, element::f64, element::f32}, + {element::f16, element::f64, element::f16}, + {element::f32, element::f32, element::f32}, + {element::f16, element::f16, element::f16}, + + {element::bf16, element::f32, element::bf16}, + {element::bf16, element::bf16, element::bf16}, + }; + + for (const auto& precisions : precisions_set) { + const auto actual_before = std::get<0>(precisions); + const auto actual_after = std::get<1>(precisions); + const auto required_after = std::get<2>(precisions); + ASSERT_TRUE(ngraph::snippets::pass::PropagatePrecision::can_be_removed( + actual_before, + actual_after, + required_after)) << "can_be_removed: " << actual_before << " => " << actual_after << " => " << required_after; + + if ((actual_before == actual_after) && (actual_before == required_after)) { + continue; + } + } +} + +} // namespace snippets +} // namespace test +} // namespace ov \ No newline at end of file diff --git a/src/common/snippets/tests/src/pass/precision_propagation_get_precisions.cpp b/src/common/snippets/tests/src/pass/precision_propagation_get_precisions.cpp new file mode 100644 index 00000000000000..9e97fcc8ad4aa1 --- /dev/null +++ b/src/common/snippets/tests/src/pass/precision_propagation_get_precisions.cpp @@ -0,0 +1,45 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include "snippets/pass/propagate_precision.hpp" + +namespace ov { +namespace test { +namespace snippets { + + +class PrecisionPropagationGetPrecisionsTest : public testing::Test {}; + +TEST_F(PrecisionPropagationGetPrecisionsTest, empty) { + ASSERT_EQ(std::vector{}, ngraph::snippets::pass::PropagatePrecision::get_precisions({}, {})); +} + +TEST_F(PrecisionPropagationGetPrecisionsTest, selected) { + ASSERT_EQ( + std::vector({element::f32, element::f32}), + ngraph::snippets::pass::PropagatePrecision::get_precisions( + { element::f32, element::f32 }, + { + {element::bf16, element::bf16}, + {element::f32, element::f32}, + {element::i8, element::i8}, + })); +} + +TEST_F(PrecisionPropagationGetPrecisionsTest, first) { + ASSERT_EQ( + std::vector({ element::bf16, element::bf16 }), + ngraph::snippets::pass::PropagatePrecision::get_precisions( + { element::i32, element::i32 }, + { + {element::bf16, element::bf16}, + {element::f32, element::f32}, + {element::i8, element::i8}, + })); +} + +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/core/src/pass/visualize_tree.cpp b/src/core/src/pass/visualize_tree.cpp index 70ee298b547e5e..c89decb3f42121 100644 --- a/src/core/src/pass/visualize_tree.cpp +++ b/src/core/src/pass/visualize_tree.cpp @@ -503,7 +503,9 @@ string pass::VisualizeTree::get_node_name(shared_ptr node) { if (node->get_friendly_name() != node->get_name()) { rc += "\\n" + (nvtmn ? string("name: ") : "") + node->get_name(); } - rc += "\\n" + (nvtmn ? string("type_name: ") : "") + std::string(node->get_type_name()); + const auto type_info = node->get_type_info(); + rc += "\\n" + (nvtmn ? string("type_name: ") : "") + std::string(type_info.version_id) + + "::" + std::string(type_info.name); static const bool nvttn = getenv_bool("OV_VISUALIZE_TREE_TENSORS_NAME"); if (nvttn) { diff --git a/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp index 8423a9bec9d611..8c2e666d6b6438 100644 --- a/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp @@ -26,8 +26,14 @@ using namespace std; using namespace ngraph::snippets; -#define CREATE_EMITTER(e_type) [this](const std::shared_ptr& n) \ - -> std::shared_ptr {return std::make_shared(h.get(), isa, n);}; +#define CREATE_EMITTER(e_type) { \ + [this](const std::shared_ptr& n) -> std::shared_ptr { \ + return std::make_shared(h.get(), isa, n); \ + }, \ + [](const std::shared_ptr& n) -> std::set> { \ + return e_type::get_supported_precisions(n); \ + } \ +}; class jit_snippet : public dnnl::impl::cpu::x64::jit_generator { public: diff --git a/src/plugins/intel_cpu/src/emitters/jit_dnnl_emitters.cpp b/src/plugins/intel_cpu/src/emitters/jit_dnnl_emitters.cpp index 501cd934753b10..416218b92a3bb6 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_dnnl_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/jit_dnnl_emitters.cpp @@ -13,6 +13,10 @@ using namespace Xbyak; namespace ov { namespace intel_cpu { +std::set> jit_dnnl_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32}}; +} + jit_dnnl_emitter::jit_dnnl_emitter(jit_generator *host, cpu_isa_t host_isa, const std::shared_ptr& node, InferenceEngine::Precision exec_prc) : jit_emitter(host, host_isa, node, exec_prc) { diff --git a/src/plugins/intel_cpu/src/emitters/jit_dnnl_emitters.hpp b/src/plugins/intel_cpu/src/emitters/jit_dnnl_emitters.hpp index b9ea5ffd2339da..0b7165d2484580 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_dnnl_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/jit_dnnl_emitters.hpp @@ -20,6 +20,8 @@ class jit_dnnl_emitter : public jit_emitter { void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const override {}; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + protected: jit_dnnl_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, dnnl_alg_kind_t algKind, float inpAlpha, float inpBeta, diff --git a/src/plugins/intel_cpu/src/emitters/jit_eltwise_emitters.cpp b/src/plugins/intel_cpu/src/emitters/jit_eltwise_emitters.cpp index d222f8345511dc..150d524ac04ce7 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_eltwise_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/jit_eltwise_emitters.cpp @@ -3,6 +3,7 @@ // #include "jit_eltwise_emitters.hpp" +#include "ie_ngraph_utils.hpp" using namespace InferenceEngine; using namespace dnnl::impl::utils; @@ -16,9 +17,26 @@ using namespace Xbyak; namespace ov { namespace intel_cpu { +namespace { +InferenceEngine::Precision get_arithmetic_binary_exec_precision(const std::shared_ptr& n) { + std::vector input_precisions; + for (const auto& input : n->inputs()) { + input_precisions.push_back( + InferenceEngine::details::convertPrecision(input.get_source_output().get_element_type())); + } + + assert(std::all_of( + input_precisions.begin(), + input_precisions.end(), + [&input_precisions](const InferenceEngine::Precision& precision) {return precision == input_precisions[0]; })); + + return input_precisions[0]; +} +} // namespace + /// ADD /// -jit_add_emitter::jit_add_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, Precision exec_prc) -: jit_emitter(host, host_isa, node, exec_prc) {} +jit_add_emitter::jit_add_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node) +: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} jit_add_emitter::jit_add_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, Precision exec_prc) : jit_emitter(host, host_isa, exec_prc) {} @@ -59,13 +77,13 @@ void jit_add_emitter::emit_isa(const std::vector &in_vec_idxs, const std } } -std::set jit_add_emitter::get_supported_precisions() { - return {Precision::FP32, Precision::I32}; +std::set> jit_add_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}, {element::i32, element::i32}}; } /// MUL_ADD /// -jit_mul_add_emitter::jit_mul_add_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, Precision exec_prc) -: jit_emitter(host, host_isa, node, exec_prc) {} +jit_mul_add_emitter::jit_mul_add_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node) +: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} jit_mul_add_emitter::jit_mul_add_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, Precision exec_prc) : jit_emitter(host, host_isa, exec_prc) {} @@ -150,13 +168,13 @@ size_t jit_mul_add_emitter::aux_vecs_count() const { return 1; } -std::set jit_mul_add_emitter::get_supported_precisions() { - return {Precision::FP32, Precision::I32}; +std::set> jit_mul_add_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32, element::f32}, {element::i32, element::i32, element::i32}}; } /// SUB /// -jit_subtract_emitter::jit_subtract_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, Precision exec_prc) -: jit_emitter(host, host_isa, node, exec_prc) {} +jit_subtract_emitter::jit_subtract_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node) +: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} jit_subtract_emitter::jit_subtract_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, Precision exec_prc) : jit_emitter(host, host_isa, exec_prc) {} @@ -197,13 +215,13 @@ void jit_subtract_emitter::emit_isa(const std::vector &in_vec_idxs, cons } } -std::set jit_subtract_emitter::get_supported_precisions() { - return {Precision::FP32, Precision::I32}; +std::set> jit_subtract_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}, {element::i32, element::i32}}; } /// MULTIPLY /// -jit_multiply_emitter::jit_multiply_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, Precision exec_prc) -: jit_emitter(host, host_isa, node, exec_prc) {} +jit_multiply_emitter::jit_multiply_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node) +: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} jit_multiply_emitter::jit_multiply_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, Precision exec_prc) : jit_emitter(host, host_isa, exec_prc) {} @@ -244,13 +262,13 @@ void jit_multiply_emitter::emit_isa(const std::vector &in_vec_idxs, cons } } -std::set jit_multiply_emitter::get_supported_precisions() { - return {Precision::FP32, Precision::I32}; +std::set> jit_multiply_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}, {element::i32, element::i32}}; } /// DIVIDE /// jit_divide_emitter::jit_divide_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, Precision exec_prc) -: jit_emitter(host, host_isa, node, exec_prc) {} +: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} jit_divide_emitter::jit_divide_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, Precision exec_prc) : jit_emitter(host, host_isa, exec_prc) {} @@ -305,8 +323,8 @@ void jit_divide_emitter::emit_isa(const std::vector &in_vec_idxs, const } } -std::set jit_divide_emitter::get_supported_precisions() { - return {Precision::FP32, Precision::I32}; +std::set> jit_divide_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}, {element::i32, element::i32}}; } size_t jit_divide_emitter::aux_vecs_count() const { @@ -321,7 +339,11 @@ jit_floor_emitter::jit_floor_emitter(x64::jit_generator *host, x64::cpu_isa_t ho size_t jit_floor_emitter::get_inputs_num() const { return 1; } -void jit_floor_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_floor_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32}}; +} + +void jit_floor_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -344,11 +366,15 @@ void jit_floor_emitter::emit_isa(const std::vector &in_vec_idxs, const s /// CEILING /// jit_ceiling_emitter::jit_ceiling_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, Precision exec_prc) : jit_emitter(host, host_isa, node, exec_prc) {} -jit_ceiling_emitter::jit_ceiling_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, Precision exec_prc) +jit_ceiling_emitter::jit_ceiling_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, Precision exec_prc) : jit_emitter(host, host_isa, exec_prc) {} size_t jit_ceiling_emitter::get_inputs_num() const { return 1; } +std::set> jit_ceiling_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32}}; +} + void jit_ceiling_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { @@ -378,7 +404,11 @@ jit_floor_mod_emitter::jit_floor_mod_emitter(x64::jit_generator *host, x64::cpu_ size_t jit_floor_mod_emitter::get_inputs_num() const { return 2; } -void jit_floor_mod_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_floor_mod_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}}; +} + +void jit_floor_mod_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -428,7 +458,11 @@ jit_mod_emitter::jit_mod_emitter(x64::jit_generator *host, x64::cpu_isa_t host_i size_t jit_mod_emitter::get_inputs_num() const { return 2; } -void jit_mod_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_mod_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}}; +} + +void jit_mod_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -471,8 +505,8 @@ size_t jit_mod_emitter::aux_vecs_count() const { } /// MAXIMUM /// -jit_maximum_emitter::jit_maximum_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, Precision exec_prc) -: jit_emitter(host, host_isa, node, exec_prc) {} +jit_maximum_emitter::jit_maximum_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node) +: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} jit_maximum_emitter::jit_maximum_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, Precision exec_prc) : jit_emitter(host, host_isa, exec_prc) {} @@ -514,13 +548,13 @@ void jit_maximum_emitter::emit_isa(const std::vector &in_vec_idxs, const } } -std::set jit_maximum_emitter::get_supported_precisions() { - return {Precision::FP32, Precision::I32}; +std::set> jit_maximum_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}, {element::i32, element::i32}}; } /// MINIMUM /// -jit_minimum_emitter::jit_minimum_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, Precision exec_prc) -: jit_emitter(host, host_isa, node, exec_prc) {} +jit_minimum_emitter::jit_minimum_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node) +: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {} jit_minimum_emitter::jit_minimum_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, Precision exec_prc) : jit_emitter(host, host_isa, exec_prc) {} @@ -562,8 +596,8 @@ void jit_minimum_emitter::emit_isa(const std::vector &in_vec_idxs, const } } -std::set jit_minimum_emitter::get_supported_precisions() { - return {Precision::FP32, Precision::I32}; +std::set> jit_minimum_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}, {element::i32, element::i32}}; } /// SQUARED_DIFFERENCE /// @@ -617,8 +651,8 @@ void jit_squared_difference_emitter::emit_isa(const std::vector &in_vec_ } } -std::set jit_squared_difference_emitter::get_supported_precisions() { - return {Precision::FP32, Precision::I32}; +std::set> jit_squared_difference_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}, {element::i32, element::i32}}; } /// POWER_DYNAMIC /// @@ -630,7 +664,11 @@ jit_power_dynamic_emitter::jit_power_dynamic_emitter(x64::jit_generator *host, x size_t jit_power_dynamic_emitter::get_inputs_num() const { return 2; } -void jit_power_dynamic_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_power_dynamic_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}}; +} + +void jit_power_dynamic_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -741,7 +779,11 @@ jit_equal_emitter::jit_equal_emitter(x64::jit_generator *host, x64::cpu_isa_t ho size_t jit_equal_emitter::get_inputs_num() const { return 2; } -void jit_equal_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_equal_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}}; +} + +void jit_equal_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -800,7 +842,11 @@ jit_not_equal_emitter::jit_not_equal_emitter(x64::jit_generator *host, x64::cpu_ size_t jit_not_equal_emitter::get_inputs_num() const { return 2; } -void jit_not_equal_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_not_equal_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}}; +} + +void jit_not_equal_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -859,7 +905,11 @@ jit_greater_emitter::jit_greater_emitter(x64::jit_generator *host, x64::cpu_isa_ size_t jit_greater_emitter::get_inputs_num() const { return 2; } -void jit_greater_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_greater_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}}; +} + +void jit_greater_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -919,7 +969,11 @@ jit_greater_equal_emitter::jit_greater_equal_emitter(x64::jit_generator *host, x size_t jit_greater_equal_emitter::get_inputs_num() const { return 2; } -void jit_greater_equal_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_greater_equal_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}}; +} + +void jit_greater_equal_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -978,7 +1032,11 @@ jit_less_emitter::jit_less_emitter(x64::jit_generator *host, x64::cpu_isa_t host size_t jit_less_emitter::get_inputs_num() const { return 2; } -void jit_less_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_less_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}}; +} + +void jit_less_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1037,7 +1095,11 @@ jit_less_equal_emitter::jit_less_equal_emitter(x64::jit_generator *host, x64::cp size_t jit_less_equal_emitter::get_inputs_num() const { return 2; } -void jit_less_equal_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_less_equal_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}}; +} + +void jit_less_equal_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1097,7 +1159,11 @@ jit_logical_and_emitter::jit_logical_and_emitter(x64::jit_generator *host, x64:: size_t jit_logical_and_emitter::get_inputs_num() const { return 2; } -void jit_logical_and_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_logical_and_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}}; +} + +void jit_logical_and_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1177,7 +1243,11 @@ jit_logical_or_emitter::jit_logical_or_emitter(x64::jit_generator *host, x64::cp size_t jit_logical_or_emitter::get_inputs_num() const { return 2; } -void jit_logical_or_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_logical_or_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}}; +} + +void jit_logical_or_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1256,7 +1326,11 @@ jit_logical_xor_emitter::jit_logical_xor_emitter(x64::jit_generator *host, x64:: size_t jit_logical_xor_emitter::get_inputs_num() const { return 2; } -void jit_logical_xor_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_logical_xor_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}}; +} + +void jit_logical_xor_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1335,7 +1409,11 @@ jit_logical_not_emitter::jit_logical_not_emitter(x64::jit_generator *host, x64:: size_t jit_logical_not_emitter::get_inputs_num() const { return 1; } -void jit_logical_not_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_logical_not_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32}}; +} + +void jit_logical_not_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1405,7 +1483,11 @@ jit_power_static_emitter::jit_power_static_emitter(x64::jit_generator *host, x64 size_t jit_power_static_emitter::get_inputs_num() const { return 1; } -void jit_power_static_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_power_static_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32}}; +} + +void jit_power_static_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1579,7 +1661,11 @@ jit_prelu_emitter::jit_prelu_emitter(x64::jit_generator *host, x64::cpu_isa_t ho } size_t jit_prelu_emitter::get_inputs_num() const { return 2; } -void jit_prelu_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_prelu_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}}; +} + +void jit_prelu_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1634,7 +1720,11 @@ jit_sqrt_emitter::jit_sqrt_emitter(x64::jit_generator *host, x64::cpu_isa_t host size_t jit_sqrt_emitter::get_inputs_num() const { return 1; } -void jit_sqrt_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_sqrt_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32}}; +} + +void jit_sqrt_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1661,7 +1751,11 @@ jit_negative_emitter::jit_negative_emitter(x64::jit_generator *host, x64::cpu_is size_t jit_negative_emitter::get_inputs_num() const { return 1; } -void jit_negative_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_negative_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32}}; +} + +void jit_negative_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1695,6 +1789,10 @@ jit_erf_emitter::jit_erf_emitter(x64::jit_generator *host, x64::cpu_isa_t host_i size_t jit_erf_emitter::get_inputs_num() const { return 1; } +std::set> jit_erf_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32}}; +} + void jit_erf_emitter::emit_impl( const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { @@ -1875,7 +1973,11 @@ jit_soft_sign_emitter::jit_soft_sign_emitter(x64::jit_generator *host, x64::cpu_ size_t jit_soft_sign_emitter::get_inputs_num() const { return 1; } -void jit_soft_sign_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +std::set> jit_soft_sign_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32}}; +} + +void jit_soft_sign_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -2086,6 +2188,10 @@ jit_select_emitter::jit_select_emitter(x64::jit_generator *host, x64::cpu_isa_t size_t jit_select_emitter::get_inputs_num() const { return 3; } +std::set> jit_select_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32, element::f32}}; +} + size_t jit_select_emitter::aux_vecs_count() const { if (host_isa_ == x64::avx512_core) return 0; diff --git a/src/plugins/intel_cpu/src/emitters/jit_eltwise_emitters.hpp b/src/plugins/intel_cpu/src/emitters/jit_eltwise_emitters.hpp index 138ba513eda71a..5c00e4584b4274 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_eltwise_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/jit_eltwise_emitters.hpp @@ -13,11 +13,10 @@ class jit_add_emitter : public jit_emitter { public: jit_add_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); - jit_add_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); + jit_add_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n); size_t get_inputs_num() const override; - static std::set get_supported_precisions(); + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -30,11 +29,10 @@ class jit_mul_add_emitter : public jit_emitter { public: jit_mul_add_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); - jit_mul_add_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); + jit_mul_add_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n); size_t get_inputs_num() const override; - static std::set get_supported_precisions(); + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -50,11 +48,10 @@ class jit_subtract_emitter : public jit_emitter { public: jit_subtract_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); - jit_subtract_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); + jit_subtract_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n); size_t get_inputs_num() const override; - static std::set get_supported_precisions(); + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -68,11 +65,10 @@ class jit_multiply_emitter : public jit_emitter { public: jit_multiply_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); - jit_multiply_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); + jit_multiply_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n); size_t get_inputs_num() const override; - static std::set get_supported_precisions(); + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -90,7 +86,7 @@ class jit_divide_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; - static std::set get_supported_precisions(); + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -108,6 +104,7 @@ class jit_floor_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -124,6 +121,7 @@ class jit_ceiling_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -140,6 +138,7 @@ class jit_floor_mod_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -158,6 +157,7 @@ class jit_mod_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -172,11 +172,10 @@ class jit_maximum_emitter : public jit_emitter { public: jit_maximum_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); - jit_maximum_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); + jit_maximum_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n); size_t get_inputs_num() const override; - static std::set get_supported_precisions(); + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -190,11 +189,10 @@ class jit_minimum_emitter : public jit_emitter { public: jit_minimum_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); - jit_minimum_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); + jit_minimum_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n); size_t get_inputs_num() const override; - static std::set get_supported_precisions(); + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -213,7 +211,7 @@ class jit_squared_difference_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; - static std::set get_supported_precisions(); + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -231,6 +229,7 @@ class jit_power_dynamic_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -248,6 +247,7 @@ class jit_equal_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -268,6 +268,7 @@ class jit_not_equal_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -288,6 +289,7 @@ class jit_greater_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -308,6 +310,7 @@ class jit_greater_equal_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -328,6 +331,7 @@ class jit_less_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -349,6 +353,7 @@ class jit_less_equal_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -369,6 +374,7 @@ class jit_logical_and_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -389,6 +395,7 @@ class jit_logical_or_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -409,6 +416,7 @@ class jit_logical_xor_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -428,6 +436,7 @@ class jit_logical_not_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -448,6 +457,8 @@ class jit_power_static_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -471,6 +482,7 @@ class jit_prelu_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -489,6 +501,7 @@ class jit_sqrt_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -503,6 +516,7 @@ class jit_negative_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector& in, const std::vector& out) const override; @@ -520,6 +534,7 @@ class jit_erf_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl( @@ -541,6 +556,7 @@ class jit_soft_sign_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; @@ -563,6 +579,9 @@ class jit_is_finite_emitter : public jit_emitter { } size_t get_inputs_num() const override { return 1; }; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr) { + return {{element::f32}}; + } protected: size_t aux_gprs_count() const override { return (entry_map_.empty() ? 0 : 1) + 1; } @@ -588,6 +607,9 @@ class jit_is_inf_emitter : public jit_emitter { } size_t get_inputs_num() const override { return 1; }; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr) { + return {{element::f32}}; + } protected: size_t aux_gprs_count() const override { return (entry_map_.empty() ? 0 : 1) + 1; } @@ -615,6 +637,9 @@ class jit_is_nan_emitter : public jit_emitter { } size_t get_inputs_num() const override { return 1; } + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr) { + return {{element::f32}}; + } protected: size_t aux_gprs_count() const override { return (entry_map_.empty() ? 0 : 1) + 1; } @@ -635,6 +660,7 @@ class jit_select_emitter : public jit_emitter { InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); size_t get_inputs_num() const override; + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); size_t aux_vecs_count() const override; private: diff --git a/src/plugins/intel_cpu/src/emitters/jit_emitter.cpp b/src/plugins/intel_cpu/src/emitters/jit_emitter.cpp index 3bbd03935563f0..7d9ab0d0994315 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/jit_emitter.cpp @@ -3,8 +3,8 @@ // #include "jit_emitter.hpp" -#include "utils/general_utils.h" #include +#include "utils/general_utils.h" using namespace dnnl::impl::cpu; using namespace dnnl::impl; @@ -55,8 +55,8 @@ size_t jit_emitter::aux_gprs_count() const { return entry_map_.empty() ? 0 : 1; } -std::set jit_emitter::get_supported_precisions() { - return {InferenceEngine::Precision::FP32}; +std::set> jit_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {}; } void jit_emitter::emitter_preamble(const std::vector &in_idxs, const std::vector &out_idxs, diff --git a/src/plugins/intel_cpu/src/emitters/jit_emitter.hpp b/src/plugins/intel_cpu/src/emitters/jit_emitter.hpp index be548c614e0aa2..eb3309de32d8c5 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/jit_emitter.hpp @@ -49,7 +49,13 @@ class jit_emitter : public ngraph::snippets::Emitter { virtual size_t get_inputs_num() const = 0; virtual size_t aux_vecs_count() const; emitter_in_out_map get_in_out_type() const; - static std::set get_supported_precisions(); + + /** + * @brief Returns supported precisions. + * Precisions are ordered, the first bigger bitness precision with the same type will be selected. + * Empty collection means the emitter supports any input precisions. + */ + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); protected: virtual size_t aux_gprs_count() const; diff --git a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp index af583e804b157f..4f63dd641f6295 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp @@ -479,7 +479,20 @@ void BroadcastMoveEmitter::emit_isa(const std::vector &in, const std::ve ScalarEmitter::ScalarEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n) : jit_emitter(h, isa, n) { - value = dnnl::impl::cpu::x64::float2int(ov::as_type_ptr(n)->cast_vector()[0]); + const auto precision = n->get_output_element_type(0); + switch (precision) { + case element::i32: { + value = ov::as_type_ptr(n)->cast_vector()[0]; + break; + } + case element::f32: { + value = dnnl::impl::cpu::x64::float2int(ov::as_type_ptr(n)->cast_vector()[0]); + break; + } + default: { + IE_THROW() << "Scalar emitter doesn't support " << precision; + } + } push_arg_entry_of("scalar", value, true); prepare_table(); } diff --git a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp index caeab227ad4b44..cae08b3fe43ac8 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp @@ -322,6 +322,9 @@ class BrgemmEmitter : public jit_emitter { BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n); size_t get_inputs_num() const override {return 2;} + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr) { + return {{element::f32, element::f32}}; + } private: void emit_impl(const std::vector& in, @@ -369,6 +372,9 @@ class HorizonMaxEmitter : public jit_emitter { HorizonMaxEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n); size_t get_inputs_num() const override {return 1;} + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr) { + return {{element::f32}}; + } protected: size_t aux_gprs_count() const override {return 1;} @@ -387,6 +393,9 @@ class HorizonSumEmitter : public jit_emitter { HorizonSumEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n); size_t get_inputs_num() const override {return 1;} + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr) { + return {{element::f32}}; + } protected: size_t aux_gprs_count() const override {return 1;} diff --git a/src/plugins/intel_cpu/src/nodes/eltwise.cpp b/src/plugins/intel_cpu/src/nodes/eltwise.cpp index 4ef400ae601a2f..5bc46c00b40b7e 100644 --- a/src/plugins/intel_cpu/src/nodes/eltwise.cpp +++ b/src/plugins/intel_cpu/src/nodes/eltwise.cpp @@ -9,6 +9,7 @@ #include "cpu_types.h" #include "utils/bfloat16.hpp" +#include "ie_ngraph_utils.hpp" #include #include @@ -58,7 +59,7 @@ namespace { template struct SupportedPrecisions { - void operator()(std::set &precisions) { + void operator()(std::set> &precisions) { precisions = T::get_supported_precisions(); } }; @@ -105,7 +106,7 @@ struct EltwiseEmitter { /** * Implements Eltwise shape inference algorithm. The algorithm is based on broadcasting all the input shapes * according to the NUMPY broadcast rule. This implementation is more lightweight than the ngraph one. - * + * */ class EltwiseShapeInfer : public ShapeInferEmptyPads { public: @@ -176,10 +177,31 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener void generate() override { Precision exec_prc = Precision::UNSPECIFIED; - std::set supported_precision_intersection = get_supported_precisions(eltwise_data_.front().algo); + std::set> supported_precision_intersection = get_supported_precisions(eltwise_data_.front().algo); + + // for element-wise operations all inputs must to have the same precisions + assert(std::all_of( + supported_precision_intersection.begin(), + supported_precision_intersection.end(), + [&supported_precision_intersection](const std::vector& precisions) { + return std::all_of( + precisions.begin(), + precisions.end(), + [&precisions](const element::Type precision) { return precision == precisions[0]; }); + })); + for (size_t i = 1; i < eltwise_data_.size(); ++i) { - std::set prcs = get_supported_precisions(eltwise_data_[i].algo); - std::set prcs_intersect = {}; + std::set> prcs = get_supported_precisions(eltwise_data_[i].algo); + std::set> prcs_intersect = {}; + + // to support previous functionality + if (!std::all_of( + prcs.begin(), + prcs.end(), + [&supported_precision_intersection](const std::vector& types) { + return types.size() == supported_precision_intersection.size(); })) { + continue; + } std::set_intersection(supported_precision_intersection.begin(), supported_precision_intersection.end(), prcs.begin(), prcs.end(), std::inserter(prcs_intersect, prcs_intersect.begin())); @@ -187,19 +209,22 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener supported_precision_intersection = prcs_intersect; } - static const Precision exec_precisions_priority[] = { - Precision::U8, - Precision::I8, - Precision::U16, - Precision::I16, - Precision::BF16, - Precision::I32, - Precision::FP32 + static const element::Type exec_precisions_priority[] = { + element::u8, + element::i8, + element::u16, + element::i16, + element::bf16, + element::i32, + element::f32 }; - for (auto prc : exec_precisions_priority) { - if (std::find(supported_precision_intersection.begin(), supported_precision_intersection.end(), prc) != supported_precision_intersection.end()) { - exec_prc = prc; + for (const auto prc : exec_precisions_priority) { + if (std::any_of( + supported_precision_intersection.begin(), + supported_precision_intersection.end(), + [&prc](const std::vector& precisions) { return std::find(precisions.begin(), precisions.end(), prc) != precisions.end(); })) { + exec_prc = InferenceEngine::details::convertPrecision(prc); break; } } @@ -482,8 +507,8 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener const std::vector& ops_list_; const dnnl::post_ops& post_ops_; - std::set get_supported_precisions(Algorithm algo) { - std::set precisions; + std::set> get_supported_precisions(Algorithm algo) { + std::set> precisions; OV_SWITCH(intel_cpu, SupportedPrecisions, precisions, algo, OV_CASE(Algorithm::EltwiseRelu, jit_dnnl_aux_emitter), diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index d11fc50d33edfe..8eb425e7ec4921 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -25,6 +25,7 @@ #include "utils/cpu_utils.hpp" #include "snippets_transformations/fuse_load_store_and_convert.hpp" #include "snippets_transformations/mul_add_to_fma.hpp" +#include "snippets_transformations/remove_converts.hpp" #include "ngraph_transformations/convert_to_swish_cpu.hpp" using namespace InferenceEngine; @@ -39,7 +40,7 @@ namespace node { namespace { /* This class implementation is a temporal WA - TODO: revise the implementation to remove the node reference*/ + TODO: revise the implementation to remove the node reference*/ class SnippetShapeInfer : public ShapeInferEmptyPads { public: SnippetShapeInfer(Snippet* node) : m_node(node) {} @@ -531,28 +532,36 @@ bool Snippet::created() const { } void Snippet::generate(const jit_snippets_compile_args* jcp) { - ov::pass::Manager optManager; - optManager.register_pass(); - optManager.register_pass(); - optManager.register_pass(); - optManager.register_pass(); + ov::pass::Manager pre_dialect; + pre_dialect.register_pass(); + ov::pass::Manager post_dialect; + + ov::pass::Manager post_precision; + post_precision.register_pass(); + post_precision.register_pass(); + post_precision.register_pass(); // LoadConvert uses Load emitter that support conversion from any type to only f32 - optManager.get_pass_config()->set_callback( + post_precision.get_pass_config()->set_callback( [](const std::shared_ptr& n) -> bool { if (const auto& convert = std::dynamic_pointer_cast(n)) return convert->get_destination_type() != ov::element::f32; return true; }); - // StoreConvert uses Store emitter that support conversion from only f32 to any types - optManager.get_pass_config()->set_callback( + post_precision.get_pass_config()->set_callback( [](const std::shared_ptr& n) -> bool { if (const auto& convert = std::dynamic_pointer_cast(n)) return convert->get_input_element_type(0) != ov::element::f32; return true; }); - schedule = snippet->generate(optManager, reinterpret_cast(jcp)); + post_precision.register_pass(); + + schedule = snippet->generate( + pre_dialect, + post_dialect, + post_precision, + reinterpret_cast(jcp)); } void Snippet::update_ptrs(jit_snippets_call_args& call_args) { diff --git a/src/plugins/intel_cpu/src/snippets_transformations/remove_converts.cpp b/src/plugins/intel_cpu/src/snippets_transformations/remove_converts.cpp new file mode 100644 index 00000000000000..238fadaa47e897 --- /dev/null +++ b/src/plugins/intel_cpu/src/snippets_transformations/remove_converts.cpp @@ -0,0 +1,38 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "remove_converts.hpp" + +#include "snippets/itt.hpp" +#include "ngraph/opsets/opset1.hpp" +#include "ngraph/rt_info.hpp" +#include "ngraph/pattern/op/wrap_type.hpp" + +#include "snippets/op/convert_saturation.hpp" + +ov::intel_cpu::pass::RemoveConverts::RemoveConverts() { + MATCHER_SCOPE(RemoveConverts); + auto parent_convert_wrap = ngraph::pattern::wrap_type(); + auto child_convert_wrap = ngraph::pattern::wrap_type({ parent_convert_wrap }); + + auto callback = [=](ngraph::pattern::Matcher& m) { + OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::RemoveConverts") + const auto& pm = m.get_pattern_value_map(); + const auto parent_convert = pm.at(parent_convert_wrap).get_node_shared_ptr(); + const auto child_convert = pm.at(child_convert_wrap).get_node_shared_ptr(); + if ( + (parent_convert->get_input_element_type(0) != element::f32) || + (parent_convert->get_output_target_inputs(0).size() != 1ull) || + (parent_convert->get_output_element_type(0) != element::bf16) || + (child_convert->get_output_element_type(0) != element::f32)) { + return false; + } + + replace_output_update_name(child_convert->output(0), parent_convert->get_input_source_output(0)); + return true; + }; + + auto m = std::make_shared(child_convert_wrap, matcher_name); + register_matcher(m, callback); +} diff --git a/src/plugins/intel_cpu/src/snippets_transformations/remove_converts.hpp b/src/plugins/intel_cpu/src/snippets_transformations/remove_converts.hpp new file mode 100644 index 00000000000000..b1fc6d4503d606 --- /dev/null +++ b/src/plugins/intel_cpu/src/snippets_transformations/remove_converts.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "ngraph/pass/graph_rewrite.hpp" +#include "ngraph/pattern/matcher.hpp" + +namespace ov { +namespace intel_cpu { +namespace pass { + +/** + * @interface RemoveConverts + * @brief Remove sequence of two ConvertSaturation operations for specific precisions: FP32 => BF16 => FP32 + * @ingroup snippets + */ +class RemoveConverts : public ngraph::pass::MatcherPass { +public: + OPENVINO_RTTI("RemoveConverts", "0"); + RemoveConverts(); +}; + +} // namespace pass +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/check_broadcast.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/check_broadcast.cpp new file mode 100644 index 00000000000000..9469bc9607141a --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/check_broadcast.cpp @@ -0,0 +1,81 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/check_broadcast.hpp" +#include "common_test_utils/test_constants.hpp" + +namespace ov { +namespace test { +namespace snippets { + + +namespace { + +const std::vector input_types = { + // TODO: 105804 + //ov::element::i32, + ov::element::f32 +}; + +const std::vector test_cases = { + // broadcast is neccessary + { + {{1, 3, 4, 4}, {4, 4}}, + ov::op::AutoBroadcastSpec(ov::op::AutoBroadcastType::PDPD, -1), + 1, + 0 + }, + { + {{1, 3, 4, 4}, {4, 4}}, + ov::op::AutoBroadcastSpec(ov::op::AutoBroadcastType::PDPD, 2), + 1, + 0 + }, + + // broadcast is not neccessary + { + {{1, 3, 4, 4}, {1, 3, 4, 4}}, + ov::op::AutoBroadcastSpec(ov::op::AutoBroadcastType::PDPD, -1), + 1, + 1 + }, + { + {{1, 3, 4, 4}, {1, 3, 4, 4}}, + ov::op::AutoBroadcastSpec(ov::op::AutoBroadcastType::PDPD, 0), + 1, + 1 + }, + + // any other PDPD + { + {{1, 3, 4, 4}, {4, 4}}, + ov::op::AutoBroadcastSpec(ov::op::AutoBroadcastType::NUMPY, -1), + 1, + 1 + }, + { + {{1, 3, 4, 4}, {4, 4}}, + ov::op::AutoBroadcastSpec(ov::op::AutoBroadcastType::NUMPY, 0), + 1, + 1 + }, + { + {{1, 3, 4, 4}, {4, 4}}, + ov::op::AutoBroadcastSpec(ov::op::AutoBroadcastType::NUMPY, 2), + 1, + 1 + }, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_CheckBroadcast, CheckBroadcast, + ::testing::Combine( + ::testing::ValuesIn(input_types), + ::testing::ValuesIn(test_cases), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + CheckBroadcast::getTestCaseName); + +} // namespace +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/precision_propagation_convertion.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/precision_propagation_convertion.cpp new file mode 100644 index 00000000000000..5c93badbd3c9e9 --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/precision_propagation_convertion.cpp @@ -0,0 +1,37 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/precision_propagation_convertion.hpp" +#include +#include + +namespace ov { +namespace test { +namespace snippets { + + +namespace { + +const std::vector> input_shapes = { + {{ 1, 3, 16, 16 }, { 1, 1, 1, 16 }}, +}; + +const std::vector> fake_quantize_intervals = { + {0.f, 2.55f, 0.f, 2.55f}, + {-1.28f, 1.27f, -1.28f, 1.27f} +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_PrecisionPropagation_Convertion, PrecisionPropagationConvertion, + ::testing::Combine( + ::testing::ValuesIn(input_shapes), + ::testing::ValuesIn(fake_quantize_intervals), + ::testing::Values(1), + ::testing::Values(1), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + PrecisionPropagationConvertion::getTestCaseName); + +} // namespace +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/plugins/intel_cpu/tests/unit/ngraph_transformations/mul_add_to_fma.cpp b/src/plugins/intel_cpu/tests/unit/ngraph_transformations/mul_add_to_fma.cpp index 0fcaaceadd70ab..5431cbb2626a55 100644 --- a/src/plugins/intel_cpu/tests/unit/ngraph_transformations/mul_add_to_fma.cpp +++ b/src/plugins/intel_cpu/tests/unit/ngraph_transformations/mul_add_to_fma.cpp @@ -155,7 +155,7 @@ class MulAddToFMATests : public LoweringTests, public testing::WithParamInterfac }; TEST_P(MulAddToFMATests, MulAddToFMATests) { - auto subgraph = getLoweredSubgraph(snippets_function->getOriginal(), master_shape, cpu_manager, generator); + auto subgraph = getLoweredSubgraph(snippets_function->getOriginal(), master_shape, {}, {}, cpu_manager, generator); model = subgraph->body_ptr(); model_ref = snippets_function->getLowered(); } diff --git a/src/tests/functional/plugin/shared/include/snippets/check_broadcast.hpp b/src/tests/functional/plugin/shared/include/snippets/check_broadcast.hpp new file mode 100644 index 00000000000000..1c33792cd328ec --- /dev/null +++ b/src/tests/functional/plugin/shared/include/snippets/check_broadcast.hpp @@ -0,0 +1,38 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "shared_test_classes/base/snippets_test_utils.hpp" + +namespace ov { +namespace test { +namespace snippets { + +class CheckBroadcastTestCaseParams { +public: + std::pair input_shapes; + ov::op::AutoBroadcastSpec broadcast; + size_t num_nodes; + size_t num_subgraphs; +}; + +typedef std::tuple < + ov::element::Type, // input types + CheckBroadcastTestCaseParams, // test case details + std::string // target device +> CheckBroadcastParams; + +class CheckBroadcast : public testing::WithParamInterface, + virtual public ov::test::SnippetsTestsCommon { +public: + static std::string getTestCaseName(testing::TestParamInfo obj); + +protected: + void SetUp() override; +}; + +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/tests/functional/plugin/shared/include/snippets/precision_propagation_convertion.hpp b/src/tests/functional/plugin/shared/include/snippets/precision_propagation_convertion.hpp new file mode 100644 index 00000000000000..3ab24d7cf299f3 --- /dev/null +++ b/src/tests/functional/plugin/shared/include/snippets/precision_propagation_convertion.hpp @@ -0,0 +1,33 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "shared_test_classes/base/snippets_test_utils.hpp" + +namespace ov { +namespace test { +namespace snippets { + +typedef std::tuple< + std::vector, // Input shapes + std::vector, // FakeQuantize intervals + size_t, // Expected num nodes + size_t, // Expected num subgraphs + std::string // Target Device +> PrecisionPropagationParams; + +class PrecisionPropagationConvertion : + public testing::WithParamInterface, + virtual public ov::test::SnippetsTestsCommon { +public: + static std::string getTestCaseName(testing::TestParamInfo obj); + +protected: + void SetUp() override; +}; + +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/tests/functional/plugin/shared/src/low_precision_transformations/fuse_fake_quantize_transformation.cpp b/src/tests/functional/plugin/shared/src/low_precision_transformations/fuse_fake_quantize_transformation.cpp index 0dc3d899f7988a..8c4109c439365d 100644 --- a/src/tests/functional/plugin/shared/src/low_precision_transformations/fuse_fake_quantize_transformation.cpp +++ b/src/tests/functional/plugin/shared/src/low_precision_transformations/fuse_fake_quantize_transformation.cpp @@ -20,7 +20,7 @@ std::string FuseFakeQuantizeTransformation::getTestCaseName(const testing::TestP std::tie(targetDevice, testValues) = obj.param; std::ostringstream result; - result << targetDevice << "_" << + result << "targetDevice=" << targetDevice << "_" << testValues.actual.precisionBeforeAdd << "_" << testValues.actual.add.values.size() << "_" << testValues.actual.add.outPrecision << "_" << diff --git a/src/tests/functional/plugin/shared/src/snippets/check_broadcast.cpp b/src/tests/functional/plugin/shared/src/snippets/check_broadcast.cpp new file mode 100644 index 00000000000000..3730771a1a44d5 --- /dev/null +++ b/src/tests/functional/plugin/shared/src/snippets/check_broadcast.cpp @@ -0,0 +1,89 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/check_broadcast.hpp" + +#include "common_test_utils/common_utils.hpp" +#include "subgraph_converts.hpp" +#include "common_test_utils/ov_tensor_utils.hpp" + +namespace ov { +namespace test { +namespace snippets { + +class CheckBroadcastFunction { +public: + static std::shared_ptr get( + const PartialShape& input_shape1, + const PartialShape& input_shape2, + const ov::element::Type input_type, + const ov::op::AutoBroadcastSpec broadcast) { + const auto parameter1 = std::make_shared(input_type, input_shape1); + parameter1->set_friendly_name("parameter1"); + + const auto parameter2 = std::make_shared(input_type, input_shape2); + parameter2->set_friendly_name("parameter2"); + + std::shared_ptr parent = std::make_shared( + parameter1, + parameter2, + broadcast); + parent->set_friendly_name("multiply"); + + const auto result = std::make_shared(parent); + result->set_friendly_name("result"); + + return std::make_shared( + ngraph::ResultVector{ result }, + ngraph::ParameterVector{ parameter1, parameter2 }, + "CheckBroadcastFunction"); + } +}; + +std::string CheckBroadcast::getTestCaseName(testing::TestParamInfo obj) { + ov::element::Type input_type; + CheckBroadcastTestCaseParams test_case_params; + std::string target_device; + + std::tie(input_type, test_case_params, target_device) = obj.param; + + std::ostringstream result; + result << "IS=" << test_case_params.input_shapes.first.get_shape() << "_" << + test_case_params.input_shapes.second.get_shape() << "_"; + result << "IT=" << input_type << "_"; + result << "BCT=" << test_case_params.broadcast.m_type << "_"; + result << "BCA=" << test_case_params.broadcast.m_axis << "_"; + result << "#N=" << test_case_params.num_nodes << "_"; + result << "#S=" << test_case_params.num_subgraphs << "_"; + result << "targetDevice=" << target_device; + return result.str(); +} + +void CheckBroadcast::SetUp() { + ov::element::Type input_type; + CheckBroadcastTestCaseParams test_case_params; + + std::tie(input_type, test_case_params, targetDevice) = this->GetParam(); + ref_num_nodes = test_case_params.num_nodes; + ref_num_subgraphs = test_case_params.num_subgraphs; + + init_input_shapes(static_partial_shapes_to_test_representation({ + test_case_params.input_shapes.first, + test_case_params.input_shapes.second})); + + function = CheckBroadcastFunction::get( + test_case_params.input_shapes.first, + test_case_params.input_shapes.second, + input_type, + test_case_params.broadcast); +} + +TEST_P(CheckBroadcast, CompareWithRefImpl) { + run(); + validateNumSubgraphs(); +} + +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/tests/functional/plugin/shared/src/snippets/convert.cpp b/src/tests/functional/plugin/shared/src/snippets/convert.cpp index 60419d28b2f96f..95749f32da1272 100644 --- a/src/tests/functional/plugin/shared/src/snippets/convert.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/convert.cpp @@ -106,8 +106,8 @@ parameters ConvertInput::generate_params_random() const { break; case ov::element::i32: case ov::element::i8: - startFrom = -10; - range = 20; + startFrom = -32; + range = 64; break; case ov::element::u8: startFrom = 10; diff --git a/src/tests/functional/plugin/shared/src/snippets/precision_propagation_convertion.cpp b/src/tests/functional/plugin/shared/src/snippets/precision_propagation_convertion.cpp new file mode 100644 index 00000000000000..570fa4b44dac70 --- /dev/null +++ b/src/tests/functional/plugin/shared/src/snippets/precision_propagation_convertion.cpp @@ -0,0 +1,48 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/precision_propagation_convertion.hpp" + +#include "common_test_utils/common_utils.hpp" +#include "precision_propagation_convertion_function.hpp" + +namespace ov { +namespace test { +namespace snippets { + +std::string PrecisionPropagationConvertion::getTestCaseName(testing::TestParamInfo obj) { + std::vector input_shapes; + std::vector fake_quantize_intervals; + std::string targetDevice; + size_t num_nodes, num_subgraphs; + std::tie(input_shapes, fake_quantize_intervals, num_nodes, num_subgraphs, targetDevice) = obj.param; + + std::ostringstream result; + for (size_t i = 0; i < input_shapes.size(); ++i) + result << "IS[" << i << "]=" << input_shapes[i] << "_"; + for (size_t i = 0; i < fake_quantize_intervals.size(); ++i) + result << "FQ[" << i << "]=" << fake_quantize_intervals[i] << "_"; + result << "#N=" << num_nodes << "_"; + result << "#S=" << num_subgraphs << "_"; + result << "targetDevice=" << targetDevice; + return result.str(); +} + +void PrecisionPropagationConvertion::SetUp() { + std::vector input_shapes; + std::vector fake_quantize_intervals; + std::tie(input_shapes, fake_quantize_intervals, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); + + function = PrecisionPropagationConvertionFunction(input_shapes, ov::element::f32, fake_quantize_intervals).getOriginal(); +} + +TEST_P(PrecisionPropagationConvertion, CompareWithRefImpl) { + run(); + validateNumSubgraphs(); +} + +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/precision_propagation_convertion_function.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/precision_propagation_convertion_function.hpp new file mode 100644 index 00000000000000..554d7b08fc5134 --- /dev/null +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/precision_propagation_convertion_function.hpp @@ -0,0 +1,49 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include "openvino/core/model.hpp" +#include "snippets_helpers.hpp" + +namespace ov { +namespace test { +namespace snippets { + +/** + * @class PrecisionPropagationConvertionFunction + * @brief PrecisionPropagationConvertionFunction instance returns reference and original functions. + * + * Input arguments are used to create function in getOriginal methods only. + * Dont use getReference and getLowered method, they are not implemented and throw std::runtime_error exception. + * Note, ov::element::Type_t precision base type input argument is not used. + */ +class PrecisionPropagationConvertionFunction : public SnippetsFunctionBase { +public: + PrecisionPropagationConvertionFunction( + const std::vector& input_shapes, + const element::Type input_type, + const std::vector& fake_quantize_intervals); + + /* + * Don't call this method explicity. You should create the instance of PrecisionPropagationConvertionFunction before. + * After the method will be called implicitly in getOriginal. + * Note, please, getReference and getLowered methods are not implemented and throw exception. + */ + static std::shared_ptr get( + const std::vector& input_shapes, + const element::Type input_type, + const std::vector& fake_quantize_intervals); + +protected: + std::shared_ptr initOriginal() const override; + +private: + const std::vector fake_quantize_intervals; +}; + +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/precision_propagation_function.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/precision_propagation_function.hpp new file mode 100644 index 00000000000000..b32099cf3020de --- /dev/null +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/precision_propagation_function.hpp @@ -0,0 +1,131 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include +#include "ngraph/opsets/opset1.hpp" +#include "snippets/op/convert_saturation.hpp" +#include "snippets_helpers.hpp" + +namespace ov { +namespace test { +namespace snippets { + +/** + * @class DummyAdd + * @brief DummyAdd operation has custom validate_and_infer_types method implementation. + */ +class DummyAdd : public ngraph::opset1::Add { +public: + OPENVINO_OP("DummyAdd", "test::snippets"); + + DummyAdd(const Output& arg0, + const Output& arg1, + const ngraph::op::AutoBroadcastSpec& auto_broadcast = + ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY)) + : ngraph::opset1::Add(arg0, arg1, auto_broadcast) { + constructor_validate_and_infer_types(); + } + + DummyAdd(const ngraph::opset1::Add& add) + : Add(add.get_input_source_output(0), add.get_input_source_output(1), add.get_autob()) { + constructor_validate_and_infer_types(); + } + + DummyAdd() = default; + + void validate_and_infer_types() override { + const auto input_type1 = get_input_element_type(0); + const auto input_type2 = get_input_element_type(1); + + const element::Type output_type = (input_type1 == element::i8) || (input_type2 == element::i8) ? + element::i32 : + get_input_element_type(0); + + set_output_type(0, output_type, get_input_partial_shape(0)); + } + + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override { + return std::make_shared(new_args.at(0), new_args.at(1), this->get_autob()); + } +}; + +class PrecisionPropagationAddFunctionParams { +public: + class Actual { + public: + std::pair convertion_before_op1; + element::Type convertion_before_op2_1; + std::pair convertion_before_op2_2; + }; + + class Expected { + public: + std::pair convertion_before_op1; + element::Type convertion_before_op2_1; + std::pair convertion_before_op2_2; + element::Type convertion_after_op2; + }; +}; + +/** + * @class PrecisionPropagationAddFunction + * @brief PrecisionPropagationAddFunction instance returns reference and original functions. + * + * Input arguments are used to create function in getOriginal or getReference methods only. + * Dont use getLowered method, it is not implemented and throw std::runtime_error exception. + * Note, ov::element::Type_t precision base type input argument is not used. + */ +class PrecisionPropagationAddFunction : public SnippetsFunctionBase { +public: + explicit PrecisionPropagationAddFunction( + const std::vector input_shapes, + const ngraph::element::Type precision1, + const ngraph::element::Type precision2, + const ngraph::element::Type constant_precision, + PrecisionPropagationAddFunctionParams::Actual actual, + PrecisionPropagationAddFunctionParams::Expected expected) : + SnippetsFunctionBase(input_shapes), + precision1(precision1), + precision2(precision2), + constant_precision(constant_precision), + actual(actual), + expected(expected) { + OPENVINO_ASSERT(input_shapes.size() == 2ull, "input_shapes size has to be equal to 2"); + } + + /* + * Don't call this method explicity. You should create the instance of PrecisionPropagationAddFunction before. + * After the method will be called implicitly in getOriginal or getReference methods. + * Note, please, getLowered method is not implemented and throws exception. + */ + static std::shared_ptr get( + const ngraph::element::Type precision1, + const ngraph::PartialShape& inputShape1, + const ngraph::element::Type precision2, + const ngraph::PartialShape& inputShape2, + const ngraph::element::Type constant_precision, + const std::pair& convertion_before_op1 = std::pair(), + const element::Type convertion_before_op2_1 = element::undefined, + const std::pair& convertion_before_op2_2 = std::pair(), + const element::Type convertion_after_op2 = {}); + +protected: + std::shared_ptr initOriginal() const override; + std::shared_ptr initReference() const override; + + const ngraph::element::Type precision1; + const ngraph::element::Type precision2; + const ngraph::element::Type constant_precision; + const PrecisionPropagationAddFunctionParams::Actual actual; + const PrecisionPropagationAddFunctionParams::Expected expected; +}; + +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/snippets_helpers.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/snippets_helpers.hpp index b4073b2d065ae0..9d3edad4b55339 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/snippets_helpers.hpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/snippets_helpers.hpp @@ -17,6 +17,7 @@ using ov::Model; class SnippetsFunctionBase { public: SnippetsFunctionBase() = delete; + virtual ~SnippetsFunctionBase() = default; explicit SnippetsFunctionBase(const std::vector& inputShapes, ov::element::Type_t precision = element::f32) : precision{precision}, input_shapes{inputShapes} {} diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/precision_propagation_convertion_function.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/precision_propagation_convertion_function.cpp new file mode 100644 index 00000000000000..20f517b16dfceb --- /dev/null +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/precision_propagation_convertion_function.cpp @@ -0,0 +1,92 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "precision_propagation_convertion_function.hpp" +#include +#include + +namespace ov { +namespace test { +namespace snippets { + +namespace { +std::shared_ptr make_fake_quantize( + const Output& parent, + const ngraph::PartialShape& inputShape, + const element::Type inputType, + const std::vector& fake_quantize_intervals) { + auto generate = [](const ov::element::Type precision, + const ngraph::Shape& shape, + const float initialValue, + const std::string& name) { + const auto size = ngraph::shape_size(shape); + std::vector values(size); + for (auto i = 0; i < size; ++i) { + values[i] = static_cast(initialValue + i); + } + auto constant = std::make_shared(precision, shape, values); + constant->set_friendly_name(name); + return constant; + }; + + const auto fakeQuantize = std::make_shared( + parent, + generate(inputType, {}, fake_quantize_intervals[0], "inputLow"), + generate(inputType, {}, fake_quantize_intervals[1], "inputHigh"), + generate(inputType, {}, fake_quantize_intervals[2], "outputLow"), + generate(inputType, {}, fake_quantize_intervals[3], "outputHigh"), + 256ul); + fakeQuantize->set_friendly_name("fakeQuantize"); + + return fakeQuantize; +} +} // namespace + +PrecisionPropagationConvertionFunction::PrecisionPropagationConvertionFunction( + const std::vector& input_shapes, + const element::Type input_type, + const std::vector& fake_quantize_intervals) : + SnippetsFunctionBase(input_shapes, input_type), + fake_quantize_intervals(fake_quantize_intervals) { +} + +std::shared_ptr PrecisionPropagationConvertionFunction::get( + const std::vector& input_shapes, + const element::Type input_type, + const std::vector& fake_quantize_intervals) { + assert(2ull == input_shapes.size()); + assert(4ull == fake_quantize_intervals.size()); + const auto parameter1 = std::make_shared(input_type, input_shapes[0]); + parameter1->set_friendly_name("parameter1"); + + const auto parameter2 = std::make_shared(input_type, input_shapes[1]); + parameter2->set_friendly_name("parameter2"); + + std::shared_ptr parent = make_fake_quantize( + parameter1, + input_shapes[0], + input_type, + fake_quantize_intervals); + parent->set_friendly_name("fakeQuantize"); + + parent = std::make_shared(parent, parameter2); + parent->set_friendly_name("add"); + + const auto result = std::make_shared(parent); + result->set_friendly_name("result"); + + auto function = std::make_shared( + ngraph::ResultVector{ result }, + ParameterVector{ parameter1, parameter2 }, + "PrecisionPropagationConvertionFunction"); + return function; +} + +std::shared_ptr PrecisionPropagationConvertionFunction::initOriginal() const { + return get(input_shapes, precision, fake_quantize_intervals); +} + +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/precision_propagation_function.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/precision_propagation_function.cpp new file mode 100644 index 00000000000000..6a9ef600409e84 --- /dev/null +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/precision_propagation_function.cpp @@ -0,0 +1,105 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "precision_propagation_function.hpp" +#include +#include + +namespace ov { +namespace test { +namespace snippets { + +std::shared_ptr PrecisionPropagationAddFunction::get( + const ngraph::element::Type precision1, + const ngraph::PartialShape& inputShape1, + const ngraph::element::Type precision2, + const ngraph::PartialShape& inputShape2, + const ngraph::element::Type constant_precision, + const std::pair& convertion_before_op1, + const element::Type convertion_before_op2_1, + const std::pair& convertion_before_op2_2, + const element::Type convertion_after_op2) { + const auto create_convert = [](std::shared_ptr parent, const element::Type convertion_type) -> std::shared_ptr { + return convertion_type == element::undefined + ? std::dynamic_pointer_cast(parent) + : std::make_shared(parent, convertion_type); + }; + + const auto make_branch = [&create_convert]( + const ngraph::element::Type precision, + const ngraph::PartialShape& inputShape, + const size_t index, + const element::Type convertion_type) -> std::pair, std::shared_ptr> { + const auto parameter = std::make_shared(precision, inputShape); + parameter->set_friendly_name("parameter" + std::to_string(index)); + + std::shared_ptr parent = create_convert(parameter, convertion_type); + + return { parameter, parent }; + }; + + const auto branch1 = make_branch(precision1, inputShape1, 1, convertion_before_op1.first); + const auto branch2 = make_branch(precision2, inputShape2, 2, convertion_before_op1.second); + + std::shared_ptr parent = std::make_shared(branch1.second, branch2.second); + parent->set_friendly_name("add"); + + parent = create_convert(parent, convertion_before_op2_1); + + const auto maximum_in2_type = convertion_before_op2_2.second == element::undefined ? + constant_precision : + convertion_before_op2_2.second; + if ((convertion_before_op2_2.first == element::undefined) && + (parent->get_output_element_type(0) != maximum_in2_type)) { + parent = std::make_shared(parent, maximum_in2_type); + } + + parent = std::make_shared( + create_convert(parent, convertion_before_op2_2.first), + create_convert( + std::make_shared(constant_precision, Shape{}, std::vector{0.f}), + convertion_before_op2_2.second)); + parent->set_friendly_name("maximum"); + + parent = create_convert(parent, convertion_after_op2); + + const auto result = std::make_shared(parent); + auto& result_out_tensor = result->get_output_tensor(0); + result_out_tensor.set_names({ "result_tensor" }); + result->set_friendly_name("result"); + + const ngraph::ResultVector results{ result }; + const ngraph::ParameterVector parameters{ branch1.first, branch2.first }; + const auto model = std::make_shared(results, parameters, "SnippetsPrecisionPropagation"); + return model; +} + +std::shared_ptr PrecisionPropagationAddFunction::initOriginal() const { + return get( + precision1, + input_shapes[0], + precision2, + input_shapes[1], + constant_precision, + actual.convertion_before_op1, + actual.convertion_before_op2_1, + actual.convertion_before_op2_2); +} + +std::shared_ptr PrecisionPropagationAddFunction::initReference() const { + return get( + precision1, + input_shapes[0], + precision2, + input_shapes[1], + constant_precision, + expected.convertion_before_op1, + expected.convertion_before_op2_1, + expected.convertion_before_op2_2, + expected.convertion_after_op2); +} + +} // namespace snippets +} // namespace test +} // namespace ov