Skip to content

Commit

Permalink
Scale estimation/rectification for int4 compression (openvinotoolkit#…
Browse files Browse the repository at this point in the history
…2549)

### Changes

Added scale estimation for compression which minimizes L2 error between
original MatMul and compressed one.

### Reason for changes

Increases accuracy for compressed to 4 bit models.

### Related tickets

CVS-129177

### Tests

In process

---------

Co-authored-by: Lyalyushkin Nikolay <[email protected]>
Co-authored-by: Daniil Lyakhov <[email protected]>
  • Loading branch information
3 people authored Apr 30, 2024
1 parent 489cc09 commit 9c00000
Show file tree
Hide file tree
Showing 15 changed files with 650 additions and 45 deletions.
14 changes: 13 additions & 1 deletion nncf/openvino/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from nncf.parameters import SensitivityMetric
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import AdvancedAccuracyRestorerParameters
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.advanced_parameters import convert_to_dict_recursively
from nncf.quantization.algorithms.accuracy_control.algorithm import QuantizationAccuracyRestorer
Expand Down Expand Up @@ -407,14 +408,25 @@ def compress_weights_impl(
sensitivity_metric: SensitivityMetric,
awq: bool,
subset_size: int,
scale_estimation: bool,
advanced_parameters: Optional[AdvancedCompressionParameters] = None,
) -> ov.Model:
"""
Implementation of the `compress_weights()` method for the OpenVINO backend.
"""

model = remove_friendly_name_duplicates(model)
compression_algorithm = WeightCompression(
mode, ratio, group_size, ignored_scope, all_layers, sensitivity_metric, awq, subset_size
mode,
ratio,
group_size,
ignored_scope,
all_layers,
sensitivity_metric,
awq,
subset_size,
scale_estimation,
advanced_parameters,
)
graph = NNCFGraphFactory.create(model)
return compression_algorithm.apply(model, graph, dataset=dataset)
75 changes: 75 additions & 0 deletions nncf/quantization/advanced_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,81 @@ class AdvancedQuantizationParameters:
backend_params: Dict[str, Any] = field(default_factory=dict)


@api()
@dataclass
class AdvancedAWQParameters:
"""
Contains advanced parameters for AWQ algorithm.
It regulates the calculation of the smooth scale for different node types.
A negative value switches off the algorithm for current node type. In case of inaccurate results,
this parameter may be adjusted in the range from 0 to 1 or set -1 to disable SmoothQuant algorithm.
:param subset_size: The number of samples for AWQ.
:type subset_size: int
:param percent_to_apply: The percent of outliers for correction.
:type percent_to_apply: float
:param alpha_min: Minimum value of smoothness parameter for grid search.
:type alpha_min: float
:param alpha_max: Maximal value of smoothness parameter for grid search.
:type alpha_max: float
:param steps: The number of the steps in grid search.
:type steps: int
"""

subset_size: int = 32
percent_to_apply: float = 0.002
alpha_min: float = 0.0
alpha_max: float = 1.0
steps: int = 100


@api()
@dataclass
class AdvancedScaleEstimationParameters:
"""
Contains advanced parameters for scale estimation algorithm.
It regulates the calculation of the smooth scale for different node types.
A negative value switches off the algorithm for current node type. In case of inaccurate results,
this parameter may be adjusted in the range from 0 to 1 or set -1 to disable SmoothQuant algorithm.
:param subset_size: The number of samples for scale estimation.
:type subset_size: int
:param initial_steps: The number of the steps for absmax scale rectification.
:type initial_steps: int
:param scale_steps: The number of the steps for grid search scale rectification
from 1.0 to 1.0 - 0.05 * scale_step.
:type scale_steps: int
:param weight_penalty: coefficient for penalty between fp and compressed weights. If -1 then doesn't apply.
:type weight_penalty: float
"""

subset_size: int = 32
initial_steps: int = 5
scale_steps: int = 10
weight_penalty: float = -1.0


@api()
@dataclass
class AdvancedCompressionParameters:
"""
Contains advanced parameters for compression algorithms.
:param awq_params: Advanced parameters for AWQ algorithm.
:type awq_params: AdvancedAWQParameters
:param scale_estimation_params: Advanced parameters for scale estimation algorithm.
:type scale_estimation_params: AdvancedScaleEstimationParameters
"""

# Advanced AWQ algorithm parameters
awq_params: AdvancedAWQParameters = field(default_factory=AdvancedAWQParameters)

# Advanced scale estimation algorithm parameters
scale_estimation_params: AdvancedScaleEstimationParameters = field(
default_factory=AdvancedScaleEstimationParameters
)


