From ffee0e45652fabf32d12924fb78f0f4419b320a3 Mon Sep 17 00:00:00 2001 From: Lyalyushkin Nikolay Date: Thu, 5 Oct 2023 16:08:13 +0200 Subject: [PATCH 01/10] Baseline mixed nf4-int8 quantization in NNCF via OV backend (#2150) ### Changes Implementation of [nncf.compress_weights()](https://openvinotoolkit.github.io/nncf/autoapi/nncf/index.html#nncf.compress_weights) for the case of mixed nf4-int8 grouped quantization. ### Reason for changes int4 support for llm ### Related tickets 119710 ### Tests Some tests are not going to work with `ov.nf4.type` until switching to OpenVINO with https://github.com/openvinotoolkit/openvino/pull/19900 test_calculate_scale_per_group test_quantization_error_calculation test_compare_compressed_weights_nf4 test_compress_weights_nf4 Things to complete: - [x] Reshape issue with grouped quantization (https://github.com/openvinotoolkit/openvino/pull/19987) - [x] NF4 accuracy in openvino backend vs reference implementation in pytorch. ![image](https://github.com/openvinotoolkit/nncf/assets/4014476/42d43e37-37f8-4ff9-afa3-f97dc0744c82) - [x] Nf4 compression time vs torch baseline vs itn8 weight compression ![image](https://github.com/openvinotoolkit/nncf/assets/4014476/73fe01a7-d84c-4810-9f4a-f29bcc6278e9) - [x] Test for mixed precision - [x] docs, style --- nncf/__init__.py | 1 + nncf/openvino/quantization/quantize_model.py | 22 +- .../quantization/weights_compression.py | 376 +++++++++++++++++- nncf/parameters.py | 15 + nncf/quantization/fake_quantize.py | 3 + nncf/quantization/quantize_model.py | 34 +- nncf/torch/quantization/quantize_model.py | 26 +- .../torch/quantization/weights_compression.py | 3 +- .../test_calculation_quantizer_params.py | 5 + .../IntegerModel_compressed_weights_nf4.json | 54 +++ tests/openvino/native/models.py | 49 +++ .../quantization/test_weights_compression.py | 282 ++++++++++++- tests/torch/ptq/test_weights_compression.py | 17 + 13 files changed, 848 insertions(+), 39 deletions(-) create mode 100644 tests/openvino/native/data/reference_scales/IntegerModel_compressed_weights_nf4.json diff --git a/nncf/__init__.py b/nncf/__init__.py index 07fac90b908..f61fe7ec3bc 100644 --- a/nncf/__init__.py +++ b/nncf/__init__.py @@ -18,6 +18,7 @@ from nncf.common.strip import strip from nncf.config import NNCFConfig from nncf.data import Dataset +from nncf.parameters import CompressWeightsMode from nncf.parameters import DropType from nncf.parameters import ModelType from nncf.parameters import TargetDevice diff --git a/nncf/openvino/quantization/quantize_model.py b/nncf/openvino/quantization/quantize_model.py index fba0d2853cc..ad07f6fdcf4 100644 --- a/nncf/openvino/quantization/quantize_model.py +++ b/nncf/openvino/quantization/quantize_model.py @@ -25,6 +25,7 @@ from nncf.openvino.quantization.backend_parameters import is_weight_compression_needed from nncf.openvino.quantization.quantize_ifmodel import apply_algorithm_if_bodies from nncf.openvino.quantization.weights_compression import insert_pre_compression_operations +from nncf.parameters import CompressWeightsMode from nncf.parameters import DropType from nncf.parameters import ModelType from nncf.parameters import TargetDevice @@ -438,9 +439,26 @@ def quantize_with_accuracy_control_impl( ) -def compress_weights_impl(model: ov.Model) -> ov.Model: +def compress_weights_impl( + model: ov.Model, + mode: CompressWeightsMode = CompressWeightsMode.INT8, + ratio: Optional[float] = None, + group_size: Optional[int] = None, +) -> ov.Model: """ Implementation of the `compress_weights()` method for the OpenVINO backend. + + :param model: an OpenVINO model for compression. + :param mode: Defines a mode for weight compression. + INT8 stands for 8-bit integer quantization of all weights. + NF4 stands for a mixed-precision weights quantization to NF4 data type. The first and last layers + are always compressed to a backup precision which is 8-bit integer by default. All others are quantized whether + to NF4 or to a backup precision depending on criteria and the given ratio. + :param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4 and + the rest to INT8). + :param group_size: number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale). + The value -1 means no grouping. + :return: The non-trainable model with compressed weights and dequantization operations. """ - insert_pre_compression_operations(model) + insert_pre_compression_operations(model, mode, ratio, group_size) return model diff --git a/nncf/openvino/quantization/weights_compression.py b/nncf/openvino/quantization/weights_compression.py index 8b23834e36d..fa588cd6c63 100644 --- a/nncf/openvino/quantization/weights_compression.py +++ b/nncf/openvino/quantization/weights_compression.py @@ -9,35 +9,346 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Type, Union +from dataclasses import dataclass +from typing import List, Optional, Tuple, Type, TypeVar, Union import numpy as np import openvino.runtime as ov from openvino.runtime import opset9 as opset from nncf.common.graph.operator_metatypes import OperatorMetatype +from nncf.common.logging import nncf_logger +from nncf.common.quantization.statistics import _proportion_str +from nncf.common.utils.helpers import create_table from nncf.openvino.graph.metatypes.openvino_metatypes import OVEmbeddingMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import get_node_metatype from nncf.openvino.graph.metatypes.openvino_metatypes import get_operation_const_op from nncf.openvino.graph.node_utils import get_const_value from nncf.openvino.graph.node_utils import get_matmul_channel_axes +from nncf.parameters import CompressWeightsMode from nncf.quantization.fake_quantize import calculate_scale_zero_point +TWeightType = TypeVar("TWeightType") -def insert_pre_compression_operations(model: ov.Model, bits: int = 8) -> None: +NF4_QUANTILES = np.array( + [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ] +) + + +CENTER_OF_NF4_QUANTILES = np.array( + [ + -0.8480964004993439, + -0.6106329262256622, + -0.4599952697753906, + -0.33967943489551544, + -0.23460740596055984, + -0.13791173323988914, + -0.045525018125772476, + 0.03979014977812767, + 0.1202552504837513, + 0.2035212516784668, + 0.2920137718319893, + 0.3893125355243683, + 0.5016634166240692, + 0.6427869200706482, + 0.8614784181118011, + ] +) + + +@dataclass +class WeightCompressionConfig: + """ + Information on how to compress (quantize) a specific weight. + + :param num_bits: number of bits for storing a single quantized value. 8, by default. + :param is_nf4: is NF4 format used for quantization. False, by default. + :param group_size: number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale). + The value -1 means no grouping. Defaults to -1. + """ + + num_bits: Optional[int] = 8 + is_nf4: Optional[bool] = False + group_size: Optional[int] = -1 + + +@dataclass +class WeightNodeParams: + """ + Information about weight node in the ov.Model that is useful for weight compression. + + :param reduction_axes: Axis or axes along which to reduce (collect) different statistics (e.g. min, max). + :param num_weights: number of elements in the weight array. + :param fq_name: name for the inserted weight compression operation. + :param weight_node: the weight node itself. + :param original_weight_dtype: type of elements in the weight array. + :param compression_config: configuration of weight compression for the weight node. + """ + + reduction_axes: Union[int, Tuple[int]] + num_weights: int + fq_name: str + weight_node: ov.Node + original_weight_dtype: TWeightType + compression_config = WeightCompressionConfig() + + +def _int8_compress( + weight: np.ndarray, reduction_axes: Union[int, Tuple[int]] +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Do unsigned int8 asymmetric weight compression - quantization to [0, 255] range. + + :param weight: Weight array to compress + :param reduction_axes: Axis or axes along which to reduce (collect) different statistics (e.g. min, max). + :return: compressed weights in unsigned int8, scale and zero point that was used for its quantization. + """ + num_bits = 8 + level_low = 0 + level_high = 2**num_bits - 1 + + min_values = np.min(weight, axis=reduction_axes, keepdims=True) + max_values = np.max(weight, axis=reduction_axes, keepdims=True) + + scale, zero_point = calculate_scale_zero_point(min_values, max_values, level_low, level_high, narrow_range=False) + + compressed_weights = np.round(weight / scale + zero_point) + compressed_weights = np.clip(compressed_weights, level_low, level_high).astype(np.uint8) + return compressed_weights, scale, zero_point + + +def _get_int8_err(weight: np.ndarray, reduction_axes: Union[int, Tuple[int]]) -> float: + """ + Calculates a quantity characterizing the difference between floating point weights and its int8 fake quantized + (compressed and decompressed) version. + + :param weight: Weight array to compress. + :param reduction_axes: Axis or axes along which to reduce (collect) different statistics (e.g. min, max). + :return: The quantity characterizing the int8 error. + """ + compressed_weights, scale, zero_point = _int8_compress(weight, reduction_axes) + + decompressed_weight = compressed_weights.astype(dtype=scale.dtype) + decompressed_weight = (compressed_weights - zero_point) * scale + + diff = (decompressed_weight - weight) ** 2 + layer_err = np.mean(diff, axis=reduction_axes) + val = np.max(layer_err) + return val + + +def _calculate_scale_per_group( + weight: np.ndarray, reduction_axes: Union[int, Tuple[int]], group_size: int +) -> Tuple[np.ndarray, np.ndarray]: + """ + Calculates scale and reshapes weights for group-wise quantization. + Having weights with shapes [c_out, c_in] and group size = 128, the shape of scale is [c_out, c_in // 128, 1], and + shape of weights is [c_out, c_in // 128, 128]. + + :param weight: Weight array to compress. + :param reduction_axes: Axis or axes along which to reduce (collect) different statistics (e.g. min, max). + :param group_size: number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale). + :return: Scale and reshaped weights. + """ + assert group_size != -1 + if isinstance(reduction_axes, tuple) and len(reduction_axes) != 1: + raise RuntimeError( + f"group-quantization is supported for a single reduction axes, but got {len(reduction_axes)}" + ) + reduction_axis = reduction_axes[0] if isinstance(reduction_axes, tuple) else reduction_axes + channel_size = weight.shape[reduction_axis] + if channel_size % group_size != 0: + raise RuntimeError(f"Channel size {channel_size} should be divisible by size of group {group_size}") + + num_groups_per_channel = channel_size // group_size + shape = list(weight.shape) # [a1, r, a2] - "r" refers to number of channels along reduction axis + shape[reduction_axis : reduction_axis + 1] = (num_groups_per_channel, group_size) + reshaped_weight = weight.reshape(shape) # [a1, r, a2] -> [a1, r//gs, gs, a2], when "gs" is group size + scale = np.max(np.abs(reshaped_weight), axis=reduction_axis + 1, keepdims=True) # [a1, r//gs, 1, a2] + return scale, reshaped_weight + + +def _get_norm_weight_and_nf4_scale( + weight: np.ndarray, reduction_axes: Tuple[int], group_size: int = -1 +) -> Tuple[np.ndarray, np.ndarray]: """ - Compress weights of Linear and Embedding layers to uint8. - The result of compression is the same as asymmetric weight quantization. + Calculates scale for nf4 quantization and normalizes weights by the scale. + Weights are reshaped in case of positive value of group size. + + :param weight: Weight array to compress. + :param reduction_axes: Axis or axes along which to reduce (collect) different statistics (e.g. min, max). + :param group_size: number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale). + The value -1 means no grouping. Defaults to -1. + :return: Normalized weights and nf4 scale. + """ + if group_size != -1: + # shape of scale : [a1, r//gs, 1, a2], scale of weight: [a1, r//gs, r, a2] + scale, weight = _calculate_scale_per_group(weight, reduction_axes, group_size) + else: + scale = np.max(np.abs(weight), axis=reduction_axes, keepdims=True) # [a1, 1, a2] + eps = np.finfo(weight.dtype).eps + # NOTE: adding machine epsilon to avoid division by zero + scale[np.abs(scale) < eps] = eps + norm_weight = weight / scale + return norm_weight, scale + + +def _get_nf4_error(weight: np.ndarray, reduction_axes: Tuple[int], group_size: int = -1) -> float: + """ + Calculates a quantity characterizing the difference between floating point weights and its nf4 fake quantized + (compressed and decompressed) version. + + :param weight: Weight array to compress. + :param reduction_axes: Axis or axes along which to reduce (collect) different statistics (e.g. min, max). + :return: The quantity characterizing the nf4 error. + """ + original_shape = weight.shape + + norm_weight, scale = _get_norm_weight_and_nf4_scale(weight, reduction_axes, group_size) + + index_of_quantile = np.searchsorted(CENTER_OF_NF4_QUANTILES, norm_weight) + nf4_rounded = NF4_QUANTILES[index_of_quantile] + + decompressed_weight = nf4_rounded * scale + decompressed_weight = decompressed_weight.reshape(original_shape) + diff = (decompressed_weight - weight) ** 2 + layer_err = np.mean(diff, axis=reduction_axes) + val = np.max(layer_err) + return val + + +def _proportion_str(num_weights_list: List[int], total_num_weights: int, total_num_params: int) -> str: + percentage = sum(num_weights_list) / max(total_num_weights, 1) * 100 + return f"{percentage:.0f}% ({len(num_weights_list)} / {total_num_params})" + + +def _get_bitwidth_distribution_str(all_weight_params: List[WeightNodeParams]) -> str: + """ + Generates a table that shows the ratio of weights quantized to different number of bits. + + :param all_weight_params: List of information about each weight node. + :return: A string containing the table. + """ + total_num_weights = sum(ws.num_weights for ws in all_weight_params) + num_internal_weights = 0 + num_params = len(all_weight_params) + num_internal_params = 0 + if num_params > 2: + num_internal_params = num_params - 2 + num_bits_vs_num_weights_map = {} + for i, data in enumerate(all_weight_params): + num_bits = data.compression_config.num_bits + n_total, n_internal = num_bits_vs_num_weights_map.get(num_bits, ([], [])) + if i not in (0, num_params - 1): + n_internal.append(data.num_weights) + num_internal_weights += data.num_weights + n_total.append(data.num_weights) + num_bits_vs_num_weights_map[num_bits] = (n_total, n_internal) + + # Table creation + header = ["Num bits (N)", "% all weight", "% internal weights"] + rows = [] + for bitwidth, (n_total, n_internal) in num_bits_vs_num_weights_map.items(): + rows.append( + [ + bitwidth, + _proportion_str(n_total, total_num_weights, num_params), + _proportion_str(n_internal, num_internal_weights, num_internal_params), + ] + ) + + table = create_table(header, rows) + pretty_string = f"Statistics of the bitwidth distribution:\n{table}" + return pretty_string + + +def _assign_mixed_precision(all_weight_params: List[WeightNodeParams], ratio: float, group_size: int) -> None: + """ + Assigns mixed quantization scheme (e.g. uniform int8 or non-uniform nf4) for weights based on some criteria. + + :param all_weight_params: List of information about each weight node. The quantization scheme is added to this info. + :param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4 + and the rest to INT8). + :param group_size: number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale). + The value -1 means no grouping. + """ + nf4_config = WeightCompressionConfig(num_bits=4, is_nf4=True, group_size=group_size) + if ratio != 1: + # NOTE: first and last layer is always in 8 bit. + errors = [] + num_internal_weights = 0 + for weight_param in all_weight_params[1:-1]: + weight = get_const_value(weight_param.weight_node) + axes = weight_param.reduction_axes + nf4_error = _get_nf4_error(weight, axes, group_size) + int8_error = _get_int8_err(weight, axes) + eps = np.finfo(weight.dtype).eps + error = nf4_error / (int8_error + eps) + errors.append(error) + num_internal_weights += weight_param.num_weights + # NOTE: index is defined in the array of all weight params by taking into account that errors were not + # calculated for first and last layers. + indexes_of_layers_in_ascending_order_of_errors = [ + i[0] + 1 for i in sorted(enumerate(errors), reverse=False, key=lambda x: x[1]) + ] + num_weights_in_4bit = 0 + for index in indexes_of_layers_in_ascending_order_of_errors: + weight_param = all_weight_params[index] + current_ratio = (num_weights_in_4bit + weight_param.num_weights) / num_internal_weights + if current_ratio >= ratio: + break + weight_param.compression_config = nf4_config + num_weights_in_4bit += weight_param.num_weights + + else: + for weight_param in all_weight_params[1:-1]: + weight_param.compression_config = nf4_config + nncf_logger.info(_get_bitwidth_distribution_str(all_weight_params)) + + +def insert_pre_compression_operations( + model: ov.Model, + mode: CompressWeightsMode, + ratio: float, + group_size: int, +) -> None: + """ + Compress weights of Linear and Embedding layers to 8-bit integer or to nf4 depending on mode, ratio and group size. :param model: The model to be transformed. - :param bits: Number of bits for quantization. + :param mode: Defines a mode for weight compression. + INT8 stands for 8-bit integer quantization of all weights. + NF4 stands for a mixed-precision weights quantization to NF4 data type. The first and last layers + are always compressed to a backup precision which is 8-bit integer by default. All others are quantized whether + to NF4 or to a backup precision depending on criteria and the given ratio. + :param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4 + and the rest to INT8). + :param group_size: number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale). + The value -1 means no grouping. """ allowed_metatypes_to_const_port = {OVEmbeddingMetatype: [0], OVMatMulMetatype: [0, 1]} - level_low = 0 - level_high = 2**bits - 1 - for node in model.get_ops(): + all_weight_params: List[WeightNodeParams] = [] + quantized_nodes_ids = set() + for node in model.get_ordered_ops(): metatype = get_node_metatype(node) if metatype not in allowed_metatypes_to_const_port: continue @@ -46,7 +357,8 @@ def insert_pre_compression_operations(model: ov.Model, bits: int = 8) -> None: weight_node = get_operation_const_op(node, const_port_id) if weight_node is None: continue - + if id(weight_node) in quantized_nodes_ids: + continue weight_output = weight_node.output(0) weight_name = weight_node.get_friendly_name() target_inputs = weight_output.get_target_inputs() @@ -54,27 +366,47 @@ def insert_pre_compression_operations(model: ov.Model, bits: int = 8) -> None: original_weight_dtype = weight_output.get_element_type().to_dtype() if original_weight_dtype not in [np.float32, np.float16, np.float64]: continue - - weight = get_const_value(weight_node) axes = _get_reduction_axes(metatype, node, const_port_id) - min_values = np.min(weight, axis=axes, keepdims=True) - max_values = np.max(weight, axis=axes, keepdims=True) + fq_name = f"{node.get_friendly_name()}/fq_weights_{const_port_id}" + weight = get_const_value(weight_node) + num_weights = weight.size + weight_params = WeightNodeParams(axes, num_weights, fq_name, weight_node, original_weight_dtype) + all_weight_params.append(weight_params) + quantized_nodes_ids.add(id(weight_node)) - scale, zero_point = calculate_scale_zero_point( - min_values, max_values, level_low, level_high, narrow_range=False - ) + if mode == CompressWeightsMode.NF4: + _assign_mixed_precision(all_weight_params, ratio, group_size) - compressed_weights = np.round(weight / scale + zero_point) - compressed_weights = np.clip(compressed_weights, level_low, level_high).astype(np.uint8) + for wp in all_weight_params: + weight_node = wp.weight_node + original_weight_dtype = wp.original_weight_dtype + weight_output = weight_node.output(0) + weight_name = weight_node.get_friendly_name() + target_inputs = weight_output.get_target_inputs() + + weight = get_const_value(weight_node) + config = wp.compression_config + + if config.is_nf4: + original_shape = weight.shape + norm_weight, scale = _get_norm_weight_and_nf4_scale(weight, wp.reduction_axes, group_size) + compressed_const = opset.constant(norm_weight, dtype=ov.Type.nf4, name=weight_name) + convert = opset.convert(compressed_const, original_weight_dtype) + mul = opset.multiply(convert, scale.astype(original_weight_dtype), name=wp.fq_name) + if config.group_size != -1: + mul = opset.reshape(mul, output_shape=original_shape, special_zero=False) + last_output = mul.output(0) + else: + compressed_weights, scale, zero_point = _int8_compress(weight, wp.reduction_axes) compressed_const = opset.constant(compressed_weights, dtype=np.uint8, name=weight_name) convert = opset.convert(compressed_const, original_weight_dtype) sub = opset.subtract(convert, zero_point.astype(original_weight_dtype)) - fq_name = f"{node.get_friendly_name()}/fq_weights_{const_port_id}" - mul = opset.multiply(sub, scale.astype(original_weight_dtype), name=fq_name) + mul = opset.multiply(sub, scale.astype(original_weight_dtype), name=wp.fq_name) + last_output = mul.output(0) - for target_input in target_inputs: - target_input.replace_source_output(mul.output(0)) + for target_input in target_inputs: + target_input.replace_source_output(last_output) def _get_reduction_axes(metatype: Type[OperatorMetatype], node: ov.Node, weight_port_id: int) -> Union[int, Tuple[int]]: diff --git a/nncf/parameters.py b/nncf/parameters.py index 28ae264834e..1dfdedc4496 100644 --- a/nncf/parameters.py +++ b/nncf/parameters.py @@ -56,3 +56,18 @@ class DropType(Enum): ABSOLUTE = "absolute" RELATIVE = "relative" + + +@api(canonical_alias="nncf.CompressWeightsMode") +class CompressWeightsMode(Enum): + """ + Defines a mode for weight compression. + + :param INT8: Stands for 8-bit integer quantization of all weights. + :param NF4: Stands for a mixed-precision weights quantization to NF4 data type. The first and last + layers are always compressed to a backup precision which is 8-bit integer by default. All others are quantized + whether to NF4 or to a backup precision depending on criteria and the given ratio. + """ + + INT8 = "int8" + NF4 = "nf4" diff --git a/nncf/quantization/fake_quantize.py b/nncf/quantization/fake_quantize.py index 4b68187e68e..f1744813c54 100644 --- a/nncf/quantization/fake_quantize.py +++ b/nncf/quantization/fake_quantize.py @@ -307,6 +307,9 @@ def calculate_scale_zero_point( """ levels = level_high - level_low if narrow_range else level_high - level_low + 1 scale = np.array((input_high - input_low) / (levels - 1)).astype(np.float32) + eps = np.finfo(scale.dtype).eps + # NOTE: adding machine epsilon to avoid division by zero + scale[np.abs(scale) < eps] = eps expected_level_low = level_low + 1 if narrow_range else level_low zero_point = expected_level_low - np.round(input_low / scale) zero_point = np.clip(zero_point.astype(np.int32), level_low, level_high) diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index c52dedd7021..c87479fa401 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -17,6 +17,7 @@ from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend from nncf.data import Dataset +from nncf.parameters import CompressWeightsMode from nncf.parameters import DropType from nncf.parameters import ModelType from nncf.parameters import TargetDevice @@ -226,22 +227,49 @@ def quantize_with_accuracy_control( @api(canonical_alias="nncf.compress_weights") -def compress_weights(model: TModel) -> TModel: +def compress_weights( + model: TModel, mode=CompressWeightsMode.INT8, ratio: Optional[float] = None, group_size: Optional[int] = None +) -> TModel: """ Compress model weights. :param model: A model to be compressed. + :param mode: Defines a mode for weight compression. + INT8 stands for 8-bit integer quantization of all weights. + NF4 stands for a mixed-precision weights quantization to NF4 data type. The first and last layers + are always compressed to a backup precision which is 8-bit integer by default. All others are quantized whether + to NF4 or to a backup precision depending on criteria and the given ratio. + :param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4 + and the rest to INT8). + :param group_size: number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale). + The value -1 means no grouping. :return: The non-trainable model with compressed weights. """ backend = get_backend(model) + if mode == CompressWeightsMode.INT8: + if ratio is None: + ratio = 1 + if group_size is None: + group_size = -1 + if ratio != 1 or group_size != -1: + raise AttributeError( + "INT8 mode assumes per-channel quantization of all layers in 8 bit. " + "Default values of `ratio` (1) and `group_size` (-1) parameters can not be overridden" + ) + if mode == CompressWeightsMode.NF4: + if ratio is None: + ratio = 1 + if group_size is None: + group_size = 128 + if backend == BackendType.TORCH: from nncf.torch.quantization.quantize_model import compress_weights_impl - return compress_weights_impl(model) + return compress_weights_impl(model, mode, ratio, group_size) if backend == BackendType.OPENVINO: from nncf.openvino.quantization.quantize_model import compress_weights_impl - return compress_weights_impl(model) + return compress_weights_impl(model, mode, ratio, group_size) raise RuntimeError(f"Unsupported type of backend: {backend}") diff --git a/nncf/torch/quantization/quantize_model.py b/nncf/torch/quantization/quantize_model.py index 3b0c82406e5..621f3bc6ebb 100644 --- a/nncf/torch/quantization/quantize_model.py +++ b/nncf/torch/quantization/quantize_model.py @@ -19,6 +19,7 @@ from nncf.config.structures import BNAdaptationInitArgs from nncf.config.structures import QuantizationRangeInitArgs from nncf.data import Dataset +from nncf.parameters import CompressWeightsMode from nncf.parameters import ModelType from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters @@ -260,10 +261,31 @@ def send_to_device(tensor): return compressed_model -def compress_weights_impl(model: torch.nn.Module) -> torch.nn.Module: +def compress_weights_impl( + model: torch.nn.Module, + mode=CompressWeightsMode.INT8, + ratio: Optional[float] = None, + group_size: Optional[int] = None, +) -> torch.nn.Module: """ - Implementation of the `compress_weights()` method for the PyTorch backend. + Implementation of the `compress_weights()` method for the PyTorch backend. Currently it supports INT8 + mode only with default ratio and group_size. + + :param model: a Torch model for compression. + :param mode: Defines a mode for weight compression. + INT8 stands for 8-bit integer quantization of all weights. + NF4 stands for a mixed-precision weights quantization to NF4 data type. The first and last layers + are always compressed to a backup precision which is 8-bit integer by default. All others are quantized whether + to NF4 or to a backup precision depending on criteria and the given ratio. + :param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4 + and the rest to INT8). + :param group_size: number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale). + The value -1 means no grouping. + :return: The non-trainable model with compressed weights and dequantization operations. """ + + if mode != CompressWeightsMode.INT8: + raise AttributeError(f"Torch backend supports only INT8 mode for weight compression, but given {mode} mode") compressed_model, _ = replace_modules_by_nncf_modules(model) insert_pre_compression_operations(model) diff --git a/nncf/torch/quantization/weights_compression.py b/nncf/torch/quantization/weights_compression.py index 9fc725fb235..2cf8947fd89 100644 --- a/nncf/torch/quantization/weights_compression.py +++ b/nncf/torch/quantization/weights_compression.py @@ -87,8 +87,7 @@ def insert_pre_compression_operations(module: nn.Module, bits: int = 8) -> Optio Inserts weights compression with dequantization for Linear and Embedding layers. :param module: The module to insert the weights compression. - :param bits: number of bits for compression. Note: compressed weights type is - uint8 with one element per 8 bit. + :param bits: number of bits for compression. Note: type of compressed weights is 8-bit integer. :return: The non-trainable module with inserted operations. """ user_types = list(NNCF_WRAPPED_USER_MODULES_DICT.values()) diff --git a/tests/onnx/quantization/test_calculation_quantizer_params.py b/tests/onnx/quantization/test_calculation_quantizer_params.py index 1e02b4eb25a..8beed22ff01 100644 --- a/tests/onnx/quantization/test_calculation_quantizer_params.py +++ b/tests/onnx/quantization/test_calculation_quantizer_params.py @@ -17,6 +17,8 @@ from nncf.quantization.fake_quantize import calculate_scale_zero_point from tests.post_training.test_templates.test_calculate_quantizer_parameters import TemplateTestFQParams +EPS = np.finfo(np.float32).eps + @pytest.mark.parametrize( ("inp_low, inp_high, level_low, level_high, narrow_range, ref_scale, ref_zero_point"), @@ -30,6 +32,9 @@ (-10, 10, -512, 511, False, 0.01955034, 0), (-10, 10, -128, 127, True, 0.07874016, 0), (0, 25, 0, 15, False, 1.6666666, 0), + (1, 1, -128, 127, True, EPS, -128), + (0, 0, -128, 127, True, EPS, -127), + (np.array([0, 1]), np.array([0, 1]), -128, 127, True, np.array([EPS, EPS]), np.array([-127, -128])), ), ) def test_calculate_scale_zero_point(inp_low, inp_high, level_low, level_high, narrow_range, ref_scale, ref_zero_point): diff --git a/tests/openvino/native/data/reference_scales/IntegerModel_compressed_weights_nf4.json b/tests/openvino/native/data/reference_scales/IntegerModel_compressed_weights_nf4.json new file mode 100644 index 00000000000..671776d6674 --- /dev/null +++ b/tests/openvino/native/data/reference_scales/IntegerModel_compressed_weights_nf4.json @@ -0,0 +1,54 @@ +{ + "matmul_1_data": { + "scale": [ + [ + [ + 0.42268723249435425 + ], + [ + 0.6706244349479675 + ] + ], + [ + [ + 0.997209906578064 + ], + [ + 0.9808353185653687 + ] + ], + [ + [ + 0.6884467601776123 + ], + [ + 0.721488356590271 + ] + ], + [ + [ + 0.9340435266494751 + ], + [ + 0.5715298056602478 + ] + ], + [ + [ + 0.5943000316619873 + ], + [ + 0.8902743458747864 + ] + ], + [ + [ + 0.8326441645622253 + ], + [ + 0.876484215259552 + ] + ] + ] + } +} \ No newline at end of file diff --git a/tests/openvino/native/models.py b/tests/openvino/native/models.py index 6414fc9e7a4..ca563218df3 100644 --- a/tests/openvino/native/models.py +++ b/tests/openvino/native/models.py @@ -737,3 +737,52 @@ def _create_ov_model(self): result = opset.result(if_node, name="Result") model = ov.Model([result], [input_1, input_2, input_3]) return model + + +class SequentialMatmulModel(OVReferenceModel): + """ + Model for mixed precision weight compression. + Matrices with outliers are defined in such a way that there is a different nf4, int8, relative error. + rel_error = nf4_error / int8_error + The maximum relative error is achieved with not maximum outlier 10000, because nf4 better copes with outliers. + + [[ 0. 1. 2.] + [ 3. 4. 5.] + [ 6. 7. 1000.]] + nf4 error = 28 + int8 error = 13 + rel_error=2 + + [[ 0. 1. 2.] + [ 3. 4. 5.] + [ 6. 7. 10000.]] + nf4 error = 28 + int8 error = 40 + rel_error= 0.7 + + [[ 0. 1. 2.] + [ 3. 4. 5.] + [ 6. 7. 10.]] + nf4 error = 0.06 + int8 error = 16 + rel_error= 0.03 + """ + + def _create_ov_model(self): + input_node = opset.parameter([3, 3], name="Input_1") + main_values = [100, 1000, 10000, 10, 1] + + last_node = input_node + for i, main_value in enumerate(main_values): + weights_data = np.arange(0, 9).reshape(3, 3) + weights_data[-1, -1] = main_value + current_weights = opset.constant(weights_data, dtype=np.float32, name=f"weights_{i}") + current_node = opset.matmul( + last_node, current_weights, transpose_a=False, transpose_b=True, name=f"MatMul_{i}" + ) + last_node = current_node + + result = opset.result(last_node, name="Result") + result.get_output_tensor(0).set_names(set(["Result"])) + model = ov.Model([result], [input_node]) + return model diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 6f23552d81b..52d1d32229a 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -9,13 +9,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial +from typing import List + import numpy as np import openvino.runtime as ov import pytest +from attr import dataclass +from nncf import CompressWeightsMode from nncf.openvino.graph.node_utils import get_const_value +from nncf.openvino.quantization.weights_compression import _calculate_scale_per_group +from nncf.openvino.quantization.weights_compression import _get_int8_err +from nncf.openvino.quantization.weights_compression import _get_nf4_error from nncf.quantization import compress_weights from tests.openvino.native.models import IntegerModel +from tests.openvino.native.models import SequentialMatmulModel from tests.openvino.native.models import WeightsModel from tests.openvino.native.quantization.test_fq_params_calculation import REFERENCE_SCALES_DIR from tests.shared.helpers import compare_stats @@ -28,8 +37,9 @@ @pytest.mark.parametrize("model_creator_func", TEST_MODELS) -def test_compress_weights(model_creator_func): +def test_compress_weights_int8(model_creator_func, tmp_path): ref_compressed_weights = TEST_MODELS[model_creator_func] + name = model_creator_func().__class__.__name__ model = model_creator_func().ov_model compressed_model = compress_weights(model) @@ -38,20 +48,41 @@ def test_compress_weights(model_creator_func): if op.get_type_name() == "Constant" and op.get_friendly_name() in ref_compressed_weights: assert op.get_element_type() == ov.Type(np.uint8) n_compressed_weights += 1 + ov.serialize(compressed_model, tmp_path / (name + ".xml")) + assert n_compressed_weights == len(ref_compressed_weights) + + +@pytest.mark.parametrize("model_creator_func", TEST_MODELS) +def test_compress_weights_nf4(model_creator_func): + if issubclass(IntegerModel, model_creator_func): + pytest.xfail("Waiting for the merge NF4 support in OV - PR 19900") + ref_compressed_weights = TEST_MODELS[model_creator_func] + model = model_creator_func().ov_model + compressed_model = compress_weights(model, mode=CompressWeightsMode.NF4, ratio=1, group_size=1) + + n_compressed_weights = 0 + for op in compressed_model.get_ordered_ops(): + if op.get_type_name() == "Constant" and op.get_friendly_name() in ref_compressed_weights: + if n_compressed_weights in (0, len(ref_compressed_weights) - 1): + assert op.get_element_type() == ov.Type(np.uint8) + else: + assert op.get_element_type() == ov.Type.nf4 + + n_compressed_weights += 1 assert n_compressed_weights == len(ref_compressed_weights) +def get_next_node(node): + target_inputs = node.output(0).get_target_inputs() + assert len(target_inputs) == 1 + next_node = next(iter(target_inputs)).get_node() + return next_node + + def test_compare_compressed_weights(): model = IntegerModel().ov_model compressed_model = compress_weights(model) - - def get_next_node(node): - target_inputs = node.output(0).get_target_inputs() - assert len(target_inputs) == 1 - next_node = next(iter(target_inputs)).get_node() - return next_node - nodes = {} ref_compressed_weights = TEST_MODELS[IntegerModel] for op in compressed_model.get_ops(): @@ -86,3 +117,238 @@ def get_next_node(node): ref_nodes = load_json(ref_stats_path) params = ["compressed_weight", "zero_point", "scale"] compare_stats(ref_nodes, nodes, params) + + +# TODO(nlyalyus) Waiting for the merge NF4 support in OV - PR 19900 +@pytest.mark.xfail +def test_compare_compressed_weights_nf4(): + model = IntegerModel().ov_model + compressed_model = compress_weights(model, mode=CompressWeightsMode.NF4, ratio=1, group_size=3) + + nodes = {} + ref_nf4_weight = TEST_MODELS[IntegerModel][1] + for op in compressed_model.get_ordered_ops(): + if op.get_type_name() == "Constant" and op.get_friendly_name() in ref_nf4_weight: + assert op.get_element_type() == ov.Type.nf4 + # TODO: should be fixed in python api + with pytest.raises(RuntimeError): + get_const_value(op) + + convert_node = get_next_node(op) + assert convert_node.get_type_name() == "Convert" + + mul_node = get_next_node(convert_node) + assert mul_node.get_type_name() == "Multiply" + scale_node = mul_node.input_value(1).get_node() + scale = get_const_value(scale_node) + + reshape_node = get_next_node(mul_node) + assert reshape_node.get_type_name() == "Reshape" + + nodes[op.get_friendly_name()] = { + # "compressed_weight": compressed_weight, + "scale": scale, + } + + ref_stats_path = REFERENCE_SCALES_DIR / "IntegerModel_compressed_weights_nf4.json" + + # from tests.shared.helpers import dump_to_json + # dump_to_json(ref_stats_path, nodes) + + ref_nodes = load_json(ref_stats_path) + params = [ + # "compressed_weight", + "scale" + ] + compare_stats(ref_nodes, nodes, params) + + +@pytest.mark.parametrize("group_size", (1, 3)) +@pytest.mark.parametrize( + ("ratio", "ref_nf4_nodes"), + ( + (1, ["weights_1", "weights_2", "weights_3"]), + (0.8, ["weights_2", "weights_3"]), + (0.4, ["weights_3"]), + (0.3, []), + ), +) +def test_mixed_precision(ratio, group_size, ref_nf4_nodes): + if ratio > 0.3: + pytest.xfail("Waiting for the merge NF4 support in OV - PR 19900") + model = SequentialMatmulModel().ov_model + compressed_model = compress_weights(model, mode=CompressWeightsMode.NF4, ratio=ratio, group_size=group_size) + for op in compressed_model.get_ordered_ops(): + if op.get_type_name() == "Constant" and op.get_friendly_name() in ref_nf4_nodes: + assert op.get_element_type() == ov.Type.nf4 + + +@dataclass +class BaseDesc: + weight: List[float] + ref_error: int = 0 + axis = (1,) + name: str = "" + atol: float = None + + def get_error_fn(self) -> float: + raise NotImplementedError + + def __str__(self): + prefix = "exact_match_" if self.ref_error == 0 else "" + name = self.name.replace(" ", "_") if self.name else self.__class__.__name__ + return prefix + name + + +@dataclass +class Int8Desc(BaseDesc): + def get_error_fn(self) -> float: + return partial(_get_int8_err, reduction_axes=self.axis) + + def __str__(self): + base_str = super().__str__() + return "int8_" + base_str + + +@dataclass +class NF4Desc(BaseDesc): + group_size: int = -1 + + def get_error_fn(self) -> float: + return partial(_get_nf4_error, reduction_axes=self.axis, group_size=self.group_size) + + def __str__(self): + base_str = super().__str__() + return "nf4_" + base_str + + +SCALE_1 = 1.2 +SCALE_2 = 3.4 +SCALE_3 = 5.6 +SCALE_4 = 7.8 +LINSPACE = np.arange(0, 256, 17) +NF4_LOOKUP = np.array( + [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ] +) + +TWO_ROWS_NF4 = np.vstack((NF4_LOOKUP * SCALE_1, NF4_LOOKUP * SCALE_2)) +TWO_OTHER_ROWS_NF4 = np.vstack((NF4_LOOKUP * SCALE_3, NF4_LOOKUP * SCALE_4)) +TWO_ROWS_LINSPACE = np.vstack((LINSPACE * SCALE_1, LINSPACE * SCALE_2)) +TWO_GROUPS_IN_TWO_ROWS_NF4 = np.hstack((TWO_ROWS_NF4, TWO_OTHER_ROWS_NF4)) +TWO_GROUPS_IN_TWO_ROWS_NO_1_NF4 = np.hstack((TWO_ROWS_NF4[:, 1:-1], TWO_OTHER_ROWS_NF4[:, 1:-1])) + +LIST_DESCS = [ + # zero error + Int8Desc(name="2 rows of 0-255 linspace", weight=TWO_ROWS_LINSPACE), + NF4Desc(name="2 rows of exact quantiles", weight=TWO_ROWS_NF4), + NF4Desc(name="two groups in two rows", weight=TWO_GROUPS_IN_TWO_ROWS_NF4, group_size=16), + # non-zero error + Int8Desc(name="2 rows 1-254 linspace", weight=TWO_ROWS_LINSPACE[:, 1:-1], ref_error=239, atol=1), + Int8Desc(name="2 columns of 0-255 linspace", weight=np.transpose(TWO_ROWS_LINSPACE), ref_error=46818, atol=1), + NF4Desc(name="2 rows of exact quantiles without -1 and 1", weight=TWO_ROWS_NF4[:, 1:-1], ref_error=5e-3, atol=1e-3), + NF4Desc(name="2 columns of exact quantiles", weight=np.transpose(TWO_ROWS_NF4), ref_error=1e-2, atol=1e-2), + NF4Desc( + name="two groups in two rows without -1 and 1", + weight=TWO_GROUPS_IN_TWO_ROWS_NO_1_NF4, + group_size=14, + ref_error=2e-2, + atol=1e-2, + ), +] + + +@pytest.mark.parametrize("desc", LIST_DESCS, ids=map(str, LIST_DESCS)) +def test_quantization_error_calculation(desc: BaseDesc): + weight = desc.weight + actual_error = desc.get_error_fn()(weight) + ref_error = desc.ref_error + atol = desc.atol if desc.atol is not None else 1e-8 + assert np.allclose(actual_error, ref_error, atol=atol) + + +WEIGHTS_2x4 = np.array([[-4, -3, -2, -1], [0, 11, 2, 3]]) # [2, 4] +WEIGHTS_abs_max = np.array([4, 2, 11, 3]) # [4] + + +@dataclass +class CalculateScaleDesc: + weight: np.array + ref_scale: np.array + axis: int + group_size: int + + +CALCULATE_SCALE_DESCS = [ + CalculateScaleDesc(weight=WEIGHTS_2x4, ref_scale=WEIGHTS_abs_max.reshape([2, 2, 1]), axis=1, group_size=2), + CalculateScaleDesc(weight=WEIGHTS_2x4, ref_scale=np.abs(WEIGHTS_2x4).reshape([2, 1, 4]), axis=0, group_size=1), + CalculateScaleDesc( + weight=WEIGHTS_2x4.reshape([1, 2, 4, 1]), + ref_scale=WEIGHTS_abs_max.reshape([1, 2, 2, 1, 1]), + axis=2, + group_size=2, + ), + CalculateScaleDesc( + weight=WEIGHTS_2x4.reshape([1, 2, 4, 1]), + ref_scale=np.abs(WEIGHTS_2x4.reshape([1, 2, 4, 1])), + axis=0, + group_size=1, + ), + CalculateScaleDesc( + weight=WEIGHTS_2x4.reshape([2, 2, 2]), ref_scale=WEIGHTS_abs_max.reshape([2, 2, 1, 1]), axis=2, group_size=2 + ), + CalculateScaleDesc( + weight=WEIGHTS_2x4.reshape([2, 2, 2]), + ref_scale=np.array([4, 3, 2, 11]).reshape([2, 1, 1, 2]), + axis=1, + group_size=2, + ), + CalculateScaleDesc( + weight=WEIGHTS_2x4.reshape([2, 2, 2]), + ref_scale=np.array([4, 11, 2, 3]).reshape([1, 1, 2, 2]), + axis=0, + group_size=2, + ), +] + + +@pytest.mark.parametrize("desc", CALCULATE_SCALE_DESCS) +def test_calculate_scale_per_group(desc: CalculateScaleDesc): + act_scale, _ = _calculate_scale_per_group(desc.weight, reduction_axes=desc.axis, group_size=desc.group_size) + assert np.allclose(act_scale, desc.ref_scale) + + +def test_raise_error_for_many_axes(): + with pytest.raises(RuntimeError): + _calculate_scale_per_group(WEIGHTS_2x4, reduction_axes=(0, 1), group_size=1) + + +def test_raise_error_with_incorrect_group_size(): + with pytest.raises(RuntimeError): + _calculate_scale_per_group(WEIGHTS_2x4, reduction_axes=(0,), group_size=3) + + +def test_raise_error_with_int8_and_non_default_ratio(mocker): + with pytest.raises(RuntimeError): + compress_weights(mocker.Mock(), mode=CompressWeightsMode.INT8, ratio=0.5) + + +def test_raise_error_with_int8_and_non_default_group_size(mocker): + with pytest.raises(RuntimeError): + compress_weights(mocker.Mock(), mode=CompressWeightsMode.INT8, group_size=64) diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py index e71394e9284..d831aea7bb1 100644 --- a/tests/torch/ptq/test_weights_compression.py +++ b/tests/torch/ptq/test_weights_compression.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch +from nncf import CompressWeightsMode from nncf.quantization import compress_weights @@ -70,3 +72,18 @@ def test_compress_shared_weights(): for key, val in compressed_model.wte.pre_ops.items(): assert compressed_model.lm_head.get_pre_op(key) is val + + +def test_raise_error_with_int8_and_non_default_ratio(mocker): + with pytest.raises(RuntimeError): + compress_weights(mocker.Mock(), mode=CompressWeightsMode.INT8, ratio=0.5) + + +def test_raise_error_with_int8_and_non_default_group_size(mocker): + with pytest.raises(RuntimeError): + compress_weights(mocker.Mock(), mode=CompressWeightsMode.INT8, group_size=64) + + +def test_raise_error_with_nf4(mocker): + with pytest.raises(RuntimeError): + compress_weights(mocker.Mock(), mode=CompressWeightsMode.NF4) From 849d2289fa3ef815d2e739b823be664bc373e7ba Mon Sep 17 00:00:00 2001 From: Aleksei Kashapov Date: Thu, 5 Oct 2023 17:45:46 +0200 Subject: [PATCH 02/10] [ONNX, Torch, OpenVINO] Add MVN, GELU, LINEAR_SHIFT_SCALE patterns (#2175) ### Changes Add three new patterns. With these patterns, performance of the model from the ticket is aligned with quantized OpenVINO IR. Left is before; right is after ![image](https://github.com/openvinotoolkit/nncf/assets/32935044/f67668ad-2fc9-49b9-b3ca-f0743239ee0c) ![image](https://github.com/openvinotoolkit/nncf/assets/32935044/34c29b19-848c-42ed-8125-dfee11363fc1) ![image](https://github.com/openvinotoolkit/nncf/assets/32935044/191e3273-9abf-4292-96bd-08a436928187) ### Reason for changes To get the best performance ### Related tickets 117169 ### Tests N/A --- nncf/common/graph/patterns/patterns.py | 3 + nncf/onnx/graph/metatypes/onnx_metatypes.py | 6 ++ nncf/onnx/hardware/fused_patterns.py | 99 +++++++++++++++++++ nncf/openvino/hardware/fused_patterns.py | 8 ++ nncf/torch/hardware/fused_patterns.py | 8 ++ tests/onnx/test_pattern_manager.py | 1 - tests/openvino/native/test_pattern_manager.py | 2 + tests/torch/test_pattern_manager.py | 2 + 8 files changed, 128 insertions(+), 1 deletion(-) diff --git a/nncf/common/graph/patterns/patterns.py b/nncf/common/graph/patterns/patterns.py index 4dce2ac8748..0c97a9d1eec 100644 --- a/nncf/common/graph/patterns/patterns.py +++ b/nncf/common/graph/patterns/patterns.py @@ -287,6 +287,8 @@ class HWFusedPatternNames(Enum): # ATOMIC OPERATIONS L2_NORM = PatternDesc("l2_norm") + MVN = PatternDesc("mvn") + GELU = PatternDesc("gelu") # BLOCK PATTERNS ADD_SCALE_SHIFT_OUTPUT = PatternDesc("add_scale_shift_output") @@ -338,6 +340,7 @@ class HWFusedPatternNames(Enum): LINEAR_ACTIVATIONS_BATCH_NORM = PatternDesc("linear_activations_batch_norm") LINEAR_ACTIVATIONS_SCALE_SHIFT = PatternDesc("linear_activations_scale_shift") LINEAR_ARITHMETIC = PatternDesc("linear_arithmetic") + LINEAR_SHIFT_SCALE = PatternDesc("linear_shift_scale") LINEAR_ARITHMETIC_ACTIVATIONS = PatternDesc("linear_arithmetic_activations") # Found in PicoDet models LINEAR_ARITHMETIC_ACTIVATIONS_ARITHMETIC = PatternDesc("linear_arithmetic_activations_arithmetic") diff --git a/nncf/onnx/graph/metatypes/onnx_metatypes.py b/nncf/onnx/graph/metatypes/onnx_metatypes.py index b998b889dda..35d532caeac 100644 --- a/nncf/onnx/graph/metatypes/onnx_metatypes.py +++ b/nncf/onnx/graph/metatypes/onnx_metatypes.py @@ -617,6 +617,12 @@ class ONNXDeformableConvolutionMetatype(ONNXOpMetatype): op_names = ["DeformConv"] +@ONNX_OPERATION_METATYPES.register() +class ONNXErfMetatype(ONNXOpMetatype): + name = "ErfOp" + op_names = ["Erf"] + + def get_operator_metatypes() -> List[Type[OperatorMetatype]]: """ Returns a list of the operator metatypes. diff --git a/nncf/onnx/hardware/fused_patterns.py b/nncf/onnx/hardware/fused_patterns.py index c496a94bee3..7e2a3efb53d 100644 --- a/nncf/onnx/hardware/fused_patterns.py +++ b/nncf/onnx/hardware/fused_patterns.py @@ -24,6 +24,93 @@ # BLOCK PATTERNS +@ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.MVN) +def create_mvn() -> GraphPattern: + pattern = GraphPattern() + pattern_input_node = pattern.add_node( + **{GraphPattern.LABEL_ATTR: "*INPUT_NODE*", GraphPattern.METATYPE_ATTR: GraphPattern.NON_PATTERN_NODE_TYPE} + ) + reduce_mean_node_1 = pattern.add_node( + **{GraphPattern.LABEL_ATTR: "REDUCE_MEAN_1", GraphPattern.METATYPE_ATTR: om.ONNXReduceMeanMetatype} + ) + sub_node = pattern.add_node( + **{ + GraphPattern.LABEL_ATTR: "SUBTRACT", + GraphPattern.METATYPE_ATTR: [om.ONNXSubMetatype], + } + ) + pow_node = pattern.add_node( + **{ + GraphPattern.LABEL_ATTR: "POW", + GraphPattern.METATYPE_ATTR: [om.ONNXPowMetatype], + } + ) + reduce_mean_node_2 = pattern.add_node( + **{GraphPattern.LABEL_ATTR: "REDUCE_MEAN_2", GraphPattern.METATYPE_ATTR: om.ONNXReduceMeanMetatype} + ) + add_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "ADD", GraphPattern.METATYPE_ATTR: om.ONNXAddLayerMetatype}) + sqrt_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SQRT", GraphPattern.METATYPE_ATTR: om.ONNXSqrtMetatype}) + div_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "DIV", GraphPattern.METATYPE_ATTR: om.ONNXDivLayerMetatype}) + + pattern.add_edge(pattern_input_node, reduce_mean_node_1) + pattern.add_edge(reduce_mean_node_1, sub_node) + pattern.add_edge(pattern_input_node, sub_node) + pattern.add_edge(sub_node, pow_node) + pattern.add_edge(pow_node, reduce_mean_node_2) + pattern.add_edge(reduce_mean_node_2, add_node) + pattern.add_edge(add_node, sqrt_node) + pattern.add_edge(sqrt_node, div_node) + pattern.add_edge(sub_node, div_node) + return pattern + + +@ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.MVN_SCALE_SHIFT) +def create_mvn_scale_shift() -> GraphPattern: + mvn = create_mvn() + scale_shift = create_scale_shift() + + mvn.join_patterns(scale_shift) + return mvn + + +@ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.GELU) +def create_gelu() -> GraphPattern: + pattern = GraphPattern() + pattern_input_node = pattern.add_node( + **{GraphPattern.LABEL_ATTR: "*INPUT_NODE*", GraphPattern.METATYPE_ATTR: GraphPattern.NON_PATTERN_NODE_TYPE} + ) + div_node = pattern.add_node( + **{ + GraphPattern.LABEL_ATTR: "DIV", + GraphPattern.METATYPE_ATTR: [om.ONNXDivLayerMetatype, om.ONNXMulLayerMetatype], + } + ) + erf_node = pattern.add_node( + **{ + GraphPattern.LABEL_ATTR: "ERF", + GraphPattern.METATYPE_ATTR: om.ONNXErfMetatype, + } + ) + add_node = pattern.add_node( + **{ + GraphPattern.LABEL_ATTR: "ADD", + GraphPattern.METATYPE_ATTR: [om.ONNXAddLayerMetatype, om.ONNXSubMetatype], + } + ) + mul_node = pattern.add_node( + **{ + GraphPattern.LABEL_ATTR: "MUL", + GraphPattern.METATYPE_ATTR: [om.ONNXMulLayerMetatype, om.ONNXDivLayerMetatype], + } + ) + pattern.add_edge(pattern_input_node, div_node) + pattern.add_edge(div_node, erf_node) + pattern.add_edge(erf_node, add_node) + pattern.add_edge(add_node, mul_node) + pattern.add_edge(pattern_input_node, mul_node) + return pattern + + @ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SCALE_SHIFT) def create_scale_shift() -> GraphPattern: pattern = GraphPattern() @@ -372,6 +459,15 @@ def create_linear_arithmetic_activations() -> GraphPattern: return linear +@ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.LINEAR_SHIFT_SCALE) +def create_linear_shift_scale() -> GraphPattern: + linear = linear_operations() + shift_scale = create_shift_scale() + + linear.join_patterns(shift_scale) + return linear + + @ONNX_HW_FUSED_PATTERNS.register(HWFusedPatternNames.LINEAR_ARITHMETIC_ACTIVATIONS_ARITHMETIC) def create_linear_arithmetic_activations_arithmetic() -> GraphPattern: linear_arithmetic_activations = create_linear_arithmetic_activations() @@ -423,6 +519,9 @@ def atomic_activations_operations() -> GraphPattern: hswish_without_denominator = create_hswish_without_denominator() pattern.add_pattern_alternative(hswish_without_denominator) + + gelu = create_gelu() + pattern.add_pattern_alternative(gelu) return pattern diff --git a/nncf/openvino/hardware/fused_patterns.py b/nncf/openvino/hardware/fused_patterns.py index dcf29333d28..6e1596e47e0 100644 --- a/nncf/openvino/hardware/fused_patterns.py +++ b/nncf/openvino/hardware/fused_patterns.py @@ -575,6 +575,14 @@ def create_linear_arithmetic_activations() -> GraphPattern: return linear +@OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.LINEAR_SHIFT_SCALE) +def create_linear_shift_scale() -> GraphPattern: + linear = linear_operations() + shift_scale = create_shift_scale() + linear.join_patterns(shift_scale) + return linear + + @OPENVINO_HW_FUSED_PATTERNS.register(HWFusedPatternNames.LINEAR_ARITHMETIC_ACTIVATIONS_ARITHMETIC) def create_linear_arithmetic_activations_arithmetic() -> GraphPattern: linear_arithmetic_activations = create_linear_arithmetic_activations() diff --git a/nncf/torch/hardware/fused_patterns.py b/nncf/torch/hardware/fused_patterns.py index c30e5b0d4c0..b8db07decd4 100644 --- a/nncf/torch/hardware/fused_patterns.py +++ b/nncf/torch/hardware/fused_patterns.py @@ -76,6 +76,14 @@ def create_linear_arithmetic_operations() -> GraphPattern: return linear +@PT_HW_FUSED_PATTERNS.register(HWFusedPatternNames.LINEAR_SHIFT_SCALE) +def create_linear_shift_scale() -> GraphPattern: + linear = linear_operations() + shift_scale = create_shift_scale() + linear.join_patterns(shift_scale) + return linear + + @PT_HW_FUSED_PATTERNS.register(HWFusedPatternNames.BATCH_NORM_ACTIVATIONS) def create_batch_norm_activations_operations() -> GraphPattern: batch_norm = batch_norm_operations() diff --git a/tests/onnx/test_pattern_manager.py b/tests/onnx/test_pattern_manager.py index 379b0a13df0..e06ae77e46d 100644 --- a/tests/onnx/test_pattern_manager.py +++ b/tests/onnx/test_pattern_manager.py @@ -21,7 +21,6 @@ HWFusedPatternNames.LINEAR_CONST_MULTIPLY: "Not relevant for ONNX.", HWFusedPatternNames.ADD_SCALE_SHIFT_OUTPUT: "Not relevant for ONNX.", HWFusedPatternNames.BATCH_INDEX: "Not relevant for ONNX.", - HWFusedPatternNames.MVN_SCALE_SHIFT: "Not relevant for ONNX.", HWFusedPatternNames.NORMALIZE_L2_MULTIPLY: "Not relevant for ONNX.", HWFusedPatternNames.LINEAR_WITH_BIAS: "Linear layers contains biases in ONNX.", HWFusedPatternNames.SE_BLOCK: "Not relevant for ONNX.", diff --git a/tests/openvino/native/test_pattern_manager.py b/tests/openvino/native/test_pattern_manager.py index 6d03c34176c..f4bbfe7b4c8 100644 --- a/tests/openvino/native/test_pattern_manager.py +++ b/tests/openvino/native/test_pattern_manager.py @@ -32,6 +32,8 @@ HWFusedPatternNames.LINEAR_BATCH_NORM_ACTIVATIONS: "Not relevant for OpenVINO.", HWFusedPatternNames.LINEAR_BATCH_NORM_SCALE_SHIFT_ACTIVATIONS: "Not relevant for OpenVINO.", HWFusedPatternNames.LINEAR_SCALE_SHIFT_ACTIVATIONS: "Not relevant for OpenVINO.", + HWFusedPatternNames.MVN: "Not relevant for OpenVINO.", + HWFusedPatternNames.GELU: "Not relevant for OpenVINO.", } IGNORING_IGNORED_PATTERN_REASONS = {} diff --git a/tests/torch/test_pattern_manager.py b/tests/torch/test_pattern_manager.py index 167a673b2fa..0be8fba0e9e 100644 --- a/tests/torch/test_pattern_manager.py +++ b/tests/torch/test_pattern_manager.py @@ -65,6 +65,8 @@ HWFusedPatternNames.LINEAR_SQUEEZE_ACTIVATIONS: "Not relevant for Torch.", HWFusedPatternNames.LINEAR_SQUEEZE_ARITHMETIC_ACTIVATIONS: "Not relevant for Torch.", HWFusedPatternNames.LINEAR_ACTIVATIONS_UNSQUEEZE_BN_SQUEEZE: "Not relevant for Torch.", + HWFusedPatternNames.MVN: "Not relevant for Torch.", + HWFusedPatternNames.GELU: "Not relevant for Torch.", } IGNORING_IGNORED_PATTERN_REASONS = { From 2d5ea2c72a6d5e9961caecbd5929635b2e459002 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Thu, 5 Oct 2023 19:35:06 +0300 Subject: [PATCH 03/10] Fix: missed input_shape to `convert_model` for PT backend. (#2178) ### Changes Add missed `input_shape` to `convert_model` for PT backend. ### Reason for changes IR of quantized models have dynamic shape. --- tests/post_training/pipelines/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/post_training/pipelines/base.py b/tests/post_training/pipelines/base.py index 2837746165b..9984a4df18a 100644 --- a/tests/post_training/pipelines/base.py +++ b/tests/post_training/pipelines/base.py @@ -127,6 +127,7 @@ def __init__( self.model_hf = None self.calibration_dataset = None self.dummy_tensor = None + self.input_size = None self.run_info = RunInfo(model=reported_name, backend=self.backend) @@ -212,7 +213,7 @@ def save_quantized_model(self) -> None: if self.backend == BackendType.OPTIMUM: self.path_quantized_ir = self.output_model_dir / "openvino_model.xml" elif self.backend in PT_BACKENDS: - ov_model = convert_model(self.quantized_model, example_input=self.dummy_tensor) + ov_model = convert_model(self.quantized_model, example_input=self.dummy_tensor, input_shape=self.input_size) self.path_quantized_ir = self.output_model_dir / "model.xml" ov.serialize(ov_model, self.path_quantized_ir) elif self.backend == BackendType.ONNX: From 2c417f50f6937735f07107a68c53fa1772328cc2 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Thu, 5 Oct 2023 19:35:37 +0300 Subject: [PATCH 04/10] Add table of corresponding versions for backends (#2161) ### Changes Add table of corresponding versions for backends. ### Related tickets 120873 --- README.md | 2 +- docs/Installation.md | 13 ++++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3eda0ad7fa0..5962454a9a7 100644 --- a/README.md +++ b/README.md @@ -375,7 +375,7 @@ You may also use one of the Dockerfiles in the [docker](./docker) directory to b - ONNX\* ~=1.13.1 - OpenVINO\* >=2022.3.0 -This repository is tested on Python* 3.8.10, PyTorch* 2.0.1 (NVidia CUDA\* Toolkit 11.7) and TensorFlow* 2.12.1 (NVidia CUDA\* Toolkit 11.8). +This repository is tested on Python* 3.8.10, PyTorch* 2.0.1 (NVidia CUDA\* Toolkit 11.8) and TensorFlow* 2.12.1 (NVidia CUDA\* Toolkit 11.8). ## NNCF Compressed Model Zoo diff --git a/docs/Installation.md b/docs/Installation.md index 53dc91e80ca..063e2e6f7b3 100644 --- a/docs/Installation.md +++ b/docs/Installation.md @@ -34,7 +34,7 @@ Use the same `pip install` syntax as above to install NNCF along with the backen pip install .[] ``` -List of supported backends: `torch`, `tf`, `onnx` and `openvino`. +List of supported backends: `torch`, `tf`, `onnx` and `openvino`. For development purposes install extra packages by @@ -61,3 +61,14 @@ Note that in order for this to work for pip versions >= 21.3, your Git version m ## As a Docker image Use one of the Dockerfiles in the [docker](../docker) directory to build an image with an environment already set up and ready for running NNCF [sample scripts](../README.md#model-compression-samples). + +## Corresponding versions + +The following table lists the recommended corresponding versions of backend packages +as well as the supported versions of Python: + +| NNCF | OpenVINO | PyTorch | ONNX | TensorFlow | Python | +| ------- | ---------- | -------- | -------- | ---------- | ------- | +| `2.6.0` | `2023.1.0` | `2.0.1` | `1.13.1` | `2.12.0` | `3.8` | +| `2.5.0` | `2023.0.0` | `1.13.1` | `1.13.1` | `2.11.1` | `3.8` | +| `2.4.0` | `2022.1.0` | `1.12.1` | `1.12.0` | `2.8.2` | `3.8` | From 403597cdc92683b1323646036967a6619a50fb87 Mon Sep 17 00:00:00 2001 From: Andrey Churkin Date: Fri, 6 Oct 2023 06:32:28 +0100 Subject: [PATCH 05/10] Update calibrate tool (#2174) ### Changes - Add support for models that don't have per-sample metrics. ### Reason for changes Some models don't have per-sample metrics. An error occurred with these models. ### Related tickets N/A ### Tests --- tests/openvino/tools/calibrate.py | 33 ++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/tests/openvino/tools/calibrate.py b/tests/openvino/tools/calibrate.py index 6ca576b0052..59e793ac1e3 100644 --- a/tests/openvino/tools/calibrate.py +++ b/tests/openvino/tools/calibrate.py @@ -125,15 +125,34 @@ class ACValidationFunction: "ndcg": "sigmoid_recom_loss", } - def __init__(self, model_evaluator: ModelEvaluator, metric_name: str, requests_number: Optional[int] = None): + SPECIAL_METRICS = [ + "cmc", + "reid_map", + "pairwise_accuracy_subsets", + "pairwise_accuracy", + "normalized_embedding_accuracy", + "face_recognition_tafa_pair_metric", + "localization_recall", + "coco_orig_keypoints_precision", + "coco_orig_segm_precision", + "coco_orig_keypoints_precision", + "spearman_correlation_coef", + "pearson_correlation_coef", + ] + + def __init__( + self, model_evaluator: ModelEvaluator, metric_name: str, metric_type: str, requests_number: Optional[int] = None + ): """ :param model_evaluator: Model Evaluator. :param metric_name: Name of a metric. + :param metric_type: Type of a metric. :param requests_number: A number of infer requests. If it is `None`, the count will be selected automatically. """ self._model_evaluator = model_evaluator self._metric_name = metric_name + self._metric_type = metric_type self._persample_metric_name = self.METRIC_TO_PERSAMPLE_METRIC.get(self._metric_name, self._metric_name) registered_metrics = model_evaluator.get_metrics_attributes() if self._persample_metric_name not in registered_metrics: @@ -141,6 +160,8 @@ def __init__(self, model_evaluator: ModelEvaluator, metric_name: str, requests_n self._requests_number = requests_number self._values_for_each_item = [] + self._collect_outputs = self._metric_type in self.SPECIAL_METRICS + def __call__(self, compiled_model: ov.CompiledModel, indices: Optional[Iterable[int]] = None) -> float: """ Calculates metrics for the provided model. @@ -203,6 +224,11 @@ def _output_callback(self, raw_predictions, **kwargs): return for sample_id, results in metrics_result.items(): + if self._collect_outputs: + output = list(raw_predictions.values())[0] + self._values_for_each_item.append({"sample_id": sample_id, "metric_value": output}) + continue + for metric_result in results: if metric_result.metric_name != self._persample_metric_name: continue @@ -940,10 +966,11 @@ def quantize_model_with_accuracy_control( ) model_evaluator.load_network([{"model": ov_model}]) + metric_type = accuracy_checker_config["models"][0]["datasets"][0]["metrics"][0]["type"] metric_name = accuracy_checker_config["models"][0]["datasets"][0]["metrics"][0].get("name", None) if metric_name is None: - metric_name = accuracy_checker_config["models"][0]["datasets"][0]["metrics"][0]["type"] - validation_fn = ACValidationFunction(model_evaluator, metric_name) + metric_name = metric_type + validation_fn = ACValidationFunction(model_evaluator, metric_name, metric_type) name_to_quantization_impl_map = { "pot": pot_quantize_with_native_accuracy_control, From d39d3ae971c67465b7a8c0df94a070e1c925b87f Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Fri, 6 Oct 2023 11:19:52 +0200 Subject: [PATCH 06/10] [Torch] Unused code is removed (#2172) ### Changes #2035 cleanup: Unused code is removed --- nncf/torch/tensor_statistics/reduction.py | 65 ---------------------- nncf/torch/tensor_statistics/statistics.py | 6 -- 2 files changed, 71 deletions(-) delete mode 100644 nncf/torch/tensor_statistics/reduction.py diff --git a/nncf/torch/tensor_statistics/reduction.py b/nncf/torch/tensor_statistics/reduction.py deleted file mode 100644 index 4911bc2fdf2..00000000000 --- a/nncf/torch/tensor_statistics/reduction.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) 2023 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Tuple - -import numpy as np -import torch - - -def max_reduce_like(input_: torch.Tensor, ref_tensor_shape: List[int]) -> torch.Tensor: - numel = np.prod(ref_tensor_shape) - if numel == 1: - retval = input_.max() - for _ in ref_tensor_shape: - retval.unsqueeze_(-1) - return retval - tmp_max = input_ - for dim_idx, dim in enumerate(ref_tensor_shape): - if dim == 1: - tmp_max, _ = torch.max(tmp_max, dim_idx, keepdim=True) - return tmp_max - - -def min_reduce_like(input_: torch.Tensor, ref_tensor_shape: List[int]): - numel = np.prod(ref_tensor_shape) - if numel == 1: - retval = input_.min() - for _ in ref_tensor_shape: - retval.unsqueeze_(-1) - return retval - tmp_min = input_ - for dim_idx, dim in enumerate(ref_tensor_shape): - if dim == 1: - tmp_min, _ = torch.min(tmp_min, dim_idx, keepdim=True) - return tmp_min - - -def get_channel_count_and_dim_idx(scale_shape: List[int]) -> Tuple[int, int]: - channel_dim_idx = 0 - channel_count = 1 - for dim_idx, dim in enumerate(scale_shape): - if dim != 1: - channel_dim_idx = dim_idx - channel_count = dim - return channel_count, channel_dim_idx - - -def expand_like(input_: torch.Tensor, scale_shape: List[int]) -> torch.Tensor: - retval = input_ - count, idx = get_channel_count_and_dim_idx(scale_shape) - assert input_.numel() == count - assert len(input_.size()) == 1 - for _ in range(0, idx): - retval = retval.unsqueeze(0) - for _ in range(idx + 1, len(scale_shape)): - retval = retval.unsqueeze(-1) - return retval diff --git a/nncf/torch/tensor_statistics/statistics.py b/nncf/torch/tensor_statistics/statistics.py index ba51df16ce9..187f125d813 100644 --- a/nncf/torch/tensor_statistics/statistics.py +++ b/nncf/torch/tensor_statistics/statistics.py @@ -9,8 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple - import torch from nncf.common.tensor_statistics.statistics import MeanTensorStatistic @@ -20,10 +18,6 @@ from nncf.common.tensor_statistics.statistics import TensorStatistic -def _reshape_all(targets: Tuple[torch.Tensor, ...], target_shape: Tuple[int, ...]): - return map(lambda stat: torch.reshape(stat, target_shape), targets) - - class PTMinMaxTensorStatistic(MinMaxTensorStatistic): @staticmethod def tensor_eq(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol=1e-6) -> bool: From a97fff9445f7e218c23e2a9b3142f414b0802fe4 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Mon, 9 Oct 2023 07:40:55 +0300 Subject: [PATCH 07/10] Add `cleanup_torchscript_cache` to test_quantize_conformance.py (#2171) ### Changes Add `cleanup_torchscript_cache` function to test_quantize_conformance.py ### Reason for changes After run torch.jit.trace in convert_model, PyTorch does not clear the trace cache automatically. Same function in torch test and openvino notebok: https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/jit_utils.py#L59 https://github.com/openvinotoolkit/openvino_notebooks/blob/1932d4b4e99116bdedaa620c9dc92069fbb1f05e/notebooks/236-stable-diffusion-v2/implementation/conversion_helper_utils.py#L8 --- tests/post_training/pipelines/base.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/post_training/pipelines/base.py b/tests/post_training/pipelines/base.py index 9984a4df18a..0306fc9594e 100644 --- a/tests/post_training/pipelines/base.py +++ b/tests/post_training/pipelines/base.py @@ -279,6 +279,19 @@ def run(self) -> None: self.save_quantized_model() self.get_num_fq() self.validate() + self.cleanup_torchscript_cache() + + @staticmethod + def cleanup_torchscript_cache(): + """ + Helper for removing cached model representation. + + After run torch.jit.trace in convert_model, PyTorch does not clear the trace cache automatically. + """ + # pylint: disable=protected-access + torch._C._jit_clear_class_registry() + torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() + torch.jit._state._clear_class_state() def get_run_info(self) -> RunInfo: return self.run_info From 881cca60c7e98c4b2525d0355b35eb7ca0d1756a Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Mon, 9 Oct 2023 06:42:07 +0200 Subject: [PATCH 08/10] Added a helper script for extracting an IR subgraph (#2168) ### Changes There recently have been cases of very large model IRs, for example for LLMs which can't be opened by Netron. This makes it difficult to examine model graph in general and Fake Quantize node placement in particular. The added helper script allows to select a node from a graph and extract some subgraph around it. To control how many surrounding nodes to include, `distance` parameter is used. #### Usage examples ``` python ir_subgraph.py openvino.xml "Constant_1116858" ``` The result will be saved at `./openvino_Constant_1116858_10.xml`. An additional symbolic link will be created at `./openvino_Constant_1116858_10.bin` leading to original `.bin` file so that weights for the subgraph IR are visible through Netron. - Parameter `--distance` can be used to control the subgraph size (10 by default) ``` python ir_subgraph.py openvino.xml "Constant_1116858" --distance 5 ``` The result will be saved at `./openvino_Constant_1116858_5.xml`. - Parameter `--output-path` can be used to control where to save the result. It can either be a file path or a directory path. ``` python ir_subgraph.py openvino.xml "Constant_1116858" --output-path ./subgraphs ``` The result will be saved at `./subgraphs/openvino_Constant_1116858_10.xml`. ``` python ir_subgraph.py openvino.xml "Constant_1116858" --output-path ./subgraphs/Constant_1116858.xml ``` The result will be saved at `./subgraphs/Constant_1116858.xml`. #### Simplifying usage A file can be run like below ``` ir_subgraph.py openvino.xml "Constant_1116858" ``` after performing the following steps 1. Add the directory with the `ir_subgraph.py` file to PATH variable `PATH=$PATH:/path/to/dir` 2. Make it executable `chmod +x /path/to/ir_subgraph.py` 3. Add `#!/path/to/python` as the first line of `ir_subgraph.py` ### Reason for changes Making it easier to analyze large graphs. --- tools/extract_ov_subgraph.py | 306 +++++++++++++++++++++++++++++++++++ 1 file changed, 306 insertions(+) create mode 100644 tools/extract_ov_subgraph.py diff --git a/tools/extract_ov_subgraph.py b/tools/extract_ov_subgraph.py new file mode 100644 index 00000000000..d169308156a --- /dev/null +++ b/tools/extract_ov_subgraph.py @@ -0,0 +1,306 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import shutil +import xml.etree.ElementTree as ET +from copy import copy +from copy import deepcopy +from pathlib import Path +from pprint import pprint +from typing import Any, Dict + +import defusedxml.ElementTree as dET +import networkx as nx + + +def xml_to_dict(element: ET.Element): + result = {} + if element.attrib: + result["attributes"] = element.attrib + for child in element: + child_dict = xml_to_dict(child) + if child.tag in result: + if isinstance(result[child.tag], list): + result[child.tag].append(child_dict) + else: + result[child.tag] = [result[child.tag], child_dict] + else: + result[child.tag] = child_dict + if element.text: + result["text"] = element.text + if element.tail: + result["tail"] = element.tail + return result + + +def dict_to_xml(data: Any, parent: ET.Element): + if isinstance(data, dict): + for tag_name, value in data.items(): + if tag_name == "attributes": + parent.attrib.update(value) + elif tag_name == "text": + parent.text = value + elif tag_name == "tail": + parent.tail = value + elif isinstance(value, list): + for item in value: + elem = ET.SubElement(parent, tag_name) + dict_to_xml(item, elem) + else: + elem = ET.SubElement(parent, tag_name) + dict_to_xml(value, elem) + else: + parent.text = str(data) + + +def get_edges(xml_dict: Dict): + def add_edge(edges: Dict, from_layer: int, from_port: int, to_layer: int, to_port: int): + if from_layer not in edges: + edges[from_layer] = {} + if from_port not in edges[from_layer]: + edges[from_layer][from_port] = {} + assert (to_layer, to_port) not in edges[from_layer][from_port] + edges[from_layer][from_port][(to_layer, to_port)] = {} + + edges = {} + for edge in xml_dict["edges"]["edge"]: + edge = edge["attributes"] + + from_layer = int(edge["from-layer"]) + from_port = int(edge["from-port"]) + to_layer = int(edge["to-layer"]) + to_port = int(edge["to-port"]) + + add_edge(edges, from_layer, from_port, to_layer, to_port) + add_edge(edges, to_layer, to_port, from_layer, from_port) + + return edges + + +def get_nodes(xml_dict: Dict, edges: Dict): + all_node_names = set() + nodes = {} + for node in xml_dict["layers"]["layer"]: + try: + attributes = node["attributes"] + data = node["data"]["attributes"] if "data" in node else None + inp = node["input"] if "input" in node else None + out = node["output"] if "output" in node else None + + node_id = int(attributes["id"]) + node_name = attributes["name"] + node_type = attributes["type"] + + assert node_name not in all_node_names + all_node_names.add(node_name) + + assert node_id not in nodes + nodes[node_id] = { + "name": node_name, + "type": node_type, + } + + node_dtype = data["element_type"] if data is not None and "element_type" in data else None + node_shape = data["shape"] if data is not None and "shape" in data else None + if node_dtype is not None: + nodes[node_id]["dtype"] = node_dtype + if node_shape is not None: + nodes[node_id]["shape"] = node_shape + + input_ports = [] if inp is None else inp["port"] + output_ports = [] if out is None else out["port"] + if isinstance(input_ports, dict): + input_ports = [input_ports] + if isinstance(output_ports, dict): + output_ports = [output_ports] + + for port, is_input in zip( + input_ports + output_ports, [True] * len(input_ports) + [False] * len(output_ports) + ): + from_port = int(port["attributes"]["id"]) + precision = port["attributes"]["precision"] + if "dim" in port["attributes"]: + dim = port["attributes"]["dim"] + elif "dim" in port: + dim = port["dim"] + else: + dim = [] + if isinstance(dim, dict): + dim = [dim] + shape = tuple(int(it["text"]) for it in dim) + + # Update properties of the edges leading from this port + if from_port not in edges[node_id]: + # Some edge descriptions may be missing in execution graph + continue + else: + edge = edges[node_id][from_port] + for (to_node_id, to_port), edge_properties_dict in edge.items(): + for name, value in zip(("precision", "shape", "is_input"), (precision, shape, is_input)): + assert name not in edge_properties_dict + edge_properties_dict[name] = value + except Exception as e: + pprint(node) + raise e + + return nodes + + +def create_nx_graph(xml_dict: Dict): + def get_node_label(nodes: Dict, node_id: int): + return nodes[node_id]["name"] + + def get_edge_label(edges: Dict, nodes: Dict, from_node: int, from_port: int, to_node: int, to_port: int): + edge_properties = edges[from_node][from_port][(to_node, to_port)] + return f'"{edge_properties["shape"]}\n{from_port}->{to_port}"' + + edges = get_edges(xml_dict) + nodes = get_nodes(xml_dict, edges) + + G = nx.Graph() + + # Add nodes + for node_id, node_properties in nodes.items(): + node_properties_copy = copy(node_properties) + node_properties_copy["id"] = node_id + G.add_node(get_node_label(nodes, node_id), **node_properties_copy) + + # Add edges + for node_id, from_port_dict in edges.items(): + for from_port, to_port_dict in from_port_dict.items(): + for (to_node_id, to_port), edge_properties in to_port_dict.items(): + G.add_edge( + u_of_edge=get_node_label(nodes, node_id), + v_of_edge=get_node_label(nodes, to_node_id), + label=get_edge_label(edges, nodes, node_id, from_port, to_node_id, to_port), + **edge_properties, + ) + + return G + + +def write_xml(xml_dict: Dict, filepath: Path): + write_root = ET.Element("net") + dict_to_xml(xml_dict, write_root) + xml_str = ET.tostring(write_root).decode() + xml_str = '\n' + xml_str + "\n" + with open(filepath, "w") as f: + f.write(xml_str) + + +def take_model_subgraph(xml_dict: Dict, source_node_name: str, distance: int): + # Create networkx graph from IR xml dictionary + G = create_nx_graph(xml_dict) + + # Traverse graph from target node + dfs_tree = nx.traversal.dfs_tree(G, source=source_node_name, depth_limit=distance) + node_names = set(dfs_tree.nodes) + node_ids = set([G.nodes[it]["id"] for it in node_names]) + + # Keep only the visited nodes + result_xml_dict = deepcopy(xml_dict) + result_xml_dict["layers"]["layer"] = [] + for layer in xml_dict["layers"]["layer"]: + node_name = layer["attributes"]["name"] + if node_name in node_names: + result_xml_dict["layers"]["layer"].append(layer) + + # Keep only the edges that connect the visited nodes + result_xml_dict["edges"]["edge"] = [] + for edge in xml_dict["edges"]["edge"]: + from_layer = int(edge["attributes"]["from-layer"]) + to_layer = int(edge["attributes"]["to-layer"]) + if from_layer in node_ids or to_layer in node_ids: + result_xml_dict["edges"]["edge"].append(edge) + + return result_xml_dict + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Extract a subgraph from a model in OpenVINO Intermediate Representation format.\n\nSubgraph is " + "taken around a given node. Use distance parameter to control how many nodes around the given one to include. " + "The resulting subgraph is saved next to the input .xml file or at --output_path if provided. Additionally, a " + "symbolic link targeting the original .bin file is created.", + epilog="Usage examples:\n" + ' python ir_subgraph.py openvino.xml "Constant_1116858"\n' + ' python ir_subgraph.py openvino.xml "Constant_1116858" --distance 5\n' + ' python ir_subgraph.py openvino.xml "Constant_1116858" --output-path ./subgraphs\n' + ' python ir_subgraph.py openvino.xml "Constant_1116858" --output-path ./subgraphs/Constant_1116858.xml\n', + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument("input-path", help="Input IR path.") + parser.add_argument("node", help="Target node name.") + parser.add_argument("--distance", type=int, default=10, help="Distance around the target node (default 10).") + parser.add_argument( + "--output-path", + dest="output_path", + help="Output IR path. Can either be a file path with .xml extension or a directory path.", + ) + + args = parser.parse_args() + + input_path = Path(args.__dict__["input-path"]) + node_name = args.node + distance = args.distance + output_path = Path(args.output_path) if args.output_path is not None else None + + if distance <= 0: + raise ValueError("Distance should be positive") + + if output_path is None or output_path.suffix == "": + output_filename = f"{input_path.stem}_{Path(node_name).stem}_{distance}.xml" + if output_path is None: + output_dir = input_path.parent + output_path = input_path.parent / output_filename + else: + output_dir = output_path + output_path = output_dir / output_filename + else: + output_dir = output_path.parent + + if output_path.exists(): + raise ValueError(f"There is already and IR at {output_path}. Exiting.") + + # Read IR xml as dict + tree = dET.parse(input_path) + root = tree.getroot() + xml_dict = xml_to_dict(root) + + # Take subgraph + subgraph_xml_dict = take_model_subgraph(xml_dict, source_node_name=node_name, distance=distance) + + # Save subgraph xml + if not output_dir.exists(): + output_dir.mkdir(parents=True) + write_xml(subgraph_xml_dict, output_path) + + # Create a symbolic link to original .bin file + bin_input_path = input_path.with_suffix(".bin") + bin_output_path = output_path.with_suffix(".bin") + if bin_output_path.exists(): + os.remove(bin_output_path) + try: + bin_output_path.symlink_to(os.path.relpath(bin_input_path, bin_output_path.parent)) + except OSError as e: + if "[WinError 1314]" in str(e): + if bin_input_path.exists(): + print("Copying original .bin file because can't create a symbolic link due to lack of admin privileges") + shutil.copy(bin_input_path, bin_output_path) + else: + print("Didn't create a copy of original .bin file because it is missing") + else: + raise e + + print("Saved at:", output_path) From bbb7e56877b4dee01e209a60d0e3b94989cc92f2 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Mon, 9 Oct 2023 17:21:43 +0300 Subject: [PATCH 09/10] Tensor for PTQ (#2058) ### Changes - Update `MinMax` and `FastBiasCorrection` to use common Tensor. - Remove converting torch -> numpy -> torch of data. - `FakeQuantizeParameters` collect data wrapped by tensor. - Add support cuda for torch backend. - Add new functions for `Tensor`: - stack - unstack - moveaxis - mean - round - Removed `__all__` from function.py, it's works like default behavior. - Add `statistical_functions.py` for high level functions that used only function from `functions.py` and have no backend specific implementations: - mean_per_channel - Disable warnings for divide operators of numpy ### Related tickets 113315 ### Tests --- .../statistical_functions.py | 30 ++ nncf/experimental/tensor/README.md | 38 +- nncf/experimental/tensor/functions.py | 184 ++++++--- nncf/experimental/tensor/numpy_functions.py | 180 +++++---- nncf/experimental/tensor/tensor.py | 75 ++-- nncf/experimental/tensor/torch_functions.py | 124 ++++-- .../onnx/quantization/quantizer_parameters.py | 4 +- nncf/openvino/graph/model_transformer.py | 16 +- .../fast_bias_correction/algorithm.py | 35 +- .../fast_bias_correction/backend.py | 43 +- .../fast_bias_correction/onnx_backend.py | 41 +- .../fast_bias_correction/openvino_backend.py | 41 +- .../fast_bias_correction/torch_backend.py | 42 +- .../algorithms/min_max/onnx_backend.py | 2 +- .../algorithms/min_max/openvino_backend.py | 2 +- .../algorithms/min_max/torch_backend.py | 26 +- nncf/quantization/fake_quantize.py | 92 +++-- tests/onnx/quantization/common.py | 13 +- .../test_calculate_quantizer_parameters.py | 9 +- .../test_fast_bias_correction.py | 2 +- .../template_test_nncf_tensor.py | 370 ++++++++++++------ .../ptq/test_calculation_quantizer_params.py | 95 +++-- tests/torch/ptq/test_fast_bias_correction.py | 27 ++ 23 files changed, 905 insertions(+), 586 deletions(-) create mode 100644 nncf/experimental/common/tensor_statistics/statistical_functions.py diff --git a/nncf/experimental/common/tensor_statistics/statistical_functions.py b/nncf/experimental/common/tensor_statistics/statistical_functions.py new file mode 100644 index 00000000000..24fc115a058 --- /dev/null +++ b/nncf/experimental/common/tensor_statistics/statistical_functions.py @@ -0,0 +1,30 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nncf.experimental.tensor import Tensor +from nncf.experimental.tensor import functions as fns + + +def mean_per_channel(x: Tensor, axis: int) -> Tensor: + """ + Computes the mean of elements across given channel dimension of Tensor. + + :param x: Tensor to reduce. + :param axis: The channel dimensions to reduce. + :return: Reduced Tensor. + """ + if len(x.shape) < 3: + return fns.mean(x, axis=0) + pos_axis = axis + x.ndim if axis < 0 else axis + if pos_axis < 0 or pos_axis >= x.ndim: + raise ValueError(f"axis {axis} is out of bounds for array of dimension {x.ndim}") + axis = tuple(i for i in range(x.ndim) if i != pos_axis) + return fns.mean(x, axis=axis) diff --git a/nncf/experimental/tensor/README.md b/nncf/experimental/tensor/README.md index 09e3dc6a1e0..ea8ae2168c9 100644 --- a/nncf/experimental/tensor/README.md +++ b/nncf/experimental/tensor/README.md @@ -6,7 +6,7 @@ making them more portable and reusable. ## Usage -The main idea is common algorithms should use wrapped tensors and provide to backend-specific function unwrapped tensor. +Common algorithms should use wrapped tensors and provide the unwrapped tensor to the backend-specific function. ### Initialization Tensor @@ -32,6 +32,8 @@ tenor_b = Tensor(np.array([1,2])) tensor_a + tenor_b # Tensor(array([2, 4])) ``` +**NOTE** Division operations for the numpy backend are performed with warnings disabled for the same for all backends. + ### Comparison operators All math operations are overrided to operated with wrapped object and return `Tensor` @@ -55,16 +57,16 @@ nncf_tensor.max() # Tensor(2) All available functions you can found in [functions.py](functions.py). ```python -from nncf.experimental.tensor import functions -functions.max(nncf_tensor) # Tensor(2) +from nncf.experimental.tensor import functions as fns +fns.max(nncf_tensor) # Tensor(2) ``` **NOTE** A function requires at least one positional argument, which is used to dispatch the function to the appropriate implementation depending on the type of argument. ```python -functions.max(nncf_tensor) # Correct -functions.max(a=nncf_tensor) # TypeError: wrapper requires at least 1 positional argument +fns.max(nncf_tensor) # Correct +fns.max(a=nncf_tensor) # TypeError: wrapper requires at least 1 positional argument ``` ### Loop over Tensor @@ -100,7 +102,7 @@ tensor_a[0:2] # Tensor(array([[1],[2]])) class Tensor: ... def foo(self, arg1: Type) -> "Tensor": - return functions.foo(self, arg1) + return fns.foo(self, arg1) ``` 2. Add function to [function.py](function.py) @@ -120,28 +122,36 @@ tensor_a[0:2] # Tensor(array([[1],[2]])) return NotImplemented(f"Function `foo` is not implemented for {type(a)}") ``` -3. Add function name to `__all__` in [function.py](function.py) + **NOTE** For the case when the first argument has type `List[Tensor]`, use the `_dispatch_list` function. This function dispatches function by first element in the first argument. + + ```python + @functools.singledispatch + def foo(x: List[Tensor], axis: int = 0) -> Tensor: + if isinstance(x, List): + unwrapped_x = [i.data for i in x] + return Tensor(_dispatch_list(foo, unwrapped_x, axis=axis)) + raise NotImplementedError(f"Function `foo` is not implemented for {type(x)}") + ``` -4. Add backend specific implementation of method to: +3. Add backend specific implementation of method to: - - [numpy_function.py](numpy_function.py) + - [numpy_function.py](numpy_functions.py) ```python - @functions.foo.register(np.ndarray) - @functions.foo.register(np.number) + @_register_numpy_types(fns.foo) def _(a: TType, arg1: Type) -> np.ndarray: return np.foo(a, arg1) ``` - - [torch_function.py](torch_function.py) + - [torch_function.py](torch_functions.py) ```python - @functions.foo.register(torch.Tensor) + @fns.foo.register(torch.Tensor) def _(a: torch.Tensor, arg1: Type) -> torch.Tensor: return torch.foo(a, arg1) ``` -5. Add test of method to [test template](tests/shared/test_templates/template_test_nncf_tensor.py) for Tensor class +4. Add test of method to [test template](../../../tests/shared/test_templates/template_test_nncf_tensor.py) for Tensor class ### Add new backend diff --git a/nncf/experimental/tensor/functions.py b/nncf/experimental/tensor/functions.py index 30f27a65cce..33e55c1ef50 100644 --- a/nncf/experimental/tensor/functions.py +++ b/nncf/experimental/tensor/functions.py @@ -10,14 +10,12 @@ # limitations under the License. import functools -from typing import List, Optional, Tuple, TypeVar, Union +from typing import Callable, List, Optional, Tuple, Union -from nncf.experimental.tensor import Tensor -from nncf.experimental.tensor import unwrap_tensor_data from nncf.experimental.tensor.enums import TensorDataType from nncf.experimental.tensor.enums import TensorDeviceType - -TTensor = TypeVar("TTensor") +from nncf.experimental.tensor.tensor import Tensor +from nncf.experimental.tensor.tensor import unwrap_tensor_data def _tensor_guard(func: callable): @@ -36,7 +34,7 @@ def wrapper(*args, **kwargs): @functools.singledispatch @_tensor_guard -def device(a: TTensor) -> TensorDeviceType: +def device(a: Tensor) -> TensorDeviceType: """ Return the device of the tensor. @@ -48,7 +46,7 @@ def device(a: TTensor) -> TensorDeviceType: @functools.singledispatch @_tensor_guard -def squeeze(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: +def squeeze(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: """ Remove axes of length one from a. @@ -63,7 +61,7 @@ def squeeze(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTenso @functools.singledispatch @_tensor_guard -def flatten(a: TTensor) -> TTensor: +def flatten(a: Tensor) -> Tensor: """ Return a copy of the tensor collapsed into one dimension. @@ -75,7 +73,7 @@ def flatten(a: TTensor) -> TTensor: @functools.singledispatch @_tensor_guard -def max(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin +def max(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: # pylint: disable=redefined-builtin """ Return the maximum of an array or maximum along an axis. @@ -88,7 +86,7 @@ def max(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: @functools.singledispatch @_tensor_guard -def min(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin +def min(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: # pylint: disable=redefined-builtin """ Return the minimum of an array or minimum along an axis. @@ -101,7 +99,7 @@ def min(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: @functools.singledispatch @_tensor_guard -def abs(a: TTensor) -> TTensor: # pylint: disable=redefined-builtin +def abs(a: Tensor) -> Tensor: # pylint: disable=redefined-builtin """ Calculate the absolute value element-wise. @@ -113,7 +111,7 @@ def abs(a: TTensor) -> TTensor: # pylint: disable=redefined-builtin @functools.singledispatch @_tensor_guard -def astype(a: TTensor, data_type: TensorDataType) -> TTensor: +def astype(a: Tensor, data_type: TensorDataType) -> Tensor: """ Copy of the tensor, cast to a specified type. @@ -127,7 +125,7 @@ def astype(a: TTensor, data_type: TensorDataType) -> TTensor: @functools.singledispatch @_tensor_guard -def dtype(a: TTensor) -> TensorDataType: +def dtype(a: Tensor) -> TensorDataType: """ Return data type of the tensor. @@ -139,7 +137,7 @@ def dtype(a: TTensor) -> TensorDataType: @functools.singledispatch @_tensor_guard -def reshape(a: TTensor, shape: List[int]) -> TTensor: +def reshape(a: Tensor, shape: Tuple[int, ...]) -> Tensor: """ Gives a new shape to a tensor without changing its data. @@ -152,7 +150,7 @@ def reshape(a: TTensor, shape: List[int]) -> TTensor: @functools.singledispatch @_tensor_guard -def all(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin +def all(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: # pylint: disable=redefined-builtin """ Test whether all tensor elements along a given axis evaluate to True. @@ -165,7 +163,9 @@ def all(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: @functools.singledispatch @_tensor_guard -def allclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> TTensor: +def allclose( + a: Tensor, b: Union[Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False +) -> Tensor: """ Returns True if two arrays are element-wise equal within a tolerance. @@ -191,7 +191,7 @@ def allclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, e @functools.singledispatch @_tensor_guard -def any(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: # pylint: disable=redefined-builtin +def any(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: # pylint: disable=redefined-builtin """ Test whether any tensor elements along a given axis evaluate to True. @@ -204,7 +204,7 @@ def any(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: @functools.singledispatch @_tensor_guard -def count_nonzero(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> TTensor: +def count_nonzero(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: """ Counts the number of non-zero values in the tensor input. @@ -218,19 +218,21 @@ def count_nonzero(a: TTensor, axis: Optional[Union[int, Tuple[int]]] = None) -> @functools.singledispatch @_tensor_guard -def isempty(a: TTensor) -> TTensor: +def isempty(a: Tensor) -> bool: """ Return True if input tensor is empty. :param a: The input tensor. :return: True if tensor is empty, otherwise False. """ - return Tensor(isempty(a.data)) + return isempty(a.data) @functools.singledispatch @_tensor_guard -def isclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> TTensor: +def isclose( + a: Tensor, b: Union[Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False +) -> Tensor: """ Returns a boolean array where two arrays are element-wise equal within a tolerance. @@ -256,7 +258,7 @@ def isclose(a: TTensor, b: TTensor, rtol: float = 1e-05, atol: float = 1e-08, eq @functools.singledispatch @_tensor_guard -def maximum(x1: TTensor, x2: TTensor) -> TTensor: +def maximum(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: """ Element-wise maximum of tensor elements. @@ -269,7 +271,7 @@ def maximum(x1: TTensor, x2: TTensor) -> TTensor: @functools.singledispatch @_tensor_guard -def minimum(x1: TTensor, x2: TTensor) -> TTensor: +def minimum(x1: Tensor, x2: Union[Tensor, float]) -> Tensor: """ Element-wise minimum of tensor elements. @@ -282,7 +284,7 @@ def minimum(x1: TTensor, x2: TTensor) -> TTensor: @functools.singledispatch @_tensor_guard -def ones_like(a: TTensor) -> TTensor: +def ones_like(a: Tensor) -> Tensor: """ Return a tensor of ones with the same shape and type as a given tensor. @@ -294,7 +296,7 @@ def ones_like(a: TTensor) -> TTensor: @functools.singledispatch @_tensor_guard -def where(condition: TTensor, x: TTensor, y: TTensor) -> TTensor: +def where(condition: Tensor, x: Union[Tensor, float], y: Union[Tensor, float]) -> Tensor: """ Return elements chosen from x or y depending on condition. @@ -314,7 +316,7 @@ def where(condition: TTensor, x: TTensor, y: TTensor) -> TTensor: @functools.singledispatch @_tensor_guard -def zeros_like(a: TTensor) -> TTensor: +def zeros_like(a: Tensor) -> Tensor: """ Return an tensor of zeros with the same shape and type as a given tensor. @@ -324,28 +326,114 @@ def zeros_like(a: TTensor) -> TTensor: return Tensor(zeros_like(a.data)) -__all__ = [ - "device", - "squeeze", - "flatten", - "max", - "min", - "abs", - "astype", - "reshape", - "all", - "allclose", - "any", - "count_nonzero", - "isempty", - "isclose", - "maximum", - "minimum", - "ones_like", - "minimum", - "where", - "zeros_like", -] +@functools.singledispatch +def stack(x: List[Tensor], axis: int = 0) -> Tensor: + """ + Stacks a list of Tensors rank-R tensors into one Tensor rank-(R+1) tensor. + + :param x: List of Tensors. + :param axis: The axis to stack along. + :return: Stacked Tensor. + """ + if isinstance(x, List): + return Tensor(_dispatch_list(stack, x, axis=axis)) + raise NotImplementedError(f"Function `stack` is not implemented for {type(x)}") + + +@functools.singledispatch +@_tensor_guard +def unstack(a: Tensor, axis: int = 0) -> List[Tensor]: + """ + Unstack a Tensor into list. + + :param a: Tensor to unstack. + :param axis: The axis to unstack along. + :return: List of Tensor. + """ + res = unstack(a.data, axis=axis) + return [Tensor(i) for i in res] + + +@functools.singledispatch +@_tensor_guard +def moveaxis(a: Tensor, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> Tensor: + """ + Move axes of an array to new positions. + + :param a: The array whose axes should be reordered. + :param source: Original positions of the axes to move. These must be unique. + :param destination: Destination positions for each of the original axes. These must also be unique. + :return: Array with moved axes. + """ + return Tensor(moveaxis(a.data, source, destination)) + + +@functools.singledispatch +@_tensor_guard +def mean(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Tensor: + """ + Compute the arithmetic mean along the specified axis. + + :param a: Array containing numbers whose mean is desired. + :param axis: Axis or axes along which the means are computed. + :param keepdims: Destination positions for each of the original axes. These must also be unique. + :return: Array with moved axes. + """ + return Tensor(mean(a.data, axis, keepdims)) + + +@functools.singledispatch +@_tensor_guard +def round(a: Tensor, decimals=0) -> Tensor: # pylint: disable=redefined-builtin + """ + Evenly round to the given number of decimals. + + :param a: Input data. + :param decimals: Number of decimal places to round to (default: 0). If decimals is negative, + it specifies the number of positions to the left of the decimal point. + :return: An array of the same type as a, containing the rounded values. + """ + return Tensor(round(a.data, decimals)) + + +@functools.singledispatch +@_tensor_guard +def _binary_op_nowarn(a: Tensor, b: Union[Tensor, float], operator_fn: Callable) -> Tensor: + """ + Applies a binary operation with disable warnings. + + :param a: The first tensor. + :param b: The second tensor. + :param operator_fn: The binary operation function. + :return: The result of the binary operation. + """ + return Tensor(_binary_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn)) + + +@functools.singledispatch +@_tensor_guard +def _binary_reverse_op_nowarn(a: Tensor, b: Union[Tensor, float], operator_fn: Callable) -> Tensor: + """ + Applies a binary reverse operation with disable warnings. + + :param a: The first tensor. + :param b: The second tensor. + :param operator_fn: The binary operation function. + :return: The result of the binary operation. + """ + return Tensor(_binary_reverse_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn)) + + +def _dispatch_list(fn: "functools._SingleDispatchCallable", tensor_list: List[Tensor], *args, **kwargs): + """ + Dispatches the function to the type of the wrapped data of the first element in tensor_list. + + :param fn: A function wrapped by `functools.singledispatch`. + :param tensor_list: List of Tensors. + :return: The result value of the function call. + """ + unwrapped_list = [i.data for i in tensor_list] + return fn.dispatch(type(unwrapped_list[0]))(unwrapped_list, *args, **kwargs) def _initialize_backends(): diff --git a/nncf/experimental/tensor/numpy_functions.py b/nncf/experimental/tensor/numpy_functions.py index be070db4bdb..b4c515b1da1 100644 --- a/nncf/experimental/tensor/numpy_functions.py +++ b/nncf/experimental/tensor/numpy_functions.py @@ -9,11 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np -from nncf.experimental.tensor import functions +from nncf.experimental.tensor import functions as fns from nncf.experimental.tensor.enums import TensorDataType from nncf.experimental.tensor.enums import TensorDeviceType @@ -28,137 +28,179 @@ DTYPE_MAP_REV = {v: k for k, v in DTYPE_MAP.items()} -@functions.device.register(np.ndarray) -@functions.device.register(np.number) -def _(a: Union[np.ndarray, np.number]) -> TensorDeviceType: +def _register_numpy_types(singledispatch_fn): + """ + Decorator to register function to singledispatch for numpy classes. + + :param singledispatch_fn: singledispatch function. + """ + + def inner(func): + singledispatch_fn.register(np.ndarray)(func) + singledispatch_fn.register(np.generic)(func) + return func + + return inner + + +@_register_numpy_types(fns.device) +def _(a: Union[np.ndarray, np.generic]) -> TensorDeviceType: return TensorDeviceType.CPU -@functions.squeeze.register(np.ndarray) -@functions.squeeze.register(np.number) -def _(a: Union[np.ndarray, np.number], axis: Optional[Union[int, Tuple[int]]] = None) -> np.ndarray: +@_register_numpy_types(fns.squeeze) +def _( + a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None +) -> Union[np.ndarray, np.generic]: return np.squeeze(a, axis=axis) -@functions.flatten.register(np.ndarray) -@functions.flatten.register(np.number) -def _(a: Union[np.ndarray, np.number]) -> np.ndarray: +@_register_numpy_types(fns.flatten) +def _(a: Union[np.ndarray, np.generic]) -> np.ndarray: return a.flatten() -@functions.max.register(np.ndarray) -@functions.max.register(np.number) -def _(a: Union[np.ndarray, np.number], axis: Optional[Union[int, Tuple[int]]] = None) -> np.ndarray: +@_register_numpy_types(fns.max) +def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> np.ndarray: return np.max(a, axis=axis) -@functions.min.register(np.ndarray) -@functions.min.register(np.number) -def _(a: Union[np.ndarray, np.number], axis: Optional[Union[int, Tuple[int]]] = None) -> np.ndarray: +@_register_numpy_types(fns.min) +def _( + a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None +) -> Union[np.ndarray, np.generic]: return np.min(a, axis=axis) -@functions.abs.register(np.ndarray) -@functions.abs.register(np.number) -def _(a: Union[np.ndarray, np.number]) -> np.ndarray: +@_register_numpy_types(fns.abs) +def _(a: Union[np.ndarray, np.generic]) -> Union[np.ndarray, np.generic]: return np.absolute(a) -@functions.astype.register(np.ndarray) -@functions.astype.register(np.number) -def _(a: Union[np.ndarray, np.number], dtype: TensorDataType) -> np.ndarray: +@_register_numpy_types(fns.astype) +def _(a: Union[np.ndarray, np.generic], dtype: TensorDataType) -> Union[np.ndarray, np.generic]: return a.astype(DTYPE_MAP[dtype]) -@functions.dtype.register(np.ndarray) -@functions.dtype.register(np.number) -def _(a: Union[np.ndarray, np.number]) -> TensorDataType: +@_register_numpy_types(fns.dtype) +def _(a: Union[np.ndarray, np.generic]) -> TensorDataType: return DTYPE_MAP_REV[np.dtype(a.dtype)] -@functions.reshape.register(np.ndarray) -@functions.reshape.register(np.number) -def _(a: Union[np.ndarray, np.number], shape: Union[int, Tuple[int]]) -> np.ndarray: +@_register_numpy_types(fns.reshape) +def _(a: Union[np.ndarray, np.generic], shape: Union[int, Tuple[int, ...]]) -> np.ndarray: return a.reshape(shape) -@functions.all.register(np.ndarray) -@functions.all.register(np.number) -def _(a: Union[np.ndarray, np.number], axis: Optional[Union[int, Tuple[int]]] = None) -> Union[np.ndarray, bool]: +@_register_numpy_types(fns.all) +def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[np.ndarray, bool]: return np.all(a, axis=axis) -@functions.allclose.register(np.ndarray) -@functions.allclose.register(np.number) +@_register_numpy_types(fns.allclose) def _( - a: Union[np.ndarray, np.number], - b: Union[np.ndarray, np.number], + a: Union[np.ndarray, np.generic], + b: Union[np.ndarray, np.generic, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False, -) -> bool: +) -> np.ndarray: return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) -@functions.any.register(np.ndarray) -@functions.any.register(np.number) -def _(a: Union[np.ndarray, np.number], axis: Optional[Union[int, Tuple[int]]] = None) -> Union[np.ndarray, bool]: +@_register_numpy_types(fns.any) +def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[np.ndarray, bool]: return np.any(a, axis=axis) -@functions.count_nonzero.register(np.ndarray) -@functions.count_nonzero.register(np.number) -def _(a: Union[np.ndarray, np.number], axis: Optional[Union[int, Tuple[int]]] = None) -> np.ndarray: - return np.count_nonzero(a, axis=axis) +@_register_numpy_types(fns.count_nonzero) +def _(a: Union[np.ndarray, np.generic], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> np.ndarray: + return np.array(np.count_nonzero(a, axis=axis)) -@functions.isempty.register(np.ndarray) -@functions.isempty.register(np.number) -def _(a: Union[np.ndarray, np.number]) -> bool: +@_register_numpy_types(fns.isempty) +def _(a: Union[np.ndarray, np.generic]) -> bool: return a.size == 0 -@functions.isclose.register(np.ndarray) -@functions.isclose.register(np.number) +@_register_numpy_types(fns.isclose) def _( - a: Union[np.ndarray, np.number], - b: np.ndarray, + a: Union[np.ndarray, np.generic], + b: Union[np.ndarray, np.generic, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False, -): +) -> Union[np.ndarray, bool]: return np.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) -@functions.maximum.register(np.ndarray) -@functions.maximum.register(np.number) -def _(x1: Union[np.ndarray, np.number], x2: np.ndarray) -> np.ndarray: +@_register_numpy_types(fns.maximum) +def _(x1: Union[np.ndarray, np.generic], x2: Union[np.ndarray, np.generic, float]) -> np.ndarray: return np.maximum(x1, x2) -@functions.minimum.register(np.ndarray) -@functions.minimum.register(np.number) -def _(x1: Union[np.ndarray, np.number], x2: np.ndarray) -> np.ndarray: +@_register_numpy_types(fns.minimum) +def _(x1: Union[np.ndarray, np.generic], x2: Union[np.ndarray, np.generic, float]) -> np.ndarray: return np.minimum(x1, x2) -@functions.ones_like.register(np.ndarray) -@functions.ones_like.register(np.number) -def _(a: Union[np.ndarray, np.number]) -> np.ndarray: +@_register_numpy_types(fns.ones_like) +def _(a: Union[np.ndarray, np.generic]) -> np.ndarray: return np.ones_like(a) -@functions.where.register(np.ndarray) -@functions.where.register(np.number) +@_register_numpy_types(fns.where) def _( - condition: Union[np.ndarray, np.number], - x: Union[np.ndarray, np.number, float, bool], - y: Union[np.ndarray, float, bool], + condition: Union[np.ndarray, np.generic], + x: Union[np.ndarray, np.generic, float], + y: Union[np.ndarray, np.generic, float], ) -> np.ndarray: return np.where(condition, x, y) -@functions.zeros_like.register(np.ndarray) -@functions.zeros_like.register(np.number) -def _(a: Union[np.ndarray, np.number]) -> np.ndarray: +@_register_numpy_types(fns.zeros_like) +def _(a: Union[np.ndarray, np.generic]) -> np.ndarray: return np.zeros_like(a) + + +@_register_numpy_types(fns.stack) +def _(x: Union[np.ndarray, np.generic], axis: int = 0) -> List[np.ndarray]: + return np.stack(x, axis=axis) + + +@_register_numpy_types(fns.unstack) +def _(x: Union[np.ndarray, np.generic], axis: int = 0) -> List[np.ndarray]: + return [np.squeeze(e, axis) for e in np.split(x, x.shape[axis], axis=axis)] + + +@_register_numpy_types(fns.moveaxis) +def _(a: np.ndarray, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> np.ndarray: + return np.moveaxis(a, source, destination) + + +@_register_numpy_types(fns.mean) +def _(a: Union[np.ndarray, np.generic], axis: Union[int, Tuple[int, ...]] = None, keepdims: bool = False) -> np.ndarray: + return np.mean(a, axis=axis, keepdims=keepdims) + + +@_register_numpy_types(fns.round) +def _(a: Union[np.ndarray, np.generic], decimals: int = 0) -> np.ndarray: + return np.round(a, decimals=decimals) + + +@_register_numpy_types(fns._binary_op_nowarn) # pylint: disable=protected-access +def _( + a: Union[np.ndarray, np.generic], b: Union[np.ndarray, np.generic, float], operator_fn: Callable +) -> Union[np.ndarray, np.generic]: + # Run operator with disabled warning + with np.errstate(invalid="ignore", divide="ignore"): + return operator_fn(a, b) + + +@_register_numpy_types(fns._binary_reverse_op_nowarn) # pylint: disable=protected-access +def _( + a: Union[np.ndarray, np.generic], b: Union[np.ndarray, np.generic, float], operator_fn: Callable +) -> Union[np.ndarray, np.generic]: + # Run operator with disabled warning + with np.errstate(invalid="ignore", divide="ignore"): + return operator_fn(b, a) diff --git a/nncf/experimental/tensor/tensor.py b/nncf/experimental/tensor/tensor.py index daa8e37aff4..76fd05c4ff1 100644 --- a/nncf/experimental/tensor/tensor.py +++ b/nncf/experimental/tensor/tensor.py @@ -8,9 +8,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations - -from typing import Any, List, Optional, Tuple, TypeVar, Union +import operator +from typing import Any, Optional, Tuple, TypeVar, Union from nncf.experimental.tensor.enums import TensorDataType from nncf.experimental.tensor.enums import TensorDeviceType @@ -31,8 +32,12 @@ def data(self) -> TTensor: return self._data @property - def shape(self) -> List[int]: - return list(self.data.shape) + def shape(self) -> Tuple[int, ...]: + return tuple(self.data.shape) + + @property + def ndim(self) -> int: + return self.data.ndim @property def device(self) -> TensorDeviceType: @@ -48,7 +53,7 @@ def __bool__(self) -> bool: def __iter__(self): return TensorIterator(self.data) - def __getitem__(self, index: int) -> "Tensor": + def __getitem__(self, index: int) -> Tensor: return Tensor(self.data[index]) def __str__(self) -> str: @@ -59,86 +64,86 @@ def __repr__(self) -> str: # built-in operations - def __add__(self, other: TTensor) -> "Tensor": + def __add__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data + unwrap_tensor_data(other)) - def __radd__(self, other: TTensor) -> "Tensor": + def __radd__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(unwrap_tensor_data(other) + self.data) - def __sub__(self, other: TTensor) -> "Tensor": + def __sub__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data - unwrap_tensor_data(other)) - def __rsub__(self, other: TTensor) -> "Tensor": + def __rsub__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(unwrap_tensor_data(other) - self.data) - def __mul__(self, other: TTensor) -> "Tensor": + def __mul__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data * unwrap_tensor_data(other)) - def __rmul__(self, other: TTensor) -> "Tensor": + def __rmul__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(unwrap_tensor_data(other) * self.data) - def __pow__(self, other: TTensor) -> "Tensor": + def __pow__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data ** unwrap_tensor_data(other)) - def __truediv__(self, other: TTensor) -> "Tensor": - return Tensor(self.data / unwrap_tensor_data(other)) + def __truediv__(self, other: Union[Tensor, float]) -> Tensor: + return _call_function("_binary_op_nowarn", self, other, operator.truediv) - def __rtruediv__(self, other: TTensor) -> "Tensor": - return Tensor(unwrap_tensor_data(other) / self.data) + def __rtruediv__(self, other: Union[Tensor, float]) -> Tensor: + return _call_function("_binary_reverse_op_nowarn", self, other, operator.truediv) - def __floordiv__(self, other: TTensor) -> "Tensor": - return Tensor(self.data // unwrap_tensor_data(other)) + def __floordiv__(self, other: Union[Tensor, float]) -> Tensor: + return _call_function("_binary_op_nowarn", self, other, operator.floordiv) - def __rfloordiv__(self, other: TTensor) -> "Tensor": - return Tensor(unwrap_tensor_data(other) // self.data) + def __rfloordiv__(self, other: Union[Tensor, float]) -> Tensor: + return _call_function("_binary_reverse_op_nowarn", self, other, operator.floordiv) - def __neg__(self) -> "Tensor": + def __neg__(self) -> Tensor: return Tensor(-self.data) # Comparison operators - def __lt__(self, other: TTensor) -> "Tensor": + def __lt__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data < unwrap_tensor_data(other)) - def __le__(self, other: TTensor) -> "Tensor": + def __le__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data <= unwrap_tensor_data(other)) - def __eq__(self, other: TTensor) -> "Tensor": + def __eq__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data == unwrap_tensor_data(other)) - def __ne__(self, other: TTensor) -> "Tensor": + def __ne__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data != unwrap_tensor_data(other)) - def __gt__(self, other: TTensor) -> "Tensor": + def __gt__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data > unwrap_tensor_data(other)) - def __ge__(self, other: TTensor) -> "Tensor": + def __ge__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data >= unwrap_tensor_data(other)) # Tensor functions - def squeeze(self, axis: Optional[Union[int, Tuple[int]]] = None) -> "Tensor": + def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: return _call_function("squeeze", self, axis) - def flatten(self) -> "Tensor": + def flatten(self) -> Tensor: return _call_function("flatten", self) - def max(self, axis: Optional[TTensor] = None) -> "Tensor": + def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: return _call_function("max", self, axis) - def min(self, axis: Optional[TTensor] = None) -> "Tensor": + def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: return _call_function("min", self, axis) - def abs(self) -> "Tensor": + def abs(self) -> Tensor: return _call_function("abs", self) - def isempty(self) -> "Tensor": + def isempty(self) -> bool: return _call_function("isempty", self) - def astype(self, dtype: TensorDataType): + def astype(self, dtype: TensorDataType) -> Tensor: return _call_function("astype", self, dtype) - def reshape(self, shape: TTensor) -> "Tensor": + def reshape(self, shape: Tuple[int, ...]) -> Tensor: return _call_function("reshape", self, shape) diff --git a/nncf/experimental/tensor/torch_functions.py b/nncf/experimental/tensor/torch_functions.py index 09ef0f1b886..273d5419781 100644 --- a/nncf/experimental/tensor/torch_functions.py +++ b/nncf/experimental/tensor/torch_functions.py @@ -9,13 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from nncf.experimental.tensor import TensorDataType from nncf.experimental.tensor import TensorDeviceType -from nncf.experimental.tensor import functions +from nncf.experimental.tensor import functions as fns DTYPE_MAP = { TensorDataType.float16: torch.float16, @@ -28,7 +28,7 @@ DTYPE_MAP_REV = {v: k for k, v in DTYPE_MAP.items()} -@functions.device.register(torch.Tensor) +@fns.device.register(torch.Tensor) def _(a: torch.Tensor) -> TensorDeviceType: DEVICE_MAP = { "cpu": TensorDeviceType.CPU, @@ -37,112 +37,162 @@ def _(a: torch.Tensor) -> TensorDeviceType: return DEVICE_MAP[a.device.type] -@functions.squeeze.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> torch.Tensor: +@fns.squeeze.register(torch.Tensor) +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: if axis is None: return a.squeeze() + if isinstance(axis, Tuple) and any(1 != a.shape[i] for i in axis): + # Make Numpy behavior, torch.squeeze skips axes that are not equal to one.. + raise ValueError("Cannot select an axis to squeeze out which has size not equal to one") return a.squeeze(axis) -@functions.flatten.register(torch.Tensor) +@fns.flatten.register(torch.Tensor) def _(a: torch.Tensor) -> torch.Tensor: return a.flatten() -@functions.max.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> torch.Tensor: +@fns.max.register(torch.Tensor) +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: + # Analog of numpy.max is torch.amax if axis is None: - return torch.max(a) - return torch.max(a, dim=axis).values + return torch.amax(a) + return torch.amax(a, dim=axis) -@functions.min.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> torch.Tensor: +@fns.min.register(torch.Tensor) +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: + # Analog of numpy.min is torch.amin if axis is None: - return torch.min(a) - return torch.min(a, dim=axis).values + return torch.amin(a) + return torch.amin(a, dim=axis) -@functions.abs.register(torch.Tensor) +@fns.abs.register(torch.Tensor) def _(a: torch.Tensor) -> torch.Tensor: return torch.absolute(a) -@functions.astype.register(torch.Tensor) +@fns.astype.register(torch.Tensor) def _(a: torch.Tensor, dtype: TensorDataType) -> torch.Tensor: return a.type(DTYPE_MAP[dtype]) -@functions.dtype.register(torch.Tensor) +@fns.dtype.register(torch.Tensor) def _(a: torch.Tensor) -> TensorDataType: return DTYPE_MAP_REV[a.dtype] -@functions.reshape.register(torch.Tensor) -def _(a: torch.Tensor, shape: List[int]) -> torch.Tensor: +@fns.reshape.register(torch.Tensor) +def _(a: torch.Tensor, shape: Tuple[int, ...]) -> torch.Tensor: return a.reshape(shape) -@functions.all.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Union[torch.Tensor, bool]: +@fns.all.register(torch.Tensor) +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[torch.Tensor, bool]: if axis is None: return torch.all(a) return torch.all(a, dim=axis) -@functions.allclose.register(torch.Tensor) -def _(a: torch.Tensor, b: torch.Tensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> bool: +@fns.allclose.register(torch.Tensor) +def _( + a: torch.Tensor, b: Union[torch.Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False +) -> bool: + if not isinstance(b, torch.Tensor): + b = torch.tensor(b, device=a.device) return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) -@functions.any.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> Union[torch.Tensor, bool]: +@fns.any.register(torch.Tensor) +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[torch.Tensor, bool]: if axis is None: return torch.any(a) return torch.any(a, dim=axis) -@functions.count_nonzero.register(torch.Tensor) -def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int]]] = None) -> torch.Tensor: +@fns.count_nonzero.register(torch.Tensor) +def _(a: torch.Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: return torch.count_nonzero(a, dim=axis) -@functions.isempty.register(torch.Tensor) +@fns.isempty.register(torch.Tensor) def _(a: torch.Tensor) -> bool: return a.numel() == 0 -@functions.isclose.register(torch.Tensor) -def _(a: torch.Tensor, b: torch.Tensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False): +@fns.isclose.register(torch.Tensor) +def _( + a: torch.Tensor, b: Union[torch.Tensor, float], rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False +): + if not isinstance(b, torch.Tensor): + b = torch.tensor(b, device=a.device) return torch.isclose(a, b, atol=atol, rtol=rtol, equal_nan=equal_nan) -@functions.maximum.register(torch.Tensor) -def _(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: +@fns.maximum.register(torch.Tensor) +def _(x1: torch.Tensor, x2: Union[torch.Tensor, float]) -> torch.Tensor: if not isinstance(x2, torch.Tensor): x2 = torch.tensor(x2, device=x1.data.device) return torch.maximum(x1, x2) -@functions.minimum.register(torch.Tensor) -def _(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: +@fns.minimum.register(torch.Tensor) +def _(x1: torch.Tensor, x2: Union[torch.Tensor, float]) -> torch.Tensor: if not isinstance(x2, torch.Tensor): x2 = torch.tensor(x2, device=x1.data.device) return torch.minimum(x1, x2) -@functions.ones_like.register(torch.Tensor) +@fns.ones_like.register(torch.Tensor) def _(a: torch.Tensor) -> torch.Tensor: return torch.ones_like(a) -@functions.where.register(torch.Tensor) +@fns.where.register(torch.Tensor) def _( condition: torch.Tensor, x: Union[torch.Tensor, float, bool], y: Union[torch.Tensor, float, bool] ) -> torch.Tensor: return torch.where(condition, x, y) -@functions.zeros_like.register(torch.Tensor) +@fns.zeros_like.register(torch.Tensor) def _(a: torch.Tensor) -> torch.Tensor: return torch.zeros_like(a) + + +@fns.stack.register(torch.Tensor) +def _(x: List[torch.Tensor], axis: int = 0) -> List[torch.Tensor]: + return torch.stack(x, dim=axis) + + +@fns.unstack.register(torch.Tensor) +def _(x: torch.Tensor, axis: int = 0) -> List[torch.Tensor]: + if not list(x.shape): + x = x.unsqueeze(0) + return torch.unbind(x, dim=axis) + + +@fns.moveaxis.register(torch.Tensor) +def _(a: torch.Tensor, source: Union[int, Tuple[int, ...]], destination: Union[int, Tuple[int, ...]]) -> torch.Tensor: + return torch.moveaxis(a, source, destination) + + +@fns.mean.register(torch.Tensor) +def _(a: torch.Tensor, axis: Union[int, Tuple[int, ...]] = None, keepdims: bool = False) -> torch.Tensor: + return torch.mean(a, axis=axis, keepdims=keepdims) + + +@fns.round.register(torch.Tensor) +def _(a: torch.Tensor, decimals=0) -> torch.Tensor: + return torch.round(a, decimals=decimals) + + +@fns._binary_op_nowarn.register(torch.Tensor) # pylint: disable=protected-access +def _(a: torch.Tensor, b: Union[torch.Tensor, float], operator_fn: Callable) -> torch.Tensor: + return operator_fn(a, b) + + +@fns._binary_reverse_op_nowarn.register(torch.Tensor) # pylint: disable=protected-access +def _(a: torch.Tensor, b: Union[torch.Tensor, float], operator_fn: Callable) -> torch.Tensor: + return operator_fn(b, a) diff --git a/nncf/onnx/quantization/quantizer_parameters.py b/nncf/onnx/quantization/quantizer_parameters.py index 71b3d976b50..e4470fdce1b 100644 --- a/nncf/onnx/quantization/quantizer_parameters.py +++ b/nncf/onnx/quantization/quantizer_parameters.py @@ -54,8 +54,8 @@ def convert_fq_params_to_onnx_params( if levels not in [255, 256]: raise ValueError("Can only export to INT8/UIN8 256-level ONNX Quantize/Dequantize pairs.") - input_low, input_high = parameters.input_low, parameters.input_high - output_low, output_high = parameters.output_low, parameters.output_high + input_low, input_high = parameters.input_low.data, parameters.input_high.data + output_low, output_high = parameters.output_low.data, parameters.output_high.data if not np.allclose(input_high, output_high) or not np.allclose(input_low, output_low): raise ValueError( "ONNX Quantize/Dequantize pairs only support input_high == output_high and input_low == output_low." diff --git a/nncf/openvino/graph/model_transformer.py b/nncf/openvino/graph/model_transformer.py index 19e43f4b131..16cad27bb65 100644 --- a/nncf/openvino/graph/model_transformer.py +++ b/nncf/openvino/graph/model_transformer.py @@ -249,10 +249,10 @@ def _convert_to_fp16(data): clip_data = np.clip(data, np.finfo(np.float16).min, np.finfo(np.float16).max) return clip_data.astype(np.float16) - input_low = _convert_to_fp16(fq_params.input_low) - input_high = _convert_to_fp16(fq_params.input_high) - output_low = _convert_to_fp16(fq_params.output_low) - output_high = _convert_to_fp16(fq_params.output_high) + input_low = _convert_to_fp16(fq_params.input_low.data) + input_high = _convert_to_fp16(fq_params.input_high.data) + output_low = _convert_to_fp16(fq_params.output_low.data) + output_high = _convert_to_fp16(fq_params.output_high.data) return input_low, input_high, output_low, output_high @staticmethod @@ -266,10 +266,10 @@ def _insert_fake_quantize_op( :param name_to_node_mapping: Mapping from node name to node instance. """ fq_params = transformation.quantizer_parameters - input_low = fq_params.input_low - input_high = fq_params.input_high - output_low = fq_params.output_low - output_high = fq_params.output_high + input_low = fq_params.input_low.data + input_high = fq_params.input_high.data + output_low = fq_params.output_low.data + output_high = fq_params.output_high.data levels = fq_params.levels node_name = transformation.target_point.target_node_name diff --git a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py index 7ed9182d004..65b058b612f 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from math import inf from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union from nncf import Dataset @@ -25,6 +26,9 @@ from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend +from nncf.experimental.common.tensor_statistics.statistical_functions import mean_per_channel +from nncf.experimental.tensor import Tensor +from nncf.experimental.tensor import functions as fns from nncf.quantization.algorithms.algorithm import Algorithm from nncf.quantization.algorithms.fast_bias_correction.backend import ALGO_BACKENDS @@ -167,9 +171,9 @@ def apply( output_name=sub_output_name, ) - bias_shift = self.reshape_bias_shift(bias_shift, bias_value, channel_axis) + bias_shift = self._reshape_bias_shift(bias_shift, bias_value, channel_axis) updated_bias = bias_value + bias_shift - magnitude = self._backend_entity.get_bias_shift_magnitude(bias_value, updated_bias) + magnitude = self._get_bias_shift_magnitude(bias_value, updated_bias) if magnitude < self.threshold: nncf_logger.debug(f"{node_name} bias would be changed") @@ -185,7 +189,22 @@ def apply( return transformed_model - def reshape_bias_shift(self, bias_shift: TTensor, bias_value: TTensor, channel_axis: int) -> TTensor: + @staticmethod + def _get_bias_shift_magnitude(current_bias_value: Tensor, updated_bias_value: Tensor) -> float: + """ + Calculates bias shift magnitude based on the current and updated values. + + :param current_bias_value: The original bias value. + :param updated_bias_value: The updated bias value. + :return: Magnitude between original and updated bias values. + """ + bias_shift_magnitude = inf + if fns.count_nonzero(current_bias_value == 0) == 0: + bias_shift_magnitude = fns.max(fns.abs((updated_bias_value - current_bias_value) / current_bias_value)) + return bias_shift_magnitude + + @staticmethod + def _reshape_bias_shift(bias_shift: Tensor, bias_value: Tensor, channel_axis: int) -> Tensor: """ Reshape bias_shift tensor in case of dimensions of bias_value is more then 1. @@ -198,7 +217,7 @@ def reshape_bias_shift(self, bias_shift: TTensor, bias_value: TTensor, channel_a if bias_value.ndim > 1: new_shape = [1] * bias_value.ndim new_shape[channel_axis] = bias_shift.shape[0] - bias_shift = self._backend_entity.reshape_tensor(bias_shift, new_shape) + bias_shift = bias_shift.reshape(new_shape) return bias_shift def _get_fp_inputs(self, statistic_points: StatisticPointsContainer, node_name: str) -> Tuple[List, List]: @@ -222,7 +241,7 @@ def input_filter_func(point): node_name, input_filter_func, self._algorithm_key ): statistics = tensor_collector.get_statistics() - input_fp.extend(statistics.mean_values) + input_fp.extend(Tensor(statistics.mean_values)) input_shape.extend(statistics.shape) return input_fp, input_shape @@ -245,7 +264,7 @@ def output_filter_func(point): for tensor_collector in statistic_points.get_algo_statistics_for_node( node_name, output_filter_func, self._algorithm_key ): - output_fp.extend(tensor_collector.get_statistics().mean_values) + output_fp.extend(Tensor(tensor_collector.get_statistics().mean_values)) return output_fp def _extract_submodel(self, model_transformer: ModelTransformer, node_name: str) -> TModel: @@ -299,8 +318,8 @@ def _get_bias_shift( engine = EngineFactory.create(model) raw_output = engine.infer(input_blob) q_outputs = self._backend_entity.process_model_output(raw_output, output_name) - q_outputs = self._backend_entity.tensor_processor.mean_per_channel(q_outputs, channel_axis).tensor - bias_shift = self._backend_entity.post_process_output_data(output_fp) - q_outputs + q_outputs = mean_per_channel(q_outputs, channel_axis) + bias_shift = fns.stack(output_fp) - q_outputs return bias_shift def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: diff --git a/nncf/quantization/algorithms/fast_bias_correction/backend.py b/nncf/quantization/algorithms/fast_bias_correction/backend.py index 38618fe2efa..6dde985c5d4 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/backend.py @@ -23,6 +23,7 @@ from nncf.common.tensor import NNCFTensor from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase from nncf.common.utils.registry import Registry +from nncf.experimental.tensor import Tensor TModel = TypeVar("TModel") TTensor = TypeVar("TTensor") @@ -31,13 +32,6 @@ class FastBiasCorrectionAlgoBackend(ABC): - @property - @abstractmethod - def tensor_processor(self): - """ - Returns backend-specific instance of the NNCFCollectorTensorProcessor. - """ - @staticmethod @abstractmethod def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> TargetPoint: @@ -120,7 +114,7 @@ def create_input_data( @staticmethod @abstractmethod - def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: TModel) -> np.ndarray: + def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: TModel) -> Tensor: """ Returns bias value in the NumPy format of provided node. @@ -156,7 +150,7 @@ def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: @staticmethod @abstractmethod - def process_model_output(raw_data: OutputType, output_name: str) -> NNCFTensor: + def process_model_output(raw_data: OutputType, output_name: str) -> Tensor: """ Returns backend-specific processed output from the model. @@ -176,37 +170,6 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: :return: Boolean indicating whether the node has a bias or not. """ - @staticmethod - @abstractmethod - def get_bias_shift_magnitude(current_bias_value: TTensor, updated_bias_value: TTensor) -> float: - """ - Calculates bias shift magnitude based on the current and updated values. - - :param current_bias_value: The original bias value. - :param updated_bias_value: The updated bias value. - :return: Magnitude between original and updated bias values. - """ - - @staticmethod - @abstractmethod - def post_process_output_data(data: List[TTensor]) -> TTensor: - """ - Convert data to backend specific type. - - :param data: List of data. - :return: Converted data. - """ - - @staticmethod - @abstractmethod - def reshape_tensor(data: TTensor, new_shape: List[int]) -> TTensor: - """ - Reshape tensor. - - :param data: Tensor. - :param new_shape: New shape. - """ - @staticmethod @abstractmethod def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]: diff --git a/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py b/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py index 02bb54106c6..d0646f6aeb2 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/onnx_backend.py @@ -18,6 +18,7 @@ from nncf.common.graph import NNCFNode from nncf.common.graph.transformations.commands import TargetType from nncf.common.utils.backend import BackendType +from nncf.experimental.tensor import Tensor from nncf.onnx.graph.node_utils import get_bias_value from nncf.onnx.graph.node_utils import is_any_weight_quantized from nncf.onnx.graph.node_utils import is_node_with_bias @@ -27,8 +28,6 @@ from nncf.onnx.graph.transformations.commands import ONNXNullBiasInsertionCommand from nncf.onnx.graph.transformations.commands import ONNXTargetPoint from nncf.onnx.statistics.collectors import ONNXMeanStatisticCollector -from nncf.onnx.statistics.collectors import ONNXNNCFCollectorTensorProcessor -from nncf.onnx.tensor import ONNXNNCFTensor from nncf.quantization.algorithms.fast_bias_correction.backend import ALGO_BACKENDS from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend @@ -39,10 +38,6 @@ class ONNXFastBiasCorrectionAlgoBackend(FastBiasCorrectionAlgoBackend): def types_to_insert_bias(self): return [] - @property - def tensor_processor(self) -> ONNXNNCFCollectorTensorProcessor: - return ONNXNNCFCollectorTensorProcessor - @staticmethod def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> ONNXTargetPoint: return ONNXTargetPoint(target_type, target_node_name, port_id) @@ -53,9 +48,9 @@ def create_bias_insertion_command(node: NNCFNode) -> ONNXNullBiasInsertionComman @staticmethod def create_bias_correction_command( - node: NNCFNode, bias_value: np.ndarray, nncf_graph: NNCFGraph + node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph ) -> ONNXBiasCorrectionCommand: - return create_bias_correction_command(node, bias_value) + return create_bias_correction_command(node, bias_value.data) @staticmethod def model_extraction_command(inputs: List[str], outputs: List[str]) -> ONNXModelExtractionCommand: @@ -76,27 +71,26 @@ def get_sub_input_output_names(subgraph: onnx.ModelProto) -> Tuple[str, str]: @staticmethod def create_input_data( - shape: Tuple[int], data: List[np.ndarray], input_name: str, channel_axis: int + shape: Tuple[int], data: List[Tensor], input_name: str, channel_axis: int ) -> Dict[str, np.array]: - blob = np.zeros(shape) + blob = np.zeros(shape, dtype=data[0].data.dtype) for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])): index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim)) - blob[index] = data[j] - blob = blob.astype(data[0].dtype) + blob[index] = data[j].data input_data = {input_name: blob} return input_data @staticmethod - def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: onnx.ModelProto) -> np.ndarray: - return get_bias_value(node, model) + def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: onnx.ModelProto) -> Tensor: + return Tensor(get_bias_value(node, model)) @staticmethod def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]: return 0, 0 @staticmethod - def process_model_output(raw_data: Dict, output_name: str) -> ONNXNNCFTensor: - return ONNXNNCFTensor(raw_data[output_name]) + def process_model_output(raw_data: Dict, output_name: str) -> Tensor: + return Tensor(raw_data[output_name]) @staticmethod def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: @@ -106,21 +100,6 @@ def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: return is_node_with_bias(node) - @staticmethod - def get_bias_shift_magnitude(current_bias_value: np.ndarray, updated_bias_value: np.ndarray) -> float: - bias_shift_magnitude = np.inf - if np.count_nonzero(current_bias_value == 0) == 0: - bias_shift_magnitude = np.max(np.abs((updated_bias_value - current_bias_value) / current_bias_value)) - return bias_shift_magnitude - - @staticmethod - def post_process_output_data(data: List[np.ndarray]) -> np.ndarray: - return np.array(data) - - @staticmethod - def reshape_tensor(data: np.ndarray, new_shape: List[int]) -> np.ndarray: - return data.reshape(new_shape) - @staticmethod def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]: return node.node_name, node.node_name diff --git a/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py b/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py index 61ebb5a695b..d2744da5864 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py @@ -19,6 +19,7 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.utils.backend import BackendType from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.tensor import Tensor from nncf.openvino.graph.metatypes.groups import FAKE_QUANTIZE_OPERATIONS from nncf.openvino.graph.node_utils import get_bias_value from nncf.openvino.graph.node_utils import is_node_with_bias @@ -26,28 +27,22 @@ from nncf.openvino.graph.transformations.commands import OVBiasCorrectionCommand from nncf.openvino.graph.transformations.commands import OVModelExtractionCommand from nncf.openvino.graph.transformations.commands import OVTargetPoint -from nncf.openvino.statistics.collectors import OVNNCFCollectorTensorProcessor from nncf.openvino.statistics.collectors import get_mean_statistic_collector -from nncf.openvino.tensor import OVNNCFTensor from nncf.quantization.algorithms.fast_bias_correction.backend import ALGO_BACKENDS from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend @ALGO_BACKENDS.register(BackendType.OPENVINO) class OVFastBiasCorrectionAlgoBackend(FastBiasCorrectionAlgoBackend): - @property - def tensor_processor(self) -> OVNNCFCollectorTensorProcessor: - return OVNNCFCollectorTensorProcessor - @staticmethod def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint: return OVTargetPoint(target_type, target_node_name, port_id) @staticmethod def create_bias_correction_command( - node: NNCFNode, bias_value: np.ndarray, nncf_graph: NNCFGraph + node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph ) -> OVBiasCorrectionCommand: - return OVCommandCreator.create_command_to_update_bias(node, bias_value, nncf_graph) + return OVCommandCreator.create_command_to_update_bias(node, bias_value.data, nncf_graph) @staticmethod def model_extraction_command(inputs: List[str], outputs: List[str]) -> OVModelExtractionCommand: @@ -68,19 +63,18 @@ def get_sub_input_output_names(subgraph: ov.Model) -> Tuple[str, str]: @staticmethod def create_input_data( - shape: Tuple[int], data: List[np.ndarray], input_name: str, channel_axis: int + shape: Tuple[int], data: List[Tensor], input_name: str, channel_axis: int ) -> Dict[str, np.ndarray]: - blob = np.zeros(shape) + blob = np.zeros(shape, dtype=data[0].data.dtype) for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])): index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim)) - blob[index] = data[j] - blob = blob.astype(data[0].dtype) + blob[index] = data[j].data input_data = {input_name: blob} return input_data @staticmethod - def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> np.ndarray: - return get_bias_value(node, nncf_graph, model) + def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> Tensor: + return Tensor(get_bias_value(node, nncf_graph, model)) @staticmethod def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]: @@ -97,28 +91,13 @@ def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: return weight_node.metatype in FAKE_QUANTIZE_OPERATIONS @staticmethod - def process_model_output(raw_data: Dict, output_name: str) -> OVNNCFTensor: - return OVNNCFTensor(raw_data[output_name]) + def process_model_output(raw_data: Dict, output_name: str) -> Tensor: + return Tensor(raw_data[output_name]) @staticmethod def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: return is_node_with_bias(node, nncf_graph) - @staticmethod - def get_bias_shift_magnitude(current_bias_value: np.ndarray, updated_bias_value: np.ndarray) -> float: - bias_shift_magnitude = np.inf - if np.count_nonzero(current_bias_value == 0) == 0: - bias_shift_magnitude = np.max(np.abs((updated_bias_value - current_bias_value) / current_bias_value)) - return bias_shift_magnitude - - @staticmethod - def post_process_output_data(data: List[np.ndarray]) -> np.ndarray: - return np.array(data) - - @staticmethod - def reshape_tensor(data: np.ndarray, new_shape: List[int]) -> np.ndarray: - return data.reshape(new_shape) - @staticmethod def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]: return node.node_name, node.node_name diff --git a/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py b/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py index cb9b0026e3f..193be8994d9 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py @@ -20,6 +20,7 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.utils.backend import BackendType from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.tensor import Tensor from nncf.quantization.algorithms.fast_bias_correction.backend import ALGO_BACKENDS from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend from nncf.torch.graph.transformations.command_creation import create_bias_correction_command @@ -31,8 +32,6 @@ from nncf.torch.model_analyzer import is_node_with_fused_bias from nncf.torch.model_analyzer import is_quantized_weights from nncf.torch.nncf_network import NNCFNetwork -from nncf.torch.tensor import PTNNCFTensor -from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor from nncf.torch.tensor_statistics.collectors import get_mean_statistic_collector @@ -43,10 +42,6 @@ class PTFastBiasCorrectionAlgoBackend(FastBiasCorrectionAlgoBackend): TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK, } - @property - def tensor_processor(self) -> PTNNCFCollectorTensorProcessor: - return PTNNCFCollectorTensorProcessor - @staticmethod def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION: @@ -57,9 +52,9 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) - @staticmethod def create_bias_correction_command( - node: NNCFNode, bias_value: np.ndarray, nncf_graph: NNCFGraph + node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph ) -> PTBiasCorrectionCommand: - return create_bias_correction_command(node, bias_value) + return create_bias_correction_command(node, bias_value.data) @staticmethod def model_extraction_command(inputs: List[str], outputs: List[str]) -> PTModelExtractionWithFusedBiasCommand: @@ -80,26 +75,24 @@ def get_sub_input_output_names(subgraph: NNCFNetwork) -> Tuple[str, str]: return None, None @staticmethod - def create_input_data( - shape: Tuple[int], data: List[torch.Tensor], input_name: str, channel_axis: int - ) -> torch.Tensor: - blob = torch.zeros(shape, dtype=data[0].dtype) + def create_input_data(shape: Tuple[int], data: List[Tensor], input_name: str, channel_axis: int) -> torch.Tensor: + blob = torch.zeros(shape, dtype=data[0].data.dtype, device=data[0].data.device) for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])): index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim)) - blob[index] = data[j] + blob[index] = data[j].data return blob @staticmethod - def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: NNCFNetwork) -> np.ndarray: - return get_fused_bias_value(node, model) + def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: NNCFNetwork) -> Tensor: + return Tensor(get_fused_bias_value(node, model)) @staticmethod def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]: return 0, 0 @staticmethod - def process_model_output(raw_data: Dict, output_name: str) -> PTNNCFTensor: - return PTNNCFTensor(raw_data) + def process_model_output(raw_data: Dict, output_name: str) -> Tensor: + return Tensor(raw_data) @staticmethod def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: @@ -109,21 +102,6 @@ def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: return is_node_with_fused_bias(node, nncf_graph) - @staticmethod - def get_bias_shift_magnitude(current_bias_value: torch.Tensor, updated_bias_value: torch.Tensor) -> float: - bias_shift_magnitude = torch.inf - if torch.count_nonzero(current_bias_value == 0) == 0: - bias_shift_magnitude = torch.max(torch.abs((updated_bias_value - current_bias_value) / current_bias_value)) - return bias_shift_magnitude - - @staticmethod - def post_process_output_data(data: List[torch.Tensor]) -> torch.Tensor: - return torch.Tensor(data) - - @staticmethod - def reshape_tensor(data: torch.Tensor, new_shape: List[int]) -> torch.Tensor: - return data.reshape(new_shape) - @staticmethod def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]: input_node_name = node.node_name diff --git a/nncf/quantization/algorithms/min_max/onnx_backend.py b/nncf/quantization/algorithms/min_max/onnx_backend.py index 47cf5695832..d3c1e25d0ae 100644 --- a/nncf/quantization/algorithms/min_max/onnx_backend.py +++ b/nncf/quantization/algorithms/min_max/onnx_backend.py @@ -103,7 +103,7 @@ def create_quantizer_insertion_command( quantizer_config: QuantizerConfig, parameters: FakeQuantizeParameters, ): - tensor_type = np.int8 if np.any(parameters.input_low < 0) else np.uint8 + tensor_type = np.int8 if np.any(parameters.input_low.data < 0) else np.uint8 if target_point.is_weight_target_point(): tensor_type = np.int8 # The weight is restricted to have only signed range nncf_input_node_next_nodes = ONNXMinMaxAlgoBackend._get_input_edges_mapping(nncf_graph) diff --git a/nncf/quantization/algorithms/min_max/openvino_backend.py b/nncf/quantization/algorithms/min_max/openvino_backend.py index 4ad4e309dc8..5412d42853d 100644 --- a/nncf/quantization/algorithms/min_max/openvino_backend.py +++ b/nncf/quantization/algorithms/min_max/openvino_backend.py @@ -138,7 +138,7 @@ def _get_reduction_axes_and_use_abs_max( else: raise NotImplementedError(f"Unsupported target point type {target_point.type}.") - # TODO (l-bat): Disable quantizer propogation through layout changing operations + # TODO (l-bat): Disable quantizer propagation through layout changing operations channel_axis = 1 # OpenVINO activations have channel first layout: [N, C, Z, Y, X] axes = get_channel_agnostic_reduction_axes([channel_axis], shape) return axes, use_abs_max diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index 0a8fe5778c5..d258411698e 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -11,7 +11,6 @@ from typing import Dict, List, Optional, Set, Tuple -import numpy as np import torch import nncf.torch.graph.operator_metatypes as om @@ -19,7 +18,6 @@ from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode from nncf.common.graph.layer_attributes import WeightedLayerAttributes -from nncf.common.graph.model_transformer import ModelTransformer from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.transformations.commands import TargetType from nncf.common.hardware.config import HWConfig @@ -38,7 +36,6 @@ from nncf.torch.graph.graph import PTTargetPoint from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand from nncf.torch.hardware.config import PTHWConfig -from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT from nncf.torch.quantization.init_range import PTRangeInitCollectorParams @@ -112,10 +109,6 @@ def hw_config(self) -> HWConfig: def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]: return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT - @staticmethod - def model_transformer(model: NNCFNetwork) -> ModelTransformer: - return PTModelTransformer(model) - @staticmethod def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION: @@ -139,10 +132,10 @@ def create_quantizer_insertion_command( def unify_statistics(statistics: List[PTMinMaxTensorStatistic]) -> PTMinMaxTensorStatistic: max_values, min_values = [], [] for statistic in statistics: - max_values.append(torch.tensor(statistic.max_values).flatten()) - min_values.append(torch.tensor(statistic.min_values).flatten()) - max_values = torch.max(torch.tensor(max_values)) - min_values = torch.min(torch.tensor(min_values)) + max_values.append(statistic.max_values.flatten()) + min_values.append(statistic.min_values.flatten()) + max_values = torch.amax(torch.stack(max_values), dim=0) + min_values = torch.amin(torch.stack(min_values), dim=0) return PTMinMaxTensorStatistic(min_values=min_values, max_values=max_values) @staticmethod @@ -279,13 +272,12 @@ def _create_quantizer( def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters) -> None: quantizer.eps = 0 if isinstance(quantizer, AsymmetricQuantizer): - quantizer.input_low = torch.nn.Parameter(torch.from_numpy(parameters.input_low)) - quantizer.input_range = torch.nn.Parameter( - torch.from_numpy(np.array(parameters.input_high - parameters.input_low)) - ) + quantizer.input_low = torch.nn.Parameter(parameters.input_low.data) + input_range = parameters.input_high - parameters.input_low + quantizer.input_range = torch.nn.Parameter(input_range.data) else: - quantizer.signed = np.any(parameters.input_low < 0) - quantizer.scale = torch.nn.Parameter(torch.from_numpy(parameters.input_high)) + quantizer.signed = bool(torch.any(parameters.input_low.data < 0)) + quantizer.scale = torch.nn.Parameter(parameters.input_high.data) @staticmethod def _create_quantizer_insertion_command( diff --git a/nncf/quantization/fake_quantize.py b/nncf/quantization/fake_quantize.py index f1744813c54..38b56c97019 100644 --- a/nncf/quantization/fake_quantize.py +++ b/nncf/quantization/fake_quantize.py @@ -21,6 +21,9 @@ from nncf.common.quantization.structs import QuantizerConfig from nncf.common.quantization.structs import QuantizerGroup from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic +from nncf.experimental.tensor import Tensor +from nncf.experimental.tensor import TensorDataType +from nncf.experimental.tensor import functions as fns @dataclass @@ -35,14 +38,14 @@ class FakeQuantizeParameters: :param levels: Number of quantization levels. """ - input_low: np.ndarray - input_high: np.ndarray - output_low: np.ndarray - output_high: np.ndarray + input_low: Tensor + input_high: Tensor + output_low: Tensor + output_high: Tensor levels: int -def fix_zero_filters_symmetric(max_values: np.ndarray, eps: float = 0.01) -> np.ndarray: +def fix_zero_filters_symmetric(max_values: Tensor, eps: float = 0.01) -> Tensor: """ Fixes zero filters for symmetric quantizer. @@ -50,14 +53,12 @@ def fix_zero_filters_symmetric(max_values: np.ndarray, eps: float = 0.01) -> np. :param eps: Correction coefficient. :return: Fixed the high quant number. """ - max_range = np.max(max_values) - lower_threshold = np.maximum(8e-5, eps * max_range) - return np.maximum(lower_threshold, max_values) + max_range = fns.max(max_values) + lower_threshold = fns.maximum(max_range * eps, 8e-5) + return fns.maximum(lower_threshold, max_values) -def fix_zero_filters_asymmetric( - min_values: np.ndarray, max_values: np.ndarray, eps: float = 1e-8 -) -> Tuple[np.ndarray, np.ndarray]: +def fix_zero_filters_asymmetric(min_values: Tensor, max_values: Tensor, eps: float = 1e-8) -> Tuple[Tensor, Tensor]: """ Fixes zero filters for asymmetric quantizer. @@ -69,20 +70,17 @@ def fix_zero_filters_asymmetric( level_high - fixed the high quant number """ ranges = max_values - min_values - ranges = ranges.flatten() if isinstance(ranges, np.ndarray) else np.array([ranges]) min_correction = 8e-4 - corrections = [ - (np.maximum(eps * rng, rng) - rng) * 0.5 if rng > min_correction else min_correction for rng in ranges - ] - corrections = np.array(corrections).reshape(max_values.shape) + corrections = fns.where(ranges > min_correction, (fns.maximum(eps * ranges, ranges) - ranges) * 0.5, min_correction) + level_low = min_values - corrections level_high = max_values + corrections return level_low, level_high def tune_range( - left_border: np.ndarray, right_border: np.ndarray, num_bits: int, unify_zp: bool = False -) -> Tuple[np.ndarray, np.ndarray]: + left_border: Tensor, right_border: Tensor, num_bits: int, unify_zp: bool = False +) -> Tuple[Tensor, Tensor]: """ Tunes asymmetric quantization range to unify the zero point of all channels if `unify_zp` is True, or sets zero quant precisely to zero value otherwise. @@ -101,22 +99,21 @@ def tune_range( if unify_zp: scale = (right_border - left_border) / level_high zero_point = -left_border / scale - avg_zpts = np.round(np.mean(zero_point)) - qval = np.ones_like(left_border) * avg_zpts + avg_zpts = fns.round(fns.mean(zero_point)) + qval = fns.ones_like(left_border) * avg_zpts else: s = level_high / (right_border - left_border) fval = -left_border * s - qval = np.round(fval) + qval = fns.round(fval) - with np.errstate(invalid="ignore", divide="ignore"): - ra = np.where(qval < level_high, qval / (qval - level_high) * right_border, left_border) - rb = np.where(qval > 0.0, (qval - level_high) / qval * left_border, right_border) + ra = fns.where(qval < level_high, qval / (qval - level_high) * right_border, left_border) + rb = fns.where(qval > 0.0, (qval - level_high) / qval * left_border, right_border) range_a = right_border - ra range_b = rb - left_border - mask = np.where(range_a > range_b, 1.0, 0.0) - inv_mask = np.abs(1.0 - mask) + mask = fns.where(range_a > range_b, 1.0, 0.0) + inv_mask = fns.abs(1.0 - mask) ra = mask * ra + inv_mask * left_border rb = inv_mask * rb + mask * right_border @@ -125,12 +122,12 @@ def tune_range( def symmetric_range( - min_values: np.ndarray, - max_values: np.ndarray, + min_values: Tensor, + max_values: Tensor, levels: int, quantizer_config: QuantizerConfig, q_group: QuantizerGroup, -) -> Tuple[np.ndarray, np.ndarray]: +) -> Tuple[Tensor, Tensor]: """ Calculates the numbers of the low and high quant for the symmetric quantization scheme. @@ -148,21 +145,23 @@ def symmetric_range( else: signed = quantizer_config.signedness_to_force is True level_low = ( - np.zeros_like(level_high) if np.all(min_values >= 0) and not signed else -level_high * levels / (levels - 2) + fns.zeros_like(level_high) + if fns.all(min_values >= 0) and not signed + else -level_high * levels / (levels - 2) ) - level_low = level_low.astype(np.float32) - level_high = level_high.astype(np.float32) + level_low = level_low.astype(TensorDataType.float32) + level_high = level_high.astype(TensorDataType.float32) return level_low, level_high def asymmetric_range( - min_values: np.ndarray, - max_values: np.ndarray, + min_values: Tensor, + max_values: Tensor, quantizer_config: QuantizerConfig, q_group: QuantizerGroup, unify_zp: bool = False, -) -> Tuple[np.ndarray, np.ndarray]: +) -> Tuple[Tensor, Tensor]: """ Calculates the numbers of the low and high quant for the asymmetric quantization scheme. @@ -176,15 +175,15 @@ def asymmetric_range( level_high - the high quant number """ level_low, level_high = fix_zero_filters_asymmetric(min_values, max_values) - level_low = np.where(level_low < 0.0, level_low, 0.0) - level_high = np.where(level_high > 0.0, level_high, 0.0) + level_low = fns.where(level_low < 0.0, level_low, 0.0) + level_high = fns.where(level_high > 0.0, level_high, 0.0) if unify_zp and q_group == QuantizerGroup.ACTIVATIONS: raise NotImplementedError("Unified zero point is not supported for activations.") level_low, level_high = tune_range(level_low, level_high, quantizer_config.num_bits, unify_zp=unify_zp) - level_low = level_low.astype(np.float32) - level_high = level_high.astype(np.float32) + level_low = level_low.astype(TensorDataType.float32) + level_high = level_high.astype(TensorDataType.float32) return level_low, level_high @@ -221,8 +220,8 @@ def calculate_quantizer_parameters( False - the full range is used. :return: Parameters of the FakeQuantize layer. """ - min_values = np.array(statistics.min_values).astype(np.float32) - max_values = np.array(statistics.max_values).astype(np.float32) + min_values = Tensor(statistics.min_values).astype(TensorDataType.float32) + max_values = Tensor(statistics.max_values).astype(TensorDataType.float32) if half_range: input_low, input_high, levels = _calculate_scaled_parameters( @@ -240,21 +239,20 @@ def calculate_quantizer_parameters( input_low, input_high = asymmetric_range(min_values, max_values, quantizer_config, quant_group) if not quantizer_config.per_channel: - input_low = np.squeeze(input_low) - input_high = np.squeeze(input_high) + input_low = fns.squeeze(input_low) + input_high = fns.squeeze(input_high) - input_low, input_high = np.array(input_low), np.array(input_high) output_low, output_high = input_low, input_high return FakeQuantizeParameters(input_low, input_high, output_low, output_high, levels) def _calculate_scaled_parameters( - min_values: np.ndarray, - max_values: np.ndarray, + min_values: Tensor, + max_values: Tensor, quantizer_config: QuantizerConfig, quant_group: QuantizerGroup, narrow_range: bool, -) -> Tuple[np.ndarray, np.ndarray, int]: +) -> Tuple[Tensor, Tensor, int]: """ Calculates FakeQuantize layer attributes scaled to effectively use a half range of the quantization range. diff --git a/tests/onnx/quantization/common.py b/tests/onnx/quantization/common.py index 1d3464882fe..01a916b61a5 100644 --- a/tests/onnx/quantization/common.py +++ b/tests/onnx/quantization/common.py @@ -16,6 +16,7 @@ import onnx from nncf import Dataset +from nncf.experimental.tensor import Tensor from nncf.onnx.graph.nncf_graph_builder import GraphConverter from nncf.onnx.graph.onnx_graph import ONNXGraph from nncf.onnx.statistics.statistics import ONNXMinMaxTensorStatistic @@ -32,10 +33,18 @@ def mock_collect_statistics(mocker): - get_statistics_value = ONNXMinMaxTensorStatistic(min_values=-1, max_values=1) + get_statistics_value = ONNXMinMaxTensorStatistic( + min_values=np.array(-1, dtype=np.float32), max_values=np.array(1, dtype=np.float32) + ) _ = mocker.patch( "nncf.quantization.fake_quantize.calculate_quantizer_parameters", - return_value=FakeQuantizeParameters(np.array(0), np.array(0), np.array(0), np.array(0), 256), + return_value=FakeQuantizeParameters( + Tensor(np.array(0, dtype=np.float32)), + Tensor(np.array(0, dtype=np.float32)), + Tensor(np.array(0, dtype=np.float32)), + Tensor(np.array(0, dtype=np.float32)), + 256, + ), ) _ = mocker.patch( "nncf.common.tensor_statistics.aggregator.StatisticsAggregator.collect_statistics", return_value=None diff --git a/tests/post_training/test_templates/test_calculate_quantizer_parameters.py b/tests/post_training/test_templates/test_calculate_quantizer_parameters.py index bffe249d064..ebcc5df7e75 100644 --- a/tests/post_training/test_templates/test_calculate_quantizer_parameters.py +++ b/tests/post_training/test_templates/test_calculate_quantizer_parameters.py @@ -19,6 +19,7 @@ from nncf.common.quantization.structs import QuantizationMode from nncf.common.quantization.structs import QuantizerConfig from nncf.common.quantization.structs import QuantizerGroup +from nncf.experimental.tensor import functions as fns from nncf.quantization.fake_quantize import FakeQuantizeParameters from nncf.quantization.fake_quantize import calculate_quantizer_parameters from tests.post_training.conftest import FQ_CALCULATED_PARAMETERS_PATH @@ -32,10 +33,10 @@ def compare_fq_parameters(ref_params, params): assert ref_params.input_high.shape == params.input_high.shape assert ref_params.output_low.shape == params.output_low.shape assert ref_params.output_high.shape == params.output_high.shape - assert np.allclose(ref_params.input_low, params.input_low) - assert np.allclose(ref_params.input_high, params.input_high) - assert np.allclose(ref_params.output_low, params.output_low) - assert np.allclose(ref_params.output_high, params.output_high) + assert fns.allclose(ref_params.input_low, params.input_low) + assert fns.allclose(ref_params.input_high, params.input_high) + assert fns.allclose(ref_params.output_low, params.output_low) + assert fns.allclose(ref_params.output_high, params.output_high) def get_test_reference_key(q_group, q_config, narrow_range, hf_range): diff --git a/tests/post_training/test_templates/test_fast_bias_correction.py b/tests/post_training/test_templates/test_fast_bias_correction.py index b972ce851cd..c4ea71d6551 100644 --- a/tests/post_training/test_templates/test_fast_bias_correction.py +++ b/tests/post_training/test_templates/test_fast_bias_correction.py @@ -67,7 +67,7 @@ def test_reshape_bias_shift(self, bias_value: list, bias_shift: list, channel_ax algo = FastBiasCorrection(subset_size=1, inplace_statistics=False) # pylint: disable=protected-access algo._backend_entity = self.get_backend() - new_bias_shift = algo.reshape_bias_shift(bias_shift, bias_value, channel_axis) + new_bias_shift = algo._reshape_bias_shift(bias_shift, bias_value, channel_axis) assert list(new_bias_shift.shape) == ref_shape @staticmethod diff --git a/tests/shared/test_templates/template_test_nncf_tensor.py b/tests/shared/test_templates/template_test_nncf_tensor.py index 9fff5e9de1c..461deb14fce 100644 --- a/tests/shared/test_templates/template_test_nncf_tensor.py +++ b/tests/shared/test_templates/template_test_nncf_tensor.py @@ -17,10 +17,11 @@ import pytest +from nncf.experimental.common.tensor_statistics import statistical_functions as s_fns from nncf.experimental.tensor import Tensor from nncf.experimental.tensor import TensorDataType from nncf.experimental.tensor import TensorDeviceType -from nncf.experimental.tensor import functions +from nncf.experimental.tensor import functions as fns TModel = TypeVar("TModel") TTensor = TypeVar("TTensor") @@ -68,6 +69,7 @@ def test_operators_tensor(self, op_name): assert res.dtype == res_nncf.data.dtype assert all(res == res_nncf.data) assert isinstance(res_nncf, Tensor) + assert res_nncf.device == nncf_tensor_a.device @pytest.mark.parametrize("op_name", OPERATOR_MAP.keys()) def test_operators_int(self, op_name): @@ -83,6 +85,7 @@ def test_operators_int(self, op_name): assert res.dtype == res_nncf.data.dtype assert all(res == res_nncf.data) assert isinstance(res_nncf, Tensor) + assert res_nncf.device == nncf_tensor_a.device @pytest.mark.parametrize("op_name", ("add", "sub", "mul", "truediv", "floordiv")) def test_operators_int_rev(self, op_name): @@ -98,6 +101,7 @@ def test_operators_int_rev(self, op_name): assert res.dtype == res_nncf.data.dtype assert all(res == res_nncf.data) assert isinstance(res_nncf, Tensor) + assert res_nncf.device == nncf_tensor_a.device @pytest.mark.parametrize("op_name", COMPARISON_OPERATOR_MAP.keys()) def test_comparison_tensor(self, op_name): @@ -150,6 +154,7 @@ def test_comparison_int_rev(self, op_name): ([[[[1], [2]], [[1], [2]]]], None, [[1, 2], [1, 2]]), ([[[[1], [2]], [[1], [2]]]], 0, [[[1], [2]], [[1], [2]]]), ([[[[1], [2]], [[1], [2]]]], -1, [[[1, 2], [1, 2]]]), + ([[[[1], [2]], [[1], [2]]]], (0, 3), [[1, 2], [1, 2]]), ), ) def test_squeeze(self, val, axis, ref): @@ -157,11 +162,23 @@ def test_squeeze(self, val, axis, ref): nncf_tensor = Tensor(tensor) ref_tensor = self.to_tensor(ref) res = nncf_tensor.squeeze(axis=axis) - if isinstance(ref, list): - assert functions.all(res == ref_tensor) - else: - assert res == ref_tensor assert isinstance(res, Tensor) + assert fns.allclose(res, ref_tensor) + assert res.device == nncf_tensor.device + + @pytest.mark.parametrize( + "val, axis, exception_type, exception_match", + ( + ([[[[1], [2]], [[1], [2]]]], (0, 1), ValueError, "not equal to one"), + ([[[[1], [2]], [[1], [2]]]], 42, IndexError, "out of"), + ([[[[1], [2]], [[1], [2]]]], (0, 42), IndexError, "out of"), + ), + ) + def test_squeeze_axis_error(self, val, axis, exception_type, exception_match): + tensor = self.to_tensor(val) + nncf_tensor = Tensor(tensor) + with pytest.raises(exception_type, match=exception_match): + nncf_tensor.squeeze(axis=axis) @pytest.mark.parametrize( "val, axis, ref", @@ -177,12 +194,10 @@ def test_fn_squeeze(self, val, axis, ref): tensor = self.to_tensor(val) nncf_tensor = Tensor(tensor) ref_tensor = self.to_tensor(ref) - res = functions.squeeze(nncf_tensor, axis=axis) - if isinstance(ref, list): - assert functions.all(res == ref_tensor) - else: - assert res == ref_tensor + res = fns.squeeze(nncf_tensor, axis=axis) assert isinstance(res, Tensor) + assert fns.allclose(res, ref_tensor) + assert res.device == nncf_tensor.device @pytest.mark.parametrize( "val,ref", @@ -197,31 +212,9 @@ def test_flatten(self, val, ref): nncf_tensor = Tensor(tensor) ref_tensor = self.to_tensor(ref) res = nncf_tensor.flatten() - if isinstance(ref, list): - assert all(res.data == ref_tensor) - else: - assert res.data == ref_tensor - assert isinstance(res, Tensor) - - @pytest.mark.parametrize( - "val, axis, ref", - ( - (1, None, 1), - ([1], None, 1), - ([[[[1], [2]], [[3], [4]]]], None, 4), - ([[1, 2], [3, 4]], 1, [2, 4]), - ), - ) - def test_max(self, val, axis, ref): - tensor = self.to_tensor(val) - nncf_tensor = Tensor(tensor) - ref_tensor = self.to_tensor(ref) - res = nncf_tensor.max(axis=axis) - if isinstance(ref, list): - assert all(res.data == ref_tensor) - else: - assert res.data == ref_tensor assert isinstance(res, Tensor) + assert fns.allclose(res, ref_tensor) + assert res.device == nncf_tensor.device @pytest.mark.parametrize( "val, axis, ref", @@ -236,12 +229,10 @@ def test_fn_max(self, val, axis, ref): tensor = self.to_tensor(val) nncf_tensor = Tensor(tensor) ref_tensor = self.to_tensor(ref) - res = functions.max(nncf_tensor, axis=axis) - if isinstance(ref, list): - assert all(res.data == ref_tensor) - else: - assert res.data == ref_tensor + res = fns.max(nncf_tensor, axis=axis) assert isinstance(res, Tensor) + assert fns.allclose(res, ref_tensor) + assert res.device == nncf_tensor.device @pytest.mark.parametrize( "val, axis, ref", @@ -256,30 +247,9 @@ def test_min(self, val, axis, ref): nncf_tensor = Tensor(self.to_tensor(val)) ref_tensor = self.to_tensor(ref) res = nncf_tensor.min(axis=axis) - if isinstance(ref, list): - assert all(res.data == ref_tensor) - else: - assert res.data == ref_tensor - assert isinstance(res, Tensor) - - @pytest.mark.parametrize( - "val, axis, ref", - ( - (1, None, 1), - ([1], None, 1), - ([[[[1], [2]], [[3], [4]]]], None, 1), - ([[1, 2], [3, 4]], 1, [1, 3]), - ), - ) - def test_fn_min(self, val, axis, ref): - nncf_tensor = Tensor(self.to_tensor(val)) - ref_tensor = self.to_tensor(ref) - res = functions.min(nncf_tensor, axis=axis) - if isinstance(ref, list): - assert all(res.data == ref_tensor) - else: - assert res.data == ref_tensor assert isinstance(res, Tensor) + assert fns.allclose(res, ref_tensor) + assert res.device == nncf_tensor.device @pytest.mark.parametrize( "val, ref", @@ -292,11 +262,9 @@ def test_abs(self, val, ref): nncf_tensor = Tensor(self.to_tensor(val)) nncf_ref_tensor = Tensor(self.to_tensor(ref)) res = nncf_tensor.abs() - if isinstance(ref, list): - assert all(res == nncf_ref_tensor) - else: - assert res == nncf_ref_tensor assert isinstance(res, Tensor) + assert fns.allclose(res, nncf_ref_tensor) + assert res.device == nncf_tensor.device @pytest.mark.parametrize( "val, ref", @@ -308,12 +276,10 @@ def test_abs(self, val, ref): def test_fn_abs(self, val, ref): nncf_tensor = Tensor(self.to_tensor(val)) nncf_ref_tensor = Tensor(self.to_tensor(ref)) - res = functions.abs(nncf_tensor) - if isinstance(ref, list): - assert all(res == nncf_ref_tensor) - else: - assert res == nncf_ref_tensor + res = fns.abs(nncf_tensor) assert isinstance(res, Tensor) + assert fns.allclose(res, nncf_ref_tensor) + assert res.device == nncf_tensor.device def test_getitem(self): arr = [0, 1, 2] @@ -321,6 +287,7 @@ def test_getitem(self): res = nncf_tensor[1] assert res == 1 assert isinstance(res, Tensor) + assert res.device == nncf_tensor.device def test_iter(self): arr = [0, 1, 2] @@ -341,67 +308,72 @@ def test_iter(self): ), ) def test_fn_count_nonzero(self, axis, ref): - tensor = self.to_tensor([[1, 2], [1, 0]]) + tensor = self.to_tensor([[1.0, 2.0], [1.0, 0.0]]) nncf_tensor = Tensor(tensor) ref_tensor = self.to_tensor(ref) - res = functions.count_nonzero(nncf_tensor, axis=axis) - if axis is None: - assert res.data == ref_tensor - else: - assert all(res.data == self.to_tensor(ref)) + res = fns.count_nonzero(nncf_tensor, axis=axis) + assert isinstance(res, Tensor) + assert fns.allclose(res.data, ref_tensor) + assert res.device == nncf_tensor.device def test_fn_zeros_like(self): tensor = self.to_tensor([1, 2]) nncf_tensor = Tensor(tensor) - res = functions.zeros_like(nncf_tensor) + res = fns.zeros_like(nncf_tensor) assert all(res == Tensor(tensor * 0)) assert isinstance(res, Tensor) + assert res.device == nncf_tensor.device def test_fn_maximum(self): tensor_a = Tensor(self.to_tensor([1, 2])) tensor_b = Tensor(self.to_tensor([2, 1])) tensor_ref = self.to_tensor([2, 2]) - res = functions.maximum(tensor_a, tensor_b) + res = fns.maximum(tensor_a, tensor_b) assert all(res.data == tensor_ref) assert isinstance(res, Tensor) + assert res.device == tensor_a.device def test_fn_maximum_list(self): tensor_a = Tensor(self.to_tensor([1, 2])) tensor_b = [2, 1] tensor_ref = self.to_tensor([2, 2]) - res = functions.maximum(tensor_a, tensor_b) + res = fns.maximum(tensor_a, tensor_b) assert all(res.data == tensor_ref) assert isinstance(res, Tensor) + assert res.device == tensor_a.device def test_fn_minimum(self): tensor_a = Tensor(self.to_tensor([1, 2])) tensor_b = Tensor(self.to_tensor([2, 1])) tensor_ref = self.to_tensor([1, 1]) - res = functions.minimum(tensor_a, tensor_b) + res = fns.minimum(tensor_a, tensor_b) assert all(res.data == tensor_ref) assert isinstance(res, Tensor) + assert res.device == tensor_a.device def test_fn_minimum_list(self): tensor_a = Tensor(self.to_tensor([1, 2])) tensor_b = [2, 1] tensor_ref = self.to_tensor([1, 1]) - res = functions.minimum(tensor_a, tensor_b) + res = fns.minimum(tensor_a, tensor_b) assert all(res.data == tensor_ref) assert isinstance(res, Tensor) + assert res.device == tensor_a.device def test_fn_ones_like(self): tensor_a = Tensor(self.to_tensor([1, 2])) tensor_ref = self.to_tensor([1, 1]) - res = functions.ones_like(tensor_a) + res = fns.ones_like(tensor_a) assert all(res.data == tensor_ref) assert isinstance(res, Tensor) + assert res.device == tensor_a.device @pytest.mark.parametrize( "val, axis, ref", @@ -414,12 +386,10 @@ def test_fn_ones_like(self): ) def test_fn_all(self, val, axis, ref): tensor = Tensor(self.to_tensor(val)) - res = functions.all(tensor, axis=axis) - if isinstance(ref, list): - assert all(res.data == self.to_tensor(ref)) - else: - assert res.data == self.to_tensor(ref) + res = fns.all(tensor, axis=axis) assert isinstance(res, Tensor) + assert fns.allclose(res.data, self.to_tensor(ref)) + assert res.device == tensor.device @pytest.mark.parametrize( "val, axis, ref", @@ -432,19 +402,19 @@ def test_fn_all(self, val, axis, ref): ) def test_fn_any(self, val, axis, ref): tensor = Tensor(self.to_tensor(val)) - res = functions.any(tensor, axis=axis) - if isinstance(ref, list): - assert all(res.data == self.to_tensor(ref)) - else: - assert res == ref + res = fns.any(tensor, axis=axis) + assert isinstance(res, Tensor) + assert fns.allclose(res.data, self.to_tensor(ref)) + assert res.device == tensor.device def test_fn_where(self): tensor = Tensor(self.to_tensor([1, -1])) tensor_ref = self.to_tensor([1, 0]) - res = functions.where(tensor > 0, 1, 0) + res = fns.where(tensor > 0, 1, 0) assert all(res.data == tensor_ref) assert isinstance(res, Tensor) + assert res.device == tensor.device @pytest.mark.parametrize( "val, ref", @@ -456,9 +426,9 @@ def test_fn_where(self): ) def test_fn_isempty(self, val, ref): tensor = Tensor(self.to_tensor(val)) - res = functions.isempty(tensor) + res = fns.isempty(tensor) assert res == ref - assert isinstance(res, Tensor) + assert isinstance(res, bool) @pytest.mark.parametrize( "val, ref", @@ -472,7 +442,7 @@ def test_isempty(self, val, ref): tensor = Tensor(self.to_tensor(val)) res = tensor.isempty() assert res == ref - assert isinstance(res, Tensor) + assert isinstance(res, bool) @pytest.mark.parametrize( "x1, x2, rtol, atol, ref", @@ -482,19 +452,19 @@ def test_isempty(self, val, ref): ([0.1], [0.10001], 0.1, None, True), ([0.1], [0.10001], None, 0.1, True), ([0.1], [0.20001], None, 0.1, False), + ([0.1], 0.1, None, None, True), ), ) def test_fn_allclose(self, x1, x2, rtol, atol, ref): tensor1 = Tensor(self.to_tensor(x1)) tensor2 = Tensor(self.to_tensor(x2)) if rtol is not None: - res = functions.allclose(tensor1, tensor2, rtol=rtol) + res = fns.allclose(tensor1, tensor2, rtol=rtol) elif atol is not None: - res = functions.allclose(tensor1, tensor2, atol=atol) + res = fns.allclose(tensor1, tensor2, atol=atol) else: - res = functions.allclose(tensor1, tensor2) + res = fns.allclose(tensor1, tensor2) assert res == ref - assert isinstance(res, Tensor) @pytest.mark.parametrize( "x1, x2, rtol, atol, ref", @@ -503,17 +473,18 @@ def test_fn_allclose(self, x1, x2, rtol, atol, ref): ([0.1], [0.10001], None, None, [False]), ([0.1], [0.10001], 0.1, None, [True]), ([0.1], [0.10001], None, 0.1, [True]), + ([0.1], 0.1, None, None, [True]), ), ) def test_fn_isclose(self, x1, x2, rtol, atol, ref): tensor1 = Tensor(self.to_tensor(x1)) tensor2 = Tensor(self.to_tensor(x2)) if rtol is not None: - res = functions.isclose(tensor1, tensor2, rtol=rtol) + res = fns.isclose(tensor1, tensor2, rtol=rtol) elif atol is not None: - res = functions.isclose(tensor1, tensor2, atol=atol) + res = fns.isclose(tensor1, tensor2, atol=atol) else: - res = functions.isclose(tensor1, tensor2) + res = fns.isclose(tensor1, tensor2) assert all(res == self.to_tensor(ref)) assert isinstance(res, Tensor) @@ -526,23 +497,202 @@ def test_astype(self): res = tensor.astype(TensorDataType.int8) assert isinstance(res, Tensor) assert res.dtype == TensorDataType.int8 + assert res.device == tensor.device def test_fn_astype(self): tensor = Tensor(self.to_tensor([1])) - res = functions.astype(tensor, TensorDataType.int8) + res = fns.astype(tensor, TensorDataType.int8) assert isinstance(res, Tensor) assert res.dtype == TensorDataType.int8 def test_reshape(self): tensor = Tensor(self.to_tensor([1, 1])) - assert tensor.shape == [2] - assert tensor.reshape([1, 2]).shape == [1, 2] + res = tensor.reshape((1, 2)) + assert tensor.shape == (2,) + assert res.shape == (1, 2) + assert res.device == tensor.device def test_fn_reshape(self): tensor = Tensor(self.to_tensor([1, 1])) - assert tensor.shape == [2] - assert functions.reshape(tensor, [1, 2]).shape == [1, 2] + res = fns.reshape(tensor, (1, 2)) + assert tensor.shape == (2,) + assert res.shape == (1, 2) + assert res.device == tensor.device def test_not_implemented(self): with pytest.raises(NotImplementedError, match="is not implemented for"): - functions.device({}, [1, 2]) + fns.device({}, [1, 2]) + + @pytest.mark.parametrize( + "x, axis, ref", + ( + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 0, + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 1, + [[0.8, 0.1], [0.2, 0.7], [0.2, 0.1]], + ), + ), + ) + def test_fn_unstack(self, x, axis, ref): + tensor = Tensor(self.to_tensor(x)) + ref = [self.to_tensor(r) for r in ref] + + res = fns.unstack(tensor, axis=axis) + + assert isinstance(res, list) + for i, _ in enumerate(ref): + assert all(res[i] == ref[i]) + assert res[i].device == tensor.device + + @pytest.mark.parametrize( + "x, axis, ref", + ( + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 0, + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 1, + [[0.8, 0.1], [0.2, 0.7], [0.2, 0.1]], + ), + ), + ) + def test_fn_stack(self, x, axis, ref): + list_tensor = [Tensor(self.to_tensor(i)) for i in x] + ref = self.to_tensor(ref) + + res = fns.stack(list_tensor, axis=axis) + + assert isinstance(res, Tensor) + assert fns.all(res.data == ref) + assert res.device == list_tensor[0].device + + def test_fn_moveaxis(self): + tensor = [[0, 0, 0], [0, 0, 0]] + tensor = Tensor(self.to_tensor(tensor)) + + res = fns.moveaxis(tensor, 0, -1) + + assert res.shape == (3, 2) + + @pytest.mark.parametrize( + "x, axis, keepdims, ref", + ( + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 0, + False, + [0.45, 0.45, 0.15], + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + 0, + True, + [[0.45, 0.45, 0.15]], + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + (0, 1), + True, + [[0.35]], + ), + ( + [[0.8, 0.2, 0.2], [0.1, 0.7, 0.1]], + None, + False, + 0.35, + ), + ), + ) + def test_fn_mean(self, x, axis, keepdims, ref): + tensor = Tensor(self.to_tensor(x)) + ref_tensor = self.to_tensor(ref) + + res = fns.mean(tensor, axis, keepdims) + + assert isinstance(res, Tensor) + assert fns.allclose(res.data, ref_tensor) + assert res.device == tensor.device + + @pytest.mark.parametrize( + "val, decimals, ref", + ( + (1.1, 0, 1.0), + ([1.1, 0.9], 0, [1.0, 1.0]), + ([1.11, 0.91], 1, [1.1, 0.9]), + ), + ) + def test_fn_round(self, val, decimals, ref): + tensor = Tensor(self.to_tensor(val)) + ref_tensor = self.to_tensor(ref) + + res = fns.round(tensor, decimals) + + assert isinstance(res, Tensor) + assert fns.allclose(res.data, ref_tensor) + assert res.device == tensor.device + + @pytest.mark.parametrize( + "val, axis, ref", + ( + ( + [[9.0, 9.0], [7.0, 1.0]], + 0, + [8.0, 5.0], + ), + ( + [[[9.0, 9.0], [0.0, 3.0]], [[5.0, 1.0], [7.0, 1.0]]], + 0, + [5.25, 3.5], + ), + ( + [[[9.0, 9.0], [0.0, 3.0]], [[5.0, 1.0], [7.0, 1.0]]], + 2, + [5.25, 3.5], + ), + ( + [ + [[[9.0, 6.0], [8.0, 5.0]], [[3.0, 9.0], [4.0, 6.0]]], + [[[3.0, 9.0], [9.0, 2.0]], [[2.0, 4.0], [2.0, 5.0]]], + ], + 0, + [6.25, 4.5], + ), + ( + [ + [[[9.0, 6.0], [8.0, 5.0]], [[3.0, 9.0], [4.0, 6.0]]], + [[[3.0, 9.0], [9.0, 2.0]], [[2.0, 4.0], [2.0, 5.0]]], + ], + 1, + [6.375, 4.375], + ), + ( + [ + [[[9.0, 6.0], [8.0, 5.0]], [[3.0, 9.0], [4.0, 6.0]]], + [[[3.0, 9.0], [9.0, 2.0]], [[2.0, 4.0], [2.0, 5.0]]], + ], + -1, + [5.0, 5.75], + ), + ), + ) + def test_fn_mean_per_channel(self, val, axis, ref): + tensor = Tensor(self.to_tensor(val)) + ref_tensor = self.to_tensor(ref) + res = s_fns.mean_per_channel(tensor, axis) + assert isinstance(res, Tensor) + assert fns.allclose(res, ref_tensor), f"{res.data}" + assert res.device == tensor.device + + @pytest.mark.parametrize("axis", (3, 4, -4, -5)) + def test_fn_mean_per_channel_incorrect_axis(self, axis): + tensor = Tensor(self.to_tensor([[[9.0, 9.0], [0.0, 3.0]], [[5.0, 1.0], [7.0, 1.0]]])) + with pytest.raises(ValueError, match="is out of bounds for array of dimension"): + s_fns.mean_per_channel(tensor, axis) diff --git a/tests/torch/ptq/test_calculation_quantizer_params.py b/tests/torch/ptq/test_calculation_quantizer_params.py index 3c0cf83a64e..cea666d0065 100644 --- a/tests/torch/ptq/test_calculation_quantizer_params.py +++ b/tests/torch/ptq/test_calculation_quantizer_params.py @@ -25,6 +25,8 @@ from nncf.common.quantization.structs import QuantizationPreset from nncf.common.quantization.structs import QuantizerConfig from nncf.common.quantization.structs import QuantizerGroup +from nncf.experimental.tensor import Tensor +from nncf.experimental.tensor import functions as fn from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend from nncf.quantization.fake_quantize import FakeQuantizeParameters @@ -54,10 +56,10 @@ class CaseSymParams: SYM_CASES = ( CaseSymParams( fq_params=FakeQuantizeParameters( - np.array(-0.49920455, dtype=np.float32), - np.array(0.49530452, dtype=np.float32), - np.array(-0.49920455, dtype=np.float32), - np.array(0.49530452, dtype=np.float32), + Tensor(torch.tensor(-0.49920455, dtype=torch.float32)), + Tensor(torch.tensor(0.49530452, dtype=torch.float32)), + Tensor(torch.tensor(-0.49920455, dtype=torch.float32)), + Tensor(torch.tensor(0.49530452, dtype=torch.float32)), 256, ), per_channel=False, @@ -66,10 +68,10 @@ class CaseSymParams: ), CaseSymParams( fq_params=FakeQuantizeParameters( - np.array(-0.49530452, dtype=np.float32), - np.array(0.49530452, dtype=np.float32), - np.array(-0.49530452, dtype=np.float32), - np.array(0.49530452, dtype=np.float32), + Tensor(torch.tensor(-0.49530452, dtype=torch.float32)), + Tensor(torch.tensor(0.49530452, dtype=torch.float32)), + Tensor(torch.tensor(-0.49530452, dtype=torch.float32)), + Tensor(torch.tensor(0.49530452, dtype=torch.float32)), 255, ), per_channel=False, @@ -78,33 +80,33 @@ class CaseSymParams: ), CaseSymParams( fq_params=FakeQuantizeParameters( - np.array([-0.4835594, -0.49530452, -0.49221927], dtype=np.float32).reshape(1, 3, 1, 1), - np.array([0.4797816, 0.49920455, 0.48837382], dtype=np.float32).reshape(1, 3, 1, 1), - np.array([-0.4835594, -0.49530452, -0.49221927], dtype=np.float32).reshape(1, 3, 1, 1), - np.array([0.4797816, 0.49920455, 0.48837382], dtype=np.float32).reshape(1, 3, 1, 1), + Tensor(torch.tensor([-0.4835594, -0.49530452, -0.49221927], dtype=torch.float32).reshape(1, 3, 1, 1)), + Tensor(torch.tensor([0.4797816, 0.49920455, 0.48837382], dtype=torch.float32).reshape(1, 3, 1, 1)), + Tensor(torch.tensor([-0.4835594, -0.49530452, -0.49221927], dtype=torch.float32).reshape(1, 3, 1, 1)), + Tensor(torch.tensor([0.4797816, 0.49920455, 0.48837382], dtype=torch.float32).reshape(1, 3, 1, 1)), 256, ), per_channel=True, quant_group=QuantizerGroup.ACTIVATIONS, - ref_scale=np.array([0.4797816, 0.49920455, 0.48837382]).reshape(1, 3, 1, 1), + ref_scale=torch.tensor([0.4797816, 0.49920455, 0.48837382]).reshape(1, 3, 1, 1), ), CaseSymParams( fq_params=FakeQuantizeParameters( - np.array([-0.48837382, -0.49530452], dtype=np.float32).reshape(2, 1, 1, 1), - np.array([0.48837382, 0.49530452], dtype=np.float32).reshape(2, 1, 1, 1), - np.array([-0.48837382, -0.49530452], dtype=np.float32).reshape(2, 1, 1, 1), - np.array([0.48837382, 0.49530452], dtype=np.float32).reshape(2, 1, 1, 1), + Tensor(torch.tensor([-0.48837382, -0.49530452], dtype=torch.float32).reshape(2, 1, 1, 1)), + Tensor(torch.tensor([0.48837382, 0.49530452], dtype=torch.float32).reshape(2, 1, 1, 1)), + Tensor(torch.tensor([-0.48837382, -0.49530452], dtype=torch.float32).reshape(2, 1, 1, 1)), + Tensor(torch.tensor([0.48837382, 0.49530452], dtype=torch.float32).reshape(2, 1, 1, 1)), 255, ), per_channel=True, quant_group=QuantizerGroup.WEIGHTS, - ref_scale=np.array([0.48837382, 0.49530452]).reshape(2, 1, 1, 1), + ref_scale=torch.tensor([0.48837382, 0.49530452]).reshape(2, 1, 1, 1), ), ) @pytest.mark.parametrize("case_to_test", SYM_CASES) -def test_quantizer_params_sym(case_to_test): +def test_quantizer_params_sym(case_to_test: CaseSymParams): per_ch = case_to_test.per_channel fq_params = case_to_test.fq_params quant_group = case_to_test.quant_group @@ -140,10 +142,10 @@ class CaseAsymParams: ASYM_CASES = ( CaseAsymParams( fq_params=FakeQuantizeParameters( - np.array(-0.49530452, dtype=np.float32), - np.array(0.49143496, dtype=np.float32), - np.array(-0.49530452, dtype=np.float32), - np.array(0.49143496, dtype=np.float32), + Tensor(torch.tensor(-0.49530452, dtype=torch.float32)), + Tensor(torch.tensor(0.49143496, dtype=torch.float32)), + Tensor(torch.tensor(-0.49530452, dtype=torch.float32)), + Tensor(torch.tensor(0.49143496, dtype=torch.float32)), 256, ), per_channel=False, @@ -153,10 +155,10 @@ class CaseAsymParams: ), CaseAsymParams( fq_params=FakeQuantizeParameters( - np.array(-0.49530452, dtype=np.float32), - np.array(0.49143496, dtype=np.float32), - np.array(-0.49530452, dtype=np.float32), - np.array(0.49143496, dtype=np.float32), + Tensor(torch.tensor(-0.49530452, dtype=torch.float32)), + Tensor(torch.tensor(0.49143496, dtype=torch.float32)), + Tensor(torch.tensor(-0.49530452, dtype=torch.float32)), + Tensor(torch.tensor(0.49143496, dtype=torch.float32)), 256, ), per_channel=False, @@ -166,35 +168,35 @@ class CaseAsymParams: ), CaseAsymParams( fq_params=FakeQuantizeParameters( - np.array([-0.48051512, -0.49776307, -0.44099426], dtype=np.float32).reshape(1, 3, 1, 1), - np.array([0.4767611, 0.47861832, 0.48837382], dtype=np.float32).reshape(1, 3, 1, 1), - np.array([-0.48051512, -0.49776307, -0.44099426], dtype=np.float32).reshape(1, 3, 1, 1), - np.array([0.4767611, 0.47861832, 0.48837382], dtype=np.float32).reshape(1, 3, 1, 1), + Tensor(torch.tensor([-0.48051512, -0.49776307, -0.44099426], dtype=torch.float32).reshape(1, 3, 1, 1)), + Tensor(torch.tensor([0.4767611, 0.47861832, 0.48837382], dtype=torch.float32).reshape(1, 3, 1, 1)), + Tensor(torch.tensor([-0.48051512, -0.49776307, -0.44099426], dtype=torch.float32).reshape(1, 3, 1, 1)), + Tensor(torch.tensor([0.4767611, 0.47861832, 0.48837382], dtype=torch.float32).reshape(1, 3, 1, 1)), 256, ), per_channel=True, quant_group=QuantizerGroup.ACTIVATIONS, - ref_inp_low=np.array([-0.48051512, -0.49776307, -0.44099426]).reshape(1, 3, 1, 1), - ref_inp_range=np.array([0.9572762, 0.9763814, 0.9293681]).reshape(1, 3, 1, 1), + ref_inp_low=torch.tensor([-0.48051512, -0.49776307, -0.44099426]).reshape(1, 3, 1, 1), + ref_inp_range=torch.tensor([0.9572762, 0.9763814, 0.9293681]).reshape(1, 3, 1, 1), ), CaseAsymParams( fq_params=FakeQuantizeParameters( - np.array([-0.4845584, -0.49583155], dtype=np.float32).reshape(2, 1, 1, 1), - np.array([0.48837382, 0.4767611], dtype=np.float32).reshape(2, 1, 1, 1), - np.array([-0.4845584, -0.49583155], dtype=np.float32).reshape(2, 1, 1, 1), - np.array([0.48837382, 0.4767611], dtype=np.float32).reshape(2, 1, 1, 1), + Tensor(torch.tensor([-0.4845584, -0.49583155], dtype=torch.float32).reshape(2, 1, 1, 1)), + Tensor(torch.tensor([0.48837382, 0.4767611], dtype=torch.float32).reshape(2, 1, 1, 1)), + Tensor(torch.tensor([-0.4845584, -0.49583155], dtype=torch.float32).reshape(2, 1, 1, 1)), + Tensor(torch.tensor([0.48837382, 0.4767611], dtype=torch.float32).reshape(2, 1, 1, 1)), 256, ), per_channel=True, quant_group=QuantizerGroup.WEIGHTS, - ref_inp_low=np.array([-0.4845584, -0.49583155]).reshape(2, 1, 1, 1), - ref_inp_range=np.array([0.97293222, 0.97259265]).reshape(2, 1, 1, 1), + ref_inp_low=torch.tensor([-0.4845584, -0.49583155]).reshape(2, 1, 1, 1), + ref_inp_range=torch.tensor([0.97293222, 0.97259265]).reshape(2, 1, 1, 1), ), ) @pytest.mark.parametrize("case_to_test", ASYM_CASES) -def test_quantizer_params_asym(case_to_test): +def test_quantizer_params_asym(case_to_test: CaseSymParams): per_ch = case_to_test.per_channel fq_params = case_to_test.fq_params quant_group = case_to_test.quant_group @@ -212,8 +214,8 @@ def test_quantizer_params_asym(case_to_test): ) quantizer = PTMinMaxAlgoBackend._create_quantizer(qconfig, scale_shape, fq_params, target_type) assert quantizer.levels == fq_params.levels - assert np.allclose(quantizer.input_low.detach().numpy(), case_to_test.ref_inp_low) - assert np.allclose(quantizer.input_range.detach().numpy(), case_to_test.ref_inp_range) + assert fn.allclose(quantizer.input_low.data, case_to_test.ref_inp_low) + assert fn.allclose(quantizer.input_range.data, case_to_test.ref_inp_range) class LinearTestModel(nn.Module): @@ -270,10 +272,7 @@ def calculate_statistics(data, mode, qgroup, half_range=False): else: max_values = np.amax(data, axes) - statistics = PTMinMaxTensorStatistic( - min_values=torch.from_numpy(np.array(min_values)), - max_values=torch.from_numpy(np.array(max_values)), - ) + statistics = PTMinMaxTensorStatistic(min_values=torch.tensor(min_values), max_values=torch.tensor(max_values)) signedness_to_force = True if qgroup == QuantizerGroup.WEIGHTS else None qconfig = QuantizerConfig(num_bits=8, mode=mode, per_channel=per_ch, signedness_to_force=signedness_to_force) narrow_range = get_quantizer_narrow_range(qconfig, qgroup) @@ -346,8 +345,8 @@ def test_quantizer_parameters_export(tmp_path: Path): for name, param in fq_params.items(): assert name in torch_ptq_params - assert np.allclose(param["input_low"], torch_ptq_params[name]["input_low"]) - assert np.allclose(param["input_high"], torch_ptq_params[name]["input_high"]) + assert fn.allclose(param["input_low"], torch_ptq_params[name]["input_low"]) + assert fn.allclose(param["input_high"], torch_ptq_params[name]["input_high"]) class TestFQParams(TemplateTestFQParams): diff --git a/tests/torch/ptq/test_fast_bias_correction.py b/tests/torch/ptq/test_fast_bias_correction.py index b713aeb802c..7f5639aaeba 100644 --- a/tests/torch/ptq/test_fast_bias_correction.py +++ b/tests/torch/ptq/test_fast_bias_correction.py @@ -59,3 +59,30 @@ def check_bias(model: NNCFNetwork, ref_bias: list): assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}" return raise ValueError("Not found node with bias") + + +class TestTorchCudaFBCAlgorithm(TestTorchFBCAlgorithm): + @staticmethod + def list_to_backend_type(data: List) -> torch.Tensor: + return torch.Tensor(data).cuda() + + @staticmethod + def backend_specific_model(model: bool, tmp_dir: str): + return get_nncf_network(model.cuda(), model.INPUT_SIZE) + + @staticmethod + def fn_to_type(tensor): + return torch.Tensor(tensor).cuda() + + @staticmethod + def check_bias(model: NNCFNetwork, ref_bias: list): + ref_bias = torch.Tensor(ref_bias) + nncf_graph = NNCFGraphFactory.create(model) + for node in nncf_graph.get_all_nodes(): + if not is_node_with_fused_bias(node, nncf_graph): + continue + bias_value = get_fused_bias_value(node, model).cpu() + # TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189 + assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}" + return + raise ValueError("Not found node with bias") From 48f87237ff52f58c944b2ab7ce8c08fce079f062 Mon Sep 17 00:00:00 2001 From: Aleksei Kashapov Date: Tue, 10 Oct 2023 07:09:13 +0200 Subject: [PATCH 10/10] [ONNX] Remove ONNXGraph (#2173) ### Changes Remove ONNXGraph and place all its methods into module onnx_helper.py To optimize performance four mappings are introduced for ONNX model manipulation. `get_node_mapping(model: onnx.ModelProto) -> Dict[str, onnx.NodeProto]` Mapping from node name to the corresponding node. This needs to not iterate through all nodes of a model. `get_edge_info_mapping(model: onnx.ModelProto) -> Dict[str, onnx.ValueInfoProto]` Mapping from edge name to corresponding edge information. This needs to not iterate through all edge infos of a model. `get_children_node_mapping(model: onnx.ModelProto) -> Dict[str, List[onnx.NodeProto]]` Mapping from edge name and corresponding nodes that consume that edge as an input. Used to traverse forwardwith the optimal performance. `get_parents_node_mapping(model: onnx.ModelProto) -> Dict[str, onnx.NodeProto]` Mapping from edge name to node which outputs this edge. Used to traverse backward with the optimal performance. Locally measured perf difference after removing ONNXGraph. Model | PR time (sec) | develop time (sec) | SpeedUp -- | -- | -- | -- swinv2_cr_tiny_224| 91.434 | 105.73 | 15.64% visformer_small | 57.265 | 59.097 | 3.2% deit3_small_patch16_224 | 52.31 | 55.503 | 6.1% ### Reason for changes Code refactor ### Related tickets 96982 ### Tests N/A --- nncf/onnx/graph/metatypes/onnx_metatypes.py | 29 +- nncf/onnx/graph/model_transformer.py | 82 +++-- nncf/onnx/graph/nncf_graph_builder.py | 146 ++++---- nncf/onnx/graph/node_utils.py | 18 +- nncf/onnx/graph/onnx_graph.py | 321 ------------------ nncf/onnx/graph/onnx_helper.py | 290 ++++++++++++++++ nncf/onnx/statistics/aggregator.py | 20 +- .../bias_correction/onnx_backend.py | 12 +- tests/onnx/quantization/common.py | 12 +- .../test_qdq_params_calculation.py | 9 +- tests/onnx/test_model_transformer.py | 22 +- tests/onnx/weightless_model.py | 5 +- 12 files changed, 490 insertions(+), 476 deletions(-) delete mode 100644 nncf/onnx/graph/onnx_graph.py create mode 100644 nncf/onnx/graph/onnx_helper.py diff --git a/nncf/onnx/graph/metatypes/onnx_metatypes.py b/nncf/onnx/graph/metatypes/onnx_metatypes.py index 35d532caeac..2644658f8ee 100644 --- a/nncf/onnx/graph/metatypes/onnx_metatypes.py +++ b/nncf/onnx/graph/metatypes/onnx_metatypes.py @@ -9,14 +9,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Type +from typing import Dict, List, Optional, Type import onnx from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.operator_metatypes import OperatorMetatypeRegistry from nncf.common.hardware.opset import HWConfigOpName -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_parent +from nncf.onnx.graph.onnx_helper import get_parents_node_mapping +from nncf.onnx.graph.onnx_helper import get_tensor +from nncf.onnx.graph.onnx_helper import has_tensor ONNX_OPERATION_METATYPES = OperatorMetatypeRegistry("onnx_operator_metatypes") @@ -648,7 +651,12 @@ def get_metatype(model: onnx.ModelProto, node: onnx.NodeProto) -> ONNXOpMetatype return metatype -def get_tensor_edge_name(onnx_graph: ONNXGraph, node: onnx.NodeProto, port_id: int) -> Optional[str]: +def get_tensor_edge_name( + model: onnx.ModelProto, + node: onnx.NodeProto, + port_id: int, + parents_node_mapping: Dict[str, onnx.NodeProto], +) -> Optional[str]: """ Returns an edge name associated with a weight of a node laying on an input port_id. @@ -665,9 +673,10 @@ def get_tensor_edge_name(onnx_graph: ONNXGraph, node: onnx.NodeProto, port_id: i ONNXTransposeMetatype ONNXQuantizeLinearMetatype - :param onnx_graph: ONNXGraph. + :param model: ONNX model. :param node: Node. :param port_id: Port id on which a weight edge is seeking. + :param parents_node_mapping: Mapping from edge name to node which outputs this edge. :return: Edge name associated with a weight. """ PROPAGATING_NODES = ( @@ -678,14 +687,14 @@ def get_tensor_edge_name(onnx_graph: ONNXGraph, node: onnx.NodeProto, port_id: i + ONNXDequantizeLinearMetatype.get_all_aliases() ) END_NODES = ONNXConstantMetatype.get_all_aliases() - parent = onnx_graph.get_parent(node, port_id) + parent = get_parent(node, port_id, parents_node_mapping) if not parent: - if onnx_graph.has_tensor(node.input[port_id]): + if has_tensor(model, node.input[port_id]): return node.input[port_id] elif parent.op_type in END_NODES: return node.input[port_id] elif parent.op_type in PROPAGATING_NODES: - return get_tensor_edge_name(onnx_graph, parent, 0) + return get_tensor_edge_name(model, parent, 0, parents_node_mapping) return None @@ -734,12 +743,12 @@ def _is_embedding(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: :return: True if the layer is embedding, False - otherwise. """ tensor_port_id = ONNXEmbeddingMetatype.weight_port_ids[0] - onnx_graph = ONNXGraph(model) allowed_types_list = ["TensorProto.FLOAT"] - weight_edge_name = get_tensor_edge_name(onnx_graph, node, tensor_port_id) + parents_node_mapping = get_parents_node_mapping(model) + weight_edge_name = get_tensor_edge_name(model, node, tensor_port_id, parents_node_mapping) if weight_edge_name is not None: - tensor_data_type = onnx_graph.get_tensor(weight_edge_name).data_type + tensor_data_type = get_tensor(model, weight_edge_name).data_type if onnx.helper.tensor_dtype_to_string(tensor_data_type) in allowed_types_list: return True return False diff --git a/nncf/onnx/graph/model_transformer.py b/nncf/onnx/graph/model_transformer.py index a8f5355babf..b6db3d36b0d 100644 --- a/nncf/onnx/graph/model_transformer.py +++ b/nncf/onnx/graph/model_transformer.py @@ -19,7 +19,13 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout from nncf.onnx.graph.node_utils import get_input_edge -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_children +from nncf.onnx.graph.onnx_helper import get_children_node_mapping +from nncf.onnx.graph.onnx_helper import get_edge_dtype +from nncf.onnx.graph.onnx_helper import get_edge_info_mapping +from nncf.onnx.graph.onnx_helper import get_name_to_node_map +from nncf.onnx.graph.onnx_helper import get_node_index +from nncf.onnx.graph.onnx_helper import get_tensor from nncf.onnx.graph.transformations.commands import ONNXBiasCorrectionCommand from nncf.onnx.graph.transformations.commands import ONNXModelExtractionCommand from nncf.onnx.graph.transformations.commands import ONNXOutputInsertionCommand @@ -40,7 +46,8 @@ class ONNXModelTransformer(ModelTransformer): ZERO_POINT_NAME_PREFIX = "zero_point_" def __init__(self, model: onnx.ModelProto): - super().__init__(model) + infered_model = onnx.shape_inference.infer_shapes(model) + super().__init__(infered_model) self.onnx_model_extractor = onnx.utils.Extractor(self._model) def _get_target_edge( @@ -48,7 +55,7 @@ def _get_target_edge( port_id: int, node_name: str, transform_type: TargetType, - onnx_graph: ONNXGraph, + node_mapping: Dict[str, onnx.NodeProto], input_edges_mapping: Dict[str, str], ) -> str: """ @@ -57,16 +64,16 @@ def _get_target_edge( :param port_id: Edge number of port. :param node_name: Node name. :param transform_type: Type of transformation. - :param onnx_graph: ONNXGraph. + :param node_mapping: Mapping from a node name to the node. :param input_edges_mapping: Mapping between NNCF Input nodes and - the following ONNX nodes and corresponding input port id. + the following ONNX nodes and corresponding input port id. :return: Target edge name. """ if transform_type in [TargetType.PRE_LAYER_OPERATION, TargetType.OPERATION_WITH_WEIGHTS]: - return onnx_graph.get_node_edge_names(node_name)["input"][port_id] + return node_mapping[node_name].input[port_id] if node_name in input_edges_mapping: # ADD INPUT NODE CASE - return get_input_edge(node_name, input_edges_mapping, onnx_graph) - return onnx_graph.get_node_edge_names(node_name)["output"][port_id] + return get_input_edge(node_name, input_edges_mapping, node_mapping) + return node_mapping[node_name].output[port_id] def transform(self, transformation_layout: TransformationLayout) -> onnx.ModelProto: """ @@ -123,15 +130,15 @@ def _apply_output_insertion_transformations( :param transformations: ONNXOutputInsertionCommand transformations. :return: New model with inserted outputs. """ - onnx_graph = ONNXGraph(self._model) - model_outputs = set(output.name for output in onnx_graph.get_model_outputs()) + model_outputs = set(output.name for output in self._model.graph.output) + node_mapping = get_name_to_node_map(self._model) for transformation in transformations: port_id = transformation.target_point.port_id node_name = transformation.target_point.target_node_name transform_type = transformation.target_point.type input_edges_mapping = transformation.input_edges_mapping target_edge_name = self._get_target_edge( - port_id, node_name, transform_type, onnx_graph, input_edges_mapping + port_id, node_name, transform_type, node_mapping, input_edges_mapping ) model_outputs.add(target_edge_name) @@ -146,11 +153,11 @@ def _insert_outputs(model: onnx.ModelProto, outputs: Union[List[str], Set[str]]) :param outputs: Edge names to use as outputs. :return: New model with inserted outputs. """ - onnx_graph = ONNXGraph(model) model_outputs = [] + edge_info_mapping = get_edge_info_mapping(model) for output in outputs: - edge = onnx_graph.get_edge(output) - onnx_dtype = ONNXGraph.get_edge_dtype(edge) + edge = edge_info_mapping[output] + onnx_dtype = get_edge_dtype(edge) type_proto = onnx.helper.make_tensor_type_proto(onnx_dtype, shape=None) model_outputs.append(onnx.helper.make_value_info(name=output, type_proto=type_proto)) @@ -193,7 +200,8 @@ def _apply_quantizer_insertion_transformations( """ self._added_target_edges = Counter() for transformation in transformations: - model = self._insert_quantizer_dequantizer(model, transformation) + children_node_mapping = get_children_node_mapping(model) + model = self._insert_quantizer_dequantizer(model, transformation, children_node_mapping) return model def _get_quantize_dequantize_nodes( @@ -274,35 +282,39 @@ def _get_scale_zero_point_tensors( return onnx_scale_tensor, onnx_zero_point_tensor def _get_quantizer_dequantizer_edge_name( - self, transformation: ONNXQuantizerInsertionCommand, onnx_graph: ONNXGraph + self, transformation: ONNXQuantizerInsertionCommand, node_mapping: Dict[str, onnx.NodeProto] ) -> str: """ Returns an edge name on which QuantizeLinear-DequantizeLinear nodes pair has to be inserted. :param transformation: QuantizeLinear-DequantizeLinear insertion transformation. - :param onnx_graph: ONNXGraph. + :param node_mapping: Mapping from a node name to the node. :return: Edge name to insert QuantizeLinear-DequantizeLinear nodes pair. """ port_id = transformation.target_point.port_id node_name = transformation.target_point.target_node_name transform_type = transformation.target_point.type input_edges_mapping = transformation.input_edges_mapping - target_edge_name = self._get_target_edge(port_id, node_name, transform_type, onnx_graph, input_edges_mapping) + target_edge_name = self._get_target_edge(port_id, node_name, transform_type, node_mapping, input_edges_mapping) self._added_target_edges[target_edge_name] += 1 return target_edge_name def _insert_quantizer_dequantizer( - self, model: onnx.ModelProto, transformation: ONNXQuantizerInsertionCommand + self, + model: onnx.ModelProto, + transformation: ONNXQuantizerInsertionCommand, + children_node_mapping: Dict[str, List[onnx.ValueInfoProto]], ) -> onnx.ModelProto: """ Inserts QuantizeLinear-DequantizeLinear nodes pair. :param model: Model to insert new nodes. :param transformation: QuantizeLinear-DequantizeLinear insertion transformation. + :param children_node_mapping: Mapping from edge name to nodes which consume this edge as an input. :return: Updated model with inserted QuantizeLinear-DequantizeLinear pair. """ - onnx_graph = ONNXGraph(model) - target_edge_name = self._get_quantizer_dequantizer_edge_name(transformation, onnx_graph) + node_mapping = get_name_to_node_map(model) + target_edge_name = self._get_quantizer_dequantizer_edge_name(transformation, node_mapping) quantizer, dequantizer = self._get_quantize_dequantize_nodes(transformation, target_edge_name) onnx_scale_tensor, onnx_zero_point_tensor = ONNXModelTransformer._get_scale_zero_point_tensors( transformation, quantizer, dequantizer @@ -310,7 +322,7 @@ def _insert_quantizer_dequantizer( # If several nodes on one edge input_nodes = [] - input_nodes.extend(onnx_graph.get_nodes_by_input(target_edge_name)) + input_nodes.extend(children_node_mapping[target_edge_name]) if not input_nodes: raise RuntimeError( f"Can not add the quantizer to the {target_edge_name} edge. This edge does not have end node." @@ -318,7 +330,7 @@ def _insert_quantizer_dequantizer( if transformation.target_point.type == TargetType.PRE_LAYER_OPERATION: # If we need to change only target nodes input - target_node = onnx_graph.get_node_by_name(transformation.target_point.target_node_name) + target_node = node_mapping[transformation.target_point.target_node_name] for i, inp in enumerate(target_node.input): if inp == target_edge_name: target_node.input[i] = dequantizer.output[0] @@ -336,7 +348,7 @@ def _insert_quantizer_dequantizer( ) model.graph.initializer.extend([onnx_scale_tensor, onnx_zero_point_tensor]) model.graph.value_info.extend([onnx_scale_value_info, onnx_zero_point_info]) - insert_index = onnx_graph.get_node_index(input_nodes[0].name) + insert_index = get_node_index(model, input_nodes[0].name) model.graph.node.insert(insert_index, quantizer) model.graph.node.insert(insert_index + 1, dequantizer) return model @@ -351,13 +363,13 @@ def _apply_bias_correction_transformations( :param transformations: Bias correction transformations. :return: Copy of original model with updated biases. """ - onnx_graph = ONNXGraph(model) + node_mapping = get_name_to_node_map(model) for transformation in transformations: bias_tensor_position = transformation.target_point.port_id node_name = transformation.target_point.target_node_name - onnx_node = onnx_graph.get_node_by_name(node_name) + onnx_node = node_mapping[node_name] bias_initializer_name = onnx_node.input[bias_tensor_position] - bias_initializer = onnx_graph.get_tensor(bias_initializer_name) + bias_initializer = get_tensor(model, bias_initializer_name) new_bias_tensor = onnx.numpy_helper.from_array(transformation.bias_value, bias_initializer_name) bias_initializer.CopyFrom(new_bias_tensor) @@ -370,20 +382,19 @@ def _apply_model_extraction_transformation(self, transformation: ONNXModelExtrac :param transformation: Model extraction transformation. :return: Extracted sub-model. """ - onnx_graph = ONNXGraph(self._model) - input_tensor_names = [] + node_mapping = get_name_to_node_map(self._model) for input_node_name in transformation.inputs: - input_onnx_node = onnx_graph.get_node_by_name(input_node_name) + input_onnx_node = node_mapping[input_node_name] input_tensor_names.append(input_onnx_node.input[0]) output_tensor_names = [] for output_node_name in transformation.outputs: - output_onnx_node = onnx_graph.get_node_by_name(output_node_name) + output_onnx_node = node_mapping[output_node_name] output_tensor_names.append(output_onnx_node.output[0]) if not output_tensor_names: - output_tensor_names = [n.name for n in onnx_graph.get_model_outputs()] + output_tensor_names = [n.name for n in self._model.graph.output] return self.onnx_model_extractor.extract_model(input_tensor_names, output_tensor_names) @@ -397,11 +408,12 @@ def _apply_qdq_node_removing_transformations( :param transformations: Nodes removing transformations. :return: Model with removed nodes. """ - onnx_graph = ONNXGraph(model) for transformation in transformations: - node = onnx_graph.get_node_by_name(transformation.target_point.target_node_name) + node_mapping = get_name_to_node_map(model) + children_node_mapping = get_children_node_mapping(model) + node = node_mapping[transformation.target_point.target_node_name] - node_children = onnx_graph.get_children(node) + node_children = get_children(node, children_node_mapping) for node_child in node_children: for input_id, input_obj in enumerate(node_child.input): if input_obj == node.output[0]: diff --git a/nncf/onnx/graph/nncf_graph_builder.py b/nncf/onnx/graph/nncf_graph_builder.py index b728b4020a8..148de756c30 100644 --- a/nncf/onnx/graph/nncf_graph_builder.py +++ b/nncf/onnx/graph/nncf_graph_builder.py @@ -29,7 +29,16 @@ from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpWithWeightsMetatype from nncf.onnx.graph.metatypes.onnx_metatypes import get_metatype from nncf.onnx.graph.metatypes.onnx_metatypes import get_tensor_edge_name -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_children_node_mapping +from nncf.onnx.graph.onnx_helper import get_edge_dtype +from nncf.onnx.graph.onnx_helper import get_edge_info_mapping +from nncf.onnx.graph.onnx_helper import get_edge_shape +from nncf.onnx.graph.onnx_helper import get_input_port_id_for_node_after_input +from nncf.onnx.graph.onnx_helper import get_model_inputs +from nncf.onnx.graph.onnx_helper import get_output_port_id_for_node_before_output +from nncf.onnx.graph.onnx_helper import get_parents_node_mapping +from nncf.onnx.graph.onnx_helper import get_port_ids_between_nodes +from nncf.onnx.graph.onnx_helper import is_node_has_shared_weight class ONNXLayerAttributes(BaseLayerAttributes): @@ -103,23 +112,28 @@ def get_bias_tensor_port_id(metatype: ONNXOpWithWeightsMetatype) -> Optional[int return None -def _get_weight_port_ids(node: onnx.NodeProto, onnx_graph: ONNXGraph) -> Set[int]: +def _get_weight_port_ids( + node: onnx.NodeProto, + model: onnx.ModelProto, + parents_node_mapping: Dict[str, onnx.NodeProto], +) -> Set[int]: """ Returns all weight input ports. First, add constant weight port ids from metatype. Second, add weight port ids determined dynamically if metatype could have them. :param node: ONNX node. - :param onnx_graph: ONNXGraph. + :param model: ONNX model. + :param parents_node_mapping: Mapping from edge name to node which outputs this edge. :return: Port ids with weights. """ port_ids = set() - metatype = get_metatype(onnx_graph.onnx_model, node) + metatype = get_metatype(model, node) constant_port_ids = get_constant_weight_port_ids(metatype) port_ids.update(constant_port_ids) possible_port_ids = get_possible_weight_port_ids(metatype) for port_id in possible_port_ids: - if get_tensor_edge_name(onnx_graph, node, port_id): + if get_tensor_edge_name(model, node, port_id, parents_node_mapping): port_ids.add(port_id) return port_ids @@ -129,7 +143,7 @@ def _is_node_with_bias(node: onnx.NodeProto, model: onnx.ModelProto) -> bool: Returns True if node has bias tensor, otherwise - False. :param node: ONNX node. - :param onnx_graph: ONNXGraph. + :param model: ONNX model. :return: True if node has bias tensor, otherwise - False. """ metatype = get_metatype(model, node) @@ -139,23 +153,6 @@ def _is_node_with_bias(node: onnx.NodeProto, model: onnx.ModelProto) -> bool: return False -def _get_weight_attr(node: onnx.NodeProto, onnx_graph: ONNXGraph, weight_port_id: int) -> Dict[int, Dict]: - """ - Returns weight attributes. - - :param node: ONNX node. - :param onnx_graph: ONNXGraph. - :param weight_port_ids: Port ids with weights location. - :return: Weight attributes. - """ - weight_attrs = {} - weight_edge_name = node.input[weight_port_id] - edge = onnx_graph.get_edge(weight_edge_name) - weight_shape = ONNXGraph.get_edge_shape(edge) - weight_attrs[weight_port_id] = {"name": weight_edge_name, "shape": weight_shape} - return weight_attrs - - def _get_gemm_attrs(node: onnx.NodeProto) -> Dict[str, int]: """ Returns transpose attrbiutes of GEMM node. @@ -176,7 +173,7 @@ def _get_node_attrs(node: onnx.NodeProto, model: onnx.ModelProto) -> Dict[str, A Returns node attributes. :param node: Node. - :param onnx_graph: ONNXGraph. + :param model: ONNX model. :return : Node attributes. """ metatype = get_metatype(model, node) @@ -185,19 +182,24 @@ def _get_node_attrs(node: onnx.NodeProto, model: onnx.ModelProto) -> Dict[str, A return {} -def _get_bias_attr(node: onnx.NodeProto, onnx_graph: ONNXGraph) -> Dict[str, str]: +def _get_bias_attr( + node: onnx.NodeProto, + model: onnx.ModelProto, + parents_node_mapping: Dict[str, onnx.NodeProto], +) -> Dict[str, str]: """ Returns bias tensor attributes. :param node: ONNX node. - :param onnx_graph: ONNXGraph. + :param model: ONNX model. + :param parents_node_mapping: Mapping from edge name to node which outputs this edge. :return: Bias tensor attributes. """ bias_attrs = {} - metatype = get_metatype(onnx_graph.onnx_model, node) - if _is_node_with_bias(node, onnx_graph.onnx_model): + metatype = get_metatype(model, node) + if _is_node_with_bias(node, model): bias_tensor_port_id = get_bias_tensor_port_id(metatype) - bias_edge_name = get_tensor_edge_name(onnx_graph, node, bias_tensor_port_id) + bias_edge_name = get_tensor_edge_name(model, node, bias_tensor_port_id, parents_node_mapping) bias_attrs["name"] = bias_edge_name return bias_attrs @@ -232,15 +234,22 @@ def _replace_empty_node_name(model: onnx.ModelProto) -> onnx.ModelProto: return model @staticmethod - def _add_nncf_input_nodes(onnx_graph: ONNXGraph, nncf_graph: NNCFGraph) -> None: + def _add_nncf_input_nodes( + model: onnx.ModelProto, + nncf_graph: NNCFGraph, + edge_info_mapping: Dict[str, onnx.ValueInfoProto], + children_node_mapping: Dict[str, List[onnx.NodeProto]], + ) -> None: """ Adds special NNCF Input nodes to NNCFGraph. For all the ONNX model inputs, the special NNCF Input node is placed and then corresponding edges are added. - :param onnx_graph: ONNXGraph, which helps to get information about the ONNX model. + :param model: ONNX model. :param nncf_graph: NNCFGraph, in which the new nodes will be added. + :param edge_info_mapping: Mapping from edge name to the edge info. + :param children_node_mapping: Mapping from edge name to nodes which consume this edge as an input. :return: None. """ - for i, _input in enumerate(onnx_graph.get_model_inputs()): + for i, _input in enumerate(get_model_inputs(model)): input_name = _input.name layer_attributes = ONNXLayerAttributes() input_node = nncf_graph.add_nncf_node( @@ -249,18 +258,18 @@ def _add_nncf_input_nodes(onnx_graph: ONNXGraph, nncf_graph: NNCFGraph) -> None: node_metatype=InputNoopMetatype, layer_attributes=layer_attributes, ) - to_nodes = onnx_graph.get_nodes_by_input(input_name) + to_nodes = children_node_mapping[input_name] input_node_node_id = input_node.node_id - edge = onnx_graph.get_edge(input_name) - input_shape = ONNXGraph.get_edge_shape(edge) - onnx_dtype = ONNXGraph.get_edge_dtype(edge) + edge = edge_info_mapping[input_name] + input_shape = get_edge_shape(edge) + onnx_dtype = get_edge_dtype(edge) nncf_dtype = GraphConverter.convert_onnx_dtype_to_nncf_dtype(onnx_dtype) output_port_id = 0 for node in to_nodes: to_node_id = nncf_graph.get_node_by_name(node.name).node_id - input_port_id = ONNXGraph.get_input_port_id_for_node_after_input(input_name, node) + input_port_id = get_input_port_id_for_node_after_input(input_name, node) nncf_graph.add_edge_between_nncf_nodes( from_node_id=input_node_node_id, to_node_id=to_node_id, @@ -272,15 +281,22 @@ def _add_nncf_input_nodes(onnx_graph: ONNXGraph, nncf_graph: NNCFGraph) -> None: output_port_id += 1 @staticmethod - def _add_nncf_output_nodes(onnx_graph: ONNXGraph, nncf_graph: NNCFGraph) -> None: + def _add_nncf_output_nodes( + model: onnx.ModelProto, + nncf_graph: NNCFGraph, + edge_info_mapping: Dict[str, onnx.ValueInfoProto], + parents_node_mapping: Dict[str, onnx.NodeProto], + ) -> None: """ Adds special NNCF Output nodes to NNCFGraph. For all the ONNX model outputs, the special NNCF Output node is placed and then corresponding edges are added. - :param onnx_graph: ONNXGraph, which helps to get information about the ONNX model. + :param model: ONNX model. :param nncf_graph: NNCFGraph, in which the new nodes will be added. + :param edge_info_mapping: Mapping from edge name to the edge info. + :param parents_node_mapping: Mapping from edge name to node which outputs this edge. :return: None. """ - for i, _output in enumerate(onnx_graph.get_model_outputs()): + for i, _output in enumerate(model.graph.output): output_name = _output.name layer_attributes = ONNXLayerAttributes() output_node = nncf_graph.add_nncf_node( @@ -289,16 +305,16 @@ def _add_nncf_output_nodes(onnx_graph: ONNXGraph, nncf_graph: NNCFGraph) -> None node_metatype=OutputNoopMetatype, layer_attributes=layer_attributes, ) - from_node = onnx_graph.get_node_by_output(output_name) + from_node = parents_node_mapping[output_name] output_node_node_id = output_node.node_id - edge = onnx_graph.get_edge(output_name) - output_shape = ONNXGraph.get_edge_shape(edge) - onnx_dtype = ONNXGraph.get_edge_dtype(edge) + edge = edge_info_mapping[output_name] + output_shape = get_edge_shape(edge) + onnx_dtype = get_edge_dtype(edge) nncf_dtype = GraphConverter.convert_onnx_dtype_to_nncf_dtype(onnx_dtype) input_port_id = 0 from_node_id = nncf_graph.get_node_by_name(from_node.name).node_id - output_port_id = ONNXGraph.get_output_port_id_for_node_before_output(output_name, from_node) + output_port_id = get_output_port_id_for_node_before_output(output_name, from_node) nncf_graph.add_edge_between_nncf_nodes( from_node_id=from_node_id, to_node_id=output_node_node_id, @@ -330,21 +346,27 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph: :return: NNCFGraph. """ onnx_model = GraphConverter._replace_empty_node_name(onnx_model) + onnx_model = onnx.shape_inference.infer_shapes(onnx_model) + edge_info_mapping = get_edge_info_mapping(onnx_model) + children_node_mapping = get_children_node_mapping(onnx_model) + parents_node_mapping = get_parents_node_mapping(onnx_model) nncf_graph = NNCFGraph() - onnx_graph = ONNXGraph(onnx_model) - for node in onnx_graph.get_all_nodes(): + for node in onnx_model.graph.node: metatype = get_metatype(onnx_model, node) - weight_port_ids = _get_weight_port_ids(node, onnx_graph) + weight_port_ids = _get_weight_port_ids(node, onnx_model, parents_node_mapping) is_shared = None weight_attrs = {} node_attrs = _get_node_attrs(node, onnx_model) - bias_attrs = _get_bias_attr(node, onnx_graph) + bias_attrs = _get_bias_attr(node, onnx_model, parents_node_mapping) if weight_port_ids: # If node has weight weight_edge_names = [] for weight_port_id in weight_port_ids: - weight_edge_names.append(node.input[weight_port_id]) - weight_attrs.update(_get_weight_attr(node, onnx_graph, weight_port_id)) - if not is_shared and onnx_graph.is_node_has_shared_weight(node, weight_port_id): + weight_edge_name = node.input[weight_port_id] + weight_edge_names.append(weight_edge_name) + edge = edge_info_mapping[weight_edge_name] + weight_shape = get_edge_shape(edge) + weight_attrs[weight_port_id] = {"name": weight_edge_name, "shape": weight_shape} + if not is_shared and is_node_has_shared_weight(node, weight_port_id, children_node_mapping): is_shared = True layer_attributes = ONNXLayerAttributes( @@ -357,22 +379,23 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph: layer_attributes=layer_attributes, is_shared=is_shared, ) - for output_node in onnx_graph.get_all_nodes(): - output_edges = onnx_graph.get_node_edge_names(output_node.name)["output"] + + for output_node in onnx_model.graph.node: + output_edges = output_node.output for output_edge in output_edges: - edge = onnx_graph.get_edge(output_edge) + edge = edge_info_mapping.get(output_edge) if edge is None: # If the edge is None it means that the edge was not added during shape inference of ONNX model. # BatchNorm exported in Training mode has unused outputs edges: mean, var, saved_mean, saved_var. # NNCFGraph should not contain such edges. continue - tensor_shape = ONNXGraph.get_edge_shape(edge) - onnx_dtype = ONNXGraph.get_edge_dtype(edge) + tensor_shape = get_edge_shape(edge) + onnx_dtype = get_edge_dtype(edge) nncf_dtype = GraphConverter.convert_onnx_dtype_to_nncf_dtype(onnx_dtype) output_node_id = nncf_graph.get_node_by_name(output_node.name).node_id - input_nodes = onnx_graph.get_nodes_by_input(output_edge) + input_nodes = children_node_mapping[output_edge] for input_node in input_nodes: - port_ids = ONNXGraph.get_port_ids_between_nodes(output_node, input_node) + port_ids = get_port_ids_between_nodes(output_node, input_node) input_port_id = port_ids["input_port_id"] output_port_id = port_ids["output_port_id"] in_node_id = nncf_graph.get_node_by_name(input_node.name).node_id @@ -384,6 +407,7 @@ def create_nncf_graph(onnx_model: onnx.ModelProto) -> NNCFGraph: output_port_id=output_port_id, dtype=Dtype(nncf_dtype), ) - GraphConverter._add_nncf_input_nodes(onnx_graph, nncf_graph) - GraphConverter._add_nncf_output_nodes(onnx_graph, nncf_graph) + + GraphConverter._add_nncf_input_nodes(onnx_model, nncf_graph, edge_info_mapping, children_node_mapping) + GraphConverter._add_nncf_output_nodes(onnx_model, nncf_graph, edge_info_mapping, parents_node_mapping) return nncf_graph diff --git a/nncf/onnx/graph/node_utils.py b/nncf/onnx/graph/node_utils.py index 6575dff6f1c..1e9a162211d 100644 --- a/nncf/onnx/graph/node_utils.py +++ b/nncf/onnx/graph/node_utils.py @@ -21,7 +21,7 @@ from nncf.common.tensor_statistics.collectors import ReductionAxes from nncf.onnx.graph.metatypes import onnx_metatypes as om from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXDequantizeLinearMetatype -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_tensor_value from nncf.onnx.graph.transformations.commands import ONNXTargetPoint @@ -45,10 +45,9 @@ def get_bias_value(node_with_bias: NNCFNode, model: onnx.ModelProto) -> np.ndarr :param model: The model that contains this operation. :return: The bias value that is applied to the output tensor of the node's operation. """ - onnx_graph = ONNXGraph(model) assert node_with_bias.layer_attributes.has_bias() bias_name = node_with_bias.layer_attributes.bias_attrs["name"] - return onnx_graph.get_tensor_value(bias_name) + return get_tensor_value(model, bias_name) def get_input_edges_mapping(nncf_graph: NNCFGraph) -> Dict[str, Tuple[str, int]]: @@ -68,20 +67,25 @@ def get_input_edges_mapping(nncf_graph: NNCFGraph) -> Dict[str, Tuple[str, int]] return input_edges_mapping -def get_input_edge(input_node_name: str, input_edges_mapping: Dict[str, Tuple[str, int]], onnx_graph: ONNXGraph) -> str: +def get_input_edge( + input_node_name: str, + input_edges_mapping: Dict[str, Tuple[str, int]], + node_mapping: Dict[str, onnx.NodeProto], +) -> str: """ Returns input edge corresponding to the NNCF input node with the name input_node_name. :param input_node_name: Name of NNCF input node. :param input_edges_mapping: A mapping of NNCF input node names and - a tuple with the consumed node names and their input port ids. - :param onnx_graph: Instance of ONNXGraph of the model. + a tuple with the consumed node names and their input port ids. + :param node_mapping: Mapping of node names to the nodes. :return: Input edge name. """ input_edges = set() for node_info in input_edges_mapping[input_node_name]: name, port_id = node_info - input_edges.add(onnx_graph.get_node_edge_names(name)["input"][port_id]) + node = node_mapping[name] + input_edges.add(node.input[port_id]) assert len(input_edges) == 1 return input_edges.pop() diff --git a/nncf/onnx/graph/onnx_graph.py b/nncf/onnx/graph/onnx_graph.py deleted file mode 100644 index df754263c99..00000000000 --- a/nncf/onnx/graph/onnx_graph.py +++ /dev/null @@ -1,321 +0,0 @@ -# Copyright (c) 2023 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, Iterator, List, Optional, Union - -import numpy as np -import onnx -from onnx import numpy_helper - - -class ONNXGraph: - """ - The class provides the interface to get the necessary information from ONNX model. - """ - - def __init__(self, onnx_model: onnx.ModelProto): - self.onnx_model = onnx_model - self._node_name_to_node = None # type: Dict[str, onnx.NodeProto] - self._edge_name_to_value_info = None # type: Dict[str, onnx.ValueInfoProto] - - def _update_edges(self) -> None: - self.onnx_model = onnx.shape_inference.infer_shapes(self.onnx_model) - value_infos = [ - *self.onnx_model.graph.value_info, - *self.onnx_model.graph.input, - *self.onnx_model.graph.output, - *self.onnx_model.graph.initializer, - ] - self._edge_name_to_value_info = {tensor.name: tensor for tensor in value_infos} - - def _update_node_names(self) -> None: - self._node_name_to_node = {n.name: n for n in self.onnx_model.graph.node} - - def _get_all_tensors(self) -> Iterator[onnx.TensorProto]: - """ - Iterate over all tensors of ONNX model. - - :yield: tensors of ONNX model. - """ - for initializer in self.onnx_model.graph.initializer: - yield initializer - for node in self.onnx_model.graph.node: - for attribute in node.attribute: - if attribute.HasField("t"): - yield attribute.t - yield from attribute.tensors - - def get_all_nodes(self) -> List[onnx.NodeProto]: - """ - Returns model nodes in the original order. - - :return: model nodes. - """ - return self.onnx_model.graph.node - - def get_node_by_name(self, node_name: str) -> Optional[onnx.NodeProto]: - """ - Returns a model node with the name equals to 'node_name' from self._node_name_to_node. - If the self._node_name_to_node is None, fills it with the nodes from the self.onnx_model. - If there is no node with such name returns None. - - :param node_name: Name of the node. - :return: None if the node with the specified name exists - otherwise returns the node. - """ - if self._node_name_to_node is None: - self._update_node_names() - return self._node_name_to_node[node_name] if node_name in self._node_name_to_node else None - - def get_edge(self, edge_name: str) -> Optional[onnx.ValueInfoProto]: - """ - Returns edge by its name or None if the model has no such edge. - If self._edge_name_to_value_info is not initialized runs an initialization. - - :param edge_name: Name of edge. - :return: Edge. - """ - if self._edge_name_to_value_info is None: - self._update_edges() - return self._edge_name_to_value_info.get(edge_name, None) - - def get_model_inputs(self) -> List[onnx.ValueInfoProto]: - """ - Returns all model inputs. - - :return: Model Inputs. - """ - inputs = [] - input_all = [node.name for node in self.onnx_model.graph.input] - input_initializer = [node.name for node in self.onnx_model.graph.initializer] - net_feed_input = list(set(input_all) - set(input_initializer)) - for node in self.onnx_model.graph.input: - if node.name in net_feed_input: - inputs.append(node) - return inputs - - def get_model_outputs(self) -> List[onnx.ValueInfoProto]: - """ - Returns all model outputs. - - :return: Model Outputs. - """ - return list(self.onnx_model.graph.output) - - def get_node_by_output(self, output_name: str) -> Optional[onnx.NodeProto]: - """ - Returns node that have output edge with the name 'output_name'. - - :param output_name: The name of output edge. - :return: Node with corresponding output. - """ - for node in self.get_all_nodes(): - if output_name in node.output: - return node - return None - - def get_nodes_by_input(self, input_name: str) -> List[onnx.NodeProto]: - """ - Returns all nodes that have input with the name 'input_name'. - - :param input_name: The name of input edge. - :return: Nodes with corresponding input. - """ - output = [] - for node in self.get_all_nodes(): - if input_name in node.input: - output.append(node) - return output - - def get_node_edge_names(self, node_name: str) -> Dict[str, List[str]]: - """ - Returns node edge names. - - :param node_name: The name of the node. - :return: Dict with two keys: 'input' and 'output', - which are corresponding to input and output edges accordingly. - """ - if self._node_name_to_node is None: - self._update_node_names() - if node_name in self._node_name_to_node: - return { - "input": list(self._node_name_to_node[node_name].input), - "output": list(self._node_name_to_node[node_name].output), - } - raise RuntimeError("There is no node with the name {}".format(node_name)) - - @staticmethod - def get_input_port_id_for_node_after_input(input_name: str, to_node: onnx.NodeProto) -> int: - """ - Returns input_port_id for 'to_node' connected with the model input with the name 'input_name'. - - :param input_name: Name of the ONNX model Input. - :param to_node: Node, which has input edge with 'input_name' name. - :return: input port number for 'to_node', which is connected to 'input_name'. - """ - for input_port_id, port in enumerate(to_node.input): - if port == input_name: - return input_port_id - raise RuntimeError(f"The node {to_node} does not have input edge with the name {input_name}") - - @staticmethod - def get_output_port_id_for_node_before_output(output_name: str, from_node: onnx.NodeProto) -> int: - """ - Returns output_port_id for 'from_node' connected with the model output with the name 'output_name'. - - :param output_name: Name of the ONNX model Output. - :param from_node: Node, which has output edge with 'output_name' name. - :return: output port number for 'from_node', which is connected to 'output_name'. - """ - for output_port_id, port in enumerate(from_node.output): - if port == output_name: - return output_port_id - raise RuntimeError(f"The node {from_node} does not have output edge with the name {output_name}") - - @staticmethod - def get_port_ids_between_nodes(from_node: onnx.NodeProto, to_node: onnx.NodeProto) -> Dict[str, int]: - """ - Returns input_port_id and output_port_id between 'from_node' and 'to_node'. - - :param from_node: Node, whose output is connected to 'to_node' node. - :param to_node: Node, whose input is connected to 'from_node' node. - :return: Dict{'input_port_id': input port id, 'output_port_id': output port id} - """ - output = {"input_port_id": None, "output_port_id": None} - for port_id, port in enumerate(to_node.input): - if port in from_node.output: - output["input_port_id"] = port_id - for port_id, port in enumerate(from_node.output): - if port in to_node.input: - output["output_port_id"] = port_id - if output["output_port_id"] is None or output["input_port_id"] is None: - raise RuntimeError(f"The nodes {from_node.name} and {to_node.name} do not have edges between.") - return output - - def get_node_index(self, node_name: str) -> int: - """ - Returns the node index in the model. - - :param node_name: Name of the node. - :return: Node index, -1 if there is no such node. - """ - for i, node in enumerate(self.get_all_nodes()): - if node.name == node_name: - return i - return -1 - - def has_tensor(self, tensor_name: str) -> bool: - """ - Returns True whether the model has the tensor with the name equals to tensor_name. - - :param tensor_name: Name of the tensor. - :return: True if the model has such tensor, False - otherwise. - """ - for tensor in self._get_all_tensors(): - if tensor.name == tensor_name: - return True - return False - - def get_tensor_value(self, tensor_name: str) -> np.ndarray: - """ - Returns tensor value of a tensor with the name 'tensor_name'. - - :param tensor_name: Name of the tensor. - :return: The value of the tensor. - """ - tensor = self.get_tensor(tensor_name) - return numpy_helper.to_array(tensor) - - def get_tensor(self, tensor_name: str) -> onnx.TensorProto: - """ - Returns a tensor with the name 'tensor_name'. - - :param initializer_name: Name of the Initializer. - :return: The Initializer. - """ - for tensor in self._get_all_tensors(): - if tensor.name == tensor_name: - return tensor - raise RuntimeError("There is no tensor with the name {}".format(tensor_name)) - - @staticmethod - def get_edge_shape(edge: Union[onnx.ValueInfoProto, onnx.TensorProto]) -> List[int]: - """ - Returns edge shape. - - :param edge: The edge. - :return: Shape of the Tensor. - """ - if isinstance(edge, onnx.TensorProto): - return list(edge.dims) - tensor_type = edge.type.tensor_type - shape = [] - if tensor_type.HasField("shape"): - for d in tensor_type.shape.dim: - if d.HasField("dim_value"): - dim_value = d.dim_value - if isinstance(dim_value, int): - shape.append(dim_value) - else: - return shape - elif d.HasField("dim_param"): - # flexible shape make manually -1 - shape.append(-1) - else: - return shape - return shape - - @staticmethod - def get_edge_dtype(edge: Union[onnx.ValueInfoProto, onnx.TensorProto]) -> int: - """ - Returns the data type of the edge. - - :param edge: The edge. - :return: Data type of the edge. - """ - if isinstance(edge, onnx.ValueInfoProto): - return edge.type.tensor_type.elem_type - return edge.data_type - - def get_parent(self, node: onnx.NodeProto, port_id: int) -> Optional[onnx.NodeProto]: - """ - Returns parents of the node. If there is no parent node, returns None. - - :param node: The child node. - :param port_id: Input port id on which the parent is seeked. - :return: Parent node. - """ - if port_id < len(node.input): - return self.get_node_by_output(node.input[port_id]) - return None - - def get_children(self, node: onnx.NodeProto) -> List[onnx.NodeProto]: - """ - Returns children of the node. - - :param node: The parent node. - :return: All children nodes. - """ - output = [] - node_edges = self.get_node_edge_names(node.name)["output"] - for node_edge in node_edges: - output.extend(self.get_nodes_by_input(node_edge)) - return output - - def is_node_has_shared_weight(self, node: onnx.NodeProto, weight_port_id: int) -> bool: - """ - Returns whether the node share a weight. - - :param node: Node. - :return: True whether node shares a weight - otherwise False. - """ - weight_tensor_edge = self.get_node_edge_names(node.name)["input"][weight_port_id] - nodes = self.get_nodes_by_input(weight_tensor_edge) - return len(nodes) > 1 diff --git a/nncf/onnx/graph/onnx_helper.py b/nncf/onnx/graph/onnx_helper.py new file mode 100644 index 00000000000..f6b082050a0 --- /dev/null +++ b/nncf/onnx/graph/onnx_helper.py @@ -0,0 +1,290 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict +from typing import Dict, Iterator, List, Optional, Union + +import numpy as np +import onnx +from onnx import numpy_helper + + +def get_name_to_node_map(model: onnx.ModelProto) -> Dict[str, onnx.NodeProto]: + """ + Returns mapping from node name to the node. + + :param model: Model from mapping is built. + :return: Mapping. + """ + return {node.name: node for node in model.graph.node} + + +def get_edge_info_mapping(model: onnx.ModelProto) -> Dict[str, onnx.ValueInfoProto]: + """ + Retuns mapping from edge name to the edge info. + + :param model: Model from mapping is built. + :return: Mapping. + """ + return { + tensor.name: tensor + for tensor in (*model.graph.value_info, *model.graph.input, *model.graph.output, *model.graph.initializer) + } + + +def get_children_node_mapping(model: onnx.ModelProto) -> Dict[str, List[onnx.NodeProto]]: + """ + Returns a mapping from edge name to nodes which consume this edge as an input. + + :param model: ONNX model. + :return: Mapping from edge name to nodes which consume this edge as an input. + """ + output = defaultdict(list) + for node in model.graph.node: + for edge in node.input: + output[edge].append(node) + return output + + +def get_parents_node_mapping(model: onnx.ModelProto) -> Dict[str, onnx.NodeProto]: + """ + Returns a mapping from edge name to node which outputs this edge. + + :param model: ONNX model. + :return: Mapping from edge name to node which outputs this edge. + """ + output = {} + for node in model.graph.node: + for edge in node.output: + output[edge] = node + return output + + +def get_model_inputs(model: onnx.ModelProto) -> List[onnx.ValueInfoProto]: + """ + Returns all model inputs. + + :param model: ONNX model. + :return: Model Inputs. + """ + inputs = [] + input_all = [node.name for node in model.graph.input] + input_initializer = [node.name for node in model.graph.initializer] + net_feed_input = list(set(input_all) - set(input_initializer)) + for node in model.graph.input: + if node.name in net_feed_input: + inputs.append(node) + return inputs + + +def get_input_port_id_for_node_after_input(input_name: str, to_node: onnx.NodeProto) -> int: + """ + Returns input_port_id for 'to_node' connected with the model input with the name 'input_name'. + + :param input_name: Name of the ONNX model Input. + :param to_node: Node, which has input edge with 'input_name' name. + :return: input port number for 'to_node', which is connected to 'input_name'. + """ + for input_port_id, port in enumerate(to_node.input): + if port == input_name: + return input_port_id + raise RuntimeError(f"The node {to_node} does not have input edge with the name {input_name}") + + +def get_output_port_id_for_node_before_output(output_name: str, from_node: onnx.NodeProto) -> int: + """ + Returns output_port_id for 'from_node' connected with the model output with the name 'output_name'. + + :param output_name: Name of the ONNX model Output. + :param from_node: Node, which has output edge with 'output_name' name. + :return: output port number for 'from_node', which is connected to 'output_name'. + """ + for output_port_id, port in enumerate(from_node.output): + if port == output_name: + return output_port_id + raise RuntimeError(f"The node {from_node} does not have output edge with the name {output_name}") + + +def get_port_ids_between_nodes(from_node: onnx.NodeProto, to_node: onnx.NodeProto) -> Dict[str, int]: + """ + Returns input_port_id and output_port_id between 'from_node' and 'to_node'. + + :param from_node: Node, whose output is connected to 'to_node' node. + :param to_node: Node, whose input is connected to 'from_node' node. + :return: Dict{'input_port_id': input port id, 'output_port_id': output port id} + """ + output = {"input_port_id": None, "output_port_id": None} + for port_id, port in enumerate(to_node.input): + if port in from_node.output: + output["input_port_id"] = port_id + for port_id, port in enumerate(from_node.output): + if port in to_node.input: + output["output_port_id"] = port_id + if output["output_port_id"] is None or output["input_port_id"] is None: + raise RuntimeError(f"The nodes {from_node.name} and {to_node.name} do not have edges between.") + return output + + +def get_node_index(model: onnx.ModelProto, node_name: str) -> Optional[int]: + """ + Returns the node index in the model. + + :param model: ONNX model. + :param node_name: Name of the node. + :return: Node index, -1 if there is no such node. + """ + for i, node in enumerate(model.graph.node): + if node.name == node_name: + return i + return None + + +def _get_all_tensors(model: onnx.ModelProto) -> Iterator[onnx.TensorProto]: + """ + Iterate over all tensors of ONNX model. + + :param model: ONNX model. + :yield: tensors of ONNX model. + """ + for initializer in model.graph.initializer: + yield initializer + for node in model.graph.node: + for attribute in node.attribute: + if attribute.HasField("t"): + yield attribute.t + yield from attribute.tensors + + +def has_tensor(model: onnx.ModelProto, tensor_name: str) -> bool: + """ + Returns True whether the model has the tensor with the name equals to tensor_name. + + :param model: ONNX model. + :param tensor_name: Name of the tensor. + :return: True if the model has such tensor, False - otherwise. + """ + for tensor in _get_all_tensors(model): + if tensor.name == tensor_name: + return True + return False + + +def get_tensor(model: onnx.ModelProto, tensor_name: str) -> onnx.TensorProto: + """ + Returns a tensor with the name 'tensor_name'. + + :param model: ONNX model. + :param tensor_name: Name of the tensor. + :return: The Initializer. + """ + for tensor in _get_all_tensors(model): + if tensor.name == tensor_name: + return tensor + raise RuntimeError("There is no tensor with the name {}".format(tensor_name)) + + +def get_tensor_value(model: onnx.ModelProto, tensor_name: str) -> np.ndarray: + """ + Returns tensor value of a tensor with the name 'tensor_name'. + + :param model: ONNX model. + :param tensor_name: Name of the tensor. + :return: The value of the tensor. + """ + return numpy_helper.to_array(get_tensor(model, tensor_name)) + + +def get_edge_shape(edge: Union[onnx.ValueInfoProto, onnx.TensorProto]) -> List[int]: + """ + Returns edge shape. + + :param edge: The edge. + :return: Shape of the Tensor. + """ + if isinstance(edge, onnx.TensorProto): + return list(edge.dims) + tensor_type = edge.type.tensor_type + shape = [] + if tensor_type.HasField("shape"): + for d in tensor_type.shape.dim: + if d.HasField("dim_value"): + dim_value = d.dim_value + if isinstance(dim_value, int): + shape.append(dim_value) + else: + return shape + elif d.HasField("dim_param"): + # flexible shape make manually -1 + shape.append(-1) + else: + return shape + return shape + + +def get_edge_dtype(edge: Union[onnx.ValueInfoProto, onnx.TensorProto]) -> int: + """ + Returns the data type of the edge. + + :param edge: The edge. + :return: Data type of the edge. + """ + if isinstance(edge, onnx.ValueInfoProto): + return edge.type.tensor_type.elem_type + return edge.data_type + + +def get_parent( + node: onnx.NodeProto, + port_id: int, + parents_node_mapping: Dict[str, onnx.NodeProto], +) -> Optional[onnx.NodeProto]: + """ + Returns parents of the node. If there is no parent node, returns None. + + :param node: The child node. + :param port_id: Input port id on which the parent is seeked. + :param edge_node_mapping: Mapping describing start and consumed nodes of the edges. + :return: Parent node. + """ + if port_id < len(node.input): + return parents_node_mapping.get(node.input[port_id]) + return None + + +def get_children(node: onnx.NodeProto, children_node_mapping: Dict[str, List[onnx.NodeProto]]) -> List[onnx.NodeProto]: + """ + Returns children of the node. + + :param node: The parent node. + :param edge_node_mapping: Mapping describing start and consumed nodes of the edges. + :return: All children nodes. + """ + output = [] + for node_edge in node.output: + output.extend(children_node_mapping[node_edge]) + return output + + +def is_node_has_shared_weight( + node: onnx.NodeProto, + weight_port_id: int, + children_node_mapping: Dict[str, List[onnx.NodeProto]], +) -> bool: + """ + Returns whether the node share a weight. + + :param node: Node. + :param weight_port_id: Port id on which there is a weight. + :param edge_node_mapping: Mapping describing start and consumed nodes of the edges. + :return: True whether node shares a weight - otherwise False. + """ + weight_tensor_edge = node.input[weight_port_id] + nodes = children_node_mapping[weight_tensor_edge] + return len(nodes) > 1 diff --git a/nncf/onnx/statistics/aggregator.py b/nncf/onnx/statistics/aggregator.py index e3435382b5d..a768a855258 100644 --- a/nncf/onnx/statistics/aggregator.py +++ b/nncf/onnx/statistics/aggregator.py @@ -22,7 +22,7 @@ from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.onnx.graph.node_utils import get_input_edge from nncf.onnx.graph.node_utils import get_input_edges_mapping -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_name_to_node_map from nncf.onnx.graph.transformations.commands import ONNXOutputInsertionCommand from nncf.onnx.tensor import ONNXNNCFTensor @@ -30,28 +30,30 @@ class ONNXStatisticsAggregator(StatisticsAggregator): def collect_statistics(self, model: onnx.ModelProto, graph: NNCFGraph) -> None: self.input_edges_mapping = get_input_edges_mapping(graph) - self._onnx_graph = ONNXGraph(model) + self.node_mapping = get_name_to_node_map(model) self._registered_weights = set() super().collect_statistics(model, graph) def _register_statistics( self, outputs: Dict[str, ONNXNNCFTensor], statistic_points: StatisticPointsContainer ) -> None: - for node_name, _statistic_points in statistic_points.items(): + for _statistic_points in statistic_points.values(): for statistic_point in _statistic_points: target_point = statistic_point.target_point port_id = target_point.port_id if target_point.target_node_name in self.input_edges_mapping: # Input case edge_name = get_input_edge( - target_point.target_node_name, self.input_edges_mapping, self._onnx_graph + target_point.target_node_name, + self.input_edges_mapping, + self.node_mapping, ) - statistic_point.register_tensor(outputs[edge_name]) elif target_point.type == TargetType.POST_LAYER_OPERATION: - edge_name = self._onnx_graph.get_node_edge_names(node_name)["output"][port_id] - statistic_point.register_tensor(outputs[edge_name]) + node = self.node_mapping[target_point.target_node_name] + edge_name = node.output[port_id] elif target_point.type in [TargetType.PRE_LAYER_OPERATION, TargetType.OPERATION_WITH_WEIGHTS]: - edge_name = self._onnx_graph.get_node_edge_names(node_name)["input"][port_id] - statistic_point.register_tensor(outputs[edge_name]) + node = self.node_mapping[target_point.target_node_name] + edge_name = node.input[port_id] + statistic_point.register_tensor(outputs[edge_name]) def _get_transformation_layout_extra_outputs( self, statistic_points: StatisticPointsContainer diff --git a/nncf/quantization/algorithms/bias_correction/onnx_backend.py b/nncf/quantization/algorithms/bias_correction/onnx_backend.py index 364c93acc5a..d7f34936bfd 100644 --- a/nncf/quantization/algorithms/bias_correction/onnx_backend.py +++ b/nncf/quantization/algorithms/bias_correction/onnx_backend.py @@ -22,7 +22,7 @@ from nncf.onnx.graph.node_utils import get_bias_value from nncf.onnx.graph.node_utils import is_any_weight_quantized from nncf.onnx.graph.node_utils import is_node_with_bias -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_name_to_node_map from nncf.onnx.graph.transformations.command_creation import create_bias_correction_command from nncf.onnx.graph.transformations.commands import ONNXBiasCorrectionCommand from nncf.onnx.graph.transformations.commands import ONNXModelExtractionCommand @@ -101,15 +101,13 @@ def get_bias_value(node: NNCFNode, model: onnx.ModelProto, nncf_graph: NNCFGraph @staticmethod def get_input_name(model: onnx.ModelProto, node_name: str) -> str: - onnx_graph = ONNXGraph(model) - node = onnx_graph.get_node_by_name(node_name) - return node.input[0] + node_mapping = get_name_to_node_map(model) + return node_mapping[node_name].input[0] @staticmethod def get_output_name(model: onnx.ModelProto, node_name: str, output_id: int) -> List[str]: - onnx_graph = ONNXGraph(model) - node = onnx_graph.get_node_by_name(node_name) - return node.output[output_id] + node_mapping = get_name_to_node_map(model) + return node_mapping[node_name].output[output_id] @staticmethod def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: diff --git a/tests/onnx/quantization/common.py b/tests/onnx/quantization/common.py index 01a916b61a5..a5b9c8f47e3 100644 --- a/tests/onnx/quantization/common.py +++ b/tests/onnx/quantization/common.py @@ -18,7 +18,9 @@ from nncf import Dataset from nncf.experimental.tensor import Tensor from nncf.onnx.graph.nncf_graph_builder import GraphConverter -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_edge_dtype +from nncf.onnx.graph.onnx_helper import get_edge_info_mapping +from nncf.onnx.graph.onnx_helper import get_edge_shape from nncf.onnx.statistics.statistics import ONNXMinMaxTensorStatistic from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization @@ -62,15 +64,15 @@ def _get_input_keys(original_model: onnx.ModelProto) -> str: def get_random_dataset_for_test(model: onnx.ModelProto, has_batch_dim: bool, length: Optional[int] = 10): keys = _get_input_keys(model) - onnx_graph = ONNXGraph(model) + edge_info_mapping = get_edge_info_mapping(model) def transform_fn(i): output = {} for key in keys: - edge = onnx_graph.get_edge(key) - input_dtype = ONNXGraph.get_edge_dtype(edge) + edge = edge_info_mapping[key] + input_dtype = get_edge_dtype(edge) input_np_dtype = onnx.helper.tensor_dtype_to_np_dtype(input_dtype) - shape = ONNXGraph.get_edge_shape(edge) + shape = get_edge_shape(edge) rng = get_random_generator() tensor = rng.uniform(-1, 1, shape).astype(input_np_dtype) if has_batch_dim: diff --git a/tests/onnx/quantization/test_qdq_params_calculation.py b/tests/onnx/quantization/test_qdq_params_calculation.py index bf16eb152b2..1b3367ab6fa 100644 --- a/tests/onnx/quantization/test_qdq_params_calculation.py +++ b/tests/onnx/quantization/test_qdq_params_calculation.py @@ -15,7 +15,7 @@ import pytest from nncf.common.quantization.structs import QuantizationPreset -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_tensor_value from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.advanced_parameters import OverflowFix from tests.onnx.conftest import ONNX_TEST_ROOT @@ -36,11 +36,10 @@ def get_q_nodes_params(model: onnx.ModelProto) -> Dict[str, np.ndarray]: output = {} - onnx_graph = ONNXGraph(model) - for node in onnx_graph.get_all_nodes(): + for node in model.graph.node: if node.op_type == "QuantizeLinear": - scale = onnx_graph.get_tensor_value(node.input[1]) - zero_point = onnx_graph.get_tensor_value(node.input[2]) + scale = get_tensor_value(model, node.input[1]) + zero_point = get_tensor_value(model, node.input[2]) output[node.name] = {"scale": scale, "zero_point": zero_point} return output diff --git a/tests/onnx/test_model_transformer.py b/tests/onnx/test_model_transformer.py index 4cf5cb4e332..da039ee2d1a 100644 --- a/tests/onnx/test_model_transformer.py +++ b/tests/onnx/test_model_transformer.py @@ -20,7 +20,8 @@ from nncf.common.graph.transformations.layout import TransformationLayout from nncf.onnx.graph.model_transformer import ONNXModelTransformer from nncf.onnx.graph.nncf_graph_builder import GraphConverter -from nncf.onnx.graph.onnx_graph import ONNXGraph +from nncf.onnx.graph.onnx_helper import get_tensor +from nncf.onnx.graph.onnx_helper import get_tensor_value from nncf.onnx.graph.transformations.commands import ONNXBiasCorrectionCommand from nncf.onnx.graph.transformations.commands import ONNXOutputInsertionCommand from nncf.onnx.graph.transformations.commands import ONNXQDQNodeRemovingCommand @@ -60,7 +61,7 @@ def test_quantizer_insertion(target_layers, should_raise, quantizer_number): if should_raise: try: _ = model_transformer.transform(transformation_layout) - except RuntimeError: + except KeyError: return transformed_model = model_transformer.transform(transformation_layout) onnx.checker.check_model(transformed_model) @@ -124,17 +125,15 @@ def test_inserted_quantizer_parameters(test_parameters): transformed_model = model_transformer.transform(transformation_layout) onnx.checker.check_model(transformed_model) - onnx_graph = ONNXGraph(transformed_model) - # pylint:disable=no-member for node in transformed_model.graph.node: op_type = node.op_type if op_type == "QuantizeLinear": for attr in node.attribute: assert test_parameters.onnx_attributes[attr.name] == onnx.helper.get_attribute_value(attr) - assert np.allclose(onnx_graph.get_tensor_value(node.input[1]), np.array(test_parameters.scale)) - assert np.allclose(onnx_graph.get_tensor_value(node.input[2]), np.array(test_parameters.zero_point)) - assert onnx_graph.get_tensor_value(node.input[2]).dtype == test_parameters.onnx_dtype + assert np.allclose(get_tensor_value(transformed_model, node.input[1]), np.array(test_parameters.scale)) + assert np.allclose(get_tensor_value(transformed_model, node.input[2]), np.array(test_parameters.zero_point)) + assert get_tensor_value(transformed_model, node.input[2]).dtype == test_parameters.onnx_dtype TARGET_LAYERS = [["ReLU1"], ["Conv1", "BN1"], ["Conv1", "BN1", "ReLU1"]] @@ -160,8 +159,7 @@ def test_output_insertion(target_layers, target_layer_outputs): transformed_model = model_transformer.transform(transformation_layout) - onnx_graph = ONNXGraph(transformed_model) - assert Counter([out.name for out in onnx_graph.get_model_outputs()]) == Counter(target_layer_outputs) + assert Counter([out.name for out in transformed_model.graph.output]) == Counter(target_layer_outputs) CONV_LAYERS = [["Conv1", "Conv2"]] @@ -182,11 +180,11 @@ def test_bias_correction(layers, values, refs): model_transformer = ONNXModelTransformer(model) transformed_model = model_transformer.transform(transformation_layout) - onnx_graph = ONNXGraph(transformed_model) + node_dict = {node.name: node for node in transformed_model.graph.node} for conv_layer, bias_reference in zip(layers, refs): - bias_tensor_name = onnx_graph.get_node_by_name(conv_layer).input[2] - bias_tensor = onnx_graph.get_tensor(bias_tensor_name) + bias_tensor_name = node_dict[conv_layer].input[2] + bias_tensor = get_tensor(transformed_model, bias_tensor_name) bias_value = onnx.numpy_helper.to_array(bias_tensor) assert np.all(bias_value == bias_reference) diff --git a/tests/onnx/weightless_model.py b/tests/onnx/weightless_model.py index 046568df8eb..6f34347ba38 100644 --- a/tests/onnx/weightless_model.py +++ b/tests/onnx/weightless_model.py @@ -19,8 +19,6 @@ from onnx import TensorProto # pylint:disable=no-name-in-module from onnx.external_data_helper import uses_external_data -from nncf.onnx.graph.onnx_graph import ONNXGraph - # pylint: disable=no-member @@ -32,8 +30,7 @@ def load_model_topology_with_zeros_weights(model_path: Union[str, Path]) -> onnx :return: Onnx model with filled the all external tensors by random values. """ model = onnx.load_model(model_path, load_external_data=False) - onnx_graph = ONNXGraph(model) - for tensor in onnx_graph.onnx_model.graph.initializer: + for tensor in model.graph.initializer: if uses_external_data(tensor): np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor.data_type) np_tensor = np.zeros(list(tensor.dims)).astype(np_dtype)