diff --git a/nncf/openvino/quantization/quantize_model.py b/nncf/openvino/quantization/quantize_model.py index d74a656169e..8b8f1b8dbdf 100644 --- a/nncf/openvino/quantization/quantize_model.py +++ b/nncf/openvino/quantization/quantize_model.py @@ -34,6 +34,7 @@ from nncf.parameters import SensitivityMetric from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import AdvancedAccuracyRestorerParameters +from nncf.quantization.advanced_parameters import AdvancedCompressionParameters from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.advanced_parameters import convert_to_dict_recursively from nncf.quantization.algorithms.accuracy_control.algorithm import QuantizationAccuracyRestorer @@ -407,6 +408,8 @@ def compress_weights_impl( sensitivity_metric: SensitivityMetric, awq: bool, subset_size: int, + scale_estimation: bool, + advanced_parameters: Optional[AdvancedCompressionParameters] = None, ) -> ov.Model: """ Implementation of the `compress_weights()` method for the OpenVINO backend. @@ -414,7 +417,16 @@ def compress_weights_impl( model = remove_friendly_name_duplicates(model) compression_algorithm = WeightCompression( - mode, ratio, group_size, ignored_scope, all_layers, sensitivity_metric, awq, subset_size + mode, + ratio, + group_size, + ignored_scope, + all_layers, + sensitivity_metric, + awq, + subset_size, + scale_estimation, + advanced_parameters, ) graph = NNCFGraphFactory.create(model) return compression_algorithm.apply(model, graph, dataset=dataset) diff --git a/nncf/quantization/advanced_parameters.py b/nncf/quantization/advanced_parameters.py index dd74741eecc..94a04f5bef9 100644 --- a/nncf/quantization/advanced_parameters.py +++ b/nncf/quantization/advanced_parameters.py @@ -246,6 +246,81 @@ class AdvancedQuantizationParameters: backend_params: Dict[str, Any] = field(default_factory=dict) +@api() +@dataclass +class AdvancedAWQParameters: + """ + Contains advanced parameters for AWQ algorithm. + It regulates the calculation of the smooth scale for different node types. + A negative value switches off the algorithm for current node type. In case of inaccurate results, + this parameter may be adjusted in the range from 0 to 1 or set -1 to disable SmoothQuant algorithm. + + :param subset_size: The number of samples for AWQ. + :type subset_size: int + :param percent_to_apply: The percent of outliers for correction. + :type percent_to_apply: float + :param alpha_min: Minimum value of smoothness parameter for grid search. + :type alpha_min: float + :param alpha_max: Maximal value of smoothness parameter for grid search. + :type alpha_max: float + :param steps: The number of the steps in grid search. + :type steps: int + """ + + subset_size: int = 32 + percent_to_apply: float = 0.002 + alpha_min: float = 0.0 + alpha_max: float = 1.0 + steps: int = 100 + + +@api() +@dataclass +class AdvancedScaleEstimationParameters: + """ + Contains advanced parameters for scale estimation algorithm. + It regulates the calculation of the smooth scale for different node types. + A negative value switches off the algorithm for current node type. In case of inaccurate results, + this parameter may be adjusted in the range from 0 to 1 or set -1 to disable SmoothQuant algorithm. + + :param subset_size: The number of samples for scale estimation. + :type subset_size: int + :param initial_steps: The number of the steps for absmax scale rectification. + :type initial_steps: int + :param scale_steps: The number of the steps for grid search scale rectification + from 1.0 to 1.0 - 0.05 * scale_step. + :type scale_steps: int + :param weight_penalty: coefficient for penalty between fp and compressed weights. If -1 then doesn't apply. + :type weight_penalty: float + """ + + subset_size: int = 32 + initial_steps: int = 5 + scale_steps: int = 10 + weight_penalty: float = -1.0 + + +@api() +@dataclass +class AdvancedCompressionParameters: + """ + Contains advanced parameters for compression algorithms. + + :param awq_params: Advanced parameters for AWQ algorithm. + :type awq_params: AdvancedAWQParameters + :param scale_estimation_params: Advanced parameters for scale estimation algorithm. + :type scale_estimation_params: AdvancedScaleEstimationParameters + """ + + # Advanced AWQ algorithm parameters + awq_params: AdvancedAWQParameters = field(default_factory=AdvancedAWQParameters) + + # Advanced scale estimation algorithm parameters + scale_estimation_params: AdvancedScaleEstimationParameters = field( + default_factory=AdvancedScaleEstimationParameters + ) + + @api() @dataclass class AdvancedAccuracyRestorerParameters: diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 6d354303aa1..e8b151e6d1c 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -30,10 +30,12 @@ from nncf.experimental.tensor.definitions import TensorDataType from nncf.parameters import CompressWeightsMode from nncf.parameters import SensitivityMetric +from nncf.quantization.advanced_parameters import AdvancedCompressionParameters from nncf.quantization.algorithms.algorithm import Algorithm from nncf.quantization.algorithms.weight_compression.awq import AWQ from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.mixed_precision import MIXED_PRECISION_CRITERIA +from nncf.quantization.algorithms.weight_compression.scale_estimation import ScaleEstimation from nncf.quantization.algorithms.weight_compression.weight_lowering import WeightCompressionConfig from nncf.scopes import IgnoredScope from nncf.scopes import get_ignored_node_names_from_ignored_scope @@ -60,6 +62,8 @@ def __init__( sensitivity_metric: SensitivityMetric, awq: bool, subset_size: int, + scale_estimation: bool, + advanced_parameters: Optional[AdvancedCompressionParameters] = None, ): """ :param mode: Defines a mode for weight compression. @@ -88,6 +92,8 @@ def __init__( :param awq: determines whether to use or not modified AWQ algorithm. :param subset_size: Number of data samples to calculate activation statistics used for assigning different quantization precision. + :param scale_estimation: determines whether to use or not scale estimation for 4 bit layers. + :param advanced_parameters: advanced parameters for algorithms in compression pipeline. """ super().__init__() self._mode = mode @@ -101,6 +107,10 @@ def __init__( self._sensitivity_metric = sensitivity_metric self._awq = awq self._subset_size = subset_size + self._scale_estimation = scale_estimation + self._advanced_parameters = ( + advanced_parameters if advanced_parameters is not None else AdvancedCompressionParameters() + ) @property def available_backends(self) -> List[BackendType]: @@ -339,14 +349,40 @@ def do_compression( nncf_logger.info(self._get_bitwidth_distribution_str(all_weight_params, ratio_defining_params)) if self._awq and activations is not None and self._mode != CompressWeightsMode.NF4: + awq_params = self._advanced_parameters.awq_params awq_algo = AWQ( - model, self._backend_entity.name_to_node_mapping, all_weight_params, nodes_to_compress, activations + model, + self._backend_entity.name_to_node_mapping, + all_weight_params, + nodes_to_compress, + activations, + awq_params.subset_size, + awq_params.percent_to_apply, + awq_params.alpha_min, + awq_params.alpha_max, + awq_params.steps, ) awq_algo.apply(model, graph) + precomputed_scales = {wp.node_with_weight.node_name: None for wp in all_weight_params} + if self._scale_estimation and activations is not None and self._mode != CompressWeightsMode.NF4: + scale_estimation_params = self._advanced_parameters.scale_estimation_params + scale_algo = ScaleEstimation( + model, + self._backend_entity.name_to_node_mapping, + all_weight_params, + nodes_to_compress, + activations, + scale_estimation_params.subset_size, + scale_estimation_params.initial_steps, + scale_estimation_params.scale_steps, + scale_estimation_params.weight_penalty, + ) + precomputed_scales = scale_algo.apply(model, graph) + # Compress model using weight compression parameters transformed_model = self._backend_entity.transform_model( - model, graph, track(all_weight_params, description="Applying Weight Compression") + model, graph, track(all_weight_params, description="Applying Weight Compression"), precomputed_scales ) self._backend_entity.dump_parameters( diff --git a/nncf/quantization/algorithms/weight_compression/awq.py b/nncf/quantization/algorithms/weight_compression/awq.py index 91c8dbcce36..dff55ce5f25 100644 --- a/nncf/quantization/algorithms/weight_compression/awq.py +++ b/nncf/quantization/algorithms/weight_compression/awq.py @@ -58,7 +58,7 @@ def __init__( activations: Optional[Dict[str, TTensor]] = None, subset_size: int = 32, percent_to_apply=0.002, - alpha_min=0.01, + alpha_min=0.0, alpha_max=1.0, steps=100, ): @@ -107,8 +107,7 @@ def _set_backend_entity(self, model: TModel) -> None: if model_backend == BackendType.OPENVINO: from nncf.quantization.algorithms.weight_compression.openvino_backend import OVAWQAlgoAlgoBackend - self._backend_entity = OVAWQAlgoAlgoBackend(model) - self._backend_entity.name_to_node_mapping = self.name_to_node_mapping + self._backend_entity = OVAWQAlgoAlgoBackend(model, self.name_to_node_mapping) self._patterns = self._backend_entity.get_awq_patterns() else: raise RuntimeError( @@ -181,11 +180,15 @@ def apply( stats = self._activations[k] X = fns.stack([fns.mean(stat, axis=0) for stat in stats]) X = fns.transpose(X) - if X.shape[1] > self._subset_size: - X = X[:, : self._subset_size] s = fns.max(fns.abs(X), axis=1) + if X.shape[1] > self._subset_size: + lens = [stat.shape[0] for stat in stats] + step = X.shape[1] // self._subset_size + idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step] + X = X[:, idxs] + top_k = max(int(s.shape[0] * self._percent_to_apply), 1) topk_idxs = fns.argsort(-s)[:top_k] @@ -263,6 +266,12 @@ def apply( merge_weight = merge_weight * a_scale self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight) + # update activations for next usage + a_scale_t = fns.transpose(a_scale) + for i, stat in enumerate(self._activations[k]): + stat = stat * a_scale_t + self._activations[k][i] = stat + return model def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: diff --git a/nncf/quantization/algorithms/weight_compression/backend.py b/nncf/quantization/algorithms/weight_compression/backend.py index cfe5564dd11..c610b079d9a 100644 --- a/nncf/quantization/algorithms/weight_compression/backend.py +++ b/nncf/quantization/algorithms/weight_compression/backend.py @@ -109,7 +109,11 @@ def set_weight( @abstractmethod def transform_model( - self, model: TModel, graph: NNCFGraph, weight_compression_parameters: Iterable[WeightCompressionParameters] + self, + model: TModel, + graph: NNCFGraph, + weight_compression_parameters: Iterable[WeightCompressionParameters], + precomputed_scales: Dict[str, Tensor] = None, ) -> TModel: """ Applies weight compression transformations to the model. @@ -117,6 +121,7 @@ def transform_model( :param model: Model in which the weights will be compressed according to the weight compression description. :param graph: The graph associated with the model. :param weight_compression_parameters: List of weight compression parameters. + :param precomputed_scales: Precomputed scales for compressed nodes. :return: The transformed model. """ diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index 5a5649a7d45..855660e8610 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -35,8 +35,11 @@ class OVWeightCompressionAlgoBackend(WeightCompressionAlgoBackend): - def __init__(self, model: ov.Model): - self.name_to_node_mapping = OVModelTransformer._get_name_to_node_mapping(model) + def __init__(self, model: ov.Model, name_to_node_mapping: Dict = None): + if name_to_node_mapping is None: + self.name_to_node_mapping = OVModelTransformer._get_name_to_node_mapping(model) + else: + self.name_to_node_mapping = name_to_node_mapping @property def matmul_metatypes(self) -> List[OperatorMetatype]: @@ -119,7 +122,11 @@ def set_weight( del const_node def transform_model( - self, model: ov.Model, graph: NNCFGraph, weight_compression_parameters: Iterable[WeightCompressionParameters] + self, + model: ov.Model, + graph: NNCFGraph, + weight_compression_parameters: Iterable[WeightCompressionParameters], + precomputed_scales: Dict[str, Tensor] = None, ) -> ov.Model: for wc_params in weight_compression_parameters: compression_config = wc_params.compression_config @@ -146,7 +153,12 @@ def transform_model( weight = Tensor(get_const_value(const_node)) original_shape = weight.shape - compressed_weight = compress_weight(weight, wc_params.reduction_axes, compression_config) + compressed_weight = compress_weight( + weight, + wc_params.reduction_axes, + compression_config, + precomputed_scales[wc_params.node_with_weight.node_name], + ) compressed_const = opset.constant( compressed_weight.tensor.data, dtype=compression_dtype, name=const_node_name @@ -195,6 +207,53 @@ def dump_parameters( ) -> None: dump_parameters(model, parameters, algo_name, path) + @staticmethod + def get_compress_decompress_pipeline( + weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape + ): + ( + w, + s, + zp, + clamp, + ) = OVWeightCompressionAlgoBackend.get_compress_pipeline( + weight_compression_parameter, w_shape, s_shape, z_p_shape, True + ) + + result = (clamp - zp) * s + model = ov.Model([result], [w, s, zp]) + + compiled_model = ov.compile_model(model) + + return lambda w, s, zp: compiled_model([w, s, zp])[0] + + @staticmethod + def get_compress_pipeline( + weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape, return_nodes=False + ): + config = weight_compression_parameter.compression_config + mode = config.mode + assert mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM] + num_bits = config.num_bits + + level_low = 0 + level_high = 2**num_bits - 1 + + w = opset.parameter(w_shape, name="w") + s = opset.parameter(s_shape, name="s") + zp = opset.parameter(z_p_shape, name="zp") + + result = opset.clamp(opset.round(w / s + zp), level_low, level_high, name="compressed_weights") + + if return_nodes: + return w, s, zp, result + + model = ov.Model([result], [w, s, zp]) + + compiled_model = ov.compile_model(model) + + return lambda w, s, zp: compiled_model([w, s, zp])[0] + class OVAWQAlgoAlgoBackend(OVWeightCompressionAlgoBackend): @staticmethod diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py new file mode 100644 index 00000000000..465feee59f3 --- /dev/null +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -0,0 +1,301 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Any, Dict, List, Optional, TypeVar + +from nncf import Dataset +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.graph import NNCFNode +from nncf.common.logging.track_progress import track +from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer +from nncf.common.utils.backend import BackendType +from nncf.common.utils.backend import get_backend +from nncf.experimental.tensor import TensorDataType +from nncf.experimental.tensor import functions as fns +from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_dequantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_quantization +from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization + +TModel = TypeVar("TModel") +TTensor = TypeVar("TTensor") +TWeightType = TypeVar("TWeightType") + + +class ScaleEstimation: + """ + Scale estimation algorithm implementation. + """ + + def __init__( + self, + model: TModel, + name_to_node_mapping: Dict[str, Any], + all_weight_params: List[WeightCompressionParameters], + nodes_to_compress: List[NNCFNode], + activations: Optional[Dict[str, TTensor]] = None, + subset_size: int = 32, + initial_steps: int = 5, + scale_steps: int = 10, + weight_penalty: float = -1.0, + ): + """ + :param model: Model for applying algorithm. + :param name_to_node_mapping: Name to node mapping for updating node weights. + :param all_weight_params: List of all weight parameters. + :param nodes_to_compress: List of nodes for processing. + :param activations: The input activations of the layers considered for compression. + :param subset_size: The number of samples for scale estimation. + :param initial_steps: The number of the steps for absmax scale rectification. + :param scale_steps: The number of the steps for grid search scale rectification + from 1.0 to 1.0 - 0.05 * scale_step. + :param weight_penalty: coefficient for penalty between fp and compressed weights. If -1 then doesn't apply. + """ + super().__init__() + self.name_to_node_mapping = name_to_node_mapping + self._all_weight_params = all_weight_params + self._nodes_to_compress = nodes_to_compress + self._activations = activations + self._subset_size = subset_size + self._initial_steps = initial_steps + self._scale_steps = scale_steps + self._weight_penalty = weight_penalty + + self._set_backend_entity(model) + + @property + def available_backends(self) -> List[BackendType]: + return [BackendType.OPENVINO] + + def _set_backend_entity(self, model: TModel) -> None: + """ + Creates a helper class with a backed-specific logic of the algorithm. + + :param model: Backend-specific input model. + :param all_weight_params: List of all weight parameters. + :param nodes_to_compress: List of nodes for processing. + :param activations: The input activations of the layers considered for compression. + """ + + model_backend = get_backend(model) + if model_backend == BackendType.OPENVINO: + from nncf.quantization.algorithms.weight_compression.openvino_backend import OVWeightCompressionAlgoBackend + + self._backend_entity = OVWeightCompressionAlgoBackend(model, self.name_to_node_mapping) + else: + raise RuntimeError( + "Cannot return backend-specific AWQ entity because {} is not supported!".format(model_backend.value) + ) + + def apply( + self, + model: TModel, + graph: NNCFGraph, + statistic_points: Optional[StatisticPointsContainer] = None, + dataset: Optional[Dataset] = None, + ) -> Dict[str, TTensor]: + """ + Estimates better scale for the int4 nodes in the model. + Minimizes per-group difference between floating point MatMul and + MatMul with compressed weights. + The algorithm computes weighted scale for the group of weights in MatMul, which + shared the same scale. + + :param model: Model for applying algorithm. + :param graph: Model graph. + :param statistic_points: Statistic points with collected statistics values. + :param dataset: A representative dataset for the calibration process. + :return: Dict with pairs (node name, estimated scale). + """ + + compress_decompress_cashe = {} + res = dict() + + for wp in track(self._all_weight_params, description="Applying Scale Estimation"): + k = wp.node_with_weight.node_name + config = wp.compression_config + + if config.num_bits != 4 or k not in self._activations: + res[k] = None + continue + + stats = self._activations[k] + reduction_axis = wp.reduction_axes[0] + + cur_config = deepcopy(config) + cur_config.group_size = -1 + + weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph) + if len(weight_data) != 1: # not supported by the algorithm + continue + _, weight_port_id = weight_data[0] + + X = fns.stack([fns.mean(stat, axis=0) for stat in stats]) + X_full = fns.transpose(X) + + # prevent high memory and time consumption + if X_full.shape[1] > self._subset_size: + lens = [stat.shape[0] for stat in stats] + step = X_full.shape[1] // self._subset_size + idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step] + X = X_full[:, idxs] + else: + X = X_full + + s = fns.max(fns.abs(X_full), axis=1) + + weight = self._backend_entity.get_weight(wp.node_with_weight, weight_port_id, model, graph) + weight = weight.astype(TensorDataType.float32) + eps = fns.finfo(weight).eps + + if reduction_axis == 0: + weight = fns.transpose(weight) + reduction_axis = 1 + + original_weight = fns.zeros_like(weight) + weight + + compressed_weights, scale, zp = do_integer_quantization(original_weight, reduction_axis, config) + zp = zp.astype(scale.dtype) + + q_weights = do_dequantization(compressed_weights, scale, zp, reduction_axis) + + s = fns.unsqueeze(s, 0) + s, _ = reshape_weight_for_grouped_quantization(s, reduction_axis, config.group_size) + + original_weight, _ = reshape_weight_for_grouped_quantization( + original_weight, reduction_axis, config.group_size + ) + + # all weight in group has importance based on corresponding input activations + importance = fns.ones_like(original_weight) + importance = importance * s + + target = compressed_weights.astype(dtype=zp.dtype) - zp + zero_mask = compressed_weights == zp + + importance = fns.where(zero_mask, 0.0, importance) + + # normalize importances for every group of weights to make sum of them equal to 1.0 + denum = fns.sum(importance, axis=2, keepdims=True) + importance = importance / (denum + eps) + + X, _ = reshape_weight_for_grouped_quantization(X, 0, config.group_size) + q_weights, _ = reshape_weight_for_grouped_quantization(q_weights, reduction_axis, config.group_size) + best_diffs = None + result_scale = None + + fp_outs = fns.matmul(fns.transpose(original_weight, (1, 0, 2)), X) + q_outs = fns.matmul(fns.transpose(q_weights, (1, 0, 2)), X) + + # metric for minimization with shape [C_OUT, N_GROUPS], N_GROUPS = C_IN / GROUP_SIZE + min_max_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1) + min_max_scale_diffs = fns.transpose(min_max_scale_diffs, (1, 0)) + if self._weight_penalty > 0.0: + min_max_scale_diffs += self._weight_penalty * fns.mean((q_weights - original_weight) ** 2, axis=-1) + + key = ( + (wp.compression_config.mode, wp.compression_config.num_bits) + q_weights.shape + scale.shape + zp.shape + ) + if key in compress_decompress_cashe: + compress_decompress_model = compress_decompress_cashe[key]["compress_decompress_model"] + compress_model = compress_decompress_cashe[key]["compress_model"] + else: + compress_decompress_model = self._backend_entity.get_compress_decompress_pipeline( + wp, q_weights.shape, scale.shape, zp.shape + ) + compress_model = self._backend_entity.get_compress_pipeline(wp, q_weights.shape, scale.shape, zp.shape) + compress_decompress_cashe[key] = { + "compress_decompress_model": compress_decompress_model, + "compress_model": compress_model, + } + + zero_scale = 0.001 + zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) + + # iterative rectification of initial scale + for i in range(self._initial_steps): + ideal_scale = fns.abs(original_weight) / (fns.abs(target) + zero_mask) + weighted_scale = ideal_scale * importance + + near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True) + + out = compress_decompress_model(original_weight.data, near_to_ideal_scale.data, zp.data) + q_weights_ = fns.zeros_like(original_weight) + out + q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X) + + ideal_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1) + ideal_scale_diffs = fns.transpose(ideal_scale_diffs, (1, 0)) + if self._weight_penalty > 0.0: + ideal_scale_diffs += self._weight_penalty * fns.mean((q_weights_ - original_weight) ** 2, axis=-1) + + if best_diffs is None: + best_diffs = min_max_scale_diffs + + mask = (ideal_scale_diffs > best_diffs).astype(best_diffs.dtype) + + best_diffs = mask * best_diffs + (1.0 - mask) * ideal_scale_diffs + + mask = fns.unsqueeze(mask, axis=2) + + if result_scale is None: + near_to_ideal_scale = mask * scale + (1.0 - mask) * near_to_ideal_scale + else: + near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale + result_scale = near_to_ideal_scale + + if i < self._initial_steps - 1: + out = compress_model(original_weight.data, near_to_ideal_scale.data, zp.data) + compressed_weights = fns.zeros_like(original_weight) + out + target = compressed_weights - zp + zero_mask = compressed_weights == zp + zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) + + # iterative rectification of scale based on grid search + for scale_steps in range(self._scale_steps): + factor = 1.0 - 0.05 * scale_steps + scaled_scale = factor * scale + + out = compress_model(original_weight.data, scaled_scale.data, zp.data) + compressed_weights = fns.zeros_like(original_weight) + out + + target = compressed_weights - zp + zero_mask = compressed_weights == zp + zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) + + ideal_scale = fns.abs(original_weight) / (fns.abs(target) + zero_mask) + weighted_scale = ideal_scale * importance + near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True) + + out = compress_decompress_model(original_weight.data, near_to_ideal_scale.data, zp.data) + q_weights_ = fns.zeros_like(original_weight) + out + + q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X) + ideal_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1) + ideal_scale_diffs = fns.transpose(ideal_scale_diffs, (1, 0)) + if self._weight_penalty > 0.0: + ideal_scale_diffs += self._weight_penalty * fns.mean((q_weights_ - original_weight) ** 2, axis=-1) + + mask = (ideal_scale_diffs > best_diffs).astype(best_diffs.dtype) + + best_diffs = mask * best_diffs + (1.0 - mask) * ideal_scale_diffs + + mask = fns.unsqueeze(mask, axis=2) + + if result_scale is None: + near_to_ideal_scale = mask * scale + (1.0 - mask) * near_to_ideal_scale + else: + near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale + result_scale = near_to_ideal_scale + + res[k] = result_scale + + return res diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index ee78b7c5899..b9beb52b70c 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple import torch @@ -174,7 +174,11 @@ def set_weight( pass def transform_model( - self, model: NNCFNetwork, graph: NNCFGraph, weight_compression_parameters: Iterable[WeightCompressionParameters] + self, + model: NNCFNetwork, + graph: NNCFGraph, + weight_compression_parameters: Iterable[WeightCompressionParameters], + precomputed_scales: Dict[str, Tensor] = None, ) -> NNCFNetwork: transformation_layout = TransformationLayout() diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/nncf/quantization/algorithms/weight_compression/weight_lowering.py index 117640511fc..733789c7d79 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -101,7 +101,7 @@ def calculate_normalized_weight_and_nf4_scale( def do_integer_quantization( - weight: Tensor, reduction_axes: ReductionAxes, config: WeightCompressionConfig + weight: Tensor, reduction_axes: ReductionAxes, config: WeightCompressionConfig, precomputed_scale: Tensor = None ) -> Tuple[Tensor, Tensor, Tensor]: """ The method quantizes the given weights to integer data type in accordance with the compression config. @@ -122,6 +122,7 @@ def do_integer_quantization( :param weight: Weight array to compress. :param reduction_axes: Axes, along which to reduce (collect) different statistics (e.g. min, max). :param config: Information on how to compress (quantize) a specific weight. + :param precomputed_scale: Precomputed scale for better performance. :return: The compressed weights tensor of uint8 type, scale tensor of float32 type and zero point tensor of int32 type that was used for its quantization. """ @@ -147,15 +148,19 @@ def do_integer_quantization( min_values, max_values, level_low, level_high, narrow_range=False ) else: - scale = fns.max(fns.abs(weight), axis=reduction_axes, keepdims=True) # [a1, r//gs, 1, a2] level_low_sym = -(2 ** (num_bits - 1)) level_high_sym = 2 ** (num_bits - 1) - 1 + + scale = fns.max(fns.abs(weight), axis=reduction_axes, keepdims=True) # [a1, r//gs, 1, a2] scale = scale / level_high_sym - zero_point = fns.as_tensor_like(scale, [-level_low_sym]) + zero_point = fns.as_tensor_like(scale, [-level_low_sym]).astype(TensorDataType.int32) eps = fns.finfo(scale).eps # NOTE: adding machine epsilon to avoid division by zero scale = fns.where(fns.abs(scale) < eps, eps, scale) + if precomputed_scale is not None: + scale = precomputed_scale + compressed_weights = fns.round(weight / scale + zero_point.astype(weight.dtype)) compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(TensorDataType.uint8) return compressed_weights, scale, zero_point @@ -189,24 +194,29 @@ def get_integer_quantization_error( return val.item() -def compress_weight(weight: Tensor, reduction_axes: ReductionAxes, config: WeightCompressionConfig): +def compress_weight( + weight: Tensor, reduction_axes: ReductionAxes, config: WeightCompressionConfig, precomputed_scale: Tensor = None +): """ Compress weight using compression configuration. :param weight: The weight to compress. :param reduction_axes: Axes, along which to reduce (collect) different statistics (e.g. min, max). :param config: Compression configuration. + :param precomputed_scale: Precomputed scale for better performance. :return: The compressed weight and decompression parameters as instance of CompressedWeight """ if config.mode == CompressWeightsMode.NF4: compressed_weight, scale = calculate_normalized_weight_and_nf4_scale(weight, reduction_axes, config.group_size) return CompressedWeight(compressed_weight, scale) + compressed_weight, scale, zero_point = do_integer_quantization(weight, reduction_axes, config, precomputed_scale) - compressed_weight, scale, zero_point = do_integer_quantization(weight, reduction_axes, config) return CompressedWeight(compressed_weight, scale, zero_point) -def do_dequantization(compressed_weights: Tensor, scale: Tensor, zero_point: Tensor) -> Tensor: +def do_dequantization( + compressed_weights: Tensor, scale: Tensor, zero_point: Tensor, reduction_axis: int = -1 +) -> Tensor: """ The method dequantizes the given weights to float point data type in accordance with the scale and zero_point data type. @@ -214,8 +224,18 @@ def do_dequantization(compressed_weights: Tensor, scale: Tensor, zero_point: Ten :param compressed_weights: compressed weights. :param scale: scale in compression/quantization. :param zero_point: zero point in compression/quantization. + :param reduction_axis: axis for return back for group compression. :return: dequantized/decompressed weights. """ decompressed_weight = compressed_weights.astype(dtype=scale.dtype) decompressed_weight = (decompressed_weight - zero_point) * scale + + if reduction_axis > -1: + shape = list(decompressed_weight.shape) # [a1, r, a2] - "r" refers to number of channels along reduction axis + shape[reduction_axis] = shape[reduction_axis] * shape[reduction_axis + 1] + shape[reduction_axis + 1] = 1 + reshaped_weight = decompressed_weight.reshape(shape) + reshaped_weight = fns.squeeze(reshaped_weight) + decompressed_weight = reshaped_weight + return decompressed_weight diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index fe8a69ace20..3efa010c180 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -29,6 +29,7 @@ from nncf.parameters import SensitivityMetric from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import AdvancedAccuracyRestorerParameters +from nncf.quantization.advanced_parameters import AdvancedCompressionParameters from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.algorithms.accuracy_control.evaluator import MetricResults from nncf.quantization.algorithms.hyperparameter_tuner.algorithm import HyperparameterTuner @@ -337,6 +338,8 @@ def compress_weights( *, subset_size: Optional[int] = 128, awq: Optional[bool] = None, + scale_estimation: Optional[bool] = None, + advanced_parameters: Optional[AdvancedCompressionParameters] = None, ) -> TModel: """ Compress model weights. @@ -390,8 +393,11 @@ def compress_weights( f"but given {mode.value} mode." ) - if awq is True: - raise AttributeError("Torch backend doesn`t supports AWQ algorithm, but awq=True is specified.") + if True in [awq, scale_estimation]: + raise AttributeError( + "Torch backend doesn`t supports scale estimation and AWQ algorithm, " + "but awq=True or scale_estimation=True is specified." + ) if is_wrapped_model(model): if not model.nncf.trace_parameters: @@ -412,6 +418,11 @@ def compress_weights( if backend == BackendType.OPENVINO: from nncf.openvino.quantization.quantize_model import compress_weights_impl as ov_compress_weights_impl + if any((awq, scale_estimation)) and (dataset is None or mode == CompressWeightsMode.NF4 or group_size == -1): + raise AttributeError( + "Scale estimation or AWQ algorithm defined, but dataset is None or mode is NF4 or group_size < 0." + ) + compression_weights_impl = ov_compress_weights_impl if mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM]: @@ -424,7 +435,7 @@ def compress_weights( "INT8 mode assumes per-channel quantization of all layers in 8 bit. " "Default values of `ratio` (1) and `group_size` (-1) parameters can not be overridden" ) - options = [all_layers, sensitivity_metric, dataset, awq] + options = [all_layers, sensitivity_metric, dataset, awq, scale_estimation] if any(option is not None for option in options): raise AttributeError( "INT8 modes do not support `all_layers`, `sensitivity_metric`, `awq` and `dataset` options. " @@ -439,6 +450,8 @@ def compress_weights( all_layers = False if awq is None: awq = False + if scale_estimation is None: + scale_estimation = False if ignored_scope is None: ignored_scope = IgnoredScope() if sensitivity_metric is None: @@ -461,7 +474,18 @@ def compress_weights( raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}") return compression_weights_impl( - model, dataset, mode, ratio, group_size, ignored_scope, all_layers, sensitivity_metric, awq, subset_size + model, + dataset, + mode, + ratio, + group_size, + ignored_scope, + all_layers, + sensitivity_metric, + awq, + subset_size, + scale_estimation, + advanced_parameters, ) diff --git a/nncf/torch/quantization/quantize_model.py b/nncf/torch/quantization/quantize_model.py index 57764163e59..0893c179b17 100644 --- a/nncf/torch/quantization/quantize_model.py +++ b/nncf/torch/quantization/quantize_model.py @@ -23,6 +23,7 @@ from nncf.parameters import QuantizationMode from nncf.parameters import SensitivityMetric from nncf.parameters import TargetDevice +from nncf.quantization.advanced_parameters import AdvancedCompressionParameters from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression @@ -90,13 +91,24 @@ def compress_weights_impl( sensitivity_metric: SensitivityMetric, awq: bool, subset_size: int, + scale_estimation: bool, + advanced_parameters: Optional[AdvancedCompressionParameters] = None, ) -> torch.nn.Module: """ Implementation of the `compress_weights()` method for the PyTorch backend. """ compression_algorithm = WeightCompression( - mode, ratio, group_size, ignored_scope, all_layers, sensitivity_metric, awq, subset_size + mode, + ratio, + group_size, + ignored_scope, + all_layers, + sensitivity_metric, + awq, + subset_size, + scale_estimation, + advanced_parameters, ) graph = NNCFGraphFactory.create(model) return compression_algorithm.apply(model, graph, dataset=dataset) diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index f339d316d0d..cfdbb579b36 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -62,6 +62,7 @@ INT8_MODES = (CompressWeightsMode.INT8, CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM) INT4_NF4_MODES = (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM, CompressWeightsMode.NF4) +INT4_MODES = (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM) def get_next_node(node): @@ -683,7 +684,7 @@ def test_call_max_var_criterion_with_dataset_by_default(mocker, mode): scores_spy.assert_called() -@pytest.mark.parametrize("mode", INT4_NF4_MODES) +@pytest.mark.parametrize("mode", INT4_MODES) def test_call_max_var_criterion_with_dataset_by_default_awq(mode): model = AWQMatmulModel().ov_model dataset = Dataset([np.ones([8, 8])]) @@ -770,3 +771,11 @@ def test_duplicate_names_generation(): name = op.get_friendly_name() assert name not in op_names op_names.add(name) + + +@pytest.mark.parametrize("mode", INT4_MODES) +def test_call_max_var_criterion_with_dataset_by_default_scale_estimation(mode): + model = AWQMatmulModel().ov_model + dataset = Dataset([np.ones([8, 8])]) + + compress_weights(model, mode=mode, ratio=1.0, group_size=2, dataset=dataset, scale_estimation=True) diff --git a/tests/post_training/data/wc_reference_data.yaml b/tests/post_training/data/wc_reference_data.yaml index 5235d155244..eed84087771 100644 --- a/tests/post_training/data/wc_reference_data.yaml +++ b/tests/post_training/data/wc_reference_data.yaml @@ -3,14 +3,18 @@ tinyllama_data_free_backend_OV: num_int4: 228 num_int8: 84 tinyllama_data_aware_backend_OV: - metric_value: 0.83084 - num_int4: 184 - num_int8: 128 -tinyllama_data_aware_awq_backend_OV: - metric_value: 0.81237 - num_int4: 184 - num_int8: 128 + metric_value: 0.83853 + num_int4: 188 + num_int8: 124 tinyllama_data_aware_awq_stateful_backend_OV: - metric_value: 0.81237 - num_int4: 184 - num_int8: 128 \ No newline at end of file + metric_value: 0.85259 + num_int4: 188 + num_int8: 124 +tinyllama_data_aware_awq_scale_estimation_backend_OV: + metric_value: 0.83795 + num_int4: 188 + num_int8: 124 +tinyllama_data_aware_awq_scale_estimation_stateful_backend_OV: + metric_value: 0.83795 + num_int4: 188 + num_int8: 124 diff --git a/tests/post_training/model_scope.py b/tests/post_training/model_scope.py index 0ad69a54317..a63a4087f95 100644 --- a/tests/post_training/model_scope.py +++ b/tests/post_training/model_scope.py @@ -17,7 +17,9 @@ from nncf import QuantizationPreset from nncf.parameters import CompressWeightsMode from nncf.parameters import SensitivityMetric +from nncf.quantization.advanced_parameters import AdvancedCompressionParameters from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters +from nncf.quantization.advanced_parameters import AdvancedScaleEstimationParameters from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters from tests.post_training.pipelines.base import ALL_PTQ_BACKENDS from tests.post_training.pipelines.base import NNCF_PTQ_BACKENDS @@ -328,17 +330,43 @@ "backends": [BackendType.OV], }, { - "reported_name": "tinyllama_data_aware_awq", + "reported_name": "tinyllama_data_aware_awq_stateful", "model_id": "tinyllama/tinyllama-1.1b-step-50k-105b", "pipeline_cls": LMWeightCompression, "compression_params": {"group_size": 64, "ratio": 0.8, "mode": CompressWeightsMode.INT4_SYM, "awq": True}, + "params": {"is_stateful": True}, "backends": [BackendType.OV], }, { - "reported_name": "tinyllama_data_aware_awq_stateful", + "reported_name": "tinyllama_data_aware_awq_scale_estimation", "model_id": "tinyllama/tinyllama-1.1b-step-50k-105b", "pipeline_cls": LMWeightCompression, - "compression_params": {"group_size": 64, "ratio": 0.8, "mode": CompressWeightsMode.INT4_SYM, "awq": True}, + "compression_params": { + "group_size": 64, + "ratio": 0.8, + "mode": CompressWeightsMode.INT4_SYM, + "awq": True, + "scale_estimation": True, + "advanced_parameters": AdvancedCompressionParameters( + scale_estimation_params=AdvancedScaleEstimationParameters(32, 5, 10, 1.0) + ), + }, + "backends": [BackendType.OV], + }, + { + "reported_name": "tinyllama_data_aware_awq_scale_estimation_stateful", + "model_id": "tinyllama/tinyllama-1.1b-step-50k-105b", + "pipeline_cls": LMWeightCompression, + "compression_params": { + "group_size": 64, + "ratio": 0.8, + "mode": CompressWeightsMode.INT4_SYM, + "awq": True, + "scale_estimation": True, + "advanced_parameters": AdvancedCompressionParameters( + scale_estimation_params=AdvancedScaleEstimationParameters(32, 5, 10, 1.0) + ), + }, "params": {"is_stateful": True}, "backends": [BackendType.OV], }, diff --git a/tests/post_training/pipelines/lm_weight_compression.py b/tests/post_training/pipelines/lm_weight_compression.py index 84829e63288..1147f64504d 100644 --- a/tests/post_training/pipelines/lm_weight_compression.py +++ b/tests/post_training/pipelines/lm_weight_compression.py @@ -91,14 +91,21 @@ def prepare_preprocessor(self) -> None: self.preprocessor = AutoTokenizer.from_pretrained(self.model_id) def get_transform_calibration_fn(self): - def transform_fn(data): + def transform_fn(data, max_tokens=128): tokenized_text = self.preprocessor(data["text"], return_tensors="np") - input_ids = tokenized_text["input_ids"] - attention_mask = tokenized_text["attention_mask"] + + bad_tokens = self.preprocessor("", return_tensors="np")["input_ids"] + raw_tokens = tokenized_text["input_ids"][0, :] + filtered_tokens = np.array(list(filter(lambda x: x not in bad_tokens, raw_tokens))) + tokenized_text["input_ids"] = np.expand_dims(filtered_tokens, 0) + tokenized_text["attention_mask"] = tokenized_text["attention_mask"][:, : filtered_tokens.shape[0]] + + input_ids = tokenized_text["input_ids"][:, :max_tokens] + attention_mask = tokenized_text["attention_mask"][:, :max_tokens] inputs = {} inputs["input_ids"] = input_ids - inputs["attention_mask"] = tokenized_text["attention_mask"] + inputs["attention_mask"] = attention_mask position_ids = np.cumsum(attention_mask, axis=1) - 1 position_ids[attention_mask == 0] = 1 @@ -130,7 +137,7 @@ def transform_fn(data): def prepare_calibration_dataset(self): dataset = load_dataset("wikitext", "wikitext-2-v1", split="train", revision="b08601e") - dataset = dataset.filter(lambda example: len(example["text"]) > 80) + dataset = dataset.filter(lambda example: len(example["text"]) > 128) self.calibration_dataset = nncf.Dataset(dataset, self.get_transform_calibration_fn()) def cleanup_cache(self):