Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weight compression via Lora Correction Algorithm #2816

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
24b1f9e
Lora Correction Algorithm for int4/nf4 weight compression
ljaljushkin Jul 22, 2024
cd5d3de
conformance test values
ljaljushkin Jul 22, 2024
42580d9
gelsy
ljaljushkin Jul 30, 2024
bff4e2b
unsupported options
ljaljushkin Jul 30, 2024
c099e4b
renaming and comments in LoRA algorithm
ljaljushkin Jul 30, 2024
4639097
less transpose, should be faster and less memory
ljaljushkin Jul 30, 2024
d26540f
renaming, typehint, gptq+lora error
ljaljushkin Jul 31, 2024
4479bc3
no wc_params in functions
ljaljushkin Jul 31, 2024
523861b
rename
ljaljushkin Jul 31, 2024
5b1e98e
changed defaults
ljaljushkin Aug 8, 2024
4747ba4
test lora with mixed precision
ljaljushkin Aug 9, 2024
258e91a
Merge remote-tracking branch 'fork/nl/lora_comments_tmp' into nl/lora…
ljaljushkin Aug 9, 2024
3cd49be
removed copy-paste in nf4 quant/dequant
ljaljushkin Aug 12, 2024
85d9c8a
tests for unsupported options
ljaljushkin Aug 12, 2024
f481f6c
Merge remote-tracking branch 'origin/develop' into nl/lora_comments_tmp
ljaljushkin Aug 19, 2024
af0c7f5
new reference for lora conformance test
ljaljushkin Aug 19, 2024
470fbd8
fixed pre-commit
ljaljushkin Aug 19, 2024
a3b26c7
fixed pre-commit
ljaljushkin Aug 19, 2024
6395026
Merge remote-tracking branch 'origin/develop' into nl/lora_correct_pr…
ljaljushkin Aug 19, 2024
565e299
dump advanced, test for transpose_b=False, expose lora params
ljaljushkin Aug 20, 2024
6021a2b
Merge remote-tracking branch 'origin/develop' into nl/lora_correct_pr…
ljaljushkin Aug 20, 2024
c73b2d6
Corrected debug output
ljaljushkin Aug 21, 2024
6bc070f
Merge remote-tracking branch 'origin/develop' into nl/lora_correct_pr…
ljaljushkin Aug 21, 2024
471f3cc
Merge remote-tracking branch 'fork/nl/lora_correct_prod_squash' into …
ljaljushkin Aug 21, 2024
4c38115
Merge remote-tracking branch 'origin/develop' into nl/lora_correct_pr…
ljaljushkin Aug 22, 2024
eefdcda
Merge remote-tracking branch 'origin/develop' into nl/lora_correct_pr…
ljaljushkin Aug 23, 2024
a296518
renaming
ljaljushkin Aug 27, 2024
4282bd5
Merge remote-tracking branch 'origin/develop' into nl/lora_correct_pr…
ljaljushkin Aug 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ nncf_dataset = nncf.Dataset(data_source, transform_fn)
compressed_model = compress_weights(model, mode=CompressWeightsMode.INT4_SYM, ratio=0.8, dataset=nncf_dataset) # model is openvino.Model object
```

- Accuracy of the 4-bit compressed models also can be improved by using AWQ, Scale Estimation or GPTQ algorithms over data-based mixed-precision algorithm. These algorithms work by equalizing a subset of weights to minimize the difference between the original precision and the 4-bit precision. The AWQ algorithm can be used in conjunction with either the Scale Estimation or GPTQ algorithm. However, Scale Estimation and GPTQ algorithms are mutually exclusive and cannot be used together. Below are examples demonstrating how to enable the AWQ, Scale Estimation or GPTQ algorithms:
- Accuracy of the 4-bit compressed models also can be improved by using AWQ, Scale Estimation, GPTQ or Lora Correction algorithms over data-based mixed-precision algorithm. These algorithms work by equalizing a subset of weights to minimize the difference between the original precision and the 4-bit precision.
Unlike all others, the Lora Correction algorithm inserts an additional Linear layers for reducing quantization noise and further accuracy improvement. Inevitably, this approach introduces a memory and a runtime overheads, but they are negligible, since the inserted weight much smaller and can be quantized to 8-bit. The AWQ, Scale Estimation (SE) and Lora Correction (LC) algo can be used in any combination together: AWQ + SE, AWQ + LC, SE + LC, AWQ + SE + LC. The GPTQ algorithm can be combined with AWQ only. Below are examples demonstrating how to enable the AWQ, Scale Estimation, GPTQ or Lora Correction algorithms:

Prepare the calibration dataset for data-based algorithms:

Expand Down Expand Up @@ -135,6 +136,16 @@ model.model = compress_weights(model.model,
gptq=True)
```