@api()
@dataclass
class AdvancedAccuracyRestorerParameters:
Expand Down
40 changes: 38 additions & 2 deletions nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
from nncf.experimental.tensor.definitions import TensorDataType
from nncf.parameters import CompressWeightsMode
from nncf.parameters import SensitivityMetric
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.quantization.algorithms.weight_compression.awq import AWQ
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.mixed_precision import MIXED_PRECISION_CRITERIA
from nncf.quantization.algorithms.weight_compression.scale_estimation import ScaleEstimation
from nncf.quantization.algorithms.weight_compression.weight_lowering import WeightCompressionConfig
from nncf.scopes import IgnoredScope
from nncf.scopes import get_ignored_node_names_from_ignored_scope
Expand All @@ -60,6 +62,8 @@ def __init__(
sensitivity_metric: SensitivityMetric,
awq: bool,
subset_size: int,
scale_estimation: bool,
advanced_parameters: Optional[AdvancedCompressionParameters] = None,
):
"""
:param mode: Defines a mode for weight compression.
Expand Down Expand Up @@ -88,6 +92,8 @@ def __init__(
:param awq: determines whether to use or not modified AWQ algorithm.
:param subset_size: Number of data samples to calculate activation statistics used for assigning different
quantization precision.
:param scale_estimation: determines whether to use or not scale estimation for 4 bit layers.
:param advanced_parameters: advanced parameters for algorithms in compression pipeline.
"""
super().__init__()
self._mode = mode
Expand All @@ -101,6 +107,10 @@ def __init__(
self._sensitivity_metric = sensitivity_metric
self._awq = awq
self._subset_size = subset_size
self._scale_estimation = scale_estimation
self._advanced_parameters = (
advanced_parameters if advanced_parameters is not None else AdvancedCompressionParameters()
)

@property
def available_backends(self) -> List[BackendType]:
Expand Down Expand Up @@ -339,14 +349,40 @@ def do_compression(
nncf_logger.info(self._get_bitwidth_distribution_str(all_weight_params, ratio_defining_params))

if self._awq and activations is not None and self._mode != CompressWeightsMode.NF4:
awq_params = self._advanced_parameters.awq_params
awq_algo = AWQ(
model, self._backend_entity.name_to_node_mapping, all_weight_params, nodes_to_compress, activations
model,
self._backend_entity.name_to_node_mapping,
all_weight_params,
nodes_to_compress,
activations,
awq_params.subset_size,
awq_params.percent_to_apply,
awq_params.alpha_min,
awq_params.alpha_max,
awq_params.steps,
)
awq_algo.apply(model, graph)

precomputed_scales = {wp.node_with_weight.node_name: None for wp in all_weight_params}
if self._scale_estimation and activations is not None and self._mode != CompressWeightsMode.NF4:
scale_estimation_params = self._advanced_parameters.scale_estimation_params
scale_algo = ScaleEstimation(
model,
self._backend_entity.name_to_node_mapping,
all_weight_params,
nodes_to_compress,
activations,
scale_estimation_params.subset_size,
scale_estimation_params.initial_steps,
scale_estimation_params.scale_steps,
scale_estimation_params.weight_penalty,
)
precomputed_scales = scale_algo.apply(model, graph)

# Compress model using weight compression parameters
transformed_model = self._backend_entity.transform_model(
model, graph, track(all_weight_params, description="Applying Weight Compression")
model, graph, track(all_weight_params, description="Applying Weight Compression"), precomputed_scales
)

self._backend_entity.dump_parameters(
Expand Down
19 changes: 14 additions & 5 deletions nncf/quantization/algorithms/weight_compression/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
activations: Optional[Dict[str, TTensor]] = None,
subset_size: int = 32,
percent_to_apply=0.002,
alpha_min=0.01,
alpha_min=0.0,
alpha_max=1.0,
steps=100,
):
Expand Down Expand Up @@ -107,8 +107,7 @@ def _set_backend_entity(self, model: TModel) -> None:
if model_backend == BackendType.OPENVINO:
from nncf.quantization.algorithms.weight_compression.openvino_backend import OVAWQAlgoAlgoBackend

self._backend_entity = OVAWQAlgoAlgoBackend(model)
self._backend_entity.name_to_node_mapping = self.name_to_node_mapping
self._backend_entity = OVAWQAlgoAlgoBackend(model, self.name_to_node_mapping)
self._patterns = self._backend_entity.get_awq_patterns()
else:
raise RuntimeError(
Expand Down Expand Up @@ -181,11 +180,15 @@ def apply(
stats = self._activations[k]
X = fns.stack([fns.mean(stat, axis=0) for stat in stats])
X = fns.transpose(X)
if X.shape[1] > self._subset_size:
X = X[:, : self._subset_size]

s = fns.max(fns.abs(X), axis=1)

if X.shape[1] > self._subset_size:
lens = [stat.shape[0] for stat in stats]
step = X.shape[1] // self._subset_size
idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step]
X = X[:, idxs]

top_k = max(int(s.shape[0] * self._percent_to_apply), 1)
topk_idxs = fns.argsort(-s)[:top_k]

Expand Down Expand Up @@ -263,6 +266,12 @@ def apply(
merge_weight = merge_weight * a_scale
self._backend_entity.set_weight(merge_node, port_id, model, graph, merge_weight)

# update activations for next usage
a_scale_t = fns.transpose(a_scale)
for i, stat in enumerate(self._activations[k]):
stat = stat * a_scale_t
self._activations[k][i] = stat

return model

def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer:
Expand Down
7 changes: 6 additions & 1 deletion nncf/quantization/algorithms/weight_compression/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,19 @@ def set_weight(

@abstractmethod
def transform_model(
self, model: TModel, graph: NNCFGraph, weight_compression_parameters: Iterable[WeightCompressionParameters]
self,
model: TModel,
graph: NNCFGraph,
weight_compression_parameters: Iterable[WeightCompressionParameters],
precomputed_scales: Dict[str, Tensor] = None,
) -> TModel:
"""
Applies weight compression transformations to the model.
:param model: Model in which the weights will be compressed according to the weight compression description.
:param graph: The graph associated with the model.
:param weight_compression_parameters: List of weight compression parameters.
:param precomputed_scales: Precomputed scales for compressed nodes.
:return: The transformed model.
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@


class OVWeightCompressionAlgoBackend(WeightCompressionAlgoBackend):
def __init__(self, model: ov.Model):
self.name_to_node_mapping = OVModelTransformer._get_name_to_node_mapping(model)
def __init__(self, model: ov.Model, name_to_node_mapping: Dict = None):
if name_to_node_mapping is None:
self.name_to_node_mapping = OVModelTransformer._get_name_to_node_mapping(model)
else:
self.name_to_node_mapping = name_to_node_mapping

@property
def matmul_metatypes(self) -> List[OperatorMetatype]:
Expand Down Expand Up @@ -119,7 +122,11 @@ def set_weight(
del const_node

def transform_model(
self, model: ov.Model, graph: NNCFGraph, weight_compression_parameters: Iterable[WeightCompressionParameters]
self,
model: ov.Model,
graph: NNCFGraph,
weight_compression_parameters: Iterable[WeightCompressionParameters],
precomputed_scales: Dict[str, Tensor] = None,
) -> ov.Model:
for wc_params in weight_compression_parameters:
compression_config = wc_params.compression_config
Expand All @@ -146,7 +153,12 @@ def transform_model(

weight = Tensor(get_const_value(const_node))
original_shape = weight.shape
compressed_weight = compress_weight(weight, wc_params.reduction_axes, compression_config)
compressed_weight = compress_weight(
weight,
wc_params.reduction_axes,
compression_config,
precomputed_scales[wc_params.node_with_weight.node_name],
)

compressed_const = opset.constant(
compressed_weight.tensor.data, dtype=compression_dtype, name=const_node_name
Expand Down Expand Up @@ -195,6 +207,53 @@ def dump_parameters(
) -> None:
dump_parameters(model, parameters, algo_name, path)

@staticmethod
def get_compress_decompress_pipeline(
weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape
):
(
w,
s,
zp,
clamp,
) = OVWeightCompressionAlgoBackend.get_compress_pipeline(
weight_compression_parameter, w_shape, s_shape, z_p_shape, True
)

result = (clamp - zp) * s
model = ov.Model([result], [w, s, zp])

compiled_model = ov.compile_model(model)

return lambda w, s, zp: compiled_model([w, s, zp])[0]

@staticmethod
def get_compress_pipeline(
weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape, return_nodes=False
):
config = weight_compression_parameter.compression_config
mode = config.mode
assert mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]
num_bits = config.num_bits

level_low = 0
level_high = 2**num_bits - 1

w = opset.parameter(w_shape, name="w")
s = opset.parameter(s_shape, name="s")
zp = opset.parameter(z_p_shape, name="zp")

result = opset.clamp(opset.round(w / s + zp), level_low, level_high, name="compressed_weights")

if return_nodes:
return w, s, zp, result

model = ov.Model([result], [w, s, zp])

compiled_model = ov.compile_model(model)

return lambda w, s, zp: compiled_model([w, s, zp])[0]


class OVAWQAlgoAlgoBackend(OVWeightCompressionAlgoBackend):
@staticmethod
Expand Down
Loading

0 comments on commit 9c00000

Please sign in to comment.