diff --git a/src/common/offline_transformations/src/compress_quantize_weigths.cpp b/src/common/offline_transformations/src/compress_quantize_weigths.cpp index f2d3526a9bd73c..9ae1e7118ac5ea 100644 --- a/src/common/offline_transformations/src/compress_quantize_weigths.cpp +++ b/src/common/offline_transformations/src/compress_quantize_weigths.cpp @@ -165,7 +165,8 @@ ov::pass::CompressWeightsWithFakeQuantize::CompressWeightsWithFakeQuantize() { if (levels <= 2 || levels > 256) return false; auto low_precision_type = element::undefined; - // Currently we support two weights quantize types: i4 and i8 + // Currently we support two weights quantize types: i4, u4, i8, u8 + // we determine that weights should be cast to u4, u8 inside compress_quantized_weights_internal if (levels <= 16) { low_precision_type = element::i4; } else if (levels <= 256) { @@ -670,15 +671,24 @@ static void numpy_broadcast_6inputs(const T* weights, for (size_t i = 0; i < shape_size(weights_shape); i++) { std::tie(in_low_stride, in_high_stride, out_low_stride, out_high_stride, zero_point_stride) = get_outer_strides(i); - *new_weights++ = f(*weights++, - *(in_low + in_low_stride), - *(in_high + in_high_stride), - *(out_low + out_low_stride), - *(out_high + out_high_stride), - *(zero_point + zero_point_stride)); + *new_weights = f(*weights++, + *(in_low + in_low_stride), + *(in_high + in_high_stride), + *(out_low + out_low_stride), + *(out_high + out_high_stride), + *(zero_point + zero_point_stride)); + new_weights++; } } +static inline uint8_t convert_to_uint8(float val) { + return static_cast(std::nearbyint(val)); +} + +static inline uint8_t convert_to_uint4(float val) { + return static_cast(std::nearbyint(val)) & 0x0f; +} + static inline int8_t convert_to_int8(float val) { return static_cast(std::nearbyint(val)); } @@ -713,22 +723,95 @@ static std::shared_ptr compress_quantized_weights_internal const ov::Shape& zero_point_shape, size_t levels, bool& can_fuse_zero_point) { - ov::Tensor compressed_weights_tensor(ov::element::i8, weights_shape); - int8_t* compressed_weights = compressed_weights_tensor.data(); - ov::Tensor compressed_weights_with_fused_zero_point_tensor(ov::element::i8, weights_shape); - int8_t* compressed_weights_with_fused_zero_point = compressed_weights_with_fused_zero_point_tensor.data(); - T levels_minus_one = static_cast(levels - 1); - can_fuse_zero_point = true; - const auto convert_to_low_precision = low_precision_type == ov::element::i4 ? convert_to_int4 : convert_to_int8; + ov::element::Type new_low_precision_type = low_precision_type; + bool all_low_not_neg = true; + for (int i = 0; i < ov::shape_size(output_low_shape); ++i) { + all_low_not_neg = all_low_not_neg && (output_low[i] >= 0); + if (!all_low_not_neg) { + break; + } + } + + if (low_precision_type == ov::element::i8 && all_low_not_neg) { + new_low_precision_type = ov::element::u8; + } else if (low_precision_type == ov::element::i4 && all_low_not_neg) { + new_low_precision_type = ov::element::u4; + } + + ov::element::Type tensor_el_type; + if (new_low_precision_type == ov::element::i8 || new_low_precision_type == ov::element::i4) { + tensor_el_type = ov::element::i8; + } else if (new_low_precision_type == ov::element::u8 || new_low_precision_type == ov::element::u4) { + tensor_el_type = ov::element::u8; + } + + ov::Tensor compressed_weights_tensor(tensor_el_type, weights_shape); + ov::Tensor compressed_weights_with_fused_zero_point_tensor(tensor_el_type, weights_shape); + + // TODO: reuse the common code parts + if (tensor_el_type == ov::element::u8) { + auto* compressed_weights = compressed_weights_tensor.data(); + auto* compressed_weights_with_fused_zero_point = + compressed_weights_with_fused_zero_point_tensor.data(); + T levels_minus_one = static_cast(levels - 1); + can_fuse_zero_point = true; + const auto convert_to_low_precision = + low_precision_type == ov::element::u4 ? convert_to_uint4 : convert_to_uint8; + + auto f = [compressed_weights_with_fused_zero_point, + levels_minus_one, + convert_to_low_precision, + &can_fuse_zero_point](T weights_value, + T input_low, + T input_high, + T output_low, + T output_high, + T zero_point) mutable { + uint8_t compressed_weights_value = + convert_to_low_precision(ov::reference::fake_quantize_details::quantize(weights_value, + input_low, + input_high, + output_low, + output_high, + levels_minus_one)); + T weights_minus_zero_point = static_cast(compressed_weights_value) - zero_point; + uint8_t compressed_weights_with_fused_zero_point_value = convert_to_low_precision(weights_minus_zero_point); + can_fuse_zero_point &= + std::fabs(compressed_weights_with_fused_zero_point_value - weights_minus_zero_point) < 1e-4; + *compressed_weights_with_fused_zero_point++ = compressed_weights_with_fused_zero_point_value; + return compressed_weights_value; + }; - auto f = - [compressed_weights_with_fused_zero_point, levels_minus_one, convert_to_low_precision, &can_fuse_zero_point]( - T weights_value, - T input_low, - T input_high, - T output_low, - T output_high, - T zero_point) mutable { + numpy_broadcast_6inputs(weights, + weights_shape, + input_low, + input_low_shape, + input_high, + input_high_shape, + output_low, + output_low_shape, + output_high, + output_high_shape, + zero_point, + zero_point_shape, + compressed_weights, + f); + } else if (tensor_el_type == ov::element::i8) { + auto* compressed_weights = compressed_weights_tensor.data(); + auto* compressed_weights_with_fused_zero_point = compressed_weights_with_fused_zero_point_tensor.data(); + T levels_minus_one = static_cast(levels - 1); + can_fuse_zero_point = true; + const auto convert_to_low_precision = low_precision_type == ov::element::i4 ? convert_to_int4 : convert_to_int8; + + auto f = [compressed_weights_with_fused_zero_point, + levels_minus_one, + convert_to_low_precision, + &can_fuse_zero_point](T weights_value, + T input_low, + T input_high, + T output_low, + T output_high, + T zero_point) mutable { int8_t compressed_weights_value = convert_to_low_precision(ov::reference::fake_quantize_details::quantize(weights_value, input_low, @@ -744,24 +827,25 @@ static std::shared_ptr compress_quantized_weights_internal return compressed_weights_value; }; - numpy_broadcast_6inputs(weights, - weights_shape, - input_low, - input_low_shape, - input_high, - input_high_shape, - output_low, - output_low_shape, - output_high, - output_high_shape, - zero_point, - zero_point_shape, - compressed_weights, - f); + numpy_broadcast_6inputs(weights, + weights_shape, + input_low, + input_low_shape, + input_high, + input_high_shape, + output_low, + output_low_shape, + output_high, + output_high_shape, + zero_point, + zero_point_shape, + compressed_weights, + f); + } return create_weights_constant( can_fuse_zero_point ? compressed_weights_with_fused_zero_point_tensor : compressed_weights_tensor, - low_precision_type); + new_low_precision_type); } std::shared_ptr compress_quantized_weights( diff --git a/src/common/transformations/tests/utils/compress_quantize_weights.cpp b/src/common/transformations/tests/utils/compress_quantize_weights.cpp index fe0918cc3425c3..b38d2fa682d483 100644 --- a/src/common/transformations/tests/utils/compress_quantize_weights.cpp +++ b/src/common/transformations/tests/utils/compress_quantize_weights.cpp @@ -82,12 +82,7 @@ class CompressQuantizeWeightsTests } }; -#ifdef OPENVINO_ARCH_ARM64 -// Ticket: CVS-122397 -TEST_P(CompressQuantizeWeightsTests, DISABLED_FusionTest) {} -#else TEST_P(CompressQuantizeWeightsTests, FusionTest) {} -#endif static std::vector params = { {Shape{2, 3, 1, 1}, @@ -138,6 +133,54 @@ static std::vector params = { 0.0313725f, -64.25f, false}, + {Shape{2, 3, 1, 1}, + {-1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 11.0f}, + 0.0f, + 10.0f, + 1.0f, + 5.0f, + 3, + element::u4, + {1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, + 3.0f, + -0.666667f, + false}, + {Shape{2, 3, 1, 1}, + {-1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 11.0f}, + 0.0f, + 10.0f, + 1.0f, + 4.0f, + 16, + element::u4, + {8.0f, 5.0f, 4.0f, 2.0f, 0.0f, 7.0f}, + 0.333333f, + -5.0f, + false}, + {Shape{2, 4, 1, 1}, + {-1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 11.0f}, + 1.0f, + 9.0f, + 2.0f, + 6.0f, + 17, + element::u8, + {4.0f, 4.0f, 4.0f, 2.0f, 0.0f, 2.0f, 4.0f, 12.0f}, + 0.5f, + -4.0f, + true}, + {Shape{2, 4, 1, 1}, + {-1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 11.0f}, + 1.0f, + 9.0f, + 2.0f, + 6.0f, + 256, + element::u8, + {128.0f, 128.0f, 128.0f, 96.0f, 64.0f, 32.0f, 0.0f, 127.0f}, + 0.0313725f, + -64.25f, + false}, }; static element::TypeVector data_precisions = {element::f32, element::f16};