diff --git a/src/common/transformations/include/transformations/common_optimizations/fuse_u4_weights_zero_point.hpp b/src/common/transformations/include/transformations/common_optimizations/fuse_u4_weights_zero_point.hpp new file mode 100644 index 00000000000000..9a37a009fc13a6 --- /dev/null +++ b/src/common/transformations/include/transformations/common_optimizations/fuse_u4_weights_zero_point.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API FuseU4WeightsAndZeroPoint; + +} // namespace pass +} // namespace ov + +/** + * @ingroup ie_transformation_common_api + * @brief Applies zero point to U4 weights and fuses the result to the I4 constant if the result values are inside I4 range. + * If some values are out of I4 range, converts zero point constant to scalar. + * Limitations: works only in case when zero point is equal to 8 + */ + +class ov::pass::FuseU4WeightsAndZeroPoint : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("FuseU4WeightsAndZeroPoint", "0"); + FuseU4WeightsAndZeroPoint(); +}; diff --git a/src/common/transformations/include/transformations/common_optimizations/weights_zero_point_fusion.hpp b/src/common/transformations/include/transformations/common_optimizations/weights_zero_point_fusion.hpp deleted file mode 100644 index ff4b1aee548e71..00000000000000 --- a/src/common/transformations/include/transformations/common_optimizations/weights_zero_point_fusion.hpp +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "openvino/pass/graph_rewrite.hpp" -#include "transformations_visibility.hpp" - -namespace ov { -namespace pass { - -class TRANSFORMATIONS_API WeightsZeroPointFusion; - -} // namespace pass -} // namespace ov - -/** - * @ingroup ie_transformation_common_api - * @brief Fuses zero point on weights to the weights constant if result values are inside requested precision range - * Also tries to convert zero point to scalar constant if all values are equal - */ - -// using original_et_to_acceptable = std::map -class ov::pass::WeightsZeroPointFusion : public ov::pass::MatcherPass { -public: - OPENVINO_RTTI("WeightsZeroPointFusion", "0"); - WeightsZeroPointFusion(); -}; diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_u4_weights_zero_point.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_u4_weights_zero_point.cpp new file mode 100644 index 00000000000000..ebf5519b5d867f --- /dev/null +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_u4_weights_zero_point.cpp @@ -0,0 +1,60 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/common_optimizations/fuse_u4_weights_zero_point.hpp" + +#include "itt.hpp" +#include "openvino/core/rt_info.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/reference/autobroadcast_binop.hpp" +#include "transformations/utils/utils.hpp" + +ov::pass::FuseU4WeightsAndZeroPoint::FuseU4WeightsAndZeroPoint() { + MATCHER_SCOPE(FuseU4WeightsAndZeroPoint); + auto weights_m = pattern::wrap_type(pattern::type_matches(ov::element::u4)); + auto convert_m = pattern::wrap_type({weights_m}); + auto zero_point_m = pattern::wrap_type(); + auto subtract_m = pattern::wrap_type({convert_m, zero_point_m}); + + ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + auto& pattern_map = m.get_pattern_value_map(); + const auto& subtract = pattern_map.at(subtract_m); + const int8_t zero_point_value = 8; + const auto zero_point = ov::as_type_ptr(pattern_map.at(zero_point_m).get_node_shared_ptr()); + if (!zero_point || !ov::op::util::constantIsEqualTo(zero_point, zero_point_value)) + return false; + + bool can_be_fused = true; + auto apply_zero_point = [&can_be_fused](int8_t weights_val, int8_t zp_val) mutable { + auto result_value = weights_val - zp_val; + can_be_fused &= -8 <= result_value && result_value <= 7; + return static_cast(result_value); + }; + const auto weights = ov::as_type_ptr(pattern_map.at(weights_m).get_node_shared_ptr()); + auto weights_values = weights->cast_vector(); + std::vector zero_point_values{8}; + std::vector new_weights_values(ov::shape_size(weights->get_shape())); + ov::reference::autobroadcast_binop(weights_values.data(), + zero_point_values.data(), + new_weights_values.data(), + weights->get_shape(), + ov::Shape{}, + ov::op::AutoBroadcastType::NUMPY, + apply_zero_point); + if (can_be_fused) { + auto new_weights = ov::op::v0::Constant::create(ov::element::i4, weights->get_shape(), new_weights_values); + ov::replace_node_update_name(weights, new_weights); + ov::replace_output_update_name(subtract, subtract.get_node()->input_value(0)); + } else { + auto new_zp = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zero_point_value}); + ov::replace_node_update_name(zero_point, new_zp); + } + return true; + }; + + auto m = std::make_shared(subtract_m, matcher_name); + register_matcher(m, callback); +} diff --git a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp index d3f7bd601e37cc..a283920650fb72 100644 --- a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp @@ -32,6 +32,7 @@ #include "transformations/common_optimizations/fold_subgraph_empty_inputs.hpp" #include "transformations/common_optimizations/fq_mul_fusion.hpp" #include "transformations/common_optimizations/fq_reshape_fusion.hpp" +#include "transformations/common_optimizations/fuse_u4_weights_zero_point.hpp" #include "transformations/common_optimizations/gelu_fusion.hpp" #include "transformations/common_optimizations/gru_cell_fusion.hpp" #include "transformations/common_optimizations/hsigmoid_fusion.hpp" @@ -76,7 +77,6 @@ #include "transformations/common_optimizations/swish_fusion.hpp" #include "transformations/common_optimizations/transpose_sinking.hpp" #include "transformations/common_optimizations/transpose_to_reshape.hpp" -#include "transformations/common_optimizations/weights_zero_point_fusion.hpp" #include "transformations/init_node_info.hpp" #include "transformations/low_precision/mark_dequantization_subgraph.hpp" #include "transformations/op_conversions/batch_norm_decomposition.hpp" @@ -213,7 +213,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr ADD_MATCHER(common_fusions, ShuffleChannelsFusion, !m_use_shapes) ADD_MATCHER(common_fusions, NonZeroHorizontalFusion) ADD_MATCHER(common_fusions, AdaptivePoolToReduce) - ADD_MATCHER(common_fusions, WeightsZeroPointFusion) + ADD_MATCHER(common_fusions, FuseU4WeightsAndZeroPoint) common_fusions->set_name("ov::pass::CommonFusions"); REGISTER_PASS(manager, BinarizeWeights) diff --git a/src/common/transformations/src/transformations/common_optimizations/weights_zero_point_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/weights_zero_point_fusion.cpp deleted file mode 100644 index b6408807e48532..00000000000000 --- a/src/common/transformations/src/transformations/common_optimizations/weights_zero_point_fusion.cpp +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "transformations/common_optimizations/weights_zero_point_fusion.hpp" - -#include "itt.hpp" -#include "openvino/core/rt_info.hpp" -#include "openvino/op/constant.hpp" -#include "openvino/op/subtract.hpp" -#include "openvino/pass/pattern/op/wrap_type.hpp" -#include "openvino/reference/autobroadcast_binop.hpp" -#include "transformations/utils/utils.hpp" - -namespace { -std::shared_ptr fuse_zero_points(const std::shared_ptr& weights, - const std::shared_ptr& zero_point, - int8_t min_value, - int8_t max_value) { - bool can_be_fused = true; - auto functor = [&](int8_t weights_val, int8_t zp_val) mutable { - auto result_value = weights_val - zp_val; - can_be_fused &= min_value <= result_value && result_value <= max_value; - return static_cast(result_value); - }; - - auto weights_ptr = weights->cast_vector(); - auto zp_ptr = zero_point->cast_vector(); - std::vector new_weights_values(ov::shape_size(weights->get_shape())); - ov::reference::autobroadcast_binop(weights_ptr.data(), - zp_ptr.data(), - new_weights_values.data(), - weights->get_shape(), - zero_point->get_shape(), - ov::op::AutoBroadcastType::NUMPY, - functor); - return can_be_fused ? ov::op::v0::Constant::create(ov::element::i4, weights->get_shape(), new_weights_values) : nullptr; -} -} // namespace - -ov::pass::WeightsZeroPointFusion::WeightsZeroPointFusion() { - MATCHER_SCOPE(WeightsZeroPointFusion); - auto weights_m = pattern::wrap_type(pattern::type_matches(ov::element::u4)); - auto convert_m = pattern::wrap_type({weights_m}); - auto zero_point_m = pattern::wrap_type(); - auto subtract_m = pattern::wrap_type({convert_m, zero_point_m}); - - ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { - auto& pattern_map = m.get_pattern_value_map(); - bool graph_changed = false; - const auto& subtract = pattern_map.at(subtract_m); - auto zero_point = ov::as_type_ptr(pattern_map.at(zero_point_m).get_node_shared_ptr()); - if (!zero_point) - return graph_changed; - - if (ov::shape_size(zero_point->get_shape()) > 1) { - float value; - if (ov::op::util::get_single_value(zero_point, value)) { - const auto new_zp = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {value}); - ov::replace_node_update_name(zero_point, new_zp); - zero_point = new_zp; - graph_changed = true; - } - } - - const auto weights = ov::as_type_ptr(pattern_map.at(weights_m).get_node_shared_ptr()); - if (auto new_weights = fuse_zero_points(weights, zero_point, -8, 7)) { - ov::replace_node_update_name(weights, new_weights); - ov::replace_output_update_name(subtract, subtract.get_node()->input_value(0)); - graph_changed = true; - } - - return graph_changed; - }; - - auto m = std::make_shared(subtract_m, matcher_name); - register_matcher(m, callback); -}