Skip to content

Commit

Permalink
Compress weights with a single reduction axis only (#2254)
Browse files Browse the repository at this point in the history
### Changes

Exclude from weight compression nodes that has more than one reduction
axes

### Reason for changes

There's only one model that has multiple reduction axes.
It's `chatglm` with one embedding layer having [8132,32,2] shape. It was
decided to not quantize this layer, since it would save just 6Mb in 4Gb
model in case of int8 quantization with risk to reduce accuracy, and it
can't be quantized group-wise.

The idea is to switch to multiple reduction axes when it will be really
needed.

### Related tickets

n/a

### Tests

Tested on 104 models from share with IR's for llm models. In all cases
except chatglm there's a single reduction axis.
  • Loading branch information
ljaljushkin authored Nov 14, 2023
1 parent fd18823 commit b4b2e19
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 48 deletions.
2 changes: 1 addition & 1 deletion nncf/common/graph/layer_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
:param weight_requires_grad: Is True if gradients need to be computed for the corresponding Tensor,
False otherwise.
:param weight_shape: shape of weight tensor.
:param filter_dimension_idx: the axis along which the filters are stored.
:param filter_dimension_idx: the axis, along which the filters are stored.
"""
super().__init__(weight_requires_grad=weight_requires_grad, with_bias=with_bias)
self.weight_shape = weight_shape
Expand Down
2 changes: 1 addition & 1 deletion nncf/common/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class OperatorMetatype:
:param name: The name of the operator.
:param hw_config_names: The names of the hardware configurations.
:param output_channel_axis: The axis along which the output channels of the operator are arranged.
:param output_channel_axis: The axis, along which the output channels of the operator are arranged.
:param ignored_input_ports: Input ports of the operations that should not be considered for purposes of compression.
"""

Expand Down
2 changes: 1 addition & 1 deletion nncf/common/pruning/tensor_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def concatenate(cls, tensors: List[NNCFTensor], axis: int) -> NNCFTensor:
Join a list of NNCFTensors along an existing axis.
:param tensors: List of NNCFTensors.
:param axis: The axis along which the tensors will be joined.
:param axis: The axis, along which the tensors will be joined.
:returns: The concatenated List of the tensors.
"""

Expand Down
2 changes: 1 addition & 1 deletion nncf/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor:
Join a sequence of arrays along an existing axis.
:param x: The input tensor.
:param axis: The axis along which the arrays will be joined.
:param axis: The axis, along which the arrays will be joined.
:return: The concatenated array.
"""

Expand Down
8 changes: 4 additions & 4 deletions nncf/onnx/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ def get_reduction_shape(shape: List[int], axis: int) -> ReductionAxes:

def _get_weight_quantization_axis(node: NNCFNode, port_id: int) -> int:
"""
Returns weight tensor axis along quantizer parameters are calculated.
Returns weight tensor axis, along which quantizer parameters are calculated.
:param node: NNCFNode, which has a weight on input port_id.
:param port_id: Input port id on which there is a weight of a node.
:return: Axis along quantizer parameters are calculated.
:return: Axis, along which quantizer parameters are calculated.
"""
weight_channel_axis = node.metatype.weight_channel_axis
if node.layer_attributes.has_node_attrs():
Expand All @@ -174,9 +174,9 @@ def _get_weight_quantization_axis(node: NNCFNode, port_id: int) -> int:

def _get_activation_quantization_axis() -> int:
"""
Returns activation tensor axis along quantizer parameters are calculated.
Returns activation tensor axis, along which quantizer parameters are calculated.
:return: Axis along quantizer parameters are calculated.
:return: Axis, along which quantizer parameters are calculated.
"""
return 1 # Activations have channel first layout: [N, C, Z, Y, X]

Expand Down
70 changes: 37 additions & 33 deletions nncf/quantization/algorithms/weight_compression/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Tuple, TypeVar, Union
from typing import List, Optional, Tuple, TypeVar

import numpy as np
import openvino.runtime as ov
Expand Down Expand Up @@ -73,13 +73,23 @@ def do_compression(
continue
const_shape = nncf_node.layer_attributes.constant_attributes[weight_port_id]["shape"]
channel_axes = get_weight_channel_axes(nncf_node, weight_port_id)
axes = get_channel_agnostic_reduction_axes(channel_axes, const_shape)
reduction_axes = get_channel_agnostic_reduction_axes(channel_axes, const_shape)
if isinstance(reduction_axes, tuple) and len(reduction_axes) != 1:
nncf_logger.warning(
f"Weight compression expects a single reduction axes, but given {len(reduction_axes)}. "
f"Weight shape: {const_shape}, reduction axes: {reduction_axes}, node name: {nncf_node.name}. "
"The node won't be quantized."
)
continue
reduction_axis = reduction_axes[0] if isinstance(reduction_axes, tuple) else reduction_axes

fq_name = f"{weight_op_friendly_name}/fq_weights_{weight_port_id}"
num_weights = np.prod(const_shape)
weight_params = WeightNodeParams(axes, num_weights, fq_name, weight_node, original_weight_dtype)
weight_params = WeightNodeParams(
reduction_axis, num_weights, fq_name, weight_node, original_weight_dtype
)
all_weight_params.append(weight_params)
quantized_nodes_ids.add(id(weight_node))

if mode != CompressWeightsMode.INT8:
primary_config = WeightCompressionConfig(mode=mode, group_size=group_size)
_assign_mixed_precision(all_weight_params, ratio, primary_config)
Expand All @@ -98,7 +108,7 @@ def do_compression(
config = wp.compression_config
if config.mode == CompressWeightsMode.NF4:
original_shape = weight.shape
norm_weight, scale = _get_norm_weight_and_nf4_scale(weight, wp.reduction_axes, group_size)
norm_weight, scale = _get_norm_weight_and_nf4_scale(weight, wp.reduction_axis, group_size)
compressed_const = opset.constant(norm_weight, dtype=ov.Type.nf4, name=weight_name)
convert = opset.convert(compressed_const, original_weight_dtype)
mul = opset.multiply(convert, scale.astype(original_weight_dtype), name=wp.fq_name)
Expand All @@ -107,7 +117,7 @@ def do_compression(
last_output = mul.output(0)
else:
original_shape = weight.shape
compressed_weights, scale, zero_point = _do_integer_quantization(weight, wp.reduction_axes, config)
compressed_weights, scale, zero_point = _do_integer_quantization(weight, wp.reduction_axis, config)
compression_type = np.uint8 if config.num_bits == 8 else ov.Type.u4
compressed_weights_node = opset.constant(compressed_weights, dtype=compression_type, name=weight_name)
convert_weights_node = opset.convert(compressed_weights_node, original_weight_dtype)
Expand Down Expand Up @@ -153,15 +163,15 @@ class WeightNodeParams:
"""
Information about weight node in the ov.Model that is useful for weight compression.
:param reduction_axes: Axis or axes along which to reduce (collect) different statistics (e.g. min, max).
:param reduction_axis: Axis, along which to reduce (collect) different statistics (e.g. min, max).
:param num_weights: Number of elements in the weight array.
:param fq_name: Name for the inserted weight compression operation.
:param weight_node: The weight node itself.
:param original_weight_dtype: Type of elements in the weight array.
:param compression_config: Configuration of weight compression for the weight node.
"""

reduction_axes: Union[int, Tuple[int]]
reduction_axis: int
num_weights: int
fq_name: str
weight_node: ov.Node
Expand All @@ -170,7 +180,7 @@ class WeightNodeParams:


def _do_integer_quantization(
weight: np.ndarray, reduction_axes: Union[int, Tuple[int]], config: WeightCompressionConfig
weight: np.ndarray, reduction_axis: int, config: WeightCompressionConfig
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
The method quantizes the given weights to integer data type in accordance with the compression config.
Expand All @@ -186,7 +196,7 @@ def _do_integer_quantization(
(scales).
:param weight: Weight array to compress.
:param reduction_axes: Axis or axes along which to reduce (collect) different statistics (e.g. min, max).
:param reduction_axis: Axis, along which to reduce (collect) different statistics (e.g. min, max).
:param config: Information on how to compress (quantize) a specific weight.
:return: The compressed weights, scale and zero point that was used for its quantization.
"""
Expand All @@ -200,16 +210,16 @@ def _do_integer_quantization(

if group_size != -1:
# weights are reshaped from [a1, r, a2] to [a1, r//gs, gs, a2]
weight, reduction_axes = _reshape_weights_for_grouped_quantization(weight, reduction_axes, group_size)
weight, reduction_axis = _reshape_weights_for_grouped_quantization(weight, reduction_axis, group_size)

if mode in [CompressWeightsMode.INT8, CompressWeightsMode.INT4_ASYM]:
min_values = np.min(weight, axis=reduction_axes, keepdims=True) # [a1, r, a2] -> [a1, 1, a2]
max_values = np.max(weight, axis=reduction_axes, keepdims=True) # [a1, r, a2] -> [a1, 1, a2]
min_values = np.min(weight, axis=reduction_axis, keepdims=True) # [a1, r, a2] -> [a1, 1, a2]
max_values = np.max(weight, axis=reduction_axis, keepdims=True) # [a1, r, a2] -> [a1, 1, a2]
scale, zero_point = calculate_scale_zero_point(
min_values, max_values, level_low, level_high, narrow_range=False
)
else:
scale = np.max(np.abs(weight), axis=reduction_axes, keepdims=True) # [a1, r//gs, 1, a2]
scale = np.max(np.abs(weight), axis=reduction_axis, keepdims=True) # [a1, r//gs, 1, a2]
level_low_sym = -(2 ** (num_bits - 1))
level_high_sym = 2 ** (num_bits - 1) - 1
scale = scale / level_high_sym
Expand All @@ -223,50 +233,44 @@ def _do_integer_quantization(
return compressed_weights, scale, zero_point


def _get_integer_quantization_error(
weight: np.ndarray, reduction_axes: Union[int, Tuple[int]], config: WeightCompressionConfig
) -> float:
def _get_integer_quantization_error(weight: np.ndarray, reduction_axis: int, config: WeightCompressionConfig) -> float:
"""
Calculates a quantity characterizing the difference between floating point weights and fake quantized
(compressed and decompressed) to integer ones.
:param weight: Weight array to compress.
:param reduction_axes: Axis or axes along which to reduce (collect) different statistics (e.g. min, max).
:param reduction_axis: Axis, along which to reduce (collect) different statistics (e.g. min, max).
:param config: Information on how to compress (quantize) a specific weight.
:return: The quantity characterizing the error of integer quantization.
"""
orig_shape = weight.shape
compressed_weights, scale, zero_point = _do_integer_quantization(weight, reduction_axes, config)
compressed_weights, scale, zero_point = _do_integer_quantization(weight, reduction_axis, config)

decompressed_weight = compressed_weights.astype(dtype=scale.dtype)
decompressed_weight = (compressed_weights - zero_point) * scale

decompressed_weight = decompressed_weight.reshape(orig_shape)
diff = (decompressed_weight - weight) ** 2
layer_err = np.mean(diff, axis=reduction_axes)
layer_err = np.mean(diff, axis=reduction_axis)
val = np.max(layer_err)
return val


def _reshape_weights_for_grouped_quantization(
weight: np.ndarray, reduction_axes: Union[int, Tuple[int]], group_size: int
weight: np.ndarray, reduction_axis: int, group_size: int
) -> Tuple[np.ndarray, int]:
"""
Reshapes weights for group-wise quantization and return a new reduction axis for collecting statistics per group
dimension. Having weights with shapes [c_out, c_in] and group size = 128, shape of reshaped weights is
[c_out, c_in // 128, 128].
:param weight: Weight array to compress.
:param reduction_axes: Axis or axes along which to reduce (collect) different statistics (e.g. min, max).
:param reduction_axis: Axis, along which to reduce (collect) different statistics (e.g. min, max).
:param group_size: Number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale).
:return: reshaped weights and new reduction axis.
"""
assert group_size != -1
if isinstance(reduction_axes, tuple) and len(reduction_axes) != 1:
raise RuntimeError(
f"group-quantization is supported for a single reduction axes, but got {len(reduction_axes)}"
)
reduction_axis = reduction_axes[0] if isinstance(reduction_axes, tuple) else reduction_axes
assert isinstance(reduction_axis, int)
channel_size = weight.shape[reduction_axis]
if channel_size % group_size != 0:
raise RuntimeError(f"Channel size {channel_size} should be divisible by size of group {group_size}")
Expand All @@ -280,24 +284,24 @@ def _reshape_weights_for_grouped_quantization(


def _get_norm_weight_and_nf4_scale(
weight: np.ndarray, reduction_axes: Tuple[int], group_size: int = -1
weight: np.ndarray, reduction_axis: int, group_size: int = -1
) -> Tuple[np.ndarray, np.ndarray]:
"""
Calculates scale for nf4 quantization and normalizes weights by the scale.
Weights are reshaped in case of positive value of group size.
:param weight: Weight array to compress.
:param reduction_axes: Axis or axes along which to reduce (collect) different statistics (e.g. min, max).
:param reduction_axis: Axis, along which to reduce (collect) different statistics (e.g. min, max).
:param group_size: Number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale).
The value -1 means no grouping. Defaults to -1.
:return: Normalized weights and nf4 scale.
"""
if group_size != -1:
# weights are reshaped: [a1, r, a2] -> [a1, r//gs, gs, a2]
weight, reduction_axis = _reshape_weights_for_grouped_quantization(weight, reduction_axes, group_size)
weight, reduction_axis = _reshape_weights_for_grouped_quantization(weight, reduction_axis, group_size)
scale = np.max(np.abs(weight), axis=reduction_axis, keepdims=True) # [a1, r//gs, 1, a2]
else:
scale = np.max(np.abs(weight), axis=reduction_axes, keepdims=True) # [a1, 1, a2]
scale = np.max(np.abs(weight), axis=reduction_axis, keepdims=True) # [a1, 1, a2]
eps = np.finfo(weight.dtype).eps
# NOTE: adding machine epsilon to avoid division by zero
scale[np.abs(scale) < eps] = eps
Expand Down Expand Up @@ -372,8 +376,8 @@ def _assign_mixed_precision(
for weight_param in track(all_weight_params[1:-1], description="Searching for Mixed-Precision Configuration"):
weight = get_const_value(weight_param.weight_node)
backup_config = weight_param.compression_config
reduction_axes = weight_param.reduction_axes
backup_error = _get_integer_quantization_error(weight, reduction_axes, backup_config)
reduction_axis = weight_param.reduction_axis
backup_error = _get_integer_quantization_error(weight, reduction_axis, backup_config)
eps = np.finfo(weight.dtype).eps
error = 1 / (backup_error + eps)
errors.append(error)
Expand Down
14 changes: 7 additions & 7 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def __str__(self):
@pytest.mark.parametrize("desc", LIST_DESCS, ids=map(str, LIST_DESCS))
def test_quantization_error_calculation(desc: QuantErrorDesc):
weight = desc.weight
axis = (1,)
axis = 1
actual_error = _get_integer_quantization_error(weight, axis, desc.config)
ref_error = desc.ref_error
atol = desc.atol if desc.atol is not None else 1e-8
Expand Down Expand Up @@ -374,20 +374,20 @@ def test_weight_compress_with_ignored_scope(ignored_scope, num_compressed):
@pytest.mark.parametrize("desc", CALCULATE_SCALE_DESCS)
def test_calculate_scale_per_group(desc: CalculateScaleDesc):
reshaped_weight, reduction_axis = _reshape_weights_for_grouped_quantization(
desc.weight, reduction_axes=desc.axis, group_size=desc.group_size
desc.weight, reduction_axis=desc.axis, group_size=desc.group_size
)
act_scale = np.max(np.abs(reshaped_weight), axis=reduction_axis, keepdims=True) # [a1, r//gs, 1, a2]
assert np.allclose(act_scale, desc.ref_scale)


def test_raise_error_for_many_axes():
with pytest.raises(RuntimeError):
_reshape_weights_for_grouped_quantization(WEIGHTS_2x4, reduction_axes=(0, 1), group_size=1)
with pytest.raises(AssertionError):
_reshape_weights_for_grouped_quantization(WEIGHTS_2x4, reduction_axis=(0, 1), group_size=1)


def test_raise_error_with_incorrect_group_size():
with pytest.raises(RuntimeError):
_reshape_weights_for_grouped_quantization(WEIGHTS_2x4, reduction_axes=(0,), group_size=3)
def test_raise_error_with_tuple():
with pytest.raises(AssertionError):
_reshape_weights_for_grouped_quantization(WEIGHTS_2x4, reduction_axis=(0,), group_size=3)


def test_raise_error_with_int8_and_non_default_ratio(mocker):
Expand Down

0 comments on commit b4b2e19

Please sign in to comment.