diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index fd05ba09a88..9b727a66b2c 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -8,6 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict from dataclasses import dataclass from typing import List, Optional, Tuple, TypeVar @@ -58,7 +59,9 @@ def do_compression( friendly_name_to_op_map = {op.get_friendly_name(): op for op in model.get_ops()} - for nncf_node in nodes_to_compress: + is_last_layer_compressed = False + n = len(nodes_to_compress) + for i, nncf_node in enumerate(nodes_to_compress): weight_port_ids = nncf_node.layer_attributes.get_const_port_ids() for weight_port_id in weight_port_ids: weight_op_friendly_name = nncf_node.layer_attributes.constant_attributes[weight_port_id]["name"] @@ -66,6 +69,8 @@ def do_compression( if weight_node is None: continue if id(weight_node) in quantized_nodes_ids: + if i == n - 1: + is_last_layer_compressed = True continue weight_output = weight_node.output(0) @@ -87,15 +92,24 @@ def do_compression( fq_name = f"{weight_op_friendly_name}/fq_weights_{weight_port_id}" num_weights = np.prod(const_shape) weight_params = WeightNodeParams( - reduction_axis, num_weights, fq_name, weight_node, original_weight_dtype + reduction_axis, + num_weights, + fq_name, + weight_node, + original_weight_dtype, + metatype=nncf_node.metatype, ) all_weight_params.append(weight_params) quantized_nodes_ids.add(id(weight_node)) + + internal_weight_params = all_weight_params if mode != CompressWeightsMode.INT8: + internal_weight_params = list(filter(lambda wp: wp.metatype != OVEmbeddingMetatype, all_weight_params)) + if not is_last_layer_compressed: + internal_weight_params = internal_weight_params[:-1] primary_config = WeightCompressionConfig(mode=mode, group_size=group_size) - _assign_mixed_precision(all_weight_params, ratio, primary_config) - - nncf_logger.info(_get_bitwidth_distribution_str(all_weight_params)) + _assign_mixed_precision(internal_weight_params, ratio, primary_config) + nncf_logger.info(_get_bitwidth_distribution_str(all_weight_params, internal_weight_params)) for wp in track(all_weight_params, description="Applying Weight Compression"): weight_node = wp.weight_node @@ -180,6 +194,7 @@ class WeightNodeParams: :param weight_node: The weight node itself. :param original_weight_dtype: Type of elements in the weight array. :param compression_config: Configuration of weight compression for the weight node. + :param metatype: Metatype of the corresponding operation with weight. """ reduction_axis: int @@ -188,6 +203,7 @@ class WeightNodeParams: weight_node: ov.Node original_weight_dtype: TWeightType compression_config = WeightCompressionConfig() + metatype: OperatorMetatype = None def _do_integer_quantization( @@ -325,29 +341,31 @@ def _proportion_str(num_weights_list: List[int], total_num_weights: int, total_n return f"{percentage:.0f}% ({len(num_weights_list)} / {total_num_params})" -def _get_bitwidth_distribution_str(all_weight_params: List[WeightNodeParams]) -> str: +def _get_bitwidth_distribution_str(all_params: List[WeightNodeParams], internal_params: List[WeightNodeParams]) -> str: """ Generates a table that shows the ratio of weights quantized to different number of bits. - :param all_weight_params: List of information about each weight node. + :param all_params: List of information about each weight node. + :param internal_params: List of information about weight nodes that are considered for mixed precision. :return: A string containing the table. """ - total_num_weights = sum(ws.num_weights for ws in all_weight_params) - num_internal_weights = 0 - num_params = len(all_weight_params) - num_internal_params = 0 - if num_params > 2: - num_internal_params = num_params - 2 + not_internal_params = [wp for wp in all_params if wp not in internal_params] num_bits_vs_num_weights_map = {} - for i, data in enumerate(all_weight_params): + for data in internal_params: + num_bits = data.compression_config.num_bits + n_internal, n_internal = num_bits_vs_num_weights_map.get(num_bits, ([], [])) + n_internal.append(data.num_weights) + num_bits_vs_num_weights_map[num_bits] = (n_internal, n_internal) + for data in not_internal_params: num_bits = data.compression_config.num_bits n_total, n_internal = num_bits_vs_num_weights_map.get(num_bits, ([], [])) - if i not in (0, num_params - 1): - n_internal.append(data.num_weights) - num_internal_weights += data.num_weights n_total.append(data.num_weights) num_bits_vs_num_weights_map[num_bits] = (n_total, n_internal) - + num_internal_weights = sum(ws.num_weights for ws in internal_params) + num_internal_params = len(internal_params) + total_num_weights = num_internal_weights + sum(ws.num_weights for ws in not_internal_params) + num_params = len(all_params) + num_bits_vs_num_weights_map = OrderedDict(sorted(num_bits_vs_num_weights_map.items(), reverse=True)) # Table creation header = ["Num bits (N)", "% all parameters (layers)", "% internal parameters (layers)"] rows = [] @@ -366,25 +384,25 @@ def _get_bitwidth_distribution_str(all_weight_params: List[WeightNodeParams]) -> def _assign_mixed_precision( - all_weight_params: List[WeightNodeParams], ratio: float, primary_config: WeightCompressionConfig + internal_weight_params: List[WeightNodeParams], ratio: float, primary_config: WeightCompressionConfig ) -> None: """ Assigns mixed quantization scheme (e.g. uniform int8 or non-uniform nf4) for weights based on some criteria. - - :param all_weight_params: List of information about each weight node. The quantization scheme is added to this info. + :param internal_weight_params: List of information about internal weight nodes. Only internal nodes are considered + for mixed precision. The quantization scheme is added to this info. :param ratio: The ratio between primary and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4 and the rest to INT8). :param primary_config: Information on how to compress (quantize) weights to primary precision. :return: None. """ if ratio == 1: - for weight_param in all_weight_params[1:-1]: + for weight_param in internal_weight_params: weight_param.compression_config = primary_config return errors = [] num_internal_weights = 0 # NOTE: first and last layers are always in 8 bit: no need to calculate error for them - for weight_param in track(all_weight_params[1:-1], description="Searching for Mixed-Precision Configuration"): + for weight_param in track(internal_weight_params, description="Searching for Mixed-Precision Configuration"): weight = get_const_value(weight_param.weight_node) backup_config = weight_param.compression_config reduction_axis = weight_param.reduction_axis @@ -393,14 +411,12 @@ def _assign_mixed_precision( error = 1 / (backup_error + eps) errors.append(error) num_internal_weights += weight_param.num_weights - # NOTE: index is defined in the array of all weight params by taking into account that errors were not - # calculated for first and last layers. indexes_of_layers_in_ascending_order_of_errors = [ - i[0] + 1 for i in sorted(enumerate(errors), reverse=False, key=lambda x: x[1]) + i[0] for i in sorted(enumerate(errors), reverse=False, key=lambda x: x[1]) ] num_weights_in_4bit = 0 for index in indexes_of_layers_in_ascending_order_of_errors: - weight_param = all_weight_params[index] + weight_param = internal_weight_params[index] current_ratio = (num_weights_in_4bit + weight_param.num_weights) / num_internal_weights if current_ratio >= ratio: break diff --git a/tests/openvino/native/data/2023.1/reference_scales/IntegerModel_compressed_weights_int4_asym.json b/tests/openvino/native/data/2023.1/reference_scales/IntegerModel_compressed_weights_int4_asym.json index 783cad5b51d..4d8fbc6318c 100644 --- a/tests/openvino/native/data/2023.1/reference_scales/IntegerModel_compressed_weights_int4_asym.json +++ b/tests/openvino/native/data/2023.1/reference_scales/IntegerModel_compressed_weights_int4_asym.json @@ -1,99 +1,35 @@ { "matmul_2_data": { - "compressed_weight": [ - [ - 115, - 51, - 154, - 255, - 79, - 18, - 139 - ], - [ - 59, - 27, - 174, - 89, - 201, - 60, - 255 - ], - [ - 110, - 32, - 189, - 255, - 132, - 255, - 150 - ], - [ - 190, - 255, - 255, - 255, - 206, - 255, - 223 - ], - [ - 165, - 245, - 129, - 229, - 222, - 255, - 36 - ], - [ - 192, - 245, - 255, - 4, - 228, - 255, - 253 - ] - ], - "zero_point": [ - [ - 0 - ], - [ - 0 - ], - [ - 0 - ], - [ - 0 - ], - [ - 0 - ], - [ - 0 - ] - ], "scale": [ [ - 0.0029188350308686495 + [ + 0.04962019622325897 + ] ], [ - 0.0033386670984327793 + [ + 0.05675733834505081 + ] ], [ - 0.003329785307869315 + [ + 0.05660634860396385 + ] ], [ - 0.0022347758058458567 + [ + 0.03799118846654892 + ] ], [ - 0.003204419743269682 + [ + 0.05447513610124588 + ] ], [ - 0.0037901517935097218 + [ + 0.06443258374929428 + ] ] ] }, @@ -190,30 +126,52 @@ ] }, "gather_2_data": { + "compressed_weight": [ + [ + 181, + 77, + 12, + 5, + 231, + 255 + ], + [ + 166, + 200, + 149, + 255, + 223, + 1 + ], + [ + 255, + 10, + 224, + 54, + 255, + 166 + ] + ], + "zero_point": [ + [ + 0 + ], + [ + 0 + ], + [ + 0 + ] + ], "scale": [ [ - [ - 0.039732541888952255 - ], - [ - 0.05974852666258812 - ] + 0.0035146193113178015 ], [ - [ - 0.012391435913741589 - ], - [ - 0.062155596911907196 - ] + 0.003656211541965604 ], [ - [ - 0.05492125079035759 - ], - [ - 0.04583488777279854 - ] + 0.003253307193517685 ] ] } diff --git a/tests/openvino/native/data/2023.1/reference_scales/IntegerModel_compressed_weights_int4_sym.json b/tests/openvino/native/data/2023.1/reference_scales/IntegerModel_compressed_weights_int4_sym.json index 35cf3ff1c15..d529ce90e47 100644 --- a/tests/openvino/native/data/2023.1/reference_scales/IntegerModel_compressed_weights_int4_sym.json +++ b/tests/openvino/native/data/2023.1/reference_scales/IntegerModel_compressed_weights_int4_sym.json @@ -1,99 +1,35 @@ { "matmul_2_data": { - "compressed_weight": [ - [ - 115, - 51, - 154, - 255, - 79, - 18, - 139 - ], - [ - 59, - 27, - 174, - 89, - 201, - 60, - 255 - ], - [ - 110, - 32, - 189, - 255, - 132, - 255, - 150 - ], - [ - 190, - 255, - 255, - 255, - 206, - 255, - 223 - ], - [ - 165, - 245, - 129, - 229, - 222, - 255, - 36 - ], - [ - 192, - 245, - 255, - 4, - 228, - 255, - 253 - ] - ], - "zero_point": [ - [ - 0 - ], - [ - 0 - ], - [ - 0 - ], - [ - 0 - ], - [ - 0 - ], - [ - 0 - ] - ], "scale": [ [ - 0.0029188350308686495 + [ + 0.11376060545444489 + ] ], [ - 0.0033386670984327793 + [ + 0.1345875859260559 + ] ], [ - 0.003329785307869315 + [ + 0.13637007772922516 + ] ], [ - 0.0022347758058458567 + [ + 0.14215664565563202 + ] ], [ - 0.003204419743269682 + [ + 0.13315138220787048 + ] ], [ - 0.0037901517935097218 + [ + 0.14017072319984436 + ] ] ] }, @@ -190,30 +126,52 @@ ] }, "gather_2_data": { + "compressed_weight": [ + [ + 181, + 77, + 12, + 5, + 231, + 255 + ], + [ + 166, + 200, + 149, + 255, + 223, + 1 + ], + [ + 255, + 10, + 224, + 54, + 255, + 166 + ] + ], + "zero_point": [ + [ + 0 + ], + [ + 0 + ], + [ + 0 + ] + ], "scale": [ [ - [ - 0.09099452942609787 - ], - [ - 0.13039365410804749 - ] + 0.0035146193113178015 ], [ - [ - 0.10421378910541534 - ], - [ - 0.13358177244663239 - ] + 0.003656211541965604 ], [ - [ - 0.12248633056879044 - ], - [ - 0.12331127375364304 - ] + 0.003253307193517685 ] ] } diff --git a/tests/openvino/native/data/2023.2/reference_scales/IntegerModel_compressed_weights_nf4.json b/tests/openvino/native/data/2023.2/reference_scales/IntegerModel_compressed_weights_nf4.json index 9f80217f056..5402cda275b 100644 --- a/tests/openvino/native/data/2023.2/reference_scales/IntegerModel_compressed_weights_nf4.json +++ b/tests/openvino/native/data/2023.2/reference_scales/IntegerModel_compressed_weights_nf4.json @@ -1,99 +1,35 @@ { "matmul_2_data": { - "compressed_weight": [ - [ - 115, - 51, - 154, - 255, - 79, - 18, - 139 - ], - [ - 59, - 27, - 174, - 89, - 201, - 60, - 255 - ], - [ - 110, - 32, - 189, - 255, - 132, - 255, - 150 - ], - [ - 190, - 255, - 255, - 255, - 206, - 255, - 223 - ], - [ - 165, - 245, - 129, - 229, - 222, - 255, - 36 - ], - [ - 192, - 245, - 255, - 4, - 228, - 255, - 253 - ] - ], - "zero_point": [ - [ - 0 - ], - [ - 0 - ], - [ - 0 - ], - [ - 0 - ], - [ - 0 - ], - [ - 0 - ] - ], "scale": [ [ - 0.0029188350308686495 + [ + 0.7963242530822754 + ] ], [ - 0.0033386670984327793 + [ + 0.9421131014823914 + ] ], [ - 0.003329785307869315 + [ + 0.9545904994010925 + ] ], [ - 0.0022347758058458567 + [ + 0.9950965046882629 + ] ], [ - 0.003204419743269682 + [ + 0.9320597052574158 + ] ], [ - 0.0037901517935097218 + [ + 0.9811950325965881 + ] ] ] }, @@ -190,30 +126,52 @@ ] }, "gather_2_data": { + "compressed_weight": [ + [ + 181, + 77, + 12, + 5, + 231, + 255 + ], + [ + 166, + 200, + 149, + 255, + 223, + 1 + ], + [ + 255, + 10, + 224, + 54, + 255, + 166 + ] + ], + "zero_point": [ + [ + 0 + ], + [ + 0 + ], + [ + 0 + ] + ], "scale": [ [ - [ - 0.6369616985321045 - ], - [ - 0.91275554895401 - ] + 0.0035146193113178015 ], [ - [ - 0.7294965386390686 - ], - [ - 0.9350724220275879 - ] + 0.003656211541965604 ], [ - [ - 0.8574042916297913 - ], - [ - 0.8631789088249207 - ] + 0.003253307193517685 ] ] } diff --git a/tests/openvino/native/models.py b/tests/openvino/native/models.py index 89033891f0e..4f4ac61a82c 100644 --- a/tests/openvino/native/models.py +++ b/tests/openvino/native/models.py @@ -794,12 +794,35 @@ def _create_ov_model(self): input_1 = opset.parameter([2, 3], name="Input") convert_1 = opset.convert(input_1, destination_type="i64", name="Convert_1") - gather_2_data = opset.constant(self._rng.random((3, 2, 1)), dtype=np.float32, name="gather_2_data") + gather_1_data = opset.constant(self._rng.random((3, 2, 1)), dtype=np.float32, name="gather_1_data") + gather_1 = opset.gather(gather_1_data, convert_1, axis=0, batch_dims=0) + gather_1.set_friendly_name("Gather_1") + + result = opset.result(gather_1, name="Result") + model = ov.Model([result], [input_1]) + return model + + +class GatherAndMatmulShareData(OVReferenceModel): + def _create_ov_model(self): + input_1 = opset.parameter([2, 3], name="Input") + convert_1 = opset.convert(input_1, destination_type="i64", name="Convert_1") + + shared_data = opset.constant(self._rng.random((2, 2)), dtype=np.float32, name="shared_data") + gather_1 = opset.gather(shared_data, convert_1, axis=0, batch_dims=0) + gather_1.set_friendly_name("Gather_1") + + gather_2_data = opset.constant(self._rng.random((2, 1)), dtype=np.float32, name="gather_2_data") gather_2 = opset.gather(gather_2_data, convert_1, axis=0, batch_dims=0) gather_2.set_friendly_name("Gather_2") - result = opset.result(gather_2, name="Result") - model = ov.Model([result], [input_1]) + matmul_1_data = opset.constant(self._rng.random((2, 3)), dtype=np.float32, name="matmul_1_data") + matmul_1 = opset.matmul(input_1, matmul_1_data, transpose_a=False, transpose_b=True, name="MatMul_1") + + matmul_2 = opset.matmul(matmul_1, shared_data, transpose_a=False, transpose_b=True, name="MatMul_2") + + result = opset.result(matmul_2, name=" Result") + model = ov.Model([result, gather_2, gather_1], [input_1]) return model diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index ecd8b9fe3e4..ab154a3453a 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -24,6 +24,7 @@ from nncf.quantization.algorithms.weight_compression.openvino_backend import _get_integer_quantization_error from nncf.quantization.algorithms.weight_compression.openvino_backend import _reshape_weights_for_grouped_quantization from nncf.scopes import IgnoredScope +from tests.openvino.native.models import GatherAndMatmulShareData from tests.openvino.native.models import GatherWithTwoReductionAxes from tests.openvino.native.models import IntegerModel from tests.openvino.native.models import SequentialMatmulModel @@ -74,7 +75,7 @@ def check_int8_node(op: ov.Node): } -def check_int4_grouped(op: ov.Node, mode: CompressWeightsMode, group_size: int = 3): +def check_int4_grouped(op: ov.Node, mode: CompressWeightsMode, group_size: int = 7): assert op.get_element_type() == ov.Type.u4 weight_shape = op.shape # NOTE: get_const_value doesn't work for 4-bit types @@ -111,7 +112,7 @@ def check_int4_grouped(op: ov.Node, mode: CompressWeightsMode, group_size: int = } -def check_nf4_grouped(op: ov.Node, group_size: int = 3): +def check_nf4_grouped(op: ov.Node, group_size: int = 7): assert op.get_element_type() == ov.Type.nf4 weight_shape = op.shape # NOTE: get_const_value doesn't work for 4-bit types @@ -145,9 +146,8 @@ def check_int4_asym_grouped(op: ov.Node): def get_mixed_mapping(primary_fn: Callable, list_layers: List[str]): mapping = {node_name: check_int8_node for node_name in list_layers} - - for node_name in TEST_MODELS[IntegerModel][1:-1]: - mapping[node_name] = primary_fn + primary_node_name = TEST_MODELS[IntegerModel][0] + mapping[primary_node_name] = primary_fn return mapping @@ -155,9 +155,9 @@ def get_mixed_mapping(primary_fn: Callable, list_layers: List[str]): ("mode", "group_size", "check_fn_per_node_map"), ( (CompressWeightsMode.INT8, -1, {node_name: check_int8_node for node_name in TEST_MODELS[IntegerModel]}), - (CompressWeightsMode.INT4_SYM, 3, get_mixed_mapping(check_int4_sym_grouped, TEST_MODELS[IntegerModel])), - (CompressWeightsMode.INT4_ASYM, 3, get_mixed_mapping(check_int4_asym_grouped, TEST_MODELS[IntegerModel])), - (CompressWeightsMode.NF4, 3, get_mixed_mapping(check_nf4_grouped, TEST_MODELS[IntegerModel])), + (CompressWeightsMode.INT4_SYM, 7, get_mixed_mapping(check_int4_sym_grouped, TEST_MODELS[IntegerModel])), + (CompressWeightsMode.INT4_ASYM, 7, get_mixed_mapping(check_int4_asym_grouped, TEST_MODELS[IntegerModel])), + (CompressWeightsMode.NF4, 7, get_mixed_mapping(check_nf4_grouped, TEST_MODELS[IntegerModel])), ), ) def test_compare_compressed_weights(mode, group_size, check_fn_per_node_map): @@ -201,10 +201,25 @@ def test_not_quantize_with_multiple_reduction_axes(): model = GatherWithTwoReductionAxes().ov_model compressed_model = compress_weights(model, mode=CompressWeightsMode.INT8) for op in compressed_model.get_ordered_ops(): - if op.get_type_name() == "Constant" and op.get_friendly_name() == "gather_2_data": + if op.get_type_name() == "Constant" and op.get_friendly_name() == "gather_1_data": assert op.get_element_type() == ov.Type(np.float32) +@pytest.mark.parametrize("mode", (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM)) +def test_shared_gather(mode): + weight_name_vs_type = { + "gather_2_data": ov.Type(np.uint8), + "shared_data": ov.Type(np.uint8), + "matmul_1_data": ov.Type.u4, + } + model = GatherAndMatmulShareData().ov_model + compressed_model = compress_weights(model, mode, group_size=3) + for op in compressed_model.get_ordered_ops(): + op_name = op.get_friendly_name() + if op.get_type_name() == "Constant" and op_name in weight_name_vs_type: + assert op.get_element_type() == weight_name_vs_type[op_name] + + @dataclass class QuantErrorDesc: weight: List[float]