diff --git a/docs/usage/post_training_compression/weights_compression/Usage.md b/docs/usage/post_training_compression/weights_compression/Usage.md index 40a3749498f..b73d06c8fa7 100644 --- a/docs/usage/post_training_compression/weights_compression/Usage.md +++ b/docs/usage/post_training_compression/weights_compression/Usage.md @@ -1,6 +1,6 @@ ## Weights Compression -[OpenVINO](https://github.com/openvinotoolkit/openvino) is the preferred backend to run Weights Compression with, and PyTorch is also supported. +[OpenVINO](https://github.com/openvinotoolkit/openvino) is the preferred backend to run Weights Compression with. PyTorch and Torch FX are also supported. ### The algorithm description @@ -800,7 +800,7 @@ Accuracy/footprint trade-off for `microsoft/Phi-3-mini-4k-instruct`: ### Limitations -- The algorithm is supported for OpenVINO and PyTorch models. +- The algorithm is supported for OpenVINO, PyTorch and Torch FX models. - The compression applies in-place. - The compressed model is not trainable. - INT4_SYM, INT4_ASYM, NF4 and E2M1 modes, grouped quantization and mixed precision selection is available for OpenVINO backend only. diff --git a/nncf/experimental/torch/fx/nncf_graph_builder.py b/nncf/experimental/torch/fx/nncf_graph_builder.py index 946ac27ce84..737b329da4b 100644 --- a/nncf/experimental/torch/fx/nncf_graph_builder.py +++ b/nncf/experimental/torch/fx/nncf_graph_builder.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import Counter from typing import Tuple import torch.fx @@ -64,6 +65,22 @@ def _get_layer_attributes( ) return None + def _map_fx_unique_metatypes(node: torch.fx.Node, metatype: om.OperatorMetatype) -> om.OperatorMetatype: + """ + Attempts to retrieve correct subtype for the given node. + + :param node: Given node. + :param metatype: Given node metatype. + :param model: Target GraphModule instance. + :return: Correct FX metatype of the given node if it is exist or the original node metatype otherwise. + """ + if metatype in [om.PTEmbeddingMetatype]: + weight_node = node.args[0] + if weight_node.op == "get_attr": + return om.PTAtenEmbeddingMetatype + + return metatype + @staticmethod def _get_node_type_and_metatype( node: torch.fx.Node, model: torch.fx.GraphModule @@ -115,16 +132,18 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph: :param model: torch fx GraphModule. :return: NNCFGraph. """ - nncf_graph = PTNNCFGraph() + const_targets_counter = Counter([node.target for node in model.graph.nodes if node.op == "get_attr"]) for source_node in model.graph.nodes: node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node, model) + node_metatype = GraphConverter._map_fx_unique_metatypes(source_node, node_metatype) + is_shared_node = source_node.op in ("get_attr",) and ( + const_targets_counter[source_node.target] > 1 or len(source_node.users) > 1 + ) nncf_graph.add_nncf_node( - node_name=source_node.name, - node_type=node_type, - node_metatype=node_metatype, + node_name=source_node.name, node_type=node_type, node_metatype=node_metatype, is_shared=is_shared_node ) for source_node in model.graph.nodes: @@ -134,7 +153,6 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph: input_port_id, output_port_id, tensor_shape = GraphConverter.get_edge_params( model, source_node, source_nncf_node, dist_node, idx ) - nncf_graph.add_edge_between_nncf_nodes( source_nncf_node.node_id, dist_node_id, @@ -160,7 +178,7 @@ def get_edge_params( :param source_node: Source node in format of torch.fx.Node. :param source_nncf_node: Source node in format of NNCFNode. :param dist_node: Distance node in format of torch.fx.Node. - :param output_idx: Output indes of the source_node. + :param output_idx: Output index of the source_node. :return: Tuple of edge parameters: edge input port id, edge output port id and edge tensor shape. """ diff --git a/nncf/experimental/torch/fx/quantization/quantize_model.py b/nncf/experimental/torch/fx/quantization/quantize_model.py index d65dfd61f6d..8061f2ab2f4 100644 --- a/nncf/experimental/torch/fx/quantization/quantize_model.py +++ b/nncf/experimental/torch/fx/quantization/quantize_model.py @@ -28,11 +28,16 @@ from nncf.data import Dataset from nncf.experimental.torch.fx.transformations import apply_quantization_transformations from nncf.experimental.torch.fx.transformations import revert_quantization_transformations +from nncf.experimental.torch.fx.transformations import shared_constants_unification_transformation +from nncf.parameters import CompressWeightsMode from nncf.parameters import ModelType from nncf.parameters import QuantizationMode +from nncf.parameters import SensitivityMetric from nncf.parameters import TargetDevice +from nncf.quantization.advanced_parameters import AdvancedCompressionParameters from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization +from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression from nncf.scopes import IgnoredScope DEFAULT_RANGE_TYPE = "mean_min_max" @@ -49,7 +54,7 @@ def quantize_impl( model_type: Optional[ModelType] = None, ignored_scope: Optional[IgnoredScope] = None, advanced_parameters: Optional[AdvancedQuantizationParameters] = None, -) -> torch.nn.Module: +) -> torch.fx.GraphModule: """ Implementation of the `quantize()` method for the Torch FX backend. """ @@ -103,3 +108,46 @@ def quantize_impl( quantized_model = _disallow_eval_train(quantized_model) return quantized_model + + +def compress_weights_impl( + model: torch.fx.GraphModule, + dataset: Dataset, + mode: CompressWeightsMode, + ratio: float, + group_size: int, + ignored_scope: IgnoredScope, + all_layers: bool, + sensitivity_metric: SensitivityMetric, + awq: bool, + subset_size: int, + scale_estimation: bool, + gptq: bool, + lora_correction: bool, + advanced_parameters: Optional[AdvancedCompressionParameters] = None, +) -> torch.fx.GraphModule: + """ + Implementation of the `compress_weights()` method for the Torch Fx backend. + """ + + compression_algorithm = WeightCompression( + mode, + ratio, + group_size, + ignored_scope, + all_layers, + sensitivity_metric, + awq, + subset_size, + scale_estimation, + gptq, + lora_correction, + advanced_parameters, + ) + shared_constants_unification_transformation(model) + graph = NNCFGraphFactory.create(model) + compressed_model = compression_algorithm.apply(model, graph, dataset=dataset) + compressed_model = GraphModule(compressed_model, compressed_model.graph) + compressed_model = _disallow_eval_train(compressed_model) + + return compressed_model diff --git a/nncf/experimental/torch/fx/transformations.py b/nncf/experimental/torch/fx/transformations.py index d1eb55b67d4..d097f318c37 100644 --- a/nncf/experimental/torch/fx/transformations.py +++ b/nncf/experimental/torch/fx/transformations.py @@ -28,7 +28,9 @@ TransformationFNType = Callable[[torch.fx.GraphModule], None] -def _set_new_node_meta(new_node: torch.fx.Node, prev_node: torch.fx.Node, target_module: torch.nn.Module): +def _set_new_node_meta( + new_node: torch.fx.Node, prev_node: torch.fx.Node, target_module: torch.nn.Module, model: torch.fx.GraphModule +): """ Sets correct meta \"val\" value to the new node. @@ -37,7 +39,11 @@ def _set_new_node_meta(new_node: torch.fx.Node, prev_node: torch.fx.Node, target New node expected to have only one input node. :param target_module: Module which is being called by the new node. """ - val = prev_node.meta["val"] + val = ( + prev_node.meta["val"] + if prev_node.op not in ["get_attr"] + else get_tensor_constant_from_node(prev_node, model).data + ) val = val if isinstance(val, tuple) else (val,) retval = [] for t in val: @@ -71,16 +77,16 @@ def module_insertion_transformation(model: torch.fx.GraphModule): target_node = get_graph_node_by_name(graph, target_point.target_node_name) if target_point.target_type == TargetType.OPERATOR_POST_HOOK: - _set_new_node_meta(new_node, target_node, module_to_insert) + _set_new_node_meta(new_node, target_node, module_to_insert, model) with graph.inserting_after(target_node): - for user in target_node.users: + for user in list(target_node.users): if user is new_node: continue user.replace_input_with(target_node, new_node) else: prev_node = target_node.args[target_point.input_port_id] - _set_new_node_meta(new_node, prev_node, module_to_insert) + _set_new_node_meta(new_node, prev_node, module_to_insert, model) target_node.replace_input_with(prev_node, new_node) return module_insertion_transformation @@ -136,17 +142,42 @@ def bias_update_transformation(model: torch.fx.GraphModule): return bias_update_transformation -def constant_update_transformation_builder(node: NNCFNode, value: torch.Tensor) -> TransformationFNType: +def shared_constants_unification_transformation(model: torch.fx.GraphModule): + """ + checks FX graph for shared constants and eliminates redundant + shared constant while keeping only the first instance of the constant node. + This unification transformation is cruicial since the current algorithms(min_max, solver, BC, etc.) + for torch fx do not utilize the is_shared attribute of nodes for shared constants. + + :param model: Target Torch FX GraphModule + """ + prev_targets = {} + + for source_node in model.graph.nodes: + dist_node = list(source_node.users) + if source_node.target in prev_targets and source_node.op in ("get_attr",): + dist_node[0].replace_input_with(source_node, prev_targets[source_node.target]) + else: + prev_targets[source_node.target] = source_node + + model.graph.eliminate_dead_code() + model.recompile() + + +def constant_update_transformation_builder( + node: NNCFNode, value: torch.Tensor, input_port_id: int = 1 +) -> TransformationFNType: """ Return transformation which updates constant of the given node to the given value. :param node: Node which requires bias constant update. :param value: New value to use as the node constant. + :param input_port_id: Port Id of the constant. :return: Transformation which updates constant of the given node to the given value. """ def constant_update_transformation(model: torch.fx.GraphModule): - constant_update_fn(model, get_graph_node_by_name(model.graph, node.node_name), value, input_port_id=1) + constant_update_fn(model, get_graph_node_by_name(model.graph, node.node_name), value, input_port_id) return constant_update_transformation @@ -161,9 +192,6 @@ def constant_update_fn(model: torch.fx.GraphModule, node: torch.fx.Node, value: :param input_port_id: Target constant input port id. """ graph = model.graph - with graph.inserting_before(node): - new_constant = create_getattr_from_value(model, graph, node.name + "_updated_constant", value) - args = list(node.args) # A bias node suppose to have constant on the second input port. if args[input_port_id].op != "get_attr": @@ -174,11 +202,14 @@ def constant_update_fn(model: torch.fx.GraphModule, node: torch.fx.Node, value: # Update metadata of the new constant node. previous_const = args[input_port_id] - new_constant.meta = copy(previous_const.meta) - new_constant.meta["val"] = value + consumer_nodes = list(previous_const.users) + # This list of consumer nodes will always be topologically sorted + # To ensure the updated node has the right order, + # we insert constant node before the node placed at the highest order in topological order. + with graph.inserting_before(consumer_nodes[0]): + new_constant = create_getattr_from_value(model, graph, node.name + "_updated_constant", value) - args[input_port_id] = new_constant - node.args = tuple(args) + previous_const.replace_all_uses_with(new_constant, propagate_meta=True) graph.eliminate_dead_code() @@ -509,6 +540,7 @@ def apply_quantization_transformations(model: torch.fx.GraphModule) -> None: fuse_conv_bn(model) separate_conv_and_bias(model) separate_linear_and_bias(model) + shared_constants_unification_transformation(model) def revert_quantization_transformations(model: torch.fx.GraphModule) -> None: diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 4dd03496f2d..24fc509c85e 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -135,7 +135,7 @@ def __init__( @property def available_backends(self) -> List[BackendType]: - return [BackendType.OPENVINO, BackendType.TORCH] + return [BackendType.OPENVINO, BackendType.TORCH, BackendType.TORCH_FX] def _set_backend_entity(self, model: TModel) -> None: """ @@ -152,6 +152,10 @@ def _set_backend_entity(self, model: TModel) -> None: from nncf.quantization.algorithms.weight_compression.torch_backend import PTWeightCompressionAlgoBackend self._backend_entity = PTWeightCompressionAlgoBackend() + elif model_backend == BackendType.TORCH_FX: + from nncf.quantization.algorithms.weight_compression.torch_fx_backend import FXWeightCompressionAlgoBackend + + self._backend_entity = FXWeightCompressionAlgoBackend() else: raise nncf.UnsupportedBackendError( "Cannot return backend-specific entity because {} is not supported!".format(model_backend.value) diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index de0af0cbb3b..f46d9727d63 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -50,7 +50,7 @@ class PTWeightCompressionAlgoBackend(WeightCompressionAlgoBackend): TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK, } MATMUL_METATYPES = [om.PTLinearMetatype, om.PTMatMulMetatype, om.PTAddmmMetatype] - EMBEDDING_METATYPES = [om.PTEmbeddingMetatype] + EMBEDDING_METATYPES = [om.PTEmbeddingMetatype, om.PTAtenEmbeddingMetatype] CONVOLUTION_METATYPES = [ om.PTConv1dMetatype, om.PTConv2dMetatype, diff --git a/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py b/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py new file mode 100644 index 00000000000..ca3e2d16331 --- /dev/null +++ b/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py @@ -0,0 +1,234 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Iterable, List, Optional, Tuple + +import torch +import torch.fx + +import nncf +import nncf.errors +import nncf.tensor +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.operator_metatypes import OperatorMetatype +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand +from nncf.experimental.torch.fx.model_transformer import FXModelTransformer +from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name +from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node +from nncf.experimental.torch.fx.transformations import constant_update_transformation_builder +from nncf.experimental.torch.fx.transformations import module_insertion_transformation_builder +from nncf.parameters import CompressWeightsMode +from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend +from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters +from nncf.quantization.algorithms.weight_compression.lora_correction import LoraCorrectionAlgorithm +from nncf.quantization.algorithms.weight_compression.torch_backend import PTWeightCompressionAlgoBackend +from nncf.quantization.algorithms.weight_compression.weight_lowering import compress_weight +from nncf.tensor import Tensor +from nncf.tensor.definitions import TensorDataType +from nncf.torch.graph import operator_metatypes as om +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.model_graph_manager import get_const_node +from nncf.torch.model_graph_manager import get_weight_tensor_port_ids +from nncf.torch.quantization.layers import AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import SymmetricWeightsDecompressor +from nncf.torch.tensor_statistics.collectors import get_raw_stat_collector + + +class FXWeightCompressionAlgoBackend(WeightCompressionAlgoBackend): + MATMUL_METATYPES = PTWeightCompressionAlgoBackend.MATMUL_METATYPES + EMBEDDING_METATYPES = PTWeightCompressionAlgoBackend.EMBEDDING_METATYPES + CONVOLUTION_METATYPES = PTWeightCompressionAlgoBackend.CONVOLUTION_METATYPES + + @property + def matmul_metatypes(self) -> List[OperatorMetatype]: + return FXWeightCompressionAlgoBackend.MATMUL_METATYPES + + @property + def embedding_metatypes(self) -> List[OperatorMetatype]: + return FXWeightCompressionAlgoBackend.EMBEDDING_METATYPES + + @property + def convolution_metatypes(self) -> List[OperatorMetatype]: + return FXWeightCompressionAlgoBackend.CONVOLUTION_METATYPES + + @staticmethod + def is_node_with_weights(node: NNCFNode, graph: NNCFGraph) -> bool: + return PTWeightCompressionAlgoBackend.is_node_with_weights(node, graph) + + @staticmethod + def get_weight_names_and_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Tuple[str, int]]: + port_ids = get_weight_tensor_port_ids(node, graph) + weight_name_port_ids = [(get_const_node(node, pid, graph).node_name, pid) for pid in port_ids] + return weight_name_port_ids + + @staticmethod + def get_reduction_axes(node_with_weight: NNCFNode, weight_port_id: int, graph: NNCFGraph) -> Optional[Tuple[int]]: + weight_node = get_const_node(node_with_weight, weight_port_id, graph) + edge = graph.get_edge(weight_node, graph.get_next_nodes(weight_node)[0]) + + ndims = len(edge.tensor_shape) + reduction_axes = None + if node_with_weight.metatype == om.PTAtenEmbeddingMetatype: + reduction_axes = [1] + elif node_with_weight.metatype == om.PTLinearMetatype: + reduction_axes = [ndims - 1] + elif node_with_weight.metatype == om.PTMatMulMetatype: + if weight_port_id == 0: + reduction_axes = [ndims - 1] + elif weight_port_id == 1: + reduction_axes = [max(0, ndims - 2)] + elif node_with_weight.metatype == om.PTAddmmMetatype: + if weight_port_id == 1: + reduction_axes = [ndims - 1] + elif weight_port_id == 2: + reduction_axes = [max(0, ndims - 2)] + elif node_with_weight.metatype in FXWeightCompressionAlgoBackend.CONVOLUTION_METATYPES: + channel_idx = ( + 1 + if node_with_weight.metatype + in [om.PTConvTranspose1dMetatype, om.PTConvTranspose2dMetatype, om.PTConvTranspose3dMetatype] + else 0 + ) + reduction_axes = [i for i in range(ndims) if i != channel_idx] + return tuple(reduction_axes) + + @staticmethod + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: + return PTWeightCompressionAlgoBackend.target_point(target_type, target_node_name, port_id) + + @staticmethod + def raw_statistic_collector(num_samples: Optional[int] = None) -> TensorCollector: + return get_raw_stat_collector(num_samples) + + @staticmethod + def get_activation_port_id(node: NNCFNode, graph: NNCFGraph) -> int: + return PTWeightCompressionAlgoBackend.get_activation_port_id(node, graph) + + def get_weight( + self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.fx.GraphModule, graph: NNCFGraph + ) -> Tensor: + weight_edge = graph.get_input_edge_by_port_id(node_with_weight, weight_port_id) + weight_node = weight_edge.from_node + graph_weight_node = get_graph_node_by_name(model.graph, weight_node.node_name) + weight = get_tensor_constant_from_node(graph_weight_node, model).data + if weight is None: + raise nncf.InternalError(f"Could not find a node in the model by name {weight_node}.") + + return Tensor(weight) + + def get_weight_dtype( + self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.fx.GraphModule, graph: NNCFGraph + ) -> TensorDataType: + return self.get_weight(node_with_weight, weight_port_id, model, graph).dtype + + @staticmethod + def get_weight_shape(node_with_weight: NNCFNode, weight_port_id: int, graph: NNCFGraph) -> Tuple: + weight_node = get_const_node(node_with_weight, weight_port_id, graph) + edge = graph.get_edge(weight_node, node_with_weight) + return tuple(edge.tensor_shape) + + def set_weight( + self, + node_with_weight: NNCFNode, + weight_port_id: int, + model: torch.fx.GraphModule, + graph: NNCFGraph, + weight: Tensor, + ) -> None: + constant_update_transformation_builder(node_with_weight, weight.data, input_port_id=weight_port_id)(model) + + def insert_adapters( + self, wc_params: WeightCompressionParameters, lora_A: Tensor, lora_B: Tensor, int8_lora: bool + ) -> None: + pass + + def transform_model( + self, + model: torch.fx.GraphModule, + graph: NNCFGraph, + weight_compression_parameters: Iterable[WeightCompressionParameters], + precomputed_scales: Dict[str, Tensor] = None, + precomputed_zero_points: Dict[str, Tensor] = None, + lora_correction_algo: LoraCorrectionAlgorithm = None, + ) -> torch.fx.GraphModule: + transformation_layout = TransformationLayout() + + for wc_params in weight_compression_parameters: + compression_config = wc_params.compression_config + if compression_config.mode not in [ + CompressWeightsMode.INT8_ASYM, + CompressWeightsMode.INT8_SYM, + CompressWeightsMode.INT8, + ]: + raise ValueError(f"{compression_config.mode.value} is not supported.") + weight_node = get_const_node(wc_params.node_with_weight, wc_params.weight_port_id, graph) + weight_name = weight_node.node_name + weight = self.get_weight(wc_params.node_with_weight, wc_params.weight_port_id, model, graph) + if weight is None or not isinstance(weight, Tensor): + raise nncf.InternalError(f"Could not find a nncf.tensor in the model by name {weight_name}.") + + # calculates compressed weights and decompression parameters + compressed_weight = compress_weight( + weight, + wc_params.reduction_axes, + compression_config, + None if precomputed_scales is None else precomputed_scales.get(wc_params.weight_name), + None if precomputed_zero_points is None else precomputed_zero_points.get(wc_params.weight_name), + ) + compressed_weight.scale = compressed_weight.scale.astype(dtype=TensorDataType.float16) + + # pack compressed tensor + if compression_config.mode == CompressWeightsMode.INT8_SYM: + dtype = TensorDataType.int8 + else: + dtype = TensorDataType.uint8 + packed_tensor = compressed_weight.tensor.astype(dtype) + + self.set_weight(wc_params.node_with_weight, wc_params.weight_port_id, model, graph, packed_tensor) + # creates weight decompressor + if compression_config.mode == CompressWeightsMode.INT8_SYM: + decompressor = SymmetricWeightsDecompressor( + compressed_weight.scale.data, result_dtype=weight.data.dtype + ) + decompressor_type = "symmetric" + else: + packed_zero_point = compressed_weight.zero_point.astype(dtype) + decompressor = AsymmetricWeightsDecompressor( + compressed_weight.scale.data, packed_zero_point.data, result_dtype=weight.data.dtype + ) + decompressor_type = "asymmetric" + + # register weight decompression module in the model + graph_weight_node = get_graph_node_by_name(model.graph, wc_params.node_with_weight.node_name) + compressed_weight_name = graph_weight_node.all_input_nodes[wc_params.weight_port_id].name + + decompressor_suffix = "_".join(compressed_weight_name.replace(".", "_").split("_")[:-2]) + decompressor_name = f"{decompressor_type}_weights_decompressor_{decompressor_suffix}" + + # inserts the weight decompressor into the model as the post hook on the model weight + transformation_layout.register( + FXApplyTransformationCommand( + module_insertion_transformation_builder( + decompressor, + [PTTargetPoint(TargetType.OPERATOR_POST_HOOK, target_node_name=compressed_weight_name)], + decompressor_name, + ) + ) + ) + + # apply transformations + transformed_model = FXModelTransformer(model).transform(transformation_layout) + + return transformed_model diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 63993ee58fc..1f633458e98 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -490,6 +490,28 @@ def compress_weights( dataset = None compression_weights_impl = pt_compression_weights_impl + if backend == BackendType.TORCH_FX: + from nncf.experimental.torch.fx.quantization.quantize_model import ( + compress_weights_impl as fx_compression_weights_impl, + ) + + if mode not in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM]: + raise AttributeError( + "TorchFX backend supports only INT8_ASYM, INT8_SYM modes for weight compression, " + f"but given {mode.value} mode." + ) + + if any((awq, scale_estimation, gptq, lora_correction)): + raise AttributeError( + "TorchFX backend does not support 'awq', 'scale_estimation', 'gptq'," + "and 'lora_correction' options. Set them to None." + ) + if dataset: + raise AttributeError( + "TorchFX only supports data-free weights compression," "Set the 'dataset' option to None" + ) + compression_weights_impl = fx_compression_weights_impl + if backend == BackendType.OPENVINO: from nncf.openvino.quantization.quantize_model import compress_weights_impl as ov_compress_weights_impl diff --git a/nncf/torch/graph/operator_metatypes.py b/nncf/torch/graph/operator_metatypes.py index 7a182c0d869..8eda5611049 100644 --- a/nncf/torch/graph/operator_metatypes.py +++ b/nncf/torch/graph/operator_metatypes.py @@ -28,6 +28,7 @@ ModuleAttributes = TypeVar("ModuleAttributes", bound=BaseLayerAttributes) PT_OPERATOR_METATYPES = OperatorMetatypeRegistry("operator_metatypes") +FX_OPERATOR_METATYPES = OperatorMetatypeRegistry("operator_metatypes") class PTOperatorMetatype(OperatorMetatype): @@ -918,6 +919,14 @@ class PTEmbeddingMetatype(PTOperatorMetatype): weight_port_ids = [1] +@FX_OPERATOR_METATYPES.register() +class PTAtenEmbeddingMetatype(OperatorMetatype): + name = "EmbeddingOp" + module_to_function_names = {NamespaceTarget.ATEN: ["embedding"]} + hw_config_names = [HWConfigOpName.EMBEDDING] + weight_port_ids = [0] + + @PT_OPERATOR_METATYPES.register(is_subtype=True) class PTModuleEmbeddingBagMetatype(PTModuleOperatorSubtype): name = "EmbeddingBagOp" diff --git a/tests/torch/data/reference_graphs/fx/quantized/synthetic_transformer.dot b/tests/torch/data/reference_graphs/fx/quantized/synthetic_transformer.dot new file mode 100644 index 00000000000..c274e68f66b --- /dev/null +++ b/tests/torch/data/reference_graphs/fx/quantized/synthetic_transformer.dot @@ -0,0 +1,53 @@ +strict digraph { +"0 arg0_1" [id=0, type=input]; +"1 _param_constant0" [id=1, type=get_attr]; +"2 embedding" [id=2, type=embedding]; +"3 linear_updated_constant0" [id=3, type=get_attr]; +"4 embedding_0_0_nncf_smooth_quant_0" [id=4, type=call_module]; +"5 quantize_per_tensor_default" [id=5, type=quantize_per_tensor]; +"6 dequantize_per_tensor_default" [id=6, type=dequantize_per_tensor]; +"7 linear_scale_0" [id=7, type=get_attr]; +"8 linear_zero_point_0" [id=8, type=get_attr]; +"9 quantize_per_channel_default" [id=9, type=quantize_per_channel]; +"10 dequantize_per_channel_default" [id=10, type=dequantize_per_channel]; +"11 _param_constant2_0_0" [id=11, type=get_attr]; +"12 linear" [id=12, type=linear]; +"13 linear_1_updated_constant0" [id=13, type=get_attr]; +"14 add_tensor_0_0_nncf_smooth_quant_0" [id=14, type=call_module]; +"15 quantize_per_tensor_default_1" [id=15, type=quantize_per_tensor]; +"16 dequantize_per_tensor_default_1" [id=16, type=dequantize_per_tensor]; +"17 linear_1_scale_0" [id=17, type=get_attr]; +"18 linear_1_zero_point_0" [id=18, type=get_attr]; +"19 quantize_per_channel_default_1" [id=19, type=quantize_per_channel]; +"20 dequantize_per_channel_default_1" [id=20, type=dequantize_per_channel]; +"21 _param_constant4_0_0" [id=21, type=get_attr]; +"22 linear_1" [id=22, type=linear]; +"23 output" [id=23, type=output]; +"0 arg0_1" -> "2 embedding"; +"1 _param_constant0" -> "2 embedding"; +"2 embedding" -> "4 embedding_0_0_nncf_smooth_quant_0"; +"3 linear_updated_constant0" -> "9 quantize_per_channel_default"; +"4 embedding_0_0_nncf_smooth_quant_0" -> "5 quantize_per_tensor_default"; +"5 quantize_per_tensor_default" -> "6 dequantize_per_tensor_default"; +"6 dequantize_per_tensor_default" -> "12 linear"; +"7 linear_scale_0" -> "9 quantize_per_channel_default"; +"7 linear_scale_0" -> "10 dequantize_per_channel_default"; +"8 linear_zero_point_0" -> "9 quantize_per_channel_default"; +"8 linear_zero_point_0" -> "10 dequantize_per_channel_default"; +"9 quantize_per_channel_default" -> "10 dequantize_per_channel_default"; +"10 dequantize_per_channel_default" -> "12 linear"; +"11 _param_constant2_0_0" -> "12 linear"; +"12 linear" -> "14 add_tensor_0_0_nncf_smooth_quant_0"; +"13 linear_1_updated_constant0" -> "19 quantize_per_channel_default_1"; +"14 add_tensor_0_0_nncf_smooth_quant_0" -> "15 quantize_per_tensor_default_1"; +"15 quantize_per_tensor_default_1" -> "16 dequantize_per_tensor_default_1"; +"16 dequantize_per_tensor_default_1" -> "22 linear_1"; +"17 linear_1_scale_0" -> "19 quantize_per_channel_default_1"; +"17 linear_1_scale_0" -> "20 dequantize_per_channel_default_1"; +"18 linear_1_zero_point_0" -> "19 quantize_per_channel_default_1"; +"18 linear_1_zero_point_0" -> "20 dequantize_per_channel_default_1"; +"19 quantize_per_channel_default_1" -> "20 dequantize_per_channel_default_1"; +"20 dequantize_per_channel_default_1" -> "22 linear_1"; +"21 _param_constant4_0_0" -> "22 linear_1"; +"22 linear_1" -> "23 output"; +} diff --git a/tests/torch/data/reference_graphs/fx/reference_attributes/not_unified_shared_attribute_test_model.json b/tests/torch/data/reference_graphs/fx/reference_attributes/not_unified_shared_attribute_test_model.json new file mode 100644 index 00000000000..fd4489ff8b6 --- /dev/null +++ b/tests/torch/data/reference_graphs/fx/reference_attributes/not_unified_shared_attribute_test_model.json @@ -0,0 +1,20 @@ +{ + "arg0_1": false, + "_param_constant0": false, + "_param_constant1": false, + "conv2d": false, + "_param_constant2": false, + "_param_constant3": false, + "conv2d_1": false, + "_tensor_constant0": true, + "add_": false, + "_tensor_constant0_1": true, + "add__1": false, + "add": false, + "_param_constant4": false, + "_param_constant5": false, + "conv2d_2": false, + "_tensor_constant0_2": true, + "add_1": false, + "output": false +} \ No newline at end of file diff --git a/tests/torch/data/reference_graphs/fx/reference_attributes/unified_shared_attribute_test_model.json b/tests/torch/data/reference_graphs/fx/reference_attributes/unified_shared_attribute_test_model.json new file mode 100644 index 00000000000..4c57c9317d2 --- /dev/null +++ b/tests/torch/data/reference_graphs/fx/reference_attributes/unified_shared_attribute_test_model.json @@ -0,0 +1,18 @@ +{ + "arg0_1": false, + "_param_constant0": false, + "_param_constant1": false, + "conv2d": false, + "_param_constant2": false, + "_param_constant3": false, + "conv2d_1": false, + "_tensor_constant0": true, + "add_": false, + "add__1": false, + "add": false, + "_param_constant4": false, + "_param_constant5": false, + "conv2d_2": false, + "add_1": false, + "output": false +} \ No newline at end of file diff --git a/tests/torch/data/reference_graphs/fx/reference_metatypes/synthetic_transformer.json b/tests/torch/data/reference_graphs/fx/reference_metatypes/synthetic_transformer.json new file mode 100644 index 00000000000..c8375399b97 --- /dev/null +++ b/tests/torch/data/reference_graphs/fx/reference_metatypes/synthetic_transformer.json @@ -0,0 +1,12 @@ +{ + "arg0_1": "PTInputNoopMetatype", + "_param_constant0": "PTConstNoopMetatype", + "embedding": "PTAtenEmbeddingMetatype", + "_param_constant1": "PTConstNoopMetatype", + "_param_constant2": "PTConstNoopMetatype", + "linear": "PTLinearMetatype", + "_param_constant3": "PTConstNoopMetatype", + "_param_constant4": "PTConstNoopMetatype", + "linear_1": "PTLinearMetatype", + "output": "PTOutputNoopMetatype" +} \ No newline at end of file diff --git a/tests/torch/data/reference_graphs/fx/synthetic_transformer.dot b/tests/torch/data/reference_graphs/fx/synthetic_transformer.dot new file mode 100644 index 00000000000..2731f77220f --- /dev/null +++ b/tests/torch/data/reference_graphs/fx/synthetic_transformer.dot @@ -0,0 +1,21 @@ +strict digraph { +"0 arg0_1" [id=0, type=input]; +"1 _param_constant0" [id=1, type=get_attr]; +"2 embedding" [id=2, type=embedding]; +"3 _param_constant1" [id=3, type=get_attr]; +"4 _param_constant2" [id=4, type=get_attr]; +"5 linear" [id=5, type=linear]; +"6 _param_constant3" [id=6, type=get_attr]; +"7 _param_constant4" [id=7, type=get_attr]; +"8 linear_1" [id=8, type=linear]; +"9 output" [id=9, type=output]; +"0 arg0_1" -> "2 embedding"; +"1 _param_constant0" -> "2 embedding"; +"2 embedding" -> "5 linear"; +"3 _param_constant1" -> "5 linear"; +"4 _param_constant2" -> "5 linear"; +"5 linear" -> "8 linear_1"; +"6 _param_constant3" -> "8 linear_1"; +"7 _param_constant4" -> "8 linear_1"; +"8 linear_1" -> "9 output"; +} diff --git a/tests/torch/data/reference_graphs/fx/transformed/shared_constants_unification_transformation_test.dot b/tests/torch/data/reference_graphs/fx/transformed/shared_constants_unification_transformation_test.dot new file mode 100644 index 00000000000..7b649047c71 --- /dev/null +++ b/tests/torch/data/reference_graphs/fx/transformed/shared_constants_unification_transformation_test.dot @@ -0,0 +1,36 @@ +strict digraph { +"0 arg0_1" [id=0, type=input]; +"1 _param_constant0" [id=1, type=get_attr]; +"2 _param_constant1" [id=2, type=get_attr]; +"3 conv2d" [id=3, type=conv2d]; +"4 _param_constant2" [id=4, type=get_attr]; +"5 _param_constant3" [id=5, type=get_attr]; +"6 conv2d_1" [id=6, type=conv2d]; +"7 _tensor_constant0" [id=7, type=get_attr]; +"8 add_" [id=8, type=add_]; +"9 add__1" [id=9, type=add_]; +"10 add" [id=10, type=add]; +"11 _param_constant4" [id=11, type=get_attr]; +"12 _param_constant5" [id=12, type=get_attr]; +"13 conv2d_2" [id=13, type=conv2d]; +"14 add_1" [id=14, type=add]; +"15 output" [id=15, type=output]; +"0 arg0_1" -> "3 conv2d"; +"1 _param_constant0" -> "3 conv2d"; +"2 _param_constant1" -> "3 conv2d"; +"3 conv2d" -> "6 conv2d_1"; +"3 conv2d" -> "8 add_"; +"4 _param_constant2" -> "6 conv2d_1"; +"5 _param_constant3" -> "6 conv2d_1"; +"6 conv2d_1" -> "9 add__1"; +"7 _tensor_constant0" -> "8 add_"; +"7 _tensor_constant0" -> "9 add__1"; +"7 _tensor_constant0" -> "14 add_1"; +"8 add_" -> "10 add"; +"9 add__1" -> "10 add"; +"10 add" -> "13 conv2d_2"; +"11 _param_constant4" -> "13 conv2d_2"; +"12 _param_constant5" -> "13 conv2d_2"; +"13 conv2d_2" -> "14 add_1"; +"14 add_1" -> "15 output"; +} diff --git a/tests/torch/fx/test_compress_weights.py b/tests/torch/fx/test_compress_weights.py new file mode 100644 index 00000000000..1d5012d5d57 --- /dev/null +++ b/tests/torch/fx/test_compress_weights.py @@ -0,0 +1,258 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import pytest +import torch +from torch._export import capture_pre_autograd_graph + +from nncf import CompressWeightsMode +from nncf.common.factory import NNCFGraphFactory +from nncf.data.dataset import Dataset +from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node +from nncf.quantization import compress_weights +from nncf.torch.dynamic_graph.patch_pytorch import disable_patching +from tests.torch.ptq.test_weights_compression import ALL_SENSITIVITY_METRICS +from tests.torch.ptq.test_weights_compression import SUPPORTED_MODES +from tests.torch.ptq.test_weights_compression import UNSUPPORTED_MODES +from tests.torch.ptq.test_weights_compression import ConvolutionModel +from tests.torch.ptq.test_weights_compression import DTypeModel +from tests.torch.ptq.test_weights_compression import EmptyModel +from tests.torch.ptq.test_weights_compression import FunctionalModel +from tests.torch.ptq.test_weights_compression import MatMulModel +from tests.torch.ptq.test_weights_compression import ShortTransformer + + +def get_model_size(model): + param_size = 0 + for param in model.parameters(): + param_size += param.nelement() * param.element_size() + buffer_size = 0 + for buffer in model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + + model_size_mb = (param_size + buffer_size) / 1024**2 + + return model_size_mb + + +def get_compressed_modules_weights( + compressed_model: torch.fx.GraphModule, dtype: torch.dtype, compressed_node_weight_port: Dict[str, int] +): + n_target_modules = 0 + n_compressed_weights = 0 + + for node in compressed_model.graph.nodes: + if node.op == "call_function" and hasattr(node.target, "overloadpacket"): + node_type = str(node.target.overloadpacket).split(".")[1] + if node_type in compressed_node_weight_port: + n_target_modules += 1 + weight_port_id = compressed_node_weight_port[node_type] + weight_decompressor_node = node.all_input_nodes[weight_port_id] + if weight_decompressor_node.all_input_nodes: + compressed_weight_node = weight_decompressor_node.all_input_nodes[0] + weight = get_tensor_constant_from_node(compressed_weight_node, compressed_model).data + if weight.dtype == dtype: + n_compressed_weights += 1 + + return n_target_modules, n_compressed_weights + + +def _capture_model(model, inputs): + with torch.no_grad(): + with disable_patching(): + return capture_pre_autograd_graph(model, (inputs,)) + + +@pytest.mark.parametrize("mode", SUPPORTED_MODES) +def test_compress_weights(mode): + model = ShortTransformer(5, 10) + input_ids = torch.randint(0, 10, (5,)) + exported_model = _capture_model(model, input_ids) + compressed_model = compress_weights(exported_model, mode=mode) + dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 + n_compressed_weights = 0 + n_target_modules = 0 + compressed_node_weight_port = {"linear": 1, "embedding": 0} + + n_target_modules, n_compressed_weights = get_compressed_modules_weights( + compressed_model, dtype, compressed_node_weight_port + ) + assert n_target_modules == n_compressed_weights + + +@pytest.mark.parametrize("mode", SUPPORTED_MODES) +def test_compress_weights_graph_edge(mode): + model = ShortTransformer(5, 10) + input_ids = torch.randint(0, 10, (5,)) + exported_model = _capture_model(model, input_ids) + compressed_model = compress_weights(exported_model, mode=mode) + nncf_graph = NNCFGraphFactory.create(compressed_model) + for node in nncf_graph.get_all_nodes(): + if "weights_decompressor" in node.node_name and node.node_type == "call_module": + decompressor_node_edge = nncf_graph.get_input_edges(node)[0] + decompressor_constant_edge = nncf_graph.get_edge(node, nncf_graph.get_next_nodes(node)[0]) + assert decompressor_node_edge.tensor_shape == decompressor_constant_edge.tensor_shape + + +@pytest.mark.parametrize("mode", SUPPORTED_MODES) +def test_compress_weights_shared_weights(mocker, mode): + with disable_patching(): + model = ShortTransformer(5, 10, share_weights=True) + input_ids = torch.randint(0, 10, (5,)) + exported_model = _capture_model(model, input_ids) + compressed_model = compress_weights(exported_model, mode=mode) + dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 + n_compressed_weights = 0 + n_target_modules = 0 + compressed_node_weight_port = {"linear": 1, "embedding": 0} + + n_target_modules, n_compressed_weights = get_compressed_modules_weights( + compressed_model, dtype, compressed_node_weight_port + ) + assert n_target_modules == n_compressed_weights + + num_decompression_nodes = 0 + spies = [] + for node in compressed_model.graph.nodes: + if node.op == "call_module" and "decompress" in node.name: + num_decompression_nodes += 1 + decompressor_module = getattr(compressed_model, node.target) + spy = mocker.spy(decompressor_module, "forward") + spies.append(spy) + assert num_decompression_nodes == 2 + + compressed_model(input_ids) + + for spy in spies: + assert spy.call_count == 1 + + +@pytest.mark.parametrize("mode", SUPPORTED_MODES) +def test_compressed_model_inference(mode): + torch.manual_seed(42) + model = ShortTransformer(5, 10, share_weights=True) + input_ids = torch.randint(0, 10, (5,)) + exported_model = _capture_model(model, input_ids) + exported_model_output = exported_model(input_ids) + compressed_model = compress_weights(exported_model, mode=mode) + compressed_model_outputs = compressed_model(input_ids) + + assert ( + exported_model_output.shape == compressed_model_outputs.shape + ), "Compressed model output shape is not equal to the model output shape" + assert torch.all(torch.isclose(exported_model_output, compressed_model_outputs, atol=1)).item() + + +@pytest.mark.parametrize("mode", SUPPORTED_MODES) +def test_compress_weights_model_size_conv(mode): + + dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 + model = ConvolutionModel() + + input_ids = torch.randint(0, 10, [1, 3, 300, 300]) + exported_model = _capture_model(model, input_ids) + model_size = get_model_size(exported_model) + compressed_model = compress_weights(exported_model, mode=mode) + compressed_model_size = get_model_size(compressed_model) + + n_compressed_weights = 0 + n_target_modules = 0 + compressed_node_weight_port = {"linear": 1, "conv2d": 1, "conv_transpose2d": 1} + + n_target_modules, n_compressed_weights = get_compressed_modules_weights( + compressed_model, dtype, compressed_node_weight_port + ) + + assert n_compressed_weights == n_target_modules + assert compressed_model_size < model_size + + +@pytest.mark.parametrize("mode", SUPPORTED_MODES) +def test_compress_weights_functional_model(mode): + model = FunctionalModel() + decompressor_type = "symmetric" if mode == CompressWeightsMode.INT8_SYM else "asymmetric" + + input_ids = torch.randint(0, 10, [1, 3, 300, 300]) + exported_model = _capture_model(model, input_ids) + compressed_model = compress_weights(exported_model, mode=mode) + + n_compressed_weights = 0 + + for node in compressed_model.graph.nodes: + if decompressor_type in node.name: + n_compressed_weights += 1 + assert n_compressed_weights == 4 + + +@pytest.mark.parametrize("mode", SUPPORTED_MODES) +@pytest.mark.parametrize( + "params", + ( + {"ratio": 0.5}, + {"group_size": 64}, + {"all_layers": True}, + {"all_layers": False}, + *({"sensitivity_metric": metric} for metric in ALL_SENSITIVITY_METRICS), + {"gptq": True}, + {"awq": True}, + {"scale_estimation": True}, + {"lora_correction": True}, + {"dataset": Dataset([1])}, + ), +) +def test_raise_error_with_unsupported_params_for_int8(mode, params): + dummy_torch_model = EmptyModel() + dummy_input = torch.Tensor() + exported_model = _capture_model(dummy_torch_model, dummy_input) + with pytest.raises(AttributeError): + compress_weights(exported_model, mode=mode, **params) + + +@pytest.mark.parametrize("mode", UNSUPPORTED_MODES) +def test_raise_error_with_not_int8(mode): + dummy_torch_model = EmptyModel() + dummy_input = torch.Tensor() + exported_model = _capture_model(dummy_torch_model, dummy_input) + with pytest.raises(AttributeError): + compress_weights(exported_model, mode=mode) + + +def test_get_dtype_attribute_of_parameter(): + model = DTypeModel() + dummy_input = torch.randint(0, 10, [3, 3]) + exported_model = _capture_model(model, dummy_input) + compressed_model = compress_weights(exported_model) + assert compressed_model.matmul_updated_constant0.dtype == torch.uint8 + compressed_model(dummy_input) + assert compressed_model.matmul_updated_constant0.dtype == torch.uint8 + + +@pytest.mark.parametrize("dtype", ("float16", "float32")) +def test_model_devices_and_precisions(use_cuda, dtype): + if use_cuda and not torch.cuda.is_available(): + pytest.skip("Skipping for CPU-only setups") + device = torch.device("cuda" if use_cuda else "cpu") + dtype = torch.float16 if dtype == "float16" else torch.float32 + + model = MatMulModel().to(device) + if dtype == torch.float16: + model.half() + dummy_input = torch.rand((1, 300), dtype=dtype, device=device) + exported_model = _capture_model(model, dummy_input) + compressed_model = compress_weights(exported_model) + result = compressed_model(dummy_input) + + # Scale should always be in float16 + assert compressed_model.state_dict()["asymmetric_weights_decompressor_matmul._scale"].dtype == torch.float16 + # Result should be in the precision of the model + assert result.dtype == dtype diff --git a/tests/torch/fx/test_model_transformer.py b/tests/torch/fx/test_model_transformer.py index 2dc6d251d5f..f619c955336 100644 --- a/tests/torch/fx/test_model_transformer.py +++ b/tests/torch/fx/test_model_transformer.py @@ -17,12 +17,19 @@ import torch from torch._export import capture_pre_autograd_graph +from nncf.common.factory import NNCFGraph +from nncf.common.factory import NNCFGraphFactory from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout from nncf.experimental.torch.fx.model_transformer import FXModelTransformer from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter +from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name +from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node +from nncf.experimental.torch.fx.transformations import constant_update_transformation_builder from nncf.experimental.torch.fx.transformations import output_insertion_transformation_builder +from nncf.experimental.torch.fx.transformations import shared_constants_unification_transformation from nncf.torch import disable_patching +from nncf.torch.graph.operator_metatypes import CONST_NOOP_METATYPES from nncf.torch.graph.transformations.commands import PTModelExtractionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint from tests.torch.test_compressed_graph import check_graph @@ -124,3 +131,60 @@ def test_output_insertion_transformation(tuple_output, target_point): check_graph( nncf_graph, f"output_insertion_{_target_point_to_str(target_point)}_ref.dot", TRANSFORMED_GRAPH_DIR_NAME ) + + +def count_constants(model) -> int: + num_constant_nodes = 0 + for node in model.graph.nodes: + if node.op == "get_attr": + num_constant_nodes += 1 + return num_constant_nodes + + +def test_create_shared_constant_transformation(): + model = MultiBranchesConnectedModel() + ex_inputs = torch.ones((1, 3, 3, 3)) + captured_model = _capture_model(model, ex_inputs) + shared_constants_unification_transformation(captured_model) + nncf_graph = GraphConverter.create_nncf_graph(captured_model) + check_graph(nncf_graph, "shared_constants_unification_transformation_test.dot", TRANSFORMED_GRAPH_DIR_NAME) + + +def get_shared_constant_nodes(nncf_graph: NNCFGraph): + """ + Gets a dict of constant nodes as key and consumer nodes as values which are shared in the model. + eg: + const + / \ + node1 node2 + + returns ({const:[node1, node2]}) + """ + shared_const_node_consumer_node = {} + for node in nncf_graph.get_all_nodes(): + consumer_nodes = nncf_graph.get_next_nodes(node) + if node.metatype in CONST_NOOP_METATYPES and len(consumer_nodes) > 1: + shared_const_node_consumer_node[node] = consumer_nodes + return shared_const_node_consumer_node + + +def test_update_shared_constant(): + model = MultiBranchesConnectedModel() + ex_inputs = torch.ones((1, 3, 3, 3)) + captured_model = _capture_model(model, ex_inputs) + + shared_constants_unification_transformation(captured_model) + nncf_graph = NNCFGraphFactory.create(captured_model) + shared_constants_consumers_dict = get_shared_constant_nodes(nncf_graph) + + # This returns all the constant nodes as keys and list of consumer as values + consumer_nodes = list(shared_constants_consumers_dict.values())[0] + + constant_update_transformation_builder(consumer_nodes[0], torch.tensor([100]))(captured_model) + + nncf_graph_updated_constant = NNCFGraphFactory.create(captured_model) + updated_const_node = nncf_graph_updated_constant.get_previous_nodes(consumer_nodes[1])[1] + fx_node_to_check_const = get_graph_node_by_name(captured_model.graph, updated_const_node.node_name) + fx_node_to_check_const_value = get_tensor_constant_from_node(fx_node_to_check_const, captured_model) + + assert fx_node_to_check_const_value == torch.tensor([100]) diff --git a/tests/torch/fx/test_models.py b/tests/torch/fx/test_models.py index 34edadcbfdb..7ff1e446dc3 100644 --- a/tests/torch/fx/test_models.py +++ b/tests/torch/fx/test_models.py @@ -14,7 +14,7 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Callable, Dict, Tuple, Type +from typing import Callable, Dict, Tuple, Type, Union import openvino.torch # noqa import pytest @@ -32,11 +32,14 @@ from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.utils.os import safe_open from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter +from nncf.experimental.torch.fx.transformations import shared_constants_unification_transformation from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.torch.dynamic_graph.patch_pytorch import disable_patching from tests.cross_fw.shared.paths import TEST_ROOT from tests.torch import test_models +from tests.torch.ptq.test_weights_compression import ShortTransformer from tests.torch.test_compressed_graph import check_graph +from tests.torch.test_models.synthetic import MultiBranchesConnectedModel FX_DIR_NAME = Path("fx") FX_QUANTIZED_DIR_NAME = Path("fx") / "quantized" @@ -60,6 +63,7 @@ def torchvision_model_case(model_id: str, input_shape: Tuple[int,]): torchvision_model_case("vit_b_16", (1, 3, 224, 224)), torchvision_model_case("swin_v2_s", (1, 3, 224, 224)), ModelCase(test_models.UNet, "unet", [1, 3, 224, 224]), + ModelCase(partial(ShortTransformer, 5, 10), "synthetic_transformer", [5]), ) @@ -71,18 +75,25 @@ def get_json_filename(model_name): return model_name + ".json" -def get_full_path_to_json(model_json_name: str) -> str: - path_to_dir = TEST_ROOT / "torch" / "data" / "reference_graphs" / "fx" / "reference_metatypes" +def get_full_path_to_json(model_json_name: str, attributes: bool = False) -> str: + property_to_check = "reference_metatypes" if not attributes else "reference_attributes" + path_to_dir = TEST_ROOT / "torch" / "data" / "reference_graphs" / "fx" / property_to_check path_to_json = path_to_dir / model_json_name return path_to_json -def get_ref_metatypes_from_json( - model_name: str, model_metatypes: Dict[NNCFNodeName, Type[OperatorMetatype]] -) -> Dict[NNCFNodeName, Type[OperatorMetatype]]: +def _capture_model(model: torch.nn.Module, inputs: torch.Tensor) -> torch.fx.GraphModule: + with torch.no_grad(): + with disable_patching(): + return capture_pre_autograd_graph(model, (inputs,)) + + +def get_ref_from_json( + model_name: str, model_metatypes: Dict[NNCFNodeName, Union[Type[OperatorMetatype], bool]], attributes=False +) -> Dict[NNCFNodeName, Union[Type[OperatorMetatype], bool]]: model_json_name = get_json_filename(model_name) - complete_path = get_full_path_to_json(model_json_name) + complete_path = get_full_path_to_json(model_json_name, attributes) json_parent_dir = Path(complete_path).parent @@ -98,26 +109,26 @@ def get_ref_metatypes_from_json( @pytest.mark.parametrize("test_case", TEST_MODELS, ids=[m.model_id for m in TEST_MODELS]) def test_model(test_case: ModelCase): - with disable_patching(): - device = torch.device("cpu") - model_name = test_case.model_id - model = test_case.model_builder() - model.to(device) + device = torch.device("cpu") + model_name = test_case.model_id + model = test_case.model_builder() + model.to(device) - with torch.no_grad(): - ex_input = torch.ones(test_case.input_shape) - model.eval() - exported_model = capture_pre_autograd_graph(model, args=(ex_input,)) - nncf_graph = GraphConverter.create_nncf_graph(exported_model) + with torch.no_grad(): + dtype = torch.int32 if test_case.model_id == "synthetic_transformer" else torch.float32 + ex_input = torch.ones(test_case.input_shape, dtype=dtype) + model.eval() + exported_model = _capture_model(model, ex_input) + nncf_graph = GraphConverter.create_nncf_graph(exported_model) - # Check NNCFGrpah - dot_filename = get_dot_filename(model_name) - check_graph(nncf_graph, dot_filename, FX_DIR_NAME) + # Check NNCFGrpah + dot_filename = get_dot_filename(model_name) + check_graph(nncf_graph, dot_filename, FX_DIR_NAME) - # Check metatypes - model_metatypes = {n.node_name: n.metatype.__name__ for n in nncf_graph.get_all_nodes()} - ref_metatypes = get_ref_metatypes_from_json(model_name, model_metatypes) - assert model_metatypes == ref_metatypes + # Check metatypes + model_metatypes = {n.node_name: n.metatype.__name__ for n in nncf_graph.get_all_nodes()} + ref_metatypes = get_ref_from_json(model_name, model_metatypes) + assert model_metatypes == ref_metatypes TEST_MODELS_QUANIZED = ( @@ -126,6 +137,10 @@ def test_model(test_case: ModelCase): (torchvision_model_case("mobilenet_v3_small", (1, 3, 224, 224)), {}), (torchvision_model_case("vit_b_16", (1, 3, 224, 224)), {"model_type": nncf.ModelType.TRANSFORMER}), (torchvision_model_case("swin_v2_s", (1, 3, 224, 224)), {"model_type": nncf.ModelType.TRANSFORMER}), + ( + ModelCase(partial(ShortTransformer, 5, 10), "synthetic_transformer", [5]), + {"model_type": nncf.ModelType.TRANSFORMER}, + ), ) @@ -133,26 +148,41 @@ def test_model(test_case: ModelCase): ("model_case", "quantization_parameters"), TEST_MODELS_QUANIZED, ids=[m[0].model_id for m in TEST_MODELS_QUANIZED] ) def test_quantized_model(model_case: ModelCase, quantization_parameters): - with disable_patching(): - model = model_case.model_builder() - example_input = torch.ones(model_case.input_shape) - - with torch.no_grad(): - model.eval() - fx_model = capture_pre_autograd_graph(model, args=(example_input,)) - - def transform_fn(data_item): - return data_item.to("cpu") - - calibration_dataset = nncf.Dataset([example_input], transform_fn) - - quantization_parameters["advanced_parameters"] = AdvancedQuantizationParameters(disable_bias_correction=True) - quantization_parameters["subset_size"] = 1 - - quantized_model = nncf.quantize(fx_model, calibration_dataset, **quantization_parameters) - # Uncomment to visualize torch fx graph - # from tests.torch.fx.helpers import visualize_fx_model - # visualize_fx_model(quantized_model, f"{model_case.model_id}_int8.svg") - - nncf_graph = GraphConverter.create_nncf_graph(quantized_model) - check_graph(nncf_graph, get_dot_filename(model_case.model_id), FX_QUANTIZED_DIR_NAME) + model = model_case.model_builder() + dtype = torch.int32 if model_case.model_id == "synthetic_transformer" else torch.float32 + example_input = torch.ones(model_case.input_shape, dtype=dtype) + + with torch.no_grad(): + model.eval() + fx_model = _capture_model(model, example_input) + + def transform_fn(data_item): + return data_item.to("cpu") + + calibration_dataset = nncf.Dataset([example_input], transform_fn) + + quantization_parameters["advanced_parameters"] = AdvancedQuantizationParameters(disable_bias_correction=True) + quantization_parameters["subset_size"] = 1 + + quantized_model = nncf.quantize(fx_model, calibration_dataset, **quantization_parameters) + # Uncomment to visualize torch fx graph + # from tests.torch.fx.helpers import visualize_fx_model + # visualize_fx_model(quantized_model, f"{model_case.model_id}_int8.svg") + + nncf_graph = GraphConverter.create_nncf_graph(quantized_model) + check_graph(nncf_graph, get_dot_filename(model_case.model_id), FX_QUANTIZED_DIR_NAME) + + +@pytest.mark.parametrize("unification", [False, True]) +def test_is_shared_attribute(unification): + model = MultiBranchesConnectedModel() + ex_inputs = torch.ones((1, 3, 3, 3)) + captured_model = _capture_model(model, ex_inputs) + file_prefix = "not_unified" + if unification: + file_prefix = "unified" + shared_constants_unification_transformation(captured_model) + nncf_graph = GraphConverter.create_nncf_graph(captured_model) + shared_attributes = {n.node_name: n.is_shared() for n in nncf_graph.get_all_nodes()} + ref_attributes = get_ref_from_json(f"{file_prefix}_shared_attribute_test_model", shared_attributes, attributes=True) + assert shared_attributes == ref_attributes