diff --git a/src/common/transformations/src/transformations/common_optimizations/fq_reshape_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/fq_reshape_fusion.cpp index 8840a93e07c7b9..b9bafeeff90ff0 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fq_reshape_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fq_reshape_fusion.cpp @@ -9,23 +9,29 @@ #include "itt.hpp" #include "openvino/core/rt_info.hpp" +#include "openvino/core/validation_util.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" #include "openvino/op/fake_quantize.hpp" #include "openvino/op/group_conv.hpp" #include "openvino/op/reshape.hpp" +#include "openvino/pass/pattern/op/optional.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" ov::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() { MATCHER_SCOPE(FakeQuantizeReshapeFusion); - const auto fq_node_p = ov::pass::pattern::wrap_type( - {ov::pass::pattern::wrap_type(), // for weights only - pattern::any_input(), - pattern::any_input(), - pattern::any_input(), - pattern::any_input()}, - pattern::consumers_count(1)); + // for weights only + const auto data_p = ov::pass::pattern::wrap_type(pattern::has_static_shape()); + const auto convert_p = ov::pass::pattern::optional(data_p, pattern::consumers_count(1)); + const auto fq_node_p = + ov::pass::pattern::wrap_type({convert_p, + pattern::any_input(pattern::has_static_shape()), + pattern::any_input(pattern::has_static_shape()), + pattern::any_input(pattern::has_static_shape()), + pattern::any_input(pattern::has_static_shape())}, + pattern::consumers_count(1)); const auto reshape_node_p = ov::pass::pattern::wrap_type( - {fq_node_p, pattern::any_input()}, + {fq_node_p, ov::pass::pattern::wrap_type()}, [](const Output& output) { // WA: check that all Reshape node consumers are not GroupConvolution operations const auto& target_inputs = output.get_target_inputs(); @@ -36,13 +42,11 @@ ov::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() { ov::matcher_pass_callback callback = [=](pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); - const auto fq_node = pattern_map.at(fq_node_p).get_node_shared_ptr(); - if (fq_node->is_dynamic()) - return false; + const auto& fq_node = pattern_map.at(fq_node_p).get_node_shared_ptr(); const auto& reshape_node = pattern_map.at(reshape_node_p).get_node_shared_ptr(); const auto& original_data_rank = fq_node->get_input_shape(0).size(); - OutputVector renewed_inputs = { - reshape_node->clone_with_new_inputs({fq_node->input_value(0), reshape_node->input_value(1)})}; + + OutputVector renewed_inputs = {}; for (auto i = 1; i < 5; ++i) { Output limit_input = fq_node->input_value(i); auto limit_shape = limit_input.get_shape(); @@ -62,21 +66,41 @@ ov::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() { }); const auto& new_limit_size = shape_size(new_limit_shape); if (new_limit_size == limit_size) { // we tracked future channel placement - if (new_limit_shape == limit_input.get_shape()) + if (new_limit_shape == limit_input.get_shape()) { renewed_inputs.push_back(limit_input); - else - renewed_inputs.push_back(reshape_node->clone_with_new_inputs( + } else { + auto reshaped_input = reshape_node->clone_with_new_inputs( {limit_input, - ov::op::v0::Constant::create(element::i64, {new_limit_shape.size()}, new_limit_shape)})); + ov::op::v0::Constant::create(element::i64, {new_limit_shape.size()}, new_limit_shape)}); + if (auto constant = ov::util::get_constant_from_source(reshaped_input)) { + reshaped_input = constant; + } + renewed_inputs.push_back(reshaped_input); + } continue; } } // resulting FQ will become or already is more than per-tensor / per-channel return false; } + + auto reshaped_input = + reshape_node->clone_with_new_inputs({pattern_map.at(data_p), reshape_node->input_value(1)}); + if (auto constant = ov::util::get_constant_from_source(reshaped_input)) { + reshaped_input = constant; + } + if (pattern_map.count(convert_p)) { + const auto& convert_node = pattern_map.at(convert_p).get_node_shared_ptr(); + convert_node->input(0).replace_source_output(reshaped_input); + convert_node->validate_and_infer_types(); + reshaped_input = convert_node; + } + renewed_inputs.insert(renewed_inputs.begin(), reshaped_input); + for (auto& new_input : renewed_inputs) copy_runtime_info({reshape_node, fq_node}, new_input.get_node_shared_ptr()); const auto new_fq_node = fq_node->clone_with_new_inputs(renewed_inputs); + register_new_node(new_fq_node); replace_node(reshape_node, new_fq_node); new_fq_node->set_friendly_name(reshape_node->get_friendly_name()); copy_runtime_info({fq_node, reshape_node}, new_fq_node); diff --git a/src/common/transformations/src/transformations/common_optimizations/pull_transpose_through_fq.cpp b/src/common/transformations/src/transformations/common_optimizations/pull_transpose_through_fq.cpp index 1fdd69711e3af5..0d021c55ca140d 100644 --- a/src/common/transformations/src/transformations/common_optimizations/pull_transpose_through_fq.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/pull_transpose_through_fq.cpp @@ -14,13 +14,15 @@ #include "openvino/op/fake_quantize.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" +#include "openvino/pass/pattern/op/optional.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/utils/utils.hpp" ov::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() { MATCHER_SCOPE(PullTransposeThroughFQUp); const auto weights = ov::pass::pattern::wrap_type(); - auto m_fq = pattern::wrap_type({weights, + const auto convert_p = ov::pass::pattern::optional(weights, pattern::consumers_count(1)); + auto m_fq = pattern::wrap_type({convert_p, pattern::any_input(pattern::has_static_shape()), pattern::any_input(pattern::has_static_shape()), pattern::any_input(pattern::has_static_shape()), @@ -33,25 +35,15 @@ ov::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() { auto& pattern_map = m.get_pattern_value_map(); auto transpose = pattern_map[m_transpose].get_node_shared_ptr(); auto fq = pattern_map[m_fq].get_node_shared_ptr(); - - auto are_inputs_scalars = - shape_size(fq->input_value(1).get_shape()) == 1 && shape_size(fq->input_value(2).get_shape()) == 1 && - shape_size(fq->input_value(3).get_shape()) == 1 && shape_size(fq->input_value(4).get_shape()) == 1; - if (!are_inputs_scalars) { - auto perm = ov::as_type_ptr(pattern_map[m_transpose_perm].get_node_shared_ptr()); - if (!perm) - return false; - auto perm_val = perm->cast_vector(); - if (!(perm_val[0] == 0 && perm_val[1] == 1)) - return false; - } - auto input_rank = fq->input(0).get_partial_shape().rank().get_length(); ov::NodeVector new_ops; ov::OutputVector fq_inputs; for (size_t i = 0; i < fq->inputs().size(); ++i) { auto fq_input = fq->input_value(i); + if (i == 0) { + fq_input = pattern_map[weights]; + } auto fq_input_rank = fq_input.get_partial_shape().rank().get_length(); std::vector unsqueeze_axes; for (int64_t j = 0; j < input_rank - fq_input_rank; ++j) { @@ -68,10 +60,17 @@ ov::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() { fq_input = constant; } ov::copy_runtime_info(transpose, fq_input.get_node_shared_ptr()); + if (i == 0 && pattern_map.count(convert_p)) { + const auto& convert_node = pattern_map.at(convert_p).get_node_shared_ptr(); + convert_node->input(0).replace_source_output(fq_input); + convert_node->validate_and_infer_types(); + fq_input = convert_node; + } fq_inputs.push_back(fq_input); } auto new_fq = fq->clone_with_new_inputs(fq_inputs); + register_new_node(new_fq); new_ops.push_back(new_fq); new_fq->set_friendly_name(transpose->get_friendly_name()); ov::copy_runtime_info({fq, transpose}, new_ops); diff --git a/src/common/transformations/tests/common_optimizations/fq_reshape_fusion.cpp b/src/common/transformations/tests/common_optimizations/fq_reshape_fusion.cpp index cc4ac2981b6799..940a5b29b8d702 100644 --- a/src/common/transformations/tests/common_optimizations/fq_reshape_fusion.cpp +++ b/src/common/transformations/tests/common_optimizations/fq_reshape_fusion.cpp @@ -13,7 +13,10 @@ #include "common_test_utils/ov_test_utils.hpp" #include "openvino/core/model.hpp" #include "openvino/opsets/opset4.hpp" +#include "openvino/pass/graph_rewrite.hpp" #include "openvino/pass/manager.hpp" +#include "transformations/common_optimizations/fq_mul_fusion.hpp" +#include "transformations/common_optimizations/pull_transpose_through_fq.hpp" #include "transformations/init_node_info.hpp" using namespace ov; @@ -66,13 +69,8 @@ class FQReshapeFusionTests : public ov::test::TestsCommon, } std::shared_ptr get_reference_function(const FQReshapeFusionTestCase& test_case) { - const auto& data = std::make_shared(element::f32, test_case.data_shape, 0); - const auto& reshaped_data = std::make_shared( - data, - std::make_shared(element::i64, - Shape{test_case.reshape_pattern.size()}, - test_case.reshape_pattern), - true); + auto shape = PartialShape(test_case.reshape_pattern).to_shape(); + const auto& data = std::make_shared(element::f32, shape, 0); const auto& p_il = std::make_shared(element::f32, test_case.il_shape); Output il = p_il; @@ -104,7 +102,7 @@ class FQReshapeFusionTests : public ov::test::TestsCommon, opset4::Constant::create(element::i64, {test_case.new_oh_shape.size()}, test_case.new_oh_shape), true); - auto fq = std::make_shared(reshaped_data, il, ih, ol, oh, 42); + auto fq = std::make_shared(data, il, ih, ol, oh, 42); auto result = std::make_shared(fq); ParameterVector params = {p_il, p_ih, p_ol, p_oh}; @@ -213,3 +211,77 @@ TEST_F(TransformationTestsF, FQReshapeGroupConvolution) { manager.register_pass(); manager.register_pass(); } + +TEST_F(TransformationTestsF, FQOptimizations) { + { + const auto& data = std::make_shared(element::u8, Shape{9, 32}, 0); + const auto& convert = std::make_shared(data, element::f32); + + const auto& il = op::v0::Constant::create(element::f32, Shape{1}, {0}); + const auto& ih = op::v0::Constant::create(element::f32, Shape{1}, {254}); + const auto& ol = op::v0::Constant::create(element::f32, Shape{32}, {-14.22}); + const auto& oh = op::v0::Constant::create(element::f32, Shape{32}, {14.22}); + + const auto& fq = std::make_shared(convert, il, ih, ol, oh, 255); + + const auto& reshape = + std::make_shared(fq, + op::v0::Constant::create(element::i64, Shape{4}, {3, 3, 32, 1}), + true); + + const auto& multiply = + std::make_shared(reshape, + op::v0::Constant::create(element::f32, Shape{1, 1, 32, 1}, {0.1140})); + + const auto& transpose = + std::make_shared(multiply, + op::v0::Constant::create(element::i64, Shape{4}, {2, 3, 0, 1})); + + const auto& reshape_to_weight = + std::make_shared(transpose, + op::v0::Constant::create(element::i64, Shape{5}, {32, 1, 1, 3, 3}), + true); + + const auto& input = std::make_shared(element::f32, PartialShape::dynamic(4)); + const auto& group_conv = std::make_shared(input, + reshape_to_weight, + Strides{1, 1}, + CoordinateDiff{0, 0}, + CoordinateDiff{0, 0}, + Strides{1, 1}); + + model = std::make_shared(OutputVector{group_conv}, ParameterVector{input}); + + auto fq_fusions = manager.register_pass(); + fq_fusions->add_matcher(); + fq_fusions->add_matcher(); + fq_fusions->add_matcher(); + fq_fusions->set_name("ov::pass::FakeQuantizeFusions"); + } + { + const auto& data = std::make_shared(element::u8, Shape{32, 1, 3, 3}, 0); + const auto& convert = std::make_shared(data, element::f32); + + const auto& il = op::v0::Constant::create(element::f32, Shape{1, 1, 1, 1}, {0}); + const auto& ih = op::v0::Constant::create(element::f32, Shape{1, 1, 1, 1}, {254}); + const auto& ol = op::v0::Constant::create(element::f32, Shape{32, 1, 1, 1}, {-14.22 * 0.1140}); + const auto& oh = op::v0::Constant::create(element::f32, Shape{32, 1, 1, 1}, {14.22 * 0.1140}); + + const auto& fq = std::make_shared(convert, il, ih, ol, oh, 255); + + const auto& reshape_to_weight = + std::make_shared(fq, + op::v0::Constant::create(element::i64, Shape{5}, {32, 1, 1, 3, 3}), + true); + + const auto& input = std::make_shared(element::f32, PartialShape::dynamic(4)); + const auto& group_conv = std::make_shared(input, + reshape_to_weight, + Strides{1, 1}, + CoordinateDiff{0, 0}, + CoordinateDiff{0, 0}, + Strides{1, 1}); + + model_ref = std::make_shared(OutputVector{group_conv}, ParameterVector{input}); + } +}