Skip to content

Commit

Permalink
Update compress_quantize_weights transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Aug 21, 2024
1 parent 52f6fe7 commit 6febbc1
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 42 deletions.
158 changes: 121 additions & 37 deletions src/common/offline_transformations/src/compress_quantize_weigths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ ov::pass::CompressWeightsWithFakeQuantize::CompressWeightsWithFakeQuantize() {
if (levels <= 2 || levels > 256)
return false;
auto low_precision_type = element::undefined;
// Currently we support two weights quantize types: i4 and i8
// Currently we support two weights quantize types: i4, u4, i8, u8
// we determine that weights should be cast to u4, u8 inside compress_quantized_weights_internal
if (levels <= 16) {
low_precision_type = element::i4;
} else if (levels <= 256) {
Expand Down Expand Up @@ -670,15 +671,24 @@ static void numpy_broadcast_6inputs(const T* weights,
for (size_t i = 0; i < shape_size(weights_shape); i++) {
std::tie(in_low_stride, in_high_stride, out_low_stride, out_high_stride, zero_point_stride) =
get_outer_strides(i);
*new_weights++ = f(*weights++,
*(in_low + in_low_stride),
*(in_high + in_high_stride),
*(out_low + out_low_stride),
*(out_high + out_high_stride),
*(zero_point + zero_point_stride));
*new_weights = f(*weights++,
*(in_low + in_low_stride),
*(in_high + in_high_stride),
*(out_low + out_low_stride),
*(out_high + out_high_stride),
*(zero_point + zero_point_stride));
new_weights++;
}
}

static inline uint8_t convert_to_uint8(float val) {
return static_cast<uint8_t>(std::nearbyint(val));
}

static inline uint8_t convert_to_uint4(float val) {
return static_cast<uint8_t>(std::nearbyint(val)) & 0x0f;
}

static inline int8_t convert_to_int8(float val) {
return static_cast<int8_t>(std::nearbyint(val));
}
Expand Down Expand Up @@ -713,22 +723,95 @@ static std::shared_ptr<ov::op::v0::Constant> compress_quantized_weights_internal
const ov::Shape& zero_point_shape,
size_t levels,
bool& can_fuse_zero_point) {
ov::Tensor compressed_weights_tensor(ov::element::i8, weights_shape);
int8_t* compressed_weights = compressed_weights_tensor.data<int8_t>();
ov::Tensor compressed_weights_with_fused_zero_point_tensor(ov::element::i8, weights_shape);
int8_t* compressed_weights_with_fused_zero_point = compressed_weights_with_fused_zero_point_tensor.data<int8_t>();
T levels_minus_one = static_cast<T>(levels - 1);
can_fuse_zero_point = true;
const auto convert_to_low_precision = low_precision_type == ov::element::i4 ? convert_to_int4 : convert_to_int8;
ov::element::Type new_low_precision_type = low_precision_type;
bool all_low_not_neg = true;
for (int i = 0; i < ov::shape_size(output_low_shape); ++i) {
all_low_not_neg = all_low_not_neg && (output_low[i] >= 0);
if (!all_low_not_neg) {
break;
}
}

if (low_precision_type == ov::element::i8 && all_low_not_neg) {
new_low_precision_type = ov::element::u8;
} else if (low_precision_type == ov::element::i4 && all_low_not_neg) {
new_low_precision_type = ov::element::u4;
}

ov::element::Type tensor_el_type;
if (new_low_precision_type == ov::element::i8 || new_low_precision_type == ov::element::i4) {
tensor_el_type = ov::element::i8;
} else if (new_low_precision_type == ov::element::u8 || new_low_precision_type == ov::element::u4) {
tensor_el_type = ov::element::u8;
}

ov::Tensor compressed_weights_tensor(tensor_el_type, weights_shape);
ov::Tensor compressed_weights_with_fused_zero_point_tensor(tensor_el_type, weights_shape);

// TODO: reuse the common code parts
if (tensor_el_type == ov::element::u8) {
auto* compressed_weights = compressed_weights_tensor.data<uint8_t>();
auto* compressed_weights_with_fused_zero_point =
compressed_weights_with_fused_zero_point_tensor.data<uint8_t>();
T levels_minus_one = static_cast<T>(levels - 1);
can_fuse_zero_point = true;
const auto convert_to_low_precision =
low_precision_type == ov::element::u4 ? convert_to_uint4 : convert_to_uint8;

auto f = [compressed_weights_with_fused_zero_point,
levels_minus_one,
convert_to_low_precision,
&can_fuse_zero_point](T weights_value,
T input_low,
T input_high,
T output_low,
T output_high,
T zero_point) mutable {
uint8_t compressed_weights_value =
convert_to_low_precision(ov::reference::fake_quantize_details::quantize(weights_value,
input_low,
input_high,
output_low,
output_high,
levels_minus_one));
T weights_minus_zero_point = static_cast<T>(compressed_weights_value) - zero_point;
uint8_t compressed_weights_with_fused_zero_point_value = convert_to_low_precision(weights_minus_zero_point);
can_fuse_zero_point &=
std::fabs(compressed_weights_with_fused_zero_point_value - weights_minus_zero_point) < 1e-4;
*compressed_weights_with_fused_zero_point++ = compressed_weights_with_fused_zero_point_value;
return compressed_weights_value;
};

auto f =
[compressed_weights_with_fused_zero_point, levels_minus_one, convert_to_low_precision, &can_fuse_zero_point](
T weights_value,
T input_low,
T input_high,
T output_low,
T output_high,
T zero_point) mutable {
numpy_broadcast_6inputs(weights,
weights_shape,
input_low,
input_low_shape,
input_high,
input_high_shape,
output_low,
output_low_shape,
output_high,
output_high_shape,
zero_point,
zero_point_shape,
compressed_weights,
f);
} else if (tensor_el_type == ov::element::i8) {
auto* compressed_weights = compressed_weights_tensor.data<int8_t>();
auto* compressed_weights_with_fused_zero_point = compressed_weights_with_fused_zero_point_tensor.data<int8_t>();
T levels_minus_one = static_cast<T>(levels - 1);
can_fuse_zero_point = true;
const auto convert_to_low_precision = low_precision_type == ov::element::i4 ? convert_to_int4 : convert_to_int8;

auto f = [compressed_weights_with_fused_zero_point,
levels_minus_one,
convert_to_low_precision,
&can_fuse_zero_point](T weights_value,
T input_low,
T input_high,
T output_low,
T output_high,
T zero_point) mutable {
int8_t compressed_weights_value =
convert_to_low_precision(ov::reference::fake_quantize_details::quantize(weights_value,
input_low,
Expand All @@ -744,24 +827,25 @@ static std::shared_ptr<ov::op::v0::Constant> compress_quantized_weights_internal
return compressed_weights_value;
};

numpy_broadcast_6inputs(weights,
weights_shape,
input_low,
input_low_shape,
input_high,
input_high_shape,
output_low,
output_low_shape,
output_high,
output_high_shape,
zero_point,
zero_point_shape,
compressed_weights,
f);
numpy_broadcast_6inputs(weights,
weights_shape,
input_low,
input_low_shape,
input_high,
input_high_shape,
output_low,
output_low_shape,
output_high,
output_high_shape,
zero_point,
zero_point_shape,
compressed_weights,
f);
}

return create_weights_constant(
can_fuse_zero_point ? compressed_weights_with_fused_zero_point_tensor : compressed_weights_tensor,
low_precision_type);
new_low_precision_type);
}

std::shared_ptr<ov::op::v0::Constant> compress_quantized_weights(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,7 @@ class CompressQuantizeWeightsTests
}
};

#ifdef OPENVINO_ARCH_ARM64
// Ticket: CVS-122397
TEST_P(CompressQuantizeWeightsTests, DISABLED_FusionTest) {}
#else
TEST_P(CompressQuantizeWeightsTests, FusionTest) {}
#endif

static std::vector<CompressQuantizeWeightsParams> params = {
{Shape{2, 3, 1, 1},
Expand Down Expand Up @@ -138,6 +133,54 @@ static std::vector<CompressQuantizeWeightsParams> params = {
0.0313725f,
-64.25f,
false},
{Shape{2, 3, 1, 1},
{-1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 11.0f},
0.0f,
10.0f,
1.0f,
5.0f,
3,
element::u4,
{1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f},
3.0f,
-0.666667f,
false},
{Shape{2, 3, 1, 1},
{-1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 11.0f},
0.0f,
10.0f,
1.0f,
4.0f,
16,
element::u4,
{8.0f, 5.0f, 4.0f, 2.0f, 0.0f, 7.0f},
0.333333f,
-5.0f,
false},
{Shape{2, 4, 1, 1},
{-1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 11.0f},
1.0f,
9.0f,
2.0f,
6.0f,
17,
element::u8,
{4.0f, 4.0f, 4.0f, 2.0f, 0.0f, 2.0f, 4.0f, 12.0f},
0.5f,
-4.0f,
true},
{Shape{2, 4, 1, 1},
{-1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 11.0f},
1.0f,
9.0f,
2.0f,
6.0f,
256,
element::u8,
{128.0f, 128.0f, 128.0f, 96.0f, 64.0f, 32.0f, 0.0f, 127.0f},
0.0313725f,
-64.25f,
false},
};

static element::TypeVector data_precisions = {element::f32, element::f16};
Expand Down

0 comments on commit 6febbc1

Please sign in to comment.