From fb883db737018ad70775b076e132089bb99c75e5 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 26 Apr 2024 12:01:04 +0200 Subject: [PATCH] Refactor apply function --- .../weight_compression/scale_estimation.py | 261 ++++++++++-------- 1 file changed, 147 insertions(+), 114 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 44fc36baa7b..57aea0ec30a 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -96,6 +96,112 @@ def _set_backend_entity(self, model: TModel) -> None: "Cannot return backend-specific AWQ entity because {} is not supported!".format(model_backend.value) ) + def _get_importance_and_x(self, stats, original_weight, zero_mask, reduction_axis, config, eps): + X = fns.stack([fns.mean(stat, axis=0) for stat in stats]) + X_full = fns.transpose(X) + + # prevent high memory and time consumption + if X_full.shape[1] > self._subset_size: + lens = [stat.shape[0] for stat in stats] + step = X_full.shape[1] // self._subset_size + idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step] + X = X_full[:, idxs] + else: + X = X_full + + s = fns.max(fns.abs(X_full), axis=1) + # all weight in group has importance based on corresponding input activations + + s = fns.unsqueeze(s, 0) + s, _ = reshape_weight_for_grouped_quantization(s, reduction_axis, config.group_size) + + importance = fns.ones_like(original_weight) + importance = importance * s + importance = fns.where(zero_mask, 0.0, importance) + + # normalize importances for every group of weights to make sum of them equal to 1.0 + denum = fns.sum(importance, axis=2, keepdims=True) + importance = importance / (denum + eps) + X, _ = reshape_weight_for_grouped_quantization(X, 0, config.group_size) + return importance, X + + def _get_compress_decompress_model(self, wp, q_weights, scale, zp, compress_decompress_cashe): + key = (wp.compression_config.mode, wp.compression_config.num_bits) + q_weights.shape + scale.shape + zp.shape + if key in compress_decompress_cashe: + compress_decompress_model = compress_decompress_cashe[key]["compress_decompress_model"] + compress_model = compress_decompress_cashe[key]["compress_model"] + else: + compress_decompress_model = self._backend_entity.get_compress_decompress_pipeline( + wp, q_weights.shape, scale.shape, zp.shape + ) + compress_model = self._backend_entity.get_compress_pipeline(wp, q_weights.shape, scale.shape, zp.shape) + compress_decompress_cashe[key] = { + "compress_decompress_model": compress_decompress_model, + "compress_model": compress_model, + } + return compress_model, compress_decompress_model + + def _get_min_max_scale_diffs(self, fp_outs, q_outs, original_weight, q_weights): + # metric for minimization with shape [C_OUT, N_GROUPS], N_GROUPS = C_IN / GROUP_SIZE + min_max_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1) + min_max_scale_diffs = fns.transpose(min_max_scale_diffs, (1, 0)) + if self._weight_penalty > 0.0: + min_max_scale_diffs += self._weight_penalty * fns.mean((q_weights - original_weight) ** 2, axis=-1) + return min_max_scale_diffs + + def _get_weight(self, wp, model, graph): + weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph) + if len(weight_data) != 1: # not supported by the algorithm + return None + _, weight_port_id = weight_data[0] + + weight = self._backend_entity.get_weight(wp.node_with_weight, weight_port_id, model, graph) + weight = weight.astype(TensorDataType.float32) + + def _rectification( + self, + result_scale, + best_diffs, + X, + original_weight, + target, + fp_outs, + zero_mask, + importance, + min_max_scale_diffs, + compress_decompress_model, + zp, + scale, + ): + ideal_scale = fns.abs(original_weight) / (fns.abs(target) + zero_mask) + weighted_scale = ideal_scale * importance + near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True) + + out = compress_decompress_model(original_weight.data, near_to_ideal_scale.data, zp.data) + q_weights_ = fns.zeros_like(original_weight) + out + q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X) + + ideal_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1) + ideal_scale_diffs = fns.transpose(ideal_scale_diffs, (1, 0)) + if self._weight_penalty > 0.0: + ideal_scale_diffs += self._weight_penalty * fns.mean((q_weights_ - original_weight) ** 2, axis=-1) + + if best_diffs is None: + best_diffs = min_max_scale_diffs + + mask = (ideal_scale_diffs > best_diffs).astype(best_diffs.dtype) + + best_diffs = mask * best_diffs + (1.0 - mask) * ideal_scale_diffs + + mask = fns.unsqueeze(mask, axis=2) + + if result_scale is None: + near_to_ideal_scale = mask * scale + (1.0 - mask) * near_to_ideal_scale + else: + near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale + result_scale = near_to_ideal_scale + return result_scale + def apply( self, model: TModel, @@ -129,34 +235,17 @@ def apply( continue stats = self._activations[k] - reduction_axis = wp.reduction_axes[0] cur_config = deepcopy(config) cur_config.group_size = -1 - weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph) - if len(weight_data) != 1: # not supported by the algorithm + weight = self._get_weight(wp, model, graph) + if weight is None: continue - _, weight_port_id = weight_data[0] - - X = fns.stack([fns.mean(stat, axis=0) for stat in stats]) - X_full = fns.transpose(X) - - # prevent high memory and time consumption - if X_full.shape[1] > self._subset_size: - lens = [stat.shape[0] for stat in stats] - step = X_full.shape[1] // self._subset_size - idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step] - X = X_full[:, idxs] - else: - X = X_full - s = fns.max(fns.abs(X_full), axis=1) - - weight = self._backend_entity.get_weight(wp.node_with_weight, weight_port_id, model, graph) - weight = weight.astype(TensorDataType.float32) eps = fns.finfo(weight).eps + reduction_axis = wp.reduction_axes[0] if reduction_axis == 0: weight = fns.transpose(weight) reduction_axis = 1 @@ -166,94 +255,49 @@ def apply( compressed_weights, scale, zp = do_integer_quantization(original_weight, reduction_axis, config) zp = zp.astype(scale.dtype) - q_weights = do_dequantization(compressed_weights, scale, zp, reduction_axis) - - s = fns.unsqueeze(s, 0) - s, _ = reshape_weight_for_grouped_quantization(s, reduction_axis, config.group_size) - + target = compressed_weights.astype(dtype=zp.dtype) - zp original_weight, _ = reshape_weight_for_grouped_quantization( original_weight, reduction_axis, config.group_size ) - # all weight in group has importance based on corresponding input activations - importance = fns.ones_like(original_weight) - importance = importance * s - - target = compressed_weights.astype(dtype=zp.dtype) - zp zero_mask = compressed_weights == zp + importance, X = self._get_importance_and_x(stats, original_weight, zero_mask, reduction_axis, config, eps) - importance = fns.where(zero_mask, 0.0, importance) - - # normalize importances for every group of weights to make sum of them equal to 1.0 - denum = fns.sum(importance, axis=2, keepdims=True) - importance = importance / (denum + eps) - - X, _ = reshape_weight_for_grouped_quantization(X, 0, config.group_size) + q_weights = do_dequantization(compressed_weights, scale, zp, reduction_axis) q_weights, _ = reshape_weight_for_grouped_quantization(q_weights, reduction_axis, config.group_size) - best_diffs = None - result_scale = None fp_outs = fns.matmul(fns.transpose(original_weight, (1, 0, 2)), X) q_outs = fns.matmul(fns.transpose(q_weights, (1, 0, 2)), X) + min_max_scale_diffs = self._get_min_max_scale_diffs(fp_outs, q_outs, original_weight, q_weights) - # metric for minimization with shape [C_OUT, N_GROUPS], N_GROUPS = C_IN / GROUP_SIZE - min_max_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1) - min_max_scale_diffs = fns.transpose(min_max_scale_diffs, (1, 0)) - if self._weight_penalty > 0.0: - min_max_scale_diffs += self._weight_penalty * fns.mean((q_weights - original_weight) ** 2, axis=-1) - - key = ( - (wp.compression_config.mode, wp.compression_config.num_bits) + q_weights.shape + scale.shape + zp.shape + compress_model, compress_decompress_model = self._get_compress_decompress_model( + wp, q_weights, scale, zp, compress_decompress_cashe ) - if key in compress_decompress_cashe: - compress_decompress_model = compress_decompress_cashe[key]["compress_decompress_model"] - compress_model = compress_decompress_cashe[key]["compress_model"] - else: - compress_decompress_model = self._backend_entity.get_compress_decompress_pipeline( - wp, q_weights.shape, scale.shape, zp.shape - ) - compress_model = self._backend_entity.get_compress_pipeline(wp, q_weights.shape, scale.shape, zp.shape) - compress_decompress_cashe[key] = { - "compress_decompress_model": compress_decompress_model, - "compress_model": compress_model, - } zero_scale = 0.001 zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) + best_diffs = None + result_scale = None + # iterative rectification of initial scale for i in range(self._initial_steps): - ideal_scale = fns.abs(original_weight) / (fns.abs(target) + zero_mask) - weighted_scale = ideal_scale * importance - - near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True) - - out = compress_decompress_model(original_weight.data, near_to_ideal_scale.data, zp.data) - q_weights_ = fns.zeros_like(original_weight) + out - q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X) - - ideal_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1) - ideal_scale_diffs = fns.transpose(ideal_scale_diffs, (1, 0)) - if self._weight_penalty > 0.0: - ideal_scale_diffs += self._weight_penalty * fns.mean((q_weights_ - original_weight) ** 2, axis=-1) - - if best_diffs is None: - best_diffs = min_max_scale_diffs - - mask = (ideal_scale_diffs > best_diffs).astype(best_diffs.dtype) - - best_diffs = mask * best_diffs + (1.0 - mask) * ideal_scale_diffs - - mask = fns.unsqueeze(mask, axis=2) - - if result_scale is None: - near_to_ideal_scale = mask * scale + (1.0 - mask) * near_to_ideal_scale - else: - near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale - result_scale = near_to_ideal_scale - + result_scale = self._rectification( + result_scale, + best_diffs, + X, + original_weight, + target, + fp_outs, + zero_mask, + importance, + min_max_scale_diffs, + compress_decompress_model, + zp, + scale, + ) if i < self._initial_steps - 1: - out = compress_model(original_weight.data, near_to_ideal_scale.data, zp.data) + out = compress_model(original_weight.data, result_scale.data, zp.data) compressed_weights = fns.zeros_like(original_weight) + out target = compressed_weights - zp zero_mask = compressed_weights == zp @@ -271,31 +315,20 @@ def apply( zero_mask = compressed_weights == zp zero_mask = zero_scale * zero_mask.astype(original_weight.dtype) - ideal_scale = fns.abs(original_weight) / (fns.abs(target) + zero_mask) - weighted_scale = ideal_scale * importance - near_to_ideal_scale = fns.sum(weighted_scale, axis=2, keepdims=True) - - out = compress_decompress_model(original_weight.data, near_to_ideal_scale.data, zp.data) - q_weights_ = fns.zeros_like(original_weight) + out - - q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X) - ideal_scale_diffs = fns.mean((fp_outs - q_outs) ** 2, axis=-1) - ideal_scale_diffs = fns.transpose(ideal_scale_diffs, (1, 0)) - if self._weight_penalty > 0.0: - ideal_scale_diffs += self._weight_penalty * fns.mean((q_weights_ - original_weight) ** 2, axis=-1) - - mask = (ideal_scale_diffs > best_diffs).astype(best_diffs.dtype) - - best_diffs = mask * best_diffs + (1.0 - mask) * ideal_scale_diffs - - mask = fns.unsqueeze(mask, axis=2) - - if result_scale is None: - near_to_ideal_scale = mask * scale + (1.0 - mask) * near_to_ideal_scale - else: - near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale - result_scale = near_to_ideal_scale - + result_scale = self._rectification( + result_scale, + best_diffs, + X, + original_weight, + target, + fp_outs, + zero_mask, + importance, + min_max_scale_diffs, + compress_decompress_model, + zp, + scale, + ) res[k] = result_scale return res