Skip to content

Commit

Permalink
Added weight compression algorithm for OpenVINO backend (#2059)
Browse files Browse the repository at this point in the history
### Changes

Extended data free int8 weight compression algorithm for OpenVINO
backend

Example (WeightsModel):

![image](https://github.com/openvinotoolkit/nncf/assets/22346860/02138cce-290a-40aa-b997-f83815400a6c)

PR to optimum huggingface/optimum-intel#415

### Reason for changes

Optimize the model footprint and performance of large models where the
size of weights is relatively larger than the size of activations

### Related tickets

117412

### Tests

`tests/openvino/native/quantization/test_weights_compression.py`
swin transformer support verified

Results
Task: lambada_openai
|     Model |Metric|Value |   |Stderr|
|--------------|------|-----:|---|-----:|
|dolly-v2-3b_original| ppl   |5.0144|±  |0.1510|
|              |acc   |0.6297|±  |0.0067|
|dolly-v2-3b_compressed|ppl   |4.9868|±  |0.1498|
|                |acc  |0.6313|±  |0.0067|
|Llama-2-7b-chat-hf_original|ppl   |3.2788|±  |0.0866|
|       |acc   |0.7058|±  |0.0063|
|Llama-2-7b-chat-hf_compressed|ppl   |3.2856|±  |0.0869|
|       |acc   |0.7054|±  |0.0064|
  • Loading branch information
l-bat authored Sep 1, 2023
1 parent e831b97 commit 8a01b01
Show file tree
Hide file tree
Showing 24 changed files with 863 additions and 12,263 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ learning frameworks.
| Compression algorithm |OpenVINO|PyTorch| TensorFlow | ONNX |
|:----------------------------------------------------------------------------| :---: | :---: |:--------:|:------------------:|
| [Post-Training Quantization](./docs/compression_algorithms/post_training/Quantization.md) | Supported | Supported |Supported| Supported |
| [Weights Compression](./docs/compression_algorithms/CompressWeights.md) | Not supported | Supported |Not supported| Not supported |
| [Weights Compression](./docs/compression_algorithms/CompressWeights.md) | Supported | Supported |Not supported| Not supported |

### Training-Time Compression Algorithms

Expand Down
13 changes: 4 additions & 9 deletions docs/compression_algorithms/CompressWeights.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
### Weights Compression

[OpenVINO](https://github.com/openvinotoolkit/openvino) is the preferred backend to run Weights Compression with, and PyTorch is also supported.

#### The algorithm description

The Weights Compression algorithm is aimed at compressing the weights of the models and can be used to optimize the model footprint and performance of large models where the size of weights is relatively larger than the size of activations, for example, Large Language Models (LLM). The algorithm compresses weights only for Linear and Embedding layers. It is also possible to keep the precision of the original weights and insert FakeQuantize operations by setting `use_fake_quantize` parameter to `True`.
The Weights Compression algorithm is aimed at compressing the weights of the models and can be used to optimize the model footprint and performance of large models where the size of weights is relatively larger than the size of activations, for example, Large Language Models (LLM). The algorithm compresses weights only for Linear and Embedding layers.

#### User guide

Expand All @@ -13,15 +15,8 @@ from nncf import compress_weights
compressed_model = compress_weights(model)
```

- Insert FakeQuantize layers for weights of linear layers and embeddings

```python
from nncf import compress_weights
model_with_fake_quantize = compress_weights(model, use_fake_quantize=True)
```

##### Limitations

- The algorithm is supported for PyTorch only.
- The algorithm is supported for OpenVINO and PyTorch models.
- The compression applies in-place.
- The compressed model is not trainable.
2 changes: 1 addition & 1 deletion nncf/common/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_backend(model: TModel) -> BackendType:
raise RuntimeError(
"Could not infer the backend framework from the model type because "
"the framework is not available or the model type is unsupported. "
"The available frameworks found: {}.".format(", ".join(available_backends))
"The available frameworks found: {}.".format(", ".join([b.value for b in available_backends]))
)


Expand Down
29 changes: 1 addition & 28 deletions nncf/onnx/quantization/quantizer_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np

from nncf.quantization.fake_quantize import FakeQuantizeParameters
from nncf.quantization.fake_quantize import calculate_scale_zero_point


@dataclass
Expand Down Expand Up @@ -75,31 +76,3 @@ def get_level_low_level_high(tensor_type: np.dtype) -> Tuple[int, int]:
:return: Minimum level and maximum level of the quantizer.
"""
return (0, 255) if tensor_type == np.uint8 else (-128, 127)


def calculate_scale_zero_point(
input_low: np.ndarray, input_high: np.ndarray, level_low: int, level_high: int, narrow_range: bool
) -> Tuple[np.ndarray, np.ndarray]:
"""
Calculates Quantizer/Dequantizer layer scale level.
Returns scale and zero_point values for the quantizer.
:param input_low: The minimum limit for an input value based on collected statistics.
:param input_high: The maximum limit for an input value based on collected statistics.
:param level_low: The minimum level in the integer range to quantize.
The default is "0" for an unsigned range, and "-2^(bit-1)" for a signed one .
:param level_high: The maximum level in the integer range to quantize.
The default is "2^bits-1" for an unsigned range, and "2^(bit-1)-1" for a signed one.
:param narrow_range: True if the range of quantized values is narrowed as compared to the
naive case, False otherwise.
:return: Scale and Zero point values.
"""
levels = level_high - level_low if narrow_range else level_high - level_low + 1
scale = np.array((input_high - input_low) / (levels - 1))
expected_level_low = level_low + 1 if narrow_range else level_low
zero_point = expected_level_low - np.round(input_low / scale)
zero_point = np.minimum(np.maximum(zero_point.astype(np.int32), level_low), level_high)
scale = np.array(np.squeeze(scale).astype(np.float32))
zero_point = np.array(np.squeeze(zero_point))

return scale, zero_point
18 changes: 18 additions & 0 deletions nncf/openvino/graph/metatypes/openvino_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nncf.common.graph.operator_metatypes import OUTPUT_NOOP_METATYPES
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.operator_metatypes import OperatorMetatypeRegistry
from nncf.common.graph.operator_metatypes import UnknownMetatype
from nncf.common.hardware.opset import HWConfigOpName

OV_OPERATOR_METATYPES = OperatorMetatypeRegistry("openvino_operator_metatypes")
Expand Down Expand Up @@ -772,3 +773,20 @@ def _is_embedding(node: ov.Node) -> bool:
return True

return False


def get_node_metatype(node: ov.Node) -> Type[OperatorMetatype]:
"""
Determine NNCF meta type for OpenVINO node.
:param node: OpenVINO node.
:return: NNCF meta type which corresponds to OpenVINO node.
"""
node_type = node.get_type_name()
metatype = OV_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
if metatype is not UnknownMetatype:
if metatype.get_subtypes():
subtype = metatype.determine_subtype(node)
if subtype is not None:
metatype = subtype
return metatype
24 changes: 3 additions & 21 deletions nncf/openvino/graph/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@
from nncf.common.graph import NNCFGraph
from nncf.common.graph.layer_attributes import Dtype
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.operator_metatypes import UnknownMetatype
from nncf.openvino.graph.layer_attributes import OVLayerAttributes
from nncf.openvino.graph.layer_attributes import get_weighted_layer_attributes
from nncf.openvino.graph.metatypes.openvino_metatypes import METATYPES_WITH_CONST_PORT_ID
from nncf.openvino.graph.metatypes.openvino_metatypes import OV_OPERATOR_METATYPES
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionBackpropDataMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVGroupConvolutionBackpropDataMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVGRUSequenceMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVLSTMSequenceMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import get_node_metatype
from nncf.openvino.graph.metatypes.openvino_metatypes import get_operation_const_op


Expand Down Expand Up @@ -78,23 +77,6 @@ def _filter_weight_input_ports(inputs: List[ov.Input], metatype: Type[OperatorMe
return inputs[:6]
return inputs

@staticmethod
def _get_node_metatype(node: ov.Node) -> Type[OperatorMetatype]:
"""
Determine NNCF meta type for OpenVINO node.
:param node: OpenVINO node.
:return: NNCF meta type which corresponds to OpenVINO node.
"""
node_type = node.get_type_name()
metatype = OV_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
if metatype is not UnknownMetatype:
if metatype.get_subtypes():
subtype = metatype.determine_subtype(node)
if subtype is not None:
metatype = subtype
return metatype

@staticmethod
def _add_edges_to_nncf_graph(model: ov.Model, graph: NNCFGraph) -> None:
"""
Expand Down Expand Up @@ -130,7 +112,7 @@ def _add_nncf_node(node: ov.Node, graph: NNCFGraph) -> None:
:param graph: NNCFGraph.
"""
node_type = node.get_type_name()
metatype = GraphConverter._get_node_metatype(node)
metatype = get_node_metatype(node)
graph.add_nncf_node(node_name=node.get_friendly_name(), node_type=node_type, node_metatype=metatype)

@staticmethod
Expand Down Expand Up @@ -159,7 +141,7 @@ def create_nncf_graph(model: ov.Model) -> NNCFGraph:
inference_nodes.append(inp.get_node())

for node in model.get_ops():
metatype = GraphConverter._get_node_metatype(node)
metatype = get_node_metatype(node)
# Add nodes from constant subgraphs
node_name = node.get_friendly_name()
if node_name not in visited:
Expand Down
31 changes: 22 additions & 9 deletions nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,20 +325,33 @@ def get_weight_channel_axes(node: NNCFNode, weights_port_id: int) -> List[int]:
if node.metatype == OVMatMulMetatype:
assert isinstance(node.layer_attributes, OVLayerAttributes)
assert len(channel_axes) == 1
matmul_channel_axis = channel_axes[0]
const_attrs = node.layer_attributes.constant_attributes[weights_port_id]
if (weights_port_id == 1) == const_attrs["transpose"]:
matmul_channel_axis -= 1
shape = const_attrs["shape"]
ndims = len(shape)
channel_axes = list(range(ndims - 2)) if ndims > 2 else []
matmul_channel_axis = max(ndims, 2) + matmul_channel_axis
if matmul_channel_axis < ndims:
channel_axes.append(matmul_channel_axis)
transpose = const_attrs["transpose"]
ndims = len(const_attrs["shape"])
channel_axes = get_matmul_channel_axes(weights_port_id, ndims, transpose)

return channel_axes


def get_matmul_channel_axes(weights_port_id: int, ndims: int, transpose: bool) -> List[int]:
"""
Calculate channel axes for the MatMul operation.
:param weights_port_id: Weight port id of the target node.
:param ndims: The number of MatMul dimensions.
:param transpose: Whether the transpose is applied to weights.
:return: List of channel axes for the MatMul operation.
"""
matmul_channel_axis = OVMatMulMetatype.const_channel_axis[0]
if (weights_port_id == 1) == transpose:
matmul_channel_axis -= 1
matmul_channel_axis = max(ndims, 2) + matmul_channel_axis
channel_axes = list(range(ndims - 2))
if matmul_channel_axis < ndims:
channel_axes.append(matmul_channel_axis)
return channel_axes


def get_channel_agnostic_reduction_shape(channel_axes: List[int], shape: List[int]) -> Tuple[int]:
"""
Returns filtered reduction shape without axes that corresponds channels.
Expand Down
9 changes: 9 additions & 0 deletions nncf/openvino/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nncf.openvino.graph.nncf_graph_builder import GraphConverter
from nncf.openvino.quantization.backend_parameters import BackendParameters
from nncf.openvino.quantization.backend_parameters import is_weight_compression_needed
from nncf.openvino.quantization.weights_compression import insert_pre_compression_operations
from nncf.parameters import DropType
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
Expand Down Expand Up @@ -369,3 +370,11 @@ def quantize_with_accuracy_control_impl(
advanced_quantization_parameters,
advanced_accuracy_restorer_parameters,
)


def compress_weights_impl(model: ov.Model) -> ov.Model:
"""
Implementation of the `compress_weights()` method for the OpenVINO backend.
"""
insert_pre_compression_operations(model)
return model
99 changes: 99 additions & 0 deletions nncf/openvino/quantization/weights_compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple, Type, Union

import numpy as np
import openvino.runtime as ov
from openvino.runtime import opset9 as opset

from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVEmbeddingMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import get_node_metatype
from nncf.openvino.graph.metatypes.openvino_metatypes import get_operation_const_op
from nncf.openvino.graph.node_utils import get_const_value
from nncf.openvino.graph.node_utils import get_matmul_channel_axes
from nncf.quantization.fake_quantize import calculate_scale_zero_point


def insert_pre_compression_operations(model: ov.Model, bits: int = 8) -> None:
"""
Compress weights of Linear and Embedding layers to uint8.
The result of compression is the same as asymmetric weight quantization.
:param model: The model to be transformed.
:param bits: Number of bits for quantization.
"""
allowed_metatypes_to_const_port = {OVEmbeddingMetatype: [0], OVMatMulMetatype: [0, 1]}
level_low = 0
level_high = 2**bits - 1

for node in model.get_ops():
metatype = get_node_metatype(node)
if metatype not in allowed_metatypes_to_const_port:
continue

for const_port_id in allowed_metatypes_to_const_port[metatype]:
weight_node = get_operation_const_op(node, const_port_id)
if weight_node is None:
continue

weight_output = weight_node.output(0)
weight_name = weight_node.get_friendly_name()
target_inputs = weight_output.get_target_inputs()

original_weight_dtype = weight_output.get_element_type().to_dtype()
if original_weight_dtype not in [np.float32, np.float16, np.float64]:
continue

weight = get_const_value(weight_node)
axes = _get_reduction_axes(metatype, node, const_port_id)
min_values = np.min(weight, axis=axes, keepdims=True)
max_values = np.max(weight, axis=axes, keepdims=True)

scale, zero_point = calculate_scale_zero_point(
min_values, max_values, level_low, level_high, narrow_range=False
)

compressed_weights = np.round(weight / scale + zero_point)
compressed_weights = np.clip(compressed_weights, level_low, level_high).astype(np.uint8)

compressed_const = opset.constant(compressed_weights, dtype=np.uint8, name=weight_name)
convert = opset.convert(compressed_const, original_weight_dtype)
sub = opset.subtract(convert, zero_point.astype(original_weight_dtype))
fq_name = f"{node.get_friendly_name()}/fq_weights_{const_port_id}"
mul = opset.multiply(sub, scale.astype(original_weight_dtype), name=fq_name)

for target_input in target_inputs:
target_input.replace_source_output(mul.output(0))


def _get_reduction_axes(metatype: Type[OperatorMetatype], node: ov.Node, weight_port_id: int) -> Union[int, Tuple[int]]:
"""
Determines reduction axes by given metatype and node information.
:param metatype: The metatype of the operator.
:param node: The OpenVINO node.
:param weight_port_id: The weight port ID.
:return: The reduction axes as an integer or a tuple of integers.
"""
if metatype is OVMatMulMetatype:
transpose = node.get_attributes()[f"transpose_{'a' if weight_port_id == 0 else 'b'}"]
ndims = node.input(weight_port_id).get_partial_shape().rank.get_max_length()
channel_axes = get_matmul_channel_axes(weight_port_id, ndims, transpose)
axes = tuple(i for i in range(ndims) if i not in channel_axes)
elif metatype is OVEmbeddingMetatype:
axes = (metatype.const_channel_axis[0] + 1) % 2
else:
RuntimeError("Unsupported metatype to find reduction axes.")
return axes
24 changes: 24 additions & 0 deletions nncf/quantization/fake_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,27 @@ def _calculate_scaled_parameters(
input_low *= (export_levels - 1) / (levels - 1)

return input_low, input_high, export_levels


def calculate_scale_zero_point(
input_low: np.ndarray, input_high: np.ndarray, level_low: int, level_high: int, narrow_range: bool
) -> Tuple[np.ndarray, np.ndarray]:
"""
Calculates scale and zero_point values for the quantizer.
:param input_low: The minimum limit for an input value based on collected statistics.
:param input_high: The maximum limit for an input value based on collected statistics.
:param level_low: The minimum level in the integer range to quantize.
The default is "0" for an unsigned range, and "-2^(bit-1)" for a signed one .
:param level_high: The maximum level in the integer range to quantize.
The default is "2^bits-1" for an unsigned range, and "2^(bit-1)-1" for a signed one.
:param narrow_range: True if the range of quantized values is narrowed as compared to the
naive case, False otherwise.
:return: Scale and Zero point values.
"""
levels = level_high - level_low if narrow_range else level_high - level_low + 1
scale = np.array((input_high - input_low) / (levels - 1)).astype(np.float32)
expected_level_low = level_low + 1 if narrow_range else level_low
zero_point = expected_level_low - np.round(input_low / scale)
zero_point = np.clip(zero_point.astype(np.int32), level_low, level_high)
return scale, zero_point
16 changes: 8 additions & 8 deletions nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,22 +226,22 @@ def quantize_with_accuracy_control(


@api(canonical_alias="nncf.compress_weights")
def compress_weights(model: TModel, use_fake_quantize: bool = False) -> TModel:
def compress_weights(model: TModel) -> TModel:
"""
Compress model weights.
:param model: A model to be compressed.
:param use_fake_quantize: Disables real compression of weights in Linear and Embedding layers.
If True inserts fake quantization operations,
else compress weights to int8 and inserts custom dequantization.
:return: The model with compressed weight and dequantization or model with original weights and fake quantization.
Not trainable.
:return: The non-trainable model with compressed weights.
"""
backend = get_backend(model)
if backend == BackendType.TORCH:
import nncf.torch
from nncf.torch.quantization.quantize_model import compress_weights_impl

return nncf.torch.compress_weights(model, use_fake_quantize)
return compress_weights_impl(model)
if backend == BackendType.OPENVINO:
from nncf.openvino.quantization.quantize_model import compress_weights_impl

return compress_weights_impl(model)

raise RuntimeError(f"Unsupported type of backend: {backend}")

Expand Down
Loading

0 comments on commit 8a01b01

Please sign in to comment.