Skip to content

Commit

Permalink
Hardcoded zero point
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Oct 16, 2023
1 parent ac00c30 commit eb258cd
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 109 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API FuseU4WeightsAndZeroPoint;

} // namespace pass
} // namespace ov

/**
* @ingroup ie_transformation_common_api
* @brief Applies zero point to U4 weights and fuses the result to the I4 constant if the result values are inside I4 range.
* If some values are out of I4 range, converts zero point constant to scalar.
* Limitations: works only in case when zero point is equal to 8
*/

class ov::pass::FuseU4WeightsAndZeroPoint : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("FuseU4WeightsAndZeroPoint", "0");
FuseU4WeightsAndZeroPoint();
};

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/fuse_u4_weights_zero_point.hpp"

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/reference/autobroadcast_binop.hpp"
#include "transformations/utils/utils.hpp"

ov::pass::FuseU4WeightsAndZeroPoint::FuseU4WeightsAndZeroPoint() {
MATCHER_SCOPE(FuseU4WeightsAndZeroPoint);
auto weights_m = pattern::wrap_type<ov::op::v0::Constant>(pattern::type_matches(ov::element::u4));
auto convert_m = pattern::wrap_type<ov::op::v0::Convert>({weights_m});
auto zero_point_m = pattern::wrap_type<ov::op::v0::Constant>();
auto subtract_m = pattern::wrap_type<ov::op::v1::Subtract>({convert_m, zero_point_m});

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
auto& pattern_map = m.get_pattern_value_map();
const auto& subtract = pattern_map.at(subtract_m);
const int8_t zero_point_value = 8;
const auto zero_point = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(zero_point_m).get_node_shared_ptr());
if (!zero_point || !ov::op::util::constantIsEqualTo(zero_point, zero_point_value))
return false;

bool can_be_fused = true;
auto apply_zero_point = [&can_be_fused](int8_t weights_val, int8_t zp_val) mutable {
auto result_value = weights_val - zp_val;
can_be_fused &= -8 <= result_value && result_value <= 7;
return static_cast<int8_t>(result_value);
};
const auto weights = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(weights_m).get_node_shared_ptr());
auto weights_values = weights->cast_vector<int8_t>();
std::vector<int8_t> zero_point_values{8};
std::vector<int8_t> new_weights_values(ov::shape_size(weights->get_shape()));
ov::reference::autobroadcast_binop(weights_values.data(),
zero_point_values.data(),
new_weights_values.data(),
weights->get_shape(),
ov::Shape{},
ov::op::AutoBroadcastType::NUMPY,
apply_zero_point);
if (can_be_fused) {
auto new_weights = ov::op::v0::Constant::create(ov::element::i4, weights->get_shape(), new_weights_values);
ov::replace_node_update_name(weights, new_weights);
ov::replace_output_update_name(subtract, subtract.get_node()->input_value(0));
} else {
auto new_zp = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zero_point_value});
ov::replace_node_update_name(zero_point, new_zp);
}
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(subtract_m, matcher_name);
register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "transformations/common_optimizations/fold_subgraph_empty_inputs.hpp"
#include "transformations/common_optimizations/fq_mul_fusion.hpp"
#include "transformations/common_optimizations/fq_reshape_fusion.hpp"
#include "transformations/common_optimizations/fuse_u4_weights_zero_point.hpp"
#include "transformations/common_optimizations/gelu_fusion.hpp"
#include "transformations/common_optimizations/gru_cell_fusion.hpp"
#include "transformations/common_optimizations/hsigmoid_fusion.hpp"
Expand Down Expand Up @@ -76,7 +77,6 @@
#include "transformations/common_optimizations/swish_fusion.hpp"
#include "transformations/common_optimizations/transpose_sinking.hpp"
#include "transformations/common_optimizations/transpose_to_reshape.hpp"
#include "transformations/common_optimizations/weights_zero_point_fusion.hpp"
#include "transformations/init_node_info.hpp"
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
#include "transformations/op_conversions/batch_norm_decomposition.hpp"
Expand Down Expand Up @@ -213,7 +213,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
ADD_MATCHER(common_fusions, ShuffleChannelsFusion, !m_use_shapes)
ADD_MATCHER(common_fusions, NonZeroHorizontalFusion)
ADD_MATCHER(common_fusions, AdaptivePoolToReduce)
ADD_MATCHER(common_fusions, WeightsZeroPointFusion)
ADD_MATCHER(common_fusions, FuseU4WeightsAndZeroPoint)
common_fusions->set_name("ov::pass::CommonFusions");

REGISTER_PASS(manager, BinarizeWeights)
Expand Down

This file was deleted.

0 comments on commit eb258cd

Please sign in to comment.