From af406c88e465baf48d4dc12421e96a46033f0cb6 Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Mon, 27 Mar 2023 20:09:58 +0400 Subject: [PATCH] Fixed Load+Broadcast and added FuseLoadStoreConvert support --- .../snippets/include/snippets/generator.hpp | 3 + .../pass/lowered/linear_IR_transformation.hpp | 19 +++ .../load_movebroadcast_to_broadcastload.hpp | 2 +- src/common/snippets/src/generator.cpp | 53 ++++---- .../pass/lowered/linear_IR_transformation.cpp | 28 ++++ .../load_movebroadcast_to_broadcastload.cpp | 5 +- .../intel_cpu/src/emitters/cpu_generator.cpp | 7 + .../intel_cpu/src/emitters/cpu_generator.hpp | 3 + src/plugins/intel_cpu/src/nodes/subgraph.cpp | 17 --- .../fuse_load_store_and_convert.cpp | 117 ----------------- .../fuse_load_store_and_convert.hpp | 40 ------ .../lowered/fuse_load_store_and_convert.cpp | 120 ++++++++++++++++++ .../lowered/fuse_load_store_and_convert.hpp | 36 ++++++ 13 files changed, 249 insertions(+), 201 deletions(-) create mode 100644 src/common/snippets/src/pass/lowered/linear_IR_transformation.cpp delete mode 100644 src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.cpp delete mode 100644 src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.hpp create mode 100644 src/plugins/intel_cpu/src/snippets_transformations/lowered/fuse_load_store_and_convert.cpp create mode 100644 src/plugins/intel_cpu/src/snippets_transformations/lowered/fuse_load_store_and_convert.hpp diff --git a/src/common/snippets/include/snippets/generator.hpp b/src/common/snippets/include/snippets/generator.hpp index 57d273310225aa..1bcdb4f6035e7f 100644 --- a/src/common/snippets/include/snippets/generator.hpp +++ b/src/common/snippets/include/snippets/generator.hpp @@ -12,6 +12,7 @@ #include "emitter.hpp" #include "target_machine.hpp" #include "lowered_expr.hpp" +#include "pass/lowered/linear_IR_transformation.hpp" namespace ngraph { namespace snippets { @@ -86,6 +87,8 @@ class Generator { std::shared_ptr get_target_machine() const; protected: + virtual pass::lowered::LinearIRTransformationPipeline target_specific_transformations() const; + std::shared_ptr target; // todo: we need to save lowered code to access compiled brgemm kernels on execution time (normally lowered is destructed by then). // This is temporary solution, remove this when kernel caching is implemented. Don't forget to make generate const method. diff --git a/src/common/snippets/include/snippets/pass/lowered/linear_IR_transformation.hpp b/src/common/snippets/include/snippets/pass/lowered/linear_IR_transformation.hpp index 87667d514482c3..ff9fccba676445 100644 --- a/src/common/snippets/include/snippets/pass/lowered/linear_IR_transformation.hpp +++ b/src/common/snippets/include/snippets/pass/lowered/linear_IR_transformation.hpp @@ -41,6 +41,25 @@ class LinearIRTransformation { virtual bool run(LoweredExprIR& linear_ir) = 0; }; +class LinearIRTransformationPipeline { +public: + LinearIRTransformationPipeline() = default; + + void register_transformation(const std::shared_ptr& transformation); + + template + void register_transformation(Args&&... args) { + static_assert(std::is_base_of::value, "Transformation not derived from LinearIRTransformation"); + auto transformation = std::make_shared(std::forward(args)...); + register_transformation(transformation); + } + + void run(LoweredExprIR& linear_ir); + +private: + std::vector> m_transformations; +}; + } // namespace lowered } // namespace pass } // namespace snippets diff --git a/src/common/snippets/include/snippets/pass/lowered/load_movebroadcast_to_broadcastload.hpp b/src/common/snippets/include/snippets/pass/lowered/load_movebroadcast_to_broadcastload.hpp index 85f77d3842b409..f11d8c215ff261 100644 --- a/src/common/snippets/include/snippets/pass/lowered/load_movebroadcast_to_broadcastload.hpp +++ b/src/common/snippets/include/snippets/pass/lowered/load_movebroadcast_to_broadcastload.hpp @@ -18,7 +18,7 @@ namespace lowered { */ class LoadMoveBroadcastToBroadcastLoad: public LinearIRTransformation { public: - LoadMoveBroadcastToBroadcastLoad(); + LoadMoveBroadcastToBroadcastLoad() = default; OPENVINO_RTTI("LoadMoveBroadcastToBroadcastLoad", "LinearIRTransformation") bool run(LoweredExprIR& linear_ir) override; }; diff --git a/src/common/snippets/src/generator.cpp b/src/common/snippets/src/generator.cpp index 1d0441c303a94e..fbbd6d1afe452b 100644 --- a/src/common/snippets/src/generator.cpp +++ b/src/common/snippets/src/generator.cpp @@ -41,30 +41,33 @@ Generator::LoweringResult Generator::generate(std::shared_ptr& m, con // Note: The pass LoopInit uses LoopInfo that contains entry and exit points of the corresponding Loop. // To avoid the Loop information corruption, we should call the passes with Load/Store work // (for example, LoadMoveBroadcastToBroadcastLoad()) after explicit Loop insertion (LoopInit()) - auto propagate_buffer_offsets = std::make_shared(); - std::vector> transformation_pipeline { - std::make_shared(vector_size), - std::make_shared(vector_size), - std::make_shared(), - std::make_shared(), - std::make_shared(buffer_allocation_rank), - std::make_shared(vector_size), - std::make_shared(), - std::make_shared(), - std::make_shared(), - std::make_shared(), - std::make_shared(), - propagate_buffer_offsets, - std::make_shared(), - std::make_shared(), - std::make_shared() - }; - for (const auto& transform : transformation_pipeline) { - transform->run(linear_ir); - } + const auto propagate_buffer_offsets = std::make_shared(); + pass::lowered::LinearIRTransformationPipeline common_pipeline; + common_pipeline.register_transformation(vector_size); + common_pipeline.register_transformation(vector_size); + common_pipeline.register_transformation(); + common_pipeline.register_transformation(); + common_pipeline.register_transformation(buffer_allocation_rank); + common_pipeline.register_transformation(vector_size); + common_pipeline.register_transformation(); + common_pipeline.register_transformation(); + common_pipeline.register_transformation(); + common_pipeline.register_transformation(); + common_pipeline.register_transformation(); + common_pipeline.register_transformation(propagate_buffer_offsets); + common_pipeline.register_transformation(); + common_pipeline.run(linear_ir); + + pass::lowered::LinearIRTransformationPipeline target_pipeline = target_specific_transformations(); + target_pipeline.run(linear_ir); + + pass::lowered::LinearIRTransformationPipeline final_pipeline; + final_pipeline.register_transformation(); + final_pipeline.register_transformation(); + final_pipeline.run(linear_ir); - const auto buffer_scratchpad_size = propagate_buffer_offsets->get_scratchpad_size(); linear_ir.init_emitters(target); + OV_ITT_TASK_NEXT(GENERATE, "::EmitCode") auto loops2DKernel = std::make_shared(linear_ir); loops2DKernel->compile_params = compile_params; @@ -83,12 +86,16 @@ Generator::LoweringResult Generator::generate(std::shared_ptr& m, con if (config.m_save_lowered_code) lowered_saved = linear_ir; - return {target->get_snippet(), buffer_scratchpad_size}; + return {target->get_snippet(), propagate_buffer_offsets->get_scratchpad_size()}; } std::shared_ptr Generator::get_target_machine() const { return target; } +pass::lowered::LinearIRTransformationPipeline Generator::target_specific_transformations() const { + return pass::lowered::LinearIRTransformationPipeline(); +} + }// namespace snippets }// namespace ngraph diff --git a/src/common/snippets/src/pass/lowered/linear_IR_transformation.cpp b/src/common/snippets/src/pass/lowered/linear_IR_transformation.cpp new file mode 100644 index 00000000000000..c9d4f9b379b0d2 --- /dev/null +++ b/src/common/snippets/src/pass/lowered/linear_IR_transformation.cpp @@ -0,0 +1,28 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/pass/lowered/linear_IR_transformation.hpp" +#include "snippets/snippets_isa.hpp" +#include "snippets/itt.hpp" + + +namespace ngraph { +namespace snippets { +namespace pass { +namespace lowered { + +void LinearIRTransformationPipeline::register_transformation(const std::shared_ptr& transformation) { + m_transformations.push_back(transformation); +} + +void LinearIRTransformationPipeline::run(LoweredExprIR& linear_ir) { + for (const auto& transformation : m_transformations) { + transformation->run(linear_ir); + } +} + +} // namespace lowered +} // namespace pass +} // namespace snippets +} // namespace ngraph diff --git a/src/common/snippets/src/pass/lowered/load_movebroadcast_to_broadcastload.cpp b/src/common/snippets/src/pass/lowered/load_movebroadcast_to_broadcastload.cpp index 2e22cd4530f055..5e8a980bfcc679 100644 --- a/src/common/snippets/src/pass/lowered/load_movebroadcast_to_broadcastload.cpp +++ b/src/common/snippets/src/pass/lowered/load_movebroadcast_to_broadcastload.cpp @@ -13,8 +13,6 @@ namespace pass { namespace lowered { -LoadMoveBroadcastToBroadcastLoad::LoadMoveBroadcastToBroadcastLoad() {} - bool LoadMoveBroadcastToBroadcastLoad::run(LoweredExprIR& linear_ir) { OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::LoadMoveBroadcastToBroadcastLoad") bool modified = false; @@ -47,9 +45,10 @@ bool LoadMoveBroadcastToBroadcastLoad::run(LoweredExprIR& linear_ir) { const auto in_td = std::vector{ parent_expr->get_inputs().front() }; const auto out_td = std::vector{ (*expr_it)->get_outputs().front() }; const auto mv_expr_it = expr_it; - expr_it = linear_ir.insert(std::next(expr_it), std::make_shared(broadcastload, in_td, out_td)); + const auto insertion_pos = std::next(expr_it); linear_ir.erase(std::find(linear_ir.begin(), mv_expr_it, parent_expr)); linear_ir.erase(mv_expr_it); + expr_it = linear_ir.insert(insertion_pos, std::make_shared(broadcastload, in_td, out_td)); modified |= true; } } diff --git a/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp index 8c2e666d6b6438..15ed1b7f4ea3d1 100644 --- a/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp @@ -15,6 +15,7 @@ #include "jit_dnnl_ext_emitters.hpp" #include "jit_conversion_emitters.hpp" +#include "snippets_transformations/lowered/fuse_load_store_and_convert.hpp" #include "snippets_transformations/op/load_convert.hpp" #include "snippets_transformations/op/store_convert.hpp" #include "snippets_transformations/op/fused_mul_add.hpp" @@ -169,3 +170,9 @@ code ov::intel_cpu::CPUTargetMachine::get_snippet() const { ov::intel_cpu::CPUGenerator::CPUGenerator(dnnl::impl::cpu::x64::cpu_isa_t isa_) : Generator(std::make_shared(isa_)) { } + +ngraph::snippets::pass::lowered::LinearIRTransformationPipeline ov::intel_cpu::CPUGenerator::target_specific_transformations() const { + ngraph::snippets::pass::lowered::LinearIRTransformationPipeline target_specific_transformation; + target_specific_transformation.register_transformation(); + return target_specific_transformation; +} diff --git a/src/plugins/intel_cpu/src/emitters/cpu_generator.hpp b/src/plugins/intel_cpu/src/emitters/cpu_generator.hpp index 090dbfb31cad11..00d42bc2731b09 100644 --- a/src/plugins/intel_cpu/src/emitters/cpu_generator.hpp +++ b/src/plugins/intel_cpu/src/emitters/cpu_generator.hpp @@ -29,6 +29,9 @@ class CPUTargetMachine : public ngraph::snippets::TargetMachine { class CPUGenerator : public ngraph::snippets::Generator { public: CPUGenerator(dnnl::impl::cpu::x64::cpu_isa_t isa); + +protected: + ngraph::snippets::pass::lowered::LinearIRTransformationPipeline target_specific_transformations() const override; }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index 9656ba8e71f976..bff7fb5ccba960 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -23,7 +23,6 @@ #include #include "emitters/cpu_generator.hpp" #include "utils/cpu_utils.hpp" -#include "snippets_transformations/fuse_load_store_and_convert.hpp" #include "snippets_transformations/mul_add_to_fma.hpp" #include "snippets_transformations/remove_converts.hpp" #include "ngraph_transformations/convert_to_swish_cpu.hpp" @@ -552,22 +551,6 @@ void Snippet::generate(const jit_snippets_compile_args* jcp) { ov::pass::Manager post_precision; post_precision.register_pass(); - post_precision.register_pass(); - post_precision.register_pass(); - // LoadConvert uses Load emitter that support conversion from any type to only f32 - post_precision.get_pass_config()->set_callback( - [](const std::shared_ptr& n) -> bool { - if (const auto& convert = std::dynamic_pointer_cast(n)) - return convert->get_destination_type() != ov::element::f32; - return true; - }); - // StoreConvert uses Store emitter that support conversion from only f32 to any types - post_precision.get_pass_config()->set_callback( - [](const std::shared_ptr& n) -> bool { - if (const auto& convert = std::dynamic_pointer_cast(n)) - return convert->get_input_element_type(0) != ov::element::f32; - return true; - }); post_precision.register_pass(); schedule = snippet->generate( diff --git a/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.cpp b/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.cpp deleted file mode 100644 index b47fcfe73da808..00000000000000 --- a/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.cpp +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "snippets/itt.hpp" - -#include "fuse_load_store_and_convert.hpp" -#include "snippets/snippets_isa.hpp" - -#include "snippets_transformations/op/load_convert.hpp" -#include "snippets_transformations/op/store_convert.hpp" - -#include "ngraph/opsets/opset1.hpp" -#include "ngraph/rt_info.hpp" -#include "ngraph/pattern/op/wrap_type.hpp" - -ov::intel_cpu::pass::FuseLoadConvert::FuseLoadConvert() { - MATCHER_SCOPE(FuseLoadConvert); - auto param_pattern = ngraph::pattern::wrap_type(); - auto load_pattern = ngraph::pattern::wrap_type({param_pattern}); - auto convert_pattern = ngraph::pattern::wrap_type({load_pattern}); - - auto callback = [=](ngraph::pattern::Matcher& m) { - OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::FuseLoadConvert") - auto& pm = m.get_pattern_value_map(); - const auto param = pm.at(param_pattern).get_node_shared_ptr(); - const auto load_shared = pm.at(load_pattern).get_node_shared_ptr(); - if (!load_shared || load_shared->output(0).get_target_inputs().size() != 1) { - return false; - } - - const auto load = std::dynamic_pointer_cast(load_shared); - if (!load) - return false; - - const auto convert = pm.at(convert_pattern).get_node_shared_ptr(); - if (transformation_callback(convert)) - return false; - - std::shared_ptr load_convert = nullptr; - if (const auto convert_saturation = - std::dynamic_pointer_cast(convert)) { - load_convert = std::make_shared(param, - convert_saturation->get_destination_type(), - load->get_count(), load->get_offset()); - } else if (const auto convert_truncation = - std::dynamic_pointer_cast(convert)) { - load_convert = std::make_shared(param, - convert_truncation->get_destination_type(), - load->get_count(), load->get_offset()); - } else { - throw ngraph::ngraph_error( - "Type of Convert op is undefined. Supports only fusing Load and ConvertTruncation or ConvertSaturation ops"); - } - - if (!load_convert) - return false; - - ngraph::copy_runtime_info(convert, load_convert); - ngraph::replace_node(convert, load_convert); - - return true; - }; - - auto m = std::make_shared(convert_pattern, matcher_name); - register_matcher(m, callback); -} - - -ov::intel_cpu::pass::FuseStoreConvert::FuseStoreConvert() { - MATCHER_SCOPE(FuseStoreConvert); - auto input_pattern = ngraph::pattern::any_input(); - auto convert_pattern = ngraph::pattern::wrap_type({input_pattern}); - auto store_pattern = ngraph::pattern::wrap_type({convert_pattern}); - - auto callback = [=](ngraph::pattern::Matcher& m) { - OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::FuseStoreConvert") - auto& pm = m.get_pattern_value_map(); - const auto input = pm.at(input_pattern).get_node_shared_ptr(); - - const auto store = std::dynamic_pointer_cast(pm.at(store_pattern).get_node_shared_ptr()); - if (!store) - return false; - - const auto convert = pm.at(convert_pattern).get_node_shared_ptr(); - if (convert->output(0).get_target_inputs().size() != 1 || transformation_callback(convert)) - return false; - - std::shared_ptr store_convert = nullptr; - if (const auto convert_saturation = - std::dynamic_pointer_cast(convert)) { - store_convert = std::make_shared(input, - convert_saturation->get_destination_type(), - store->get_count(), store->get_offset()); - } else if (const auto convert_truncation = - std::dynamic_pointer_cast(convert)) { - store_convert = std::make_shared(input, - convert_truncation->get_destination_type(), - store->get_count(), store->get_offset()); - } else { - throw ngraph::ngraph_error( - "Type of Convert op is undefined. Supports only fusing Store and ConvertTruncation or ConvertSaturation ops"); - } - - - if (!store_convert) - return false; - - ngraph::copy_runtime_info(store, store_convert); - ngraph::replace_node(store, store_convert); - - return true; - }; - - auto m = std::make_shared(store_pattern, matcher_name); - register_matcher(m, callback); -} diff --git a/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.hpp b/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.hpp deleted file mode 100644 index 6d49bd65983802..00000000000000 --- a/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.hpp +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "ngraph/pass/graph_rewrite.hpp" -#include "ngraph/pattern/matcher.hpp" - -namespace ov { -namespace intel_cpu { -namespace pass { - -/** - * @interface FuseLoadConvert - * @brief Fuse Load and ConvertSaturation into one op LoadConvertSaturation - * Fuse Load and ConvertTruncation into one op LoadConvertTruncation - * @ingroup snippets - */ -class FuseLoadConvert: public ngraph::pass::MatcherPass { -public: - OPENVINO_RTTI("FuseLoadConvert", "0"); - FuseLoadConvert(); -}; - -/** - * @interface FuseStoreConvert - * @brief Fuse Store and ConvertSaturation into one op StoreConvertSaturation - * Fuse Store and ConvertTruncation into one op StoreConvertTruncation - * @ingroup snippets - */ -class FuseStoreConvert: public ngraph::pass::MatcherPass { -public: - OPENVINO_RTTI("FuseStoreConvert", "0"); - FuseStoreConvert(); -}; - -} // namespace pass -} // namespace intel_cpu -} // namespace ov diff --git a/src/plugins/intel_cpu/src/snippets_transformations/lowered/fuse_load_store_and_convert.cpp b/src/plugins/intel_cpu/src/snippets_transformations/lowered/fuse_load_store_and_convert.cpp new file mode 100644 index 00000000000000..8adab5cef29f6c --- /dev/null +++ b/src/plugins/intel_cpu/src/snippets_transformations/lowered/fuse_load_store_and_convert.cpp @@ -0,0 +1,120 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/itt.hpp" + +#include "fuse_load_store_and_convert.hpp" +#include "snippets/snippets_isa.hpp" + +#include "snippets_transformations/op/load_convert.hpp" +#include "snippets_transformations/op/store_convert.hpp" + + +bool ov::intel_cpu::pass::FuseLoadStoreConvert::fuse_load_convert(ngraph::snippets::LoweredExprIR& linear_ir, + ngraph::snippets::LoweredExprIR::constExprIt& convert_it) { + const auto& convert_expr = *convert_it; + const auto& convert = ov::as_type_ptr(convert_expr->get_node()); + const auto input_td = convert_expr->get_inputs().front(); + const auto output_td = convert_expr->get_outputs().front(); + if (convert->get_destination_type() != ov::element::f32 && convert->get_destination_type() != ov::element::i32) + return false; + + const auto& load_output = linear_ir.get_expr_by_output(input_td); + const auto& load_expr = load_output.expr; + const auto load = ov::as_type_ptr(load_expr->get_node()); + if (!load) + return false; + + const auto consumers = linear_ir.get_exprs_by_input(input_td); + if (consumers.size() != 1) + return false; + + std::shared_ptr load_convert = nullptr; + if (const auto convert_saturation = ov::as_type_ptr(convert)) { + load_convert = std::make_shared(load->input_value(0), + convert_saturation->get_destination_type(), + load->get_count(), load->get_offset()); + } else if (const auto convert_truncation = ov::as_type_ptr(convert)) { + load_convert = std::make_shared(load->input_value(0), + convert_truncation->get_destination_type(), + load->get_count(), load->get_offset()); + } else { + throw ov::Exception("Type of Convert op is undefined. Supports only fusing Load and ConvertTruncation or ConvertSaturation ops"); + } + + const auto in_td = std::vector{ load_expr->get_inputs().front() }; + const auto out_td = std::vector{ output_td }; + const auto mv_expr_it = convert_it; + const auto& insertion_pos = std::next(convert_it); + linear_ir.erase(std::find(linear_ir.cbegin(), mv_expr_it, load_expr)); + linear_ir.erase(mv_expr_it); + convert_it = linear_ir.insert(insertion_pos, std::make_shared(load_convert, in_td, out_td)); + return true; +} + +bool ov::intel_cpu::pass::FuseLoadStoreConvert::fuse_store_convert(ngraph::snippets::LoweredExprIR& linear_ir, + ngraph::snippets::LoweredExprIR::constExprIt& convert_it) { + const auto& convert_expr = *convert_it; + const auto& convert = convert_expr->get_node(); + const auto input_td = convert_expr->get_inputs().front(); + const auto output_td = convert_expr->get_outputs().front(); + if (convert->get_input_element_type(0) != ov::element::f32 && convert->get_input_element_type(0) != ov::element::i32) + return false; + + const auto consumers = linear_ir.get_exprs_by_input(output_td); + if (consumers.size() != 1) + return false; + + const auto store_input = *(consumers.begin()); + const auto store_expr = store_input.expr; + const auto store = ov::as_type_ptr(store_expr->get_node()); + if (!store) + return false; + + std::shared_ptr store_convert = nullptr; + if (const auto convert_saturation = ov::as_type_ptr(convert)) { + store_convert = std::make_shared(convert->input_value(0), + convert_saturation->get_destination_type(), + store->get_count(), store->get_offset()); + } else if (const auto convert_truncation = ov::as_type_ptr(convert)) { + store_convert = std::make_shared(convert->input_value(0), + convert_truncation->get_destination_type(), + store->get_count(), store->get_offset()); + } else { + throw ov::Exception("Type of Convert op is undefined. Supports only fusing Store and ConvertTruncation or ConvertSaturation ops"); + } + + const auto in_td = std::vector{ input_td }; + const auto out_td = std::vector{ store_expr->get_outputs().front() }; + const auto store_it = std::find(convert_it, linear_ir.cend(), store_expr); + const auto& insertion_pos = std::next(store_it); + linear_ir.erase(store_it); + convert_it = linear_ir.erase(convert_it); + linear_ir.insert(insertion_pos, std::make_shared(store_convert, in_td, out_td)); + return true; +} + +bool ov::intel_cpu::pass::FuseLoadStoreConvert::run(ngraph::snippets::LoweredExprIR& linear_ir) { + OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::FuseLoadStoreConvert") + + bool modified = false; + + for (auto expr_it = linear_ir.cbegin(); expr_it != linear_ir.cend(); expr_it++) { + const auto& expr = *expr_it; + const auto& convert = expr->get_node(); + if (!ov::is_type(convert)) + continue; + + if (fuse_load_convert(linear_ir, expr_it)) { + modified = true; + continue; + } + if (fuse_store_convert(linear_ir, expr_it)) { + modified = true; + continue; + } + } + + return modified; +} diff --git a/src/plugins/intel_cpu/src/snippets_transformations/lowered/fuse_load_store_and_convert.hpp b/src/plugins/intel_cpu/src/snippets_transformations/lowered/fuse_load_store_and_convert.hpp new file mode 100644 index 00000000000000..ef7d4e87d088ff --- /dev/null +++ b/src/plugins/intel_cpu/src/snippets_transformations/lowered/fuse_load_store_and_convert.hpp @@ -0,0 +1,36 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "snippets/pass/lowered/linear_IR_transformation.hpp" + +namespace ov { +namespace intel_cpu { +namespace pass { + +/** + * @interface FuseLoadStoreConvert + * @brief Fuse Load and ConvertSaturation into one op LoadConvertSaturation + * Fuse Load and ConvertTruncation into one op LoadConvertTruncation + * Fuse Store and ConvertSaturation into one op StoreConvertSaturation + * Fuse Store and ConvertTruncation into one op StoreConvertTruncation + * @ingroup snippets + */ +class FuseLoadStoreConvert: public ngraph::snippets::pass::lowered::LinearIRTransformation { +public: + FuseLoadStoreConvert() = default; + OPENVINO_RTTI("FuseLoadStoreConvert", "LinearIRTransformation"); + bool run(ngraph::snippets::LoweredExprIR& linear_ir) override; + +private: + bool fuse_load_convert(ngraph::snippets::LoweredExprIR& linear_ir, + ngraph::snippets::LoweredExprIR::constExprIt& convert_it); + bool fuse_store_convert(ngraph::snippets::LoweredExprIR& linear_ir, + ngraph::snippets::LoweredExprIR::constExprIt& convert_it); +}; + +} // namespace pass +} // namespace intel_cpu +} // namespace ov