From 371961266e9610eea9091c6c08ef924469f901f1 Mon Sep 17 00:00:00 2001 From: Nikolay Date: Mon, 9 Oct 2023 15:28:46 +0200 Subject: [PATCH] Initial unification of weight compression --- nncf/openvino/quantization/quantize_model.py | 27 -- .../weight_compression/algorithm.py | 131 ++++++++++ .../algorithms/weight_compression/backend.py | 82 ++++++ .../weight_compression/openvino_backend.py} | 244 ++++++++---------- .../weight_compression/torch_backend.py | 128 +++++++++ nncf/quantization/quantize_model.py | 21 +- nncf/torch/quantization/quantize_model.py | 34 --- .../torch/quantization/weights_compression.py | 102 -------- .../quantization/test_weights_compression.py | 10 +- 9 files changed, 468 insertions(+), 311 deletions(-) create mode 100644 nncf/quantization/algorithms/weight_compression/algorithm.py create mode 100644 nncf/quantization/algorithms/weight_compression/backend.py rename nncf/{openvino/quantization/weights_compression.py => quantization/algorithms/weight_compression/openvino_backend.py} (70%) create mode 100644 nncf/quantization/algorithms/weight_compression/torch_backend.py delete mode 100644 nncf/torch/quantization/weights_compression.py diff --git a/nncf/openvino/quantization/quantize_model.py b/nncf/openvino/quantization/quantize_model.py index ad07f6fdcf4..cbc4e03a842 100644 --- a/nncf/openvino/quantization/quantize_model.py +++ b/nncf/openvino/quantization/quantize_model.py @@ -24,8 +24,6 @@ from nncf.openvino.quantization.backend_parameters import BackendParameters from nncf.openvino.quantization.backend_parameters import is_weight_compression_needed from nncf.openvino.quantization.quantize_ifmodel import apply_algorithm_if_bodies -from nncf.openvino.quantization.weights_compression import insert_pre_compression_operations -from nncf.parameters import CompressWeightsMode from nncf.parameters import DropType from nncf.parameters import ModelType from nncf.parameters import TargetDevice @@ -437,28 +435,3 @@ def quantize_with_accuracy_control_impl( advanced_quantization_parameters, advanced_accuracy_restorer_parameters, ) - - -def compress_weights_impl( - model: ov.Model, - mode: CompressWeightsMode = CompressWeightsMode.INT8, - ratio: Optional[float] = None, - group_size: Optional[int] = None, -) -> ov.Model: - """ - Implementation of the `compress_weights()` method for the OpenVINO backend. - - :param model: an OpenVINO model for compression. - :param mode: Defines a mode for weight compression. - INT8 stands for 8-bit integer quantization of all weights. - NF4 stands for a mixed-precision weights quantization to NF4 data type. The first and last layers - are always compressed to a backup precision which is 8-bit integer by default. All others are quantized whether - to NF4 or to a backup precision depending on criteria and the given ratio. - :param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4 and - the rest to INT8). - :param group_size: number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale). - The value -1 means no grouping. - :return: The non-trainable model with compressed weights and dequantization operations. - """ - insert_pre_compression_operations(model, mode, ratio, group_size) - return model diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py new file mode 100644 index 00000000000..91ebc2ad074 --- /dev/null +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -0,0 +1,131 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, TypeVar + +from nncf import Dataset +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.graph import NNCFNode +from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer +from nncf.common.utils.backend import BackendType +from nncf.common.utils.backend import get_backend +from nncf.parameters import CompressWeightsMode +from nncf.quantization.algorithms.algorithm import Algorithm +from nncf.quantization.algorithms.smooth_quant.backend import ALGO_BACKENDS + +TModel = TypeVar("TModel") +TTensor = TypeVar("TTensor") + + +class WeightCompression(Algorithm): + """ + Post-training Weight Compression algorithm implementation. + + Compresses weights of Linear and Embedding layers to 8-bit integer or + to nf4 depending on mode, ratio and group size. + """ + + def __init__( + self, + mode: CompressWeightsMode, + ratio: float = None, + group_size: int = None, + ): + """ + :param mode: Defines a mode for weight compression. + INT8 stands for 8-bit integer quantization of all weights. + NF4 stands for a mixed-precision weights quantization to NF4 data type. The first and last layers + are always compressed to a backup precision which is 8-bit integer by default. All others are quantized + whether to NF4 or to a backup precision depending on criteria and the given ratio. + :param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4 + and the rest to INT8). + :param group_size: number of weights (e.g. 128) in the channel dimension + that share quantization parameters (scale). The value -1 means no grouping. + """ + super().__init__() + self._mode = mode + self._group_size = group_size + self._ratio = ratio + self._backend_entity = None + self._algorithm_key = f"CW_{hash(self)}" + + @property + def available_backends(self) -> Dict[str, BackendType]: + return ALGO_BACKENDS.registry_dict + + def _set_backend_entity(self, model: TModel) -> None: + """ + Creates a helper class with a backed-specific logic of the algorithm. + + :param model: Backend-specific input model. + """ + model_backend = get_backend(model) + if model_backend == BackendType.OPENVINO: + from nncf.quantization.algorithms.weight_compression.openvino_backend import OVWeightCompressionAlgoBackend + + self._backend_entity = OVWeightCompressionAlgoBackend() + elif model_backend == BackendType.TORCH: + from nncf.quantization.algorithms.weight_compression.torch_backend import PTWeightCompressionAlgoBackend + + self._backend_entity = PTWeightCompressionAlgoBackend() + else: + raise RuntimeError( + "Cannot return backend-specific entity because {} is not supported!".format(model_backend) + ) + + def apply( + self, + model: TModel, + graph: NNCFGraph, + statistic_points: Optional[StatisticPointsContainer] = None, + dataset: Optional[Dataset] = None, + ) -> TModel: + self._set_backend_entity(model) + self._backend_entity.validate_params(self._mode) + nodes_to_compress = self._get_nodes_to_compress(graph) + transformed_model = self._backend_entity.do_compression( + model, nodes_to_compress, self._mode, self._ratio, self._group_size + ) + return transformed_model + + def _get_nodes_to_compress(self, nncf_graph: NNCFGraph) -> List[NNCFNode]: + """ + Collects nodes in the model's graph corresponding to the layers for weight compression. + + :param nncf_graph: NNCFGraph instance. + :return: List with the data for each layer. + """ + weighted_metatypes = self._backend_entity.weighted_metatypes + ordered_nodes_to_compress = [] + for node in nncf_graph.topological_sort(): + is_node_with_weights = self._backend_entity.is_node_with_weights(node) + if node.metatype in weighted_metatypes and is_node_with_weights: + ordered_nodes_to_compress.append(node) + return ordered_nodes_to_compress + + def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: + """ + Returns statistic points, for which StatisticsCollector should collect statistics. + + :param model: Model for statistics collection. + :param graph: Model graph. + :return: Statistic points, for which StatisticsCollector should collect statistics. + """ diff --git a/nncf/quantization/algorithms/weight_compression/backend.py b/nncf/quantization/algorithms/weight_compression/backend.py new file mode 100644 index 00000000000..e8bc2c8f911 --- /dev/null +++ b/nncf/quantization/algorithms/weight_compression/backend.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +from abc import abstractmethod +from typing import List, Optional, TypeVar + +from nncf.common.graph import NNCFNode +from nncf.common.graph.operator_metatypes import OperatorMetatype +from nncf.common.utils.registry import Registry +from nncf.parameters import CompressWeightsMode + +TModel = TypeVar("TModel") +ALGO_BACKENDS = Registry("algo_backends") + + +class WeightCompressionAlgoBackend(ABC): + @property + @abstractmethod + def weighted_metatypes(self) -> List[OperatorMetatype]: + """ + Property for the backend-specific metatypes. + """ + + @staticmethod + @abstractmethod + def is_node_with_weights(node: NNCFNode) -> bool: + """ + Checks whether the node with weights or not. + + :param node: NNCFNode to check. + :return: boolean indicating whether the node has weights or not. + """ + + @staticmethod + @abstractmethod + def validate_params(mode: CompressWeightsMode) -> None: + """ + Performs validation of the algorithm's parameters and raises an error for unsupported configuration of + parameters. Should be called on early algorithm steps to prevent execution of time-consuming operations. + + :param mode: Defines a mode for weight compression. + INT8 stands for 8-bit integer quantization of all weights. + NF4 stands for a mixed-precision weights quantization to NF4 data type. The first and last layers + are always compressed to a backup precision which is 8-bit integer by default. All others are quantized + whether to NF4 or to a backup precision depending on criteria and the given ratio. + """ + + @staticmethod + @abstractmethod + def do_compression( + model: TModel, + nodes_to_compress: List[NNCFNode], + mode: CompressWeightsMode, + ratio: float = None, + group_size: int = None, + ) -> TModel: + """ + Compress weights of Linear and Embedding layers to 8-bit integer or to nf4 + depending on mode, ratio and group size. + + :param model: Model for applying weight compression. + :param nodes_to_compress: List of nodes in the model's graph, + corresponding to the layers for weight compression. + :param mode: Defines a mode for weight compression. + INT8 stands for 8-bit integer quantization of all weights. + NF4 stands for a mixed-precision weights quantization to NF4 data type. The first and last layers + are always compressed to a backup precision which is 8-bit integer by default. All others are quantized + whether to NF4 or to a backup precision depending on criteria and the given ratio. + :param ratio: The ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4 + and the rest to INT8). + :param group_size: Number of weights (e.g. 128) in the channel dimension + that share quantization parameters (scale). The value -1 means no grouping. + """ diff --git a/nncf/openvino/quantization/weights_compression.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py similarity index 70% rename from nncf/openvino/quantization/weights_compression.py rename to nncf/quantization/algorithms/weight_compression/openvino_backend.py index fa588cd6c63..7670cd08db9 100644 --- a/nncf/openvino/quantization/weights_compression.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -9,26 +9,134 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Type, TypeVar, Union +from typing import List, Optional, Tuple, TypeVar, Union import numpy as np import openvino.runtime as ov from openvino.runtime import opset9 as opset +from nncf.common.graph import NNCFNode from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.logging import nncf_logger from nncf.common.quantization.statistics import _proportion_str +from nncf.common.utils.backend import BackendType from nncf.common.utils.helpers import create_table from nncf.openvino.graph.metatypes.openvino_metatypes import OVEmbeddingMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype -from nncf.openvino.graph.metatypes.openvino_metatypes import get_node_metatype -from nncf.openvino.graph.metatypes.openvino_metatypes import get_operation_const_op +from nncf.openvino.graph.node_utils import get_channel_agnostic_reduction_axes from nncf.openvino.graph.node_utils import get_const_value -from nncf.openvino.graph.node_utils import get_matmul_channel_axes +from nncf.openvino.graph.node_utils import get_weight_channel_axes from nncf.parameters import CompressWeightsMode +from nncf.quantization.algorithms.smooth_quant.backend import ALGO_BACKENDS +from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend from nncf.quantization.fake_quantize import calculate_scale_zero_point + +@ALGO_BACKENDS.register(BackendType.OPENVINO) +class OVWeightCompressionAlgoBackend(WeightCompressionAlgoBackend): + @property + def weighted_metatypes(self) -> List[OperatorMetatype]: + return [OVMatMulMetatype, OVEmbeddingMetatype] + + @staticmethod + def is_node_with_weights(node: NNCFNode) -> bool: + return node.layer_attributes and node.layer_attributes.constant_attributes + + @staticmethod + def validate_params(mode: CompressWeightsMode) -> None: + pass + + @staticmethod + def do_compression( + model: ov.Model, + nodes_to_compress: List[NNCFNode], + mode: CompressWeightsMode, + ratio: float = None, + group_size: int = None, + ) -> ov.Model: + """ + Compresses weights of Linear and Embedding layers to 8-bit integer or to nf4 + depending on mode, ratio and group size. + + :param model: The OpenVINO model for applying weight compression. + :param nodes_to_compress: List of nodes in the model's graph, + corresponding to the layers for weight compression. + :param mode: Defines a mode for weight compression. + INT8 stands for 8-bit integer quantization of all weights. + NF4 stands for a mixed-precision weights quantization to NF4 data type. The first and last layers + are always compressed to a backup precision which is 8-bit integer by default. All others are quantized + whether to NF4 or to a backup precision depending on criteria and the given ratio. + :param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4 + and the rest to INT8). + :param group_size: number of weights (e.g. 128) in the channel dimension that share + quantization parameters (scale). The value -1 means no grouping. + """ + all_weight_params: List[WeightNodeParams] = [] + quantized_nodes_ids = set() + + friendly_name_to_op_map = {op.get_friendly_name(): op for op in model.get_ops()} + + for nncf_node in nodes_to_compress: + weight_port_ids = nncf_node.layer_attributes.get_const_port_ids() + for weight_port_id in weight_port_ids: + weight_op_friendly_name = nncf_node.layer_attributes.constant_attributes[weight_port_id]["name"] + weight_node = friendly_name_to_op_map[weight_op_friendly_name] + if weight_node is None: + continue + if id(weight_node) in quantized_nodes_ids: + continue + weight_output = weight_node.output(0) + + original_weight_dtype = weight_output.get_element_type().to_dtype() + if original_weight_dtype not in [np.float32, np.float16, np.float64]: + continue + const_shape = nncf_node.layer_attributes.constant_attributes[weight_port_id]["shape"] + channel_axes = get_weight_channel_axes(nncf_node, weight_port_id) + axes = get_channel_agnostic_reduction_axes(channel_axes, const_shape) + fq_name = f"{weight_op_friendly_name}/fq_weights_{weight_port_id}" + num_weights = math.prod(const_shape) + weight_params = WeightNodeParams(axes, num_weights, fq_name, weight_node, original_weight_dtype) + all_weight_params.append(weight_params) + quantized_nodes_ids.add(id(weight_node)) + + if mode == CompressWeightsMode.NF4: + _assign_mixed_precision(all_weight_params, ratio, group_size) + + for wp in all_weight_params: + weight_node = wp.weight_node + original_weight_dtype = wp.original_weight_dtype + + weight_output = weight_node.output(0) + weight_name = weight_node.get_friendly_name() + target_inputs = weight_output.get_target_inputs() + + weight = get_const_value(weight_node) + config = wp.compression_config + + if config.is_nf4: + original_shape = weight.shape + norm_weight, scale = _get_norm_weight_and_nf4_scale(weight, wp.reduction_axes, group_size) + compressed_const = opset.constant(norm_weight, dtype=ov.Type.nf4, name=weight_name) + convert = opset.convert(compressed_const, original_weight_dtype) + mul = opset.multiply(convert, scale.astype(original_weight_dtype), name=wp.fq_name) + if config.group_size != -1: + mul = opset.reshape(mul, output_shape=original_shape, special_zero=False) + last_output = mul.output(0) + else: + compressed_weights, scale, zero_point = _int8_compress(weight, wp.reduction_axes) + compressed_const = opset.constant(compressed_weights, dtype=np.uint8, name=weight_name) + convert = opset.convert(compressed_const, original_weight_dtype) + sub = opset.subtract(convert, zero_point.astype(original_weight_dtype)) + mul = opset.multiply(sub, scale.astype(original_weight_dtype), name=wp.fq_name) + last_output = mul.output(0) + + for target_input in target_inputs: + target_input.replace_source_output(last_output) + return model + + TWeightType = TypeVar("TWeightType") NF4_QUANTILES = np.array( @@ -52,26 +160,7 @@ ] ) - -CENTER_OF_NF4_QUANTILES = np.array( - [ - -0.8480964004993439, - -0.6106329262256622, - -0.4599952697753906, - -0.33967943489551544, - -0.23460740596055984, - -0.13791173323988914, - -0.045525018125772476, - 0.03979014977812767, - 0.1202552504837513, - 0.2035212516784668, - 0.2920137718319893, - 0.3893125355243683, - 0.5016634166240692, - 0.6427869200706482, - 0.8614784181118011, - ] -) +CENTER_OF_NF4_QUANTILES = (NF4_QUANTILES[1:] + NF4_QUANTILES[:-1]) / 2 @dataclass @@ -322,110 +411,3 @@ def _assign_mixed_precision(all_weight_params: List[WeightNodeParams], ratio: fl for weight_param in all_weight_params[1:-1]: weight_param.compression_config = nf4_config nncf_logger.info(_get_bitwidth_distribution_str(all_weight_params)) - - -def insert_pre_compression_operations( - model: ov.Model, - mode: CompressWeightsMode, - ratio: float, - group_size: int, -) -> None: - """ - Compress weights of Linear and Embedding layers to 8-bit integer or to nf4 depending on mode, ratio and group size. - - :param model: The model to be transformed. - :param mode: Defines a mode for weight compression. - INT8 stands for 8-bit integer quantization of all weights. - NF4 stands for a mixed-precision weights quantization to NF4 data type. The first and last layers - are always compressed to a backup precision which is 8-bit integer by default. All others are quantized whether - to NF4 or to a backup precision depending on criteria and the given ratio. - :param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4 - and the rest to INT8). - :param group_size: number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale). - The value -1 means no grouping. - """ - allowed_metatypes_to_const_port = {OVEmbeddingMetatype: [0], OVMatMulMetatype: [0, 1]} - - all_weight_params: List[WeightNodeParams] = [] - quantized_nodes_ids = set() - for node in model.get_ordered_ops(): - metatype = get_node_metatype(node) - if metatype not in allowed_metatypes_to_const_port: - continue - - for const_port_id in allowed_metatypes_to_const_port[metatype]: - weight_node = get_operation_const_op(node, const_port_id) - if weight_node is None: - continue - if id(weight_node) in quantized_nodes_ids: - continue - weight_output = weight_node.output(0) - weight_name = weight_node.get_friendly_name() - target_inputs = weight_output.get_target_inputs() - - original_weight_dtype = weight_output.get_element_type().to_dtype() - if original_weight_dtype not in [np.float32, np.float16, np.float64]: - continue - axes = _get_reduction_axes(metatype, node, const_port_id) - fq_name = f"{node.get_friendly_name()}/fq_weights_{const_port_id}" - weight = get_const_value(weight_node) - num_weights = weight.size - weight_params = WeightNodeParams(axes, num_weights, fq_name, weight_node, original_weight_dtype) - all_weight_params.append(weight_params) - quantized_nodes_ids.add(id(weight_node)) - - if mode == CompressWeightsMode.NF4: - _assign_mixed_precision(all_weight_params, ratio, group_size) - - for wp in all_weight_params: - weight_node = wp.weight_node - original_weight_dtype = wp.original_weight_dtype - - weight_output = weight_node.output(0) - weight_name = weight_node.get_friendly_name() - target_inputs = weight_output.get_target_inputs() - - weight = get_const_value(weight_node) - config = wp.compression_config - - if config.is_nf4: - original_shape = weight.shape - norm_weight, scale = _get_norm_weight_and_nf4_scale(weight, wp.reduction_axes, group_size) - compressed_const = opset.constant(norm_weight, dtype=ov.Type.nf4, name=weight_name) - convert = opset.convert(compressed_const, original_weight_dtype) - mul = opset.multiply(convert, scale.astype(original_weight_dtype), name=wp.fq_name) - if config.group_size != -1: - mul = opset.reshape(mul, output_shape=original_shape, special_zero=False) - last_output = mul.output(0) - else: - compressed_weights, scale, zero_point = _int8_compress(weight, wp.reduction_axes) - compressed_const = opset.constant(compressed_weights, dtype=np.uint8, name=weight_name) - convert = opset.convert(compressed_const, original_weight_dtype) - sub = opset.subtract(convert, zero_point.astype(original_weight_dtype)) - mul = opset.multiply(sub, scale.astype(original_weight_dtype), name=wp.fq_name) - last_output = mul.output(0) - - for target_input in target_inputs: - target_input.replace_source_output(last_output) - - -def _get_reduction_axes(metatype: Type[OperatorMetatype], node: ov.Node, weight_port_id: int) -> Union[int, Tuple[int]]: - """ - Determines reduction axes by given metatype and node information. - - :param metatype: The metatype of the operator. - :param node: The OpenVINO node. - :param weight_port_id: The weight port ID. - - :return: The reduction axes as an integer or a tuple of integers. - """ - if metatype is OVMatMulMetatype: - transpose = node.get_attributes()[f"transpose_{'a' if weight_port_id == 0 else 'b'}"] - ndims = node.input(weight_port_id).get_partial_shape().rank.get_max_length() - channel_axes = get_matmul_channel_axes(weight_port_id, ndims, transpose) - axes = tuple(i for i in range(ndims) if i not in channel_axes) - elif metatype is OVEmbeddingMetatype: - axes = (metatype.const_channel_axis[0] + 1) % 2 - else: - RuntimeError("Unsupported metatype to find reduction axes.") - return axes diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py new file mode 100644 index 00000000000..1c28f6703d7 --- /dev/null +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -0,0 +1,128 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +from torch.nn import nn + +from nncf.common.graph import NNCFNode +from nncf.common.graph.operator_metatypes import OperatorMetatype +from nncf.common.utils.backend import BackendType +from nncf.parameters import CompressWeightsMode +from nncf.quantization.algorithms.smooth_quant.backend import ALGO_BACKENDS +from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend +from nncf.torch.graph.operator_metatypes import PTModuleEmbeddingMetatype +from nncf.torch.graph.operator_metatypes import PTModuleLinearMetatype +from nncf.torch.layers import NNCF_WRAPPED_USER_MODULES_DICT +from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules +from nncf.torch.quantization.quantize_functions import get_scale_zp_from_input_low_input_high + + +class WeightsDecompressor(nn.Module): + """ + Applies decompression of compressed weights on the forward pass. + + Attributes: + zero_point: zero point in quantization scheme + scale: scale in quantization scheme + """ + + def __init__(self, zero_point, scale): + super().__init__() + self.zero_point = zero_point + self.scale = scale + + def forward(self, layer, op_arg): + w = layer.weight.type(dtype=self.scale.dtype) + layer.weight = (w - self.zero_point) * self.scale + + +@ALGO_BACKENDS.register(BackendType.TORCH) +class PTWeightCompressionAlgoBackend(WeightCompressionAlgoBackend): + @property + def weighted_metatypes(self) -> List[OperatorMetatype]: + return [PTModuleLinearMetatype, PTModuleEmbeddingMetatype] + + @staticmethod + def is_node_with_weights(_: NNCFNode) -> bool: + return True + + @staticmethod + def do_compression( + model: nn.Module, + nodes_to_compress: List[NNCFNode], + mode: CompressWeightsMode, + ratio: float = None, + group_size: int = None, + ) -> nn.Module: + """ + Compress weights of Linear and Embedding layers to 8-bit integer. + + :param model: The Torch model for applying weight compression. + :param nodes_to_compress: List of nodes in the model's graph, + corresponding to the layers for weight compression. + :param mode: Defines a mode for weight compression. + INT8 stands for 8-bit integer quantization of all weights. + NF4 stands for a mixed-precision weights quantization to NF4 data type. The first and last layers + are always compressed to a backup precision which is 8-bit integer by default. All others are quantized + whether to NF4 or to a backup precision depending on criteria and the given ratio. + :param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4 + and the rest to INT8). + :param group_size: number of weights (e.g. 128) in the channel dimension + that share quantization parameters (scale). The value -1 means no grouping. + :return: The non-trainable module with inserted operations. + """ + model, _ = replace_modules_by_nncf_modules(model) + + bits = 8 + level_high = 2**bits - 1 + assert level_high < 256 + + user_types = list(NNCF_WRAPPED_USER_MODULES_DICT.values()) + + if compression_hist is None: + compression_hist = {} + for node in nodes_to_compress: + layer = model.nncf.get_containing_module(node.node_name) + + if not type(layer) in user_types: + continue + + if layer.weight.dtype in [torch.uint8, torch.int8]: + if layer.weight in compression_hist: + layer.register_pre_forward_operation(compression_hist[layer.weight]) + continue + + target_dim = layer.target_weight_dim_for_compression + stat_dim = (target_dim + 1) % 2 + input_low = torch.min(layer.weight, dim=stat_dim).values.detach() + input_high = torch.max(layer.weight, dim=stat_dim).values.detach() + scale, zero_point = get_scale_zp_from_input_low_input_high(0, level_high, input_low, input_high) + + scale = scale.unsqueeze(stat_dim) + zero_point = zero_point.unsqueeze(stat_dim) + key = layer.register_pre_forward_operation(WeightsDecompressor(zero_point, scale)) + + compressed_weight = layer.weight.data / scale + zero_point + compressed_weight = torch.clamp(torch.round(compressed_weight), 0, level_high) + + layer.weight.requires_grad = False + layer.weight.data = compressed_weight.type(dtype=torch.uint8) + + compression_hist[layer.weight] = layer.get_pre_op(key) + + return model + + @staticmethod + def validate_params(mode: CompressWeightsMode) -> None: + if mode != CompressWeightsMode.INT8: + raise AttributeError(f"Torch backend supports only INT8 mode for weight compression, but given {mode} mode") diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index c87479fa401..dadc25cde6a 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union from nncf.api.compression import TModel +from nncf.common.factory import NNCFGraphFactory from nncf.common.quantization.structs import QuantizationPreset from nncf.common.utils.api_marker import api from nncf.common.utils.backend import BackendType @@ -27,6 +28,7 @@ from nncf.quantization.algorithms.hyperparameter_tuner.algorithm import HyperparameterTuner from nncf.quantization.algorithms.hyperparameter_tuner.param_grid import get_quantization_param_grid from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization +from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression from nncf.scopes import IgnoredScope TTensor = TypeVar("TTensor") @@ -228,7 +230,10 @@ def quantize_with_accuracy_control( @api(canonical_alias="nncf.compress_weights") def compress_weights( - model: TModel, mode=CompressWeightsMode.INT8, ratio: Optional[float] = None, group_size: Optional[int] = None + model: TModel, + mode=CompressWeightsMode.INT8, + ratio: Optional[float] = None, + group_size: Optional[int] = None, ) -> TModel: """ Compress model weights. @@ -245,7 +250,6 @@ def compress_weights( The value -1 means no grouping. :return: The non-trainable model with compressed weights. """ - backend = get_backend(model) if mode == CompressWeightsMode.INT8: if ratio is None: ratio = 1 @@ -262,16 +266,9 @@ def compress_weights( if group_size is None: group_size = 128 - if backend == BackendType.TORCH: - from nncf.torch.quantization.quantize_model import compress_weights_impl - - return compress_weights_impl(model, mode, ratio, group_size) - if backend == BackendType.OPENVINO: - from nncf.openvino.quantization.quantize_model import compress_weights_impl - - return compress_weights_impl(model, mode, ratio, group_size) - - raise RuntimeError(f"Unsupported type of backend: {backend}") + compression_algorithm = WeightCompression(mode, ratio, group_size) + graph = NNCFGraphFactory.create(model) + return compression_algorithm.apply(model, graph) def quantize_with_tune_hyperparams( diff --git a/nncf/torch/quantization/quantize_model.py b/nncf/torch/quantization/quantize_model.py index 621f3bc6ebb..830db170644 100644 --- a/nncf/torch/quantization/quantize_model.py +++ b/nncf/torch/quantization/quantize_model.py @@ -19,7 +19,6 @@ from nncf.config.structures import BNAdaptationInitArgs from nncf.config.structures import QuantizationRangeInitArgs from nncf.data import Dataset -from nncf.parameters import CompressWeightsMode from nncf.parameters import ModelType from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters @@ -33,8 +32,6 @@ from nncf.torch.initialization import PTInitializingDataLoader from nncf.torch.model_creation import create_compressed_model from nncf.torch.nested_objects_traversal import objwalk -from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules -from nncf.torch.quantization.weights_compression import insert_pre_compression_operations from nncf.torch.utils import get_model_device from nncf.torch.utils import is_tensor @@ -259,34 +256,3 @@ def send_to_device(tensor): compressed_model.nncf.disable_dynamic_graph_building() return compressed_model - - -def compress_weights_impl( - model: torch.nn.Module, - mode=CompressWeightsMode.INT8, - ratio: Optional[float] = None, - group_size: Optional[int] = None, -) -> torch.nn.Module: - """ - Implementation of the `compress_weights()` method for the PyTorch backend. Currently it supports INT8 - mode only with default ratio and group_size. - - :param model: a Torch model for compression. - :param mode: Defines a mode for weight compression. - INT8 stands for 8-bit integer quantization of all weights. - NF4 stands for a mixed-precision weights quantization to NF4 data type. The first and last layers - are always compressed to a backup precision which is 8-bit integer by default. All others are quantized whether - to NF4 or to a backup precision depending on criteria and the given ratio. - :param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4 - and the rest to INT8). - :param group_size: number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale). - The value -1 means no grouping. - :return: The non-trainable model with compressed weights and dequantization operations. - """ - - if mode != CompressWeightsMode.INT8: - raise AttributeError(f"Torch backend supports only INT8 mode for weight compression, but given {mode} mode") - compressed_model, _ = replace_modules_by_nncf_modules(model) - insert_pre_compression_operations(model) - - return compressed_model diff --git a/nncf/torch/quantization/weights_compression.py b/nncf/torch/quantization/weights_compression.py deleted file mode 100644 index 2cf8947fd89..00000000000 --- a/nncf/torch/quantization/weights_compression.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) 2023 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, List, Optional - -import torch -from torch import nn - -from nncf.torch.layers import NNCF_WRAPPED_USER_MODULES_DICT -from nncf.torch.layers import NNCFEmbedding -from nncf.torch.layers import NNCFLinear -from nncf.torch.quantization.quantize_functions import get_scale_zp_from_input_low_input_high - - -class WeightsDecompressor(nn.Module): - """Applies decompression of compressed weights in forward pass - - Attributes: - zero_point: zero point in quantization scheme - scale: scale in quantizatin scheme - """ - - def __init__(self, zero_point, scale): - super().__init__() - self.zero_point = zero_point - self.scale = scale - - def forward(self, layer, op_arg): - w = layer.weight.type(dtype=self.scale.dtype) - layer.weight = (w - self.zero_point) * self.scale - - -def _insert_pre_compression_operations( - module: nn.Module, allowed_types: List, level_high: int = 255, compression_hist: Dict = None -) -> Optional[nn.Module]: - """ - Inserts weights compression with dequantization for layers in `allowed_types`. - - :param module: The module to insert the weights compression. - :param allowed_types: list of allowed types for weights compression. - :param level_high: highest possible value of compressed weights (lower is 0 in assymetric quantization). - :param compression_hist: mapping between layer weight and corresponding WeightsDecompressor for finding - shared weights. - :return: The non-trainable module with inserted operations. - """ - if compression_hist is None: - compression_hist = {} - for _, layer in module.named_children(): - if not type(layer) in allowed_types: - _insert_pre_compression_operations(layer, allowed_types, level_high, compression_hist) - continue - - if layer.weight.dtype in [torch.uint8, torch.int8]: - if layer.weight in compression_hist: - layer.register_pre_forward_operation(compression_hist[layer.weight]) - continue - - target_dim = layer.target_weight_dim_for_compression - stat_dim = (target_dim + 1) % 2 - input_low = torch.min(layer.weight, dim=stat_dim).values.detach() - input_high = torch.max(layer.weight, dim=stat_dim).values.detach() - scale, zero_point = get_scale_zp_from_input_low_input_high(0, level_high, input_low, input_high) - - scale = scale.unsqueeze(stat_dim) - zero_point = zero_point.unsqueeze(stat_dim) - key = layer.register_pre_forward_operation(WeightsDecompressor(zero_point, scale)) - - compressed_weight = layer.weight.data / scale + zero_point - compressed_weight = torch.clamp(torch.round(compressed_weight), 0, level_high) - - layer.weight.requires_grad = False - layer.weight.data = compressed_weight.type(dtype=torch.uint8) - - compression_hist[layer.weight] = layer.get_pre_op(key) - - -def insert_pre_compression_operations(module: nn.Module, bits: int = 8) -> Optional[nn.Module]: - """ - Inserts weights compression with dequantization for Linear and Embedding layers. - - :param module: The module to insert the weights compression. - :param bits: number of bits for compression. Note: type of compressed weights is 8-bit integer. - :return: The non-trainable module with inserted operations. - """ - user_types = list(NNCF_WRAPPED_USER_MODULES_DICT.values()) - allowed_types = [NNCFEmbedding, NNCFLinear] - level_high = 2**bits - 1 - - assert level_high < 256 - - for user_type in user_types: - allowed_types.append(user_type) - - _insert_pre_compression_operations(module, allowed_types, level_high) diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 52d1d32229a..766561efe11 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -19,10 +19,10 @@ from nncf import CompressWeightsMode from nncf.openvino.graph.node_utils import get_const_value -from nncf.openvino.quantization.weights_compression import _calculate_scale_per_group -from nncf.openvino.quantization.weights_compression import _get_int8_err -from nncf.openvino.quantization.weights_compression import _get_nf4_error from nncf.quantization import compress_weights +from nncf.quantization.algorithms.weight_compression.openvino_backend import _calculate_scale_per_group +from nncf.quantization.algorithms.weight_compression.openvino_backend import _get_int8_err +from nncf.quantization.algorithms.weight_compression.openvino_backend import _get_nf4_error from tests.openvino.native.models import IntegerModel from tests.openvino.native.models import SequentialMatmulModel from tests.openvino.native.models import WeightsModel @@ -345,10 +345,10 @@ def test_raise_error_with_incorrect_group_size(): def test_raise_error_with_int8_and_non_default_ratio(mocker): - with pytest.raises(RuntimeError): + with pytest.raises(AttributeError): compress_weights(mocker.Mock(), mode=CompressWeightsMode.INT8, ratio=0.5) def test_raise_error_with_int8_and_non_default_group_size(mocker): - with pytest.raises(RuntimeError): + with pytest.raises(AttributeError): compress_weights(mocker.Mock(), mode=CompressWeightsMode.INT8, group_size=64)