- How to compress 80% of layers to 4-bit integer with a default data-based mixed precision algorithm and Lora Correction algorithm. It requires setting `lora_correction` to `True` additionally to data-based mixed-precision algorithm.

```python
model.model = compress_weights(model.model,
mode=CompressWeightsMode.INT4_SYM,
ratio=0.8,
dataset=nncf_dataset,
lora_correction=True)
```

- `NF4` mode can be considered for improving accuracy, but currently models quantized to nf4 should not be faster models
quantized to 8-bit asymmetric integer. Here's the example how to compress weights to nf4 data type with group size = 128.
Different `group_size` and `ratio` are also supported.
Expand Down
5 changes: 5 additions & 0 deletions nncf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,13 @@
from nncf.quantization.advanced_parameters import (
AdvancedAccuracyRestorerParameters as AdvancedAccuracyRestorerParameters,
)
from nncf.quantization.advanced_parameters import AdvancedAWQParameters as AdvancedAWQParameters
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters as AdvancedBiasCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters as AdvancedCompressionParameters
from nncf.quantization.advanced_parameters import AdvancedGPTQParameters as AdvancedGPTQParameters
from nncf.quantization.advanced_parameters import AdvancedLoraCorrectionParameters as AdvancedLoraCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters as AdvancedQuantizationParameters
from nncf.quantization.advanced_parameters import AdvancedScaleEstimationParameters as AdvancedScaleEstimationParameters
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters as AdvancedSmoothQuantParameters
from nncf.quantization.advanced_parameters import OverflowFix as OverflowFix
from nncf.scopes import IgnoredScope as IgnoredScope
Expand Down
2 changes: 2 additions & 0 deletions nncf/openvino/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def compress_weights_impl(
subset_size: int,
scale_estimation: bool,
gptq: bool,
lora_correction: bool,
advanced_parameters: Optional[AdvancedCompressionParameters] = None,
) -> ov.Model:
"""
Expand All @@ -455,6 +456,7 @@ def compress_weights_impl(
subset_size,
scale_estimation,
gptq,
lora_correction,
advanced_parameters,
)
graph = NNCFGraphFactory.create(model)
Expand Down
30 changes: 30 additions & 0 deletions nncf/quantization/advanced_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,33 @@ class AdvancedGPTQParameters:
subset_size: int = 128


@api()
@dataclass
class AdvancedLoraCorrectionParameters:
"""
Contains advanced parameters for lora correction algorithm.

:param adapter_rank: rank of lora adapters. Defaults to 16.
:type adapter_rank: int
:param num_iterations: number of correction iterations. Defaults to 3.
:type num_iterations: int
:param apply_regularization: Whether to add a regularization during the correction process. Defaults to True.
Helpful for big rank values to avoid overfitting.
:type apply_regularization: bool
:param subset_size: Number of data samples for lora correction algorithm. Defaults to 128.
:type subset_size: int
:param use_int8_adapters: Whether to 8-bit quantize lora adapters, otherwise they kept in the original weights
precision. Defaults to True.
:type use_int8_adapters: bool
"""

adapter_rank: int = 8
num_iterations: int = 3
apply_regularization: bool = True
subset_size: int = 128
use_int8_adapters: bool = True


@api()
@dataclass
class AdvancedCompressionParameters:
Expand All @@ -337,6 +364,9 @@ class AdvancedCompressionParameters:
# Advanced GPTQ algorithm parameters
gptq_params: AdvancedGPTQParameters = field(default_factory=AdvancedGPTQParameters)

# Advanced Lora Correction algorithm parameters
lora_correction_params: AdvancedLoraCorrectionParameters = field(default_factory=AdvancedLoraCorrectionParameters)


@api()
@dataclass
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple, TypeVar

from nncf.tensor import functions as fns

TTensor = TypeVar("TTensor")


def process_stats(stats: List[TTensor], subset_size: int) -> Tuple[TTensor, TTensor]:
ljaljushkin marked this conversation as resolved.
Show resolved Hide resolved
"""
It's a processing of activations shared between AWQ, Scale Estimation and LoRA Correction algorithms.

