Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WC] Align compression subgraphs for both weight input data types #2537

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,51 +153,37 @@ def transform_model(
compressed_const = opset.constant(
compressed_weight.tensor.data, dtype=compression_dtype, name=const_node_name
)
converted_const = opset.convert(compressed_const, ov.Type.f32)
converted_const = opset.convert(compressed_const, ov.Type.f16)
if compressed_weight.zero_point is not None:
zero_point_const = opset.constant(
compressed_weight.zero_point.data,
dtype=compression_dtype,
name=f"{const_node_name}/zero_point",
)
converted_zero_point = opset.convert(zero_point_const, ov.Type.f32)
converted_zero_point = opset.convert(zero_point_const, ov.Type.f16)
converted_const = opset.subtract(
converted_const, converted_zero_point, name=f"{const_node_name}/zero_point/subtract"
)

scale_const = opset.constant(
compressed_weight.scale.data, dtype=ov.Type.f16, name=f"{const_node_name}/scale"
)
scale_const = opset.convert(scale_const, ov.Type.f32, name=f"{const_node_name}/scale_convert")
mul = opset.multiply(
converted_const,
scale_const,
name=f"{const_node_name}/fq_weights_{wc_params.weight_port_id}",
)
if const_dtype == ov.Type.f32:
mul = opset.convert(
mul, ov.Type.f32, name=f"{mul.get_friendly_name()}/convert"
)
nikita-savelyevv marked this conversation as resolved.
Show resolved Hide resolved

if compression_config.group_size != -1:
mul = opset.reshape(mul, output_shape=original_shape, special_zero=False)

mul_output = mul.output(0)
for target_input in const_node.output(0).get_target_inputs():
target_input_node = target_input.get_node()
if const_dtype == ov.Type.f16:
target_input_node_attrs = target_input_node.get_attributes()
if (
target_input_node.get_type_name() == "Convert"
and target_input_node_attrs["destination_type"] == "f32"
):
# Before compression, there was a f16 -> f32 Convert node after the weight. Now, scale multiply
# node is in f32, and this Convert node is not needed.
next_node_target_input = next(iter(target_input_node.output(0).get_target_inputs()))
next_node_target_input.replace_source_output(mul_output)
else:
# Both weight and activation are in f16. After the addition of f32 scale multiply node we have
# to add a Convert node.
mul_converted = opset.convert(mul, ov.Type.f16, name=f"{mul.get_friendly_name()}/convert")
target_input.replace_source_output(mul_converted.output(0))
else:
target_input.replace_source_output(mul_output)
target_input.replace_source_output(mul_output)

# reset name_to_node_mapping
self.name_to_node_mapping = None
Expand Down
43 changes: 22 additions & 21 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ def check_int8_node(op: ov.Node, mode: CompressWeightsMode = CompressWeightsMode

mul_node = get_next_node(sub_node)
assert mul_node.get_type_name() == "Multiply"
convert_node = mul_node.input_value(1).get_node()
scale_node = convert_node.input_value(0).get_node()
scale_node = mul_node.input_value(1).get_node()
scale = get_const_value(scale_node)

return {
Expand Down Expand Up @@ -132,11 +131,13 @@ def check_int4_grouped(op: ov.Node, mode: CompressWeightsMode, group_size: int =

mul_node = get_next_node(sub_node)
assert mul_node.get_type_name() == "Multiply"
convert_node = mul_node.input_value(1).get_node()
scale_node = convert_node.input_value(0).get_node()
scale_node = mul_node.input_value(1).get_node()
assert list(scale_node.shape) == reduced_weight_shape

reshape_node = get_next_node(mul_node)
convert_node = get_next_node(mul_node)
assert convert_node.get_type_name() == "Convert"

reshape_node = get_next_node(convert_node)
assert reshape_node.get_type_name() == "Reshape"

return {
Expand All @@ -157,11 +158,13 @@ def check_nf4_grouped(op: ov.Node, group_size: int = 7):

mul_node = get_next_node(convert_node)
assert mul_node.get_type_name() == "Multiply"
convert_node = mul_node.input_value(1).get_node()
scale_node = convert_node.input_value(0).get_node()
scale_node = mul_node.input_value(1).get_node()
assert list(scale_node.shape) == reduced_weight_shape

reshape_node = get_next_node(mul_node)
convert_node = get_next_node(mul_node)
assert convert_node.get_type_name() == "Convert"

reshape_node = get_next_node(convert_node)
assert reshape_node.get_type_name() == "Reshape"

return {
Expand Down Expand Up @@ -705,21 +708,19 @@ def test_compression_for_different_dtypes():
compressed_model = compress_weights(model)
name_to_node_map = {op.get_friendly_name(): op for op in compressed_model.get_ops()}

# Scale should always be converted from f16 to f32
assert "weights/scale_convert" in name_to_node_map
# Weight scale should be in fp16 nevertheless the weight data type
scale_multiply_node = name_to_node_map["weights/fq_weights_1"]
convert_node = scale_multiply_node.input_value(1).get_node()
scale_node = convert_node.input_value(0).get_node()
assert scale_node.get_element_type() == ov.Type.f16
assert convert_node.get_element_type() == ov.Type.f32

node_after_scale = get_next_node(scale_multiply_node)
if activation_dtype == np.float16 and weight_dtype == np.float16:
# If both weights and activations are in f16, there should be a f32 -> f16 convert after scale multiply
assert node_after_scale.get_type_name() == "Convert"
assert scale_multiply_node.input_value(1).get_node().get_element_type() == ov.Type.f16

next_node = get_next_node(scale_multiply_node)
if activation_dtype == np.float16:
# There should be no convert node after multiply if both weights and activations are in f16
assert next_node.get_type_name() != "Convert"
else:
# Otherwise there should be no Convert node after scale multiply
assert node_after_scale.get_type_name() == "MatMul"
assert next_node.get_type_name() == "Convert"
# In case weight is in fp32, the convert node is manually inserted
if weight_dtype == np.float32:
assert next_node.get_friendly_name() == "weights/fq_weights_1/convert"


DATASET_SIZE = 129
Expand Down
Loading