diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_u4_weights_zero_point.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_u4_weights_zero_point.cpp index ebf5519b5d867f..2931c76d18346a 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_u4_weights_zero_point.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_u4_weights_zero_point.cpp @@ -14,10 +14,13 @@ ov::pass::FuseU4WeightsAndZeroPoint::FuseU4WeightsAndZeroPoint() { MATCHER_SCOPE(FuseU4WeightsAndZeroPoint); - auto weights_m = pattern::wrap_type(pattern::type_matches(ov::element::u4)); - auto convert_m = pattern::wrap_type({weights_m}); - auto zero_point_m = pattern::wrap_type(); - auto subtract_m = pattern::wrap_type({convert_m, zero_point_m}); + auto weights_match = [](ov::Output output) -> bool { + return pattern::type_matches(ov::element::u4)(output) && pattern::consumers_count(1)(output); + }; + auto weights_m = pattern::wrap_type(weights_match); + auto convert_m = pattern::wrap_type({weights_m}, pattern::consumers_count(1)); + auto zero_point_m = pattern::wrap_type(pattern::consumers_count(1)); + auto subtract_m = pattern::wrap_type({convert_m, zero_point_m}, pattern::consumers_count(1)); ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { auto& pattern_map = m.get_pattern_value_map(); diff --git a/src/common/transformations/tests/common_optimizations/fuse_u4_weights_zero_point.cpp b/src/common/transformations/tests/common_optimizations/fuse_u4_weights_zero_point.cpp new file mode 100644 index 00000000000000..c163e7e9f732fb --- /dev/null +++ b/src/common/transformations/tests/common_optimizations/fuse_u4_weights_zero_point.cpp @@ -0,0 +1,123 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/common_optimizations/fuse_u4_weights_zero_point.hpp" + +#include + +#include "common_test_utils/ov_test_utils.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/core/model.hpp" +#include "openvino/pass/manager.hpp" + +using namespace testing; +using namespace ov; + +TEST_F(TransformationTestsF, FuseU4WeightsAndZeroPoint) { + auto decompression_precision = ov::element::f32; + ov::Shape weights_shape{32, 128, 64}; + ov::Shape decompression_shape{32, 1, 64}; + { + std::vector weights_values(ov::shape_size(weights_shape)); + for (size_t i = 0; i < weights_values.size(); ++i) { + weights_values[i] = i % 16; + } + auto weights = ov::op::v0::Constant::create(ov::element::u4, weights_shape, weights_values); + auto convert = std::make_shared(weights, decompression_precision); + auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {8}); + auto subtract = std::make_shared(convert, zero_point); + auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f}); + auto multiply = std::make_shared(subtract, scale); + model = std::make_shared(NodeVector{multiply}, ParameterVector{}); + manager.register_pass(); + } + { + std::vector weights_values(ov::shape_size(weights_shape)); + for (size_t i = 0; i < weights_values.size(); ++i) { + weights_values[i] = i % 16 - 8; + } + auto weights = ov::op::v0::Constant::create(ov::element::i4, weights_shape, weights_values); + auto convert = std::make_shared(weights, decompression_precision); + auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f}); + auto multiply = std::make_shared(convert, scale); + model_ref = std::make_shared(NodeVector{multiply}, ParameterVector{}); + } +} + +TEST_F(TransformationTestsF, FuseU4WeightsAndZeroPointUnaccaptableZeroPoint) { + auto decompression_precision = ov::element::f32; + ov::Shape weights_shape{32, 128, 64}; + ov::Shape decompression_shape{32, 1, 64}; + + std::vector weights_values(ov::shape_size(weights_shape)); + for (size_t i = 0; i < weights_values.size(); ++i) { + weights_values[i] = i % 8; + } + auto weights = ov::op::v0::Constant::create(ov::element::u4, weights_shape, weights_values); + auto convert = std::make_shared(weights, decompression_precision); + auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {4}); + auto subtract = std::make_shared(convert, zero_point); + auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f}); + auto multiply = std::make_shared(subtract, scale); + model = std::make_shared(NodeVector{multiply}, ParameterVector{}); + manager.register_pass(); +} + +TEST_F(TransformationTestsF, FuseU4WeightsAndZeroPointOutOfI4Range) { + auto decompression_precision = ov::element::f32; + ov::Shape weights_shape{32, 128, 64}; + ov::Shape decompression_shape{32, 1, 64}; + + std::vector weights_values(ov::shape_size(weights_shape)); + for (size_t i = 0; i < weights_values.size(); ++i) { + weights_values[i] = i % 16; + } + auto weights = ov::op::v0::Constant::create(ov::element::u4, weights_shape, weights_values); + auto convert = std::make_shared(weights, decompression_precision); + std::vector zero_point_values(ov::shape_size(decompression_shape)); + zero_point_values.back() = 16; + auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, zero_point_values); + auto subtract = std::make_shared(convert, zero_point); + auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f}); + auto multiply = std::make_shared(subtract, scale); + model = std::make_shared(NodeVector{multiply}, ParameterVector{}); + manager.register_pass(); +} + +TEST_F(TransformationTestsF, FuseU4WeightsAndZeroPointWrongWeightsPrecision) { + auto decompression_precision = ov::element::f32; + ov::Shape weights_shape{32, 128, 64}; + + std::vector weights_values(ov::shape_size(weights_shape)); + for (size_t i = 0; i < weights_values.size(); ++i) { + weights_values[i] = i % 256; + } + auto weights = ov::op::v0::Constant::create(ov::element::u8, weights_shape, weights_values); + auto convert = std::make_shared(weights, decompression_precision); + auto zero_point = ov::op::v0::Constant::create(decompression_precision, {32, 1, 64}, {8}); + auto subtract = std::make_shared(convert, zero_point); + auto scale = ov::op::v0::Constant::create(decompression_precision, {32, 1, 64}, {3.f}); + auto multiply = std::make_shared(subtract, scale); + model = std::make_shared(NodeVector{multiply}, ParameterVector{}); + manager.register_pass(); +} + +TEST_F(TransformationTestsF, FuseU4WeightsAndZeroPointSeveralWeightsConsumers) { + auto decompression_precision = ov::element::f32; + ov::Shape weights_shape{32, 128, 64}; + ov::Shape decompression_shape{32, 1, 64}; + + auto weights = ov::op::v0::Constant::create(ov::element::u4, weights_shape, {4}); + auto convert = std::make_shared(weights, decompression_precision); + auto additional_consumer = std::make_shared(weights, ov::element::i32); + auto additional_op = std::make_shared(additional_consumer, ov::op::v0::Constant::create(ov::element::i32, {}, {-1})); + auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {8}); + auto subtract = std::make_shared(convert, zero_point); + auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f}); + auto multiply = std::make_shared(subtract, scale); + model = std::make_shared(NodeVector{multiply, additional_op}, ParameterVector{}); + manager.register_pass(); +}