:param stats: list of activation statistics for a layer that contains N tensors with shape [SeqLen, HiddenDim]
:type stats: List[TTensor]
:param subset_size: The number of samples for AWQ.
:type subset_size: int
:return: tuple of the following tensors:
s - maximum channel magnitude across samples [HiddenDim]
X - average channel magnitude across tokens in the sequence [HiddenDim, SampleSize]
:rtype: Tuple[TTensor, TTensor]
"""
X = fns.stack([fns.mean(stat, axis=0) for stat in stats]) # [Batch, HiddenDim]
X_full = fns.transpose(X) # [HiddenDim, Batch]

# prevent high memory and time consumption
if X_full.shape[1] > subset_size:
lens = [stat.shape[0] for stat in stats]
step = X_full.shape[1] // subset_size
idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step]
X = X_full[:, idxs] # [HiddenDim, SampleSize]
else:
X = X_full
s = fns.max(fns.abs(X_full), axis=1) # [HiddenDim]
return s, X
17 changes: 16 additions & 1 deletion nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
from nncf.parameters import CompressWeightsMode
from nncf.parameters import SensitivityMetric
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
from nncf.quantization.advanced_parameters import convert_to_dict_recursively
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.quantization.algorithms.weight_compression.awq import AWQ
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.gptq import GPTQ
from nncf.quantization.algorithms.weight_compression.lora_correction import LoraCorrectionAlgorithm
from nncf.quantization.algorithms.weight_compression.mixed_precision import MIXED_PRECISION_CRITERIA
from nncf.quantization.algorithms.weight_compression.scale_estimation import ScaleEstimation
from nncf.quantization.algorithms.weight_compression.weight_lowering import WeightCompressionConfig
Expand Down Expand Up @@ -65,6 +67,7 @@ def __init__(
subset_size: int,
scale_estimation: bool,
gptq: bool,
lora_correction: bool,
advanced_parameters: Optional[AdvancedCompressionParameters] = None,
):
"""
Expand Down Expand Up @@ -97,6 +100,7 @@ def __init__(
quantization precision.
:param scale_estimation: determines whether to use or not scale estimation for 4 bit layers.
:param gptq: determines whether to use or not GPTQ algorithm.
:param lora_correction: determines whether to use or not LoRA Correction algorithm.
:param advanced_parameters: advanced parameters for algorithms in compression pipeline.
"""
super().__init__()
Expand All @@ -113,6 +117,7 @@ def __init__(
self._subset_size = subset_size
self._scale_estimation = scale_estimation
self._gptq = gptq
self._lora_correction = lora_correction
self._advanced_parameters = (
advanced_parameters if advanced_parameters is not None else AdvancedCompressionParameters()
)
Expand Down Expand Up @@ -403,6 +408,13 @@ def apply(
backend_entity=self._backend_entity,
)

lora_correction_algo = None
description = "Applying Weight Compression"
if self._lora_correction:
lora_correction_params = self._advanced_parameters.lora_correction_params
lora_correction_algo = LoraCorrectionAlgorithm(activations, lora_correction_params)
description += " with correction of low-rank adapters"

# Sort weight params to start compression with the bigger constants. This lowers peak memory footprint.
all_weight_params = sorted(all_weight_params, key=lambda wp: wp.num_weights, reverse=True)
all_weight_sizes = [wp.num_weights for wp in all_weight_params]
Expand All @@ -411,9 +423,10 @@ def apply(
transformed_model = self._backend_entity.transform_model(
model,
graph,
track(all_weight_params, description="Applying Weight Compression", weights=all_weight_sizes),
track(all_weight_params, description=description, weights=all_weight_sizes),
scales,
zero_points,
lora_correction_algo,
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
)

self._backend_entity.dump_parameters(
Expand All @@ -428,6 +441,8 @@ def apply(
"awq": self._awq,
"scale_estimation": self._scale_estimation,
"gptq": self._gptq,
"lora_correction": self._lora_correction,
"advanced_parameters": convert_to_dict_recursively(self._advanced_parameters),
},
algo_name="weight_compression",
)
Expand Down
24 changes: 6 additions & 18 deletions nncf/quantization/algorithms/weight_compression/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_quantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_quantization
from nncf.quantization.passes import transform_to_inference_graph
from nncf.tensor import functions as fns

Expand Down Expand Up @@ -101,9 +102,6 @@ def _set_backend_entity(self, model: TModel) -> None:
Creates a helper class with a backed-specific logic of the algorithm.

:param model: Backend-specific input model.
:param all_weight_params: List of all weight parameters.
:param nodes_to_compress: List of nodes for processing.
:param activations: The input activations of the layers considered for compression.
"""

model_backend = get_backend(model)
Expand Down Expand Up @@ -197,17 +195,7 @@ def apply(

config = wp.compression_config

stats = self._activations[k]
X = fns.stack([fns.mean(stat, axis=0) for stat in stats])
X = fns.transpose(X)

s = fns.max(fns.abs(X), axis=1)

if X.shape[1] > self._subset_size:
lens = [stat.shape[0] for stat in stats]
step = X.shape[1] // self._subset_size
idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step]
X = X[:, idxs]
s, X = process_stats(self._activations[k], self._subset_size)

top_k = max(int(s.shape[0] * self._percent_to_apply), 1)
topk_idxs = fns.argsort(-s)[:top_k]
Expand Down Expand Up @@ -257,10 +245,10 @@ def apply(
for _ in range(self._steps):
cur_scale = gscale**alpha

g_compressed_weighs, g_c_scale, g_c_zp = do_integer_quantization(
g_compressed_weighs, g_c_scale, g_c_zp = do_int_quantization(
gweight * cur_scale, reduction_axis, awq_config
)
g_decompressed_weighs = do_dequantization(g_compressed_weighs, g_c_scale, g_c_zp)
g_decompressed_weighs = do_int_dequantization(g_compressed_weighs, g_c_scale, g_c_zp)
sacts = gacts / fns.unsqueeze(cur_scale, 1)

cur_out = fns.matmul(g_decompressed_weighs, sacts)
Expand Down
28 changes: 28 additions & 0 deletions nncf/quantization/algorithms/weight_compression/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,34 @@ def transform_model(
:return: The transformed model.
"""

@abstractmethod
def insert_adapters(
self, wc_params: WeightCompressionParameters, lora_A: Tensor, lora_B: Tensor, int8_lora: bool
) -> None:
"""
Expands a model's execution graph following the Low-Rank Adaptation (LoRA) concept.

It inserts two additional Linear layers with weight matrices of low rank that are executed in parallel to the
target Linear layer.

Before insertion:

----INPUT
\
orig.MM--------------------------------OUTPUT

After insertion:

----INPUT ----lora_A.MM----lora_B.MM----\
\ add----OUTPUT
orig.MM--------------------------/

:param wc_params: Parameters for weight compression.
:param lora_A: weights for the first LoRA matrix.
:param lora_B: weights for the second LoRA matrix.
:param int8_lora: indicates whether the LoRA matrices should be compressed to 8-bit.
"""

@staticmethod
@abstractmethod
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> TargetPoint:
Expand Down
14 changes: 8 additions & 6 deletions nncf/quantization/algorithms/weight_compression/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_integer_quantization_params
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_scale
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_weight
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_quantized_weight
from nncf.quantization.algorithms.weight_compression.weight_lowering import decompress_nf4_weight
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_quantization
from nncf.tensor import Tensor
from nncf.tensor import functions as fns
from nncf.tensor.definitions import TensorDataType
Expand Down Expand Up @@ -266,13 +266,15 @@ def _quantize_weights(
scales.append(scale)
zero_points.append(zero_point)
if block_compression_config.mode == CompressWeightsMode.NF4:
compressed_weights = calculate_nf4_weight(fns.unsqueeze(weight_col, 1), scales[-1])
quantized_col = decompress_nf4_weight(compressed_weights, scales[-1])
compressed_weights = do_nf4_quantization(
fns.unsqueeze(weight_col, 1), scales[-1], is_normalized_weight=False
)
quantized_col = do_nf4_dequantization(compressed_weights, scales[-1], reduction_axis=-1)
else:
compressed_weights = calculate_quantized_weight(
fns.unsqueeze(weight_col, 1), block_compression_config, scales[-1], zero_points[-1]
)
quantized_col = do_dequantization(compressed_weights, scales[-1], zero_points[-1])
quantized_col = do_int_dequantization(compressed_weights, scales[-1], zero_points[-1])
quantized_col = fns.flatten(quantized_col)
quantized_block[:, i] = quantized_col
loss_block[:, i] = (weight_col - quantized_col) ** 2 / hessian_diag_val**2
Expand Down
Loading
Loading