Skip to content

Commit

Permalink
code style
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Oct 17, 2023
1 parent 270839e commit e35c80b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ class TRANSFORMATIONS_API FuseU4WeightsAndZeroPoint;

/**
* @ingroup ie_transformation_common_api
* @brief Applies zero point to U4 weights and fuses the result to the I4 constant if the result values are inside I4 range.
* If some values are out of I4 range, converts zero point constant to scalar.
* Limitations: works only in case when zero point is equal to 8
* @brief Applies zero point to U4 weights and fuses the result to the I4 constant if the result values are inside I4
* range. If some values are out of I4 range, converts zero point constant to scalar. Limitations: works only in case
* when zero point is equal to 8
*/

class ov::pass::FuseU4WeightsAndZeroPoint : public ov::pass::MatcherPass {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ ov::pass::FuseU4WeightsAndZeroPoint::FuseU4WeightsAndZeroPoint() {

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
auto& pattern_map = m.get_pattern_value_map();
const auto& subtract = pattern_map.at(subtract_m);
const auto& subtract_out = pattern_map.at(subtract_m);
const auto& zero_point_out = pattern_map.at(zero_point_m);
const int8_t zero_point_value = 8;
const auto zero_point = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(zero_point_m).get_node_shared_ptr());
const auto zero_point = ov::as_type_ptr<ov::op::v0::Constant>(zero_point_out.get_node_shared_ptr());
if (!zero_point || !ov::op::util::constantIsEqualTo(zero_point, zero_point_value))
return false;

Expand All @@ -50,7 +51,7 @@ ov::pass::FuseU4WeightsAndZeroPoint::FuseU4WeightsAndZeroPoint() {
if (can_be_fused) {
auto new_weights = ov::op::v0::Constant::create(ov::element::i4, weights->get_shape(), new_weights_values);
ov::replace_node_update_name(weights, new_weights);
ov::replace_output_update_name(subtract, subtract.get_node()->input_value(0));
ov::replace_output_update_name(subtract_out, subtract_out.get_node()->input_value(0));
} else {
auto new_zp = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zero_point_value});
ov::replace_node_update_name(zero_point, new_zp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
#include <gtest/gtest.h>

#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/core/model.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/core/model.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/pass/manager.hpp"

using namespace testing;
Expand Down Expand Up @@ -112,8 +112,9 @@ TEST_F(TransformationTestsF, FuseU4WeightsAndZeroPointSeveralWeightsConsumers) {

auto weights = ov::op::v0::Constant::create(ov::element::u4, weights_shape, {4});
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
auto additional_consumer = std::make_shared<ov::op::v0::Convert>(weights, ov::element::i32);
auto additional_op = std::make_shared<ov::op::v1::Multiply>(additional_consumer, ov::op::v0::Constant::create(ov::element::i32, {}, {-1}));
auto additional_consumer = std::make_shared<ov::op::v0::Convert>(weights, ov::element::i32);
auto additional_const = ov::op::v0::Constant::create(ov::element::i32, {}, {-1});
auto additional_op = std::make_shared<ov::op::v1::Multiply>(additional_consumer, additional_const);
auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {8});
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
Expand Down

0 comments on commit e35c80b

Please sign in to comment.