diff --git a/nncf/common/factory.py b/nncf/common/factory.py index 6616f9dbe3a..c5a921c8068 100644 --- a/nncf/common/factory.py +++ b/nncf/common/factory.py @@ -41,6 +41,10 @@ def create(model: TModel) -> NNCFGraph: if model_backend == BackendType.OPENVINO: from nncf.openvino.graph.nncf_graph_builder import GraphConverter + return GraphConverter.create_nncf_graph(model) + if model_backend == BackendType.TORCH_FX: + from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter + return GraphConverter.create_nncf_graph(model) if model_backend == BackendType.TORCH: return model.nncf.get_graph() @@ -72,6 +76,10 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer: from nncf.torch.model_transformer import PTModelTransformer return PTModelTransformer(model) + if model_backend == BackendType.TORCH_FX: + from nncf.experimental.torch.fx.model_transformer import FXModelTransformer + + return FXModelTransformer(model) raise nncf.UnsupportedBackendError( "Cannot create backend-specific model transformer because {} is not supported!".format(model_backend.value) ) @@ -95,7 +103,7 @@ def create(model: TModel) -> Engine: from nncf.openvino.engine import OVNativeEngine return OVNativeEngine(model) - if model_backend == BackendType.TORCH: + if model_backend in (BackendType.TORCH, BackendType.TORCH_FX): from nncf.torch.engine import PTEngine return PTEngine(model) @@ -151,6 +159,10 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator: from nncf.torch.statistics.aggregator import PTStatisticsAggregator return PTStatisticsAggregator(dataset) + if model_backend == BackendType.TORCH_FX: + from nncf.experimental.torch.fx.statistics.aggregator import FXStatisticsAggregator + + return FXStatisticsAggregator(dataset) raise nncf.UnsupportedBackendError( "Cannot create backend-specific statistics aggregator because {} is not supported!".format( model_backend.value diff --git a/nncf/common/graph/patterns/manager.py b/nncf/common/graph/patterns/manager.py index 08bae0000af..2c32e3abf56 100644 --- a/nncf/common/graph/patterns/manager.py +++ b/nncf/common/graph/patterns/manager.py @@ -47,7 +47,7 @@ def _get_backend_hw_patterns_map(backend: BackendType) -> Dict[HWFusedPatternNam Dict[HWFusedPatternNames, Callable[[], GraphPattern]], OPENVINO_HW_FUSED_PATTERNS.registry_dict ) return registry - if backend == BackendType.TORCH: + if backend in (BackendType.TORCH, BackendType.TORCH_FX): from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS registry = cast(Dict[HWFusedPatternNames, Callable[[], GraphPattern]], PT_HW_FUSED_PATTERNS.registry_dict) @@ -77,7 +77,7 @@ def _get_backend_ignored_patterns_map( Dict[IgnoredPatternNames, Callable[[], GraphPattern]], OPENVINO_IGNORED_PATTERNS.registry_dict ) return registry - if backend == BackendType.TORCH: + if backend in (BackendType.TORCH, BackendType.TORCH_FX): from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS registry = cast(Dict[IgnoredPatternNames, Callable[[], GraphPattern]], PT_IGNORED_PATTERNS.registry_dict) diff --git a/nncf/common/utils/backend.py b/nncf/common/utils/backend.py index 7589aabb739..a76c0b8670d 100644 --- a/nncf/common/utils/backend.py +++ b/nncf/common/utils/backend.py @@ -20,6 +20,7 @@ class BackendType(Enum): TORCH = "Torch" + TORCH_FX = "TorchFX" TENSORFLOW = "Tensorflow" ONNX = "ONNX" OPENVINO = "OpenVINO" @@ -33,6 +34,7 @@ def get_available_backends() -> List[BackendType]: """ frameworks = [ ("torch", BackendType.TORCH), + ("torch.fx", BackendType.TORCH_FX), ("tensorflow", BackendType.TENSORFLOW), ("onnx", BackendType.ONNX), ("openvino.runtime", BackendType.OPENVINO), @@ -51,14 +53,27 @@ def get_available_backends() -> List[BackendType]: def is_torch_model(model: TModel) -> bool: """ - Returns True if the model is an instance of torch.nn.Module, otherwise False. + Returns True if the model is an instance of torch.nn.Module and not a torch.fx.GraphModule, otherwise False. :param model: A target model. - :return: True if the model is an instance of torch.nn.Module, otherwise False. + :return: True if the model is an instance of torch.nn.Module and not torch.fx.GraphModule, otherwise False. """ import torch # type: ignore + import torch.fx # type: ignore - return isinstance(model, torch.nn.Module) + return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module) + + +def is_torch_fx_model(model: TModel) -> bool: + """ + Returns True if the model is an instance of torch.fx.GraphModule, otherwise False. + + :param model: A target model. + :return: True if the model is an instance of torch.fx.GraphModule, otherwise False. + """ + import torch.fx + + return isinstance(model, torch.fx.GraphModule) def is_tensorflow_model(model: TModel) -> bool: @@ -118,6 +133,9 @@ def get_backend(model: TModel) -> BackendType: """ available_backends = get_available_backends() + if BackendType.TORCH_FX in available_backends and is_torch_fx_model(model): + return BackendType.TORCH_FX + if BackendType.TORCH in available_backends and is_torch_model(model): return BackendType.TORCH diff --git a/nncf/experimental/torch/fx/__init__.py b/nncf/experimental/torch/fx/__init__.py new file mode 100644 index 00000000000..2e49d63977d --- /dev/null +++ b/nncf/experimental/torch/fx/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nncf/experimental/torch/fx/commands.py b/nncf/experimental/torch/fx/commands.py new file mode 100644 index 00000000000..831f177cac7 --- /dev/null +++ b/nncf/experimental/torch/fx/commands.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Union + +import torch.fx + +from nncf.common.graph.transformations.commands import Command +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.commands import TransformationType + + +class FXApplyTransformationCommand(Command): + """ + Command to apply given transformation to a model. + """ + + def __init__( + self, + transformation_fn: Callable[[torch.fx.GraphModule], None], + priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, + ): + """ + :param transformation_fn: Target transformation function. + :param priority: Transformation priority. + """ + super().__init__(TransformationType.INSERT) + self.tranformation_fn = transformation_fn + self.priority = priority diff --git a/nncf/experimental/torch/fx/model_transformer.py b/nncf/experimental/torch/fx/model_transformer.py new file mode 100644 index 00000000000..4be8f306051 --- /dev/null +++ b/nncf/experimental/torch/fx/model_transformer.py @@ -0,0 +1,116 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import List + +import torch +import torch.fx +from torch.fx.passes.split_utils import split_by_tags + +from nncf.common.graph.model_transformer import ModelTransformer +from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand +from nncf.torch.graph.transformations.commands import PTModelExtractionCommand +from nncf.torch.graph.transformations.layout import PTTransformationLayout + + +class FXModelTransformer(ModelTransformer): + """ + Applies transformations upon Torch FX model. + """ + + def __init__(self, model: torch.fx.GraphModule): + super().__init__(model) + + self._command_transformation_ordered_pairs = [ + (FXApplyTransformationCommand, self._apply_transformation), + (PTModelExtractionCommand, self._apply_model_extraction), + ] + + def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.GraphModule: + """ + Transforms the target model according to given transformation layout. + + :param transformation_layout: Given transformation layout. + :return: Target model transformered according to the given transformation layout. + """ + # TODO(dlyakhov): Manage priorities of transformations. + transformations = transformation_layout.transformations + aggregated_transformations = defaultdict(list) + for transformation in transformations: + aggregated_transformations[transformation.__class__].append(transformation) + + model = self._model + for transformation_cls, transformation_fn in self._command_transformation_ordered_pairs: + transformations = aggregated_transformations[transformation_cls] + if transformations: + model = transformation_fn(model, transformations) + + # Do not use model.graph.eliminate_dead_code() + # because the computational statistics code + # is interpolated as dead code. + model.recompile() + return model + + @staticmethod + def _apply_model_extraction( + model: torch.fx.GraphModule, + transformations: List[PTModelExtractionCommand], + ) -> torch.fx.GraphModule: + """ + Returns a submodel extracted from the given model by the given transformation. + + :param model: Given model. + :param transformations: List of one transformation which specifies + how to retrieve a submodule from the model. In case list contains + more than one element this function raises an assert. + :return: Returns a submodel extracted from the given model by the given transformation. + """ + transformation = transformations[-1] + assert len(transformation.input_node_names) == 1 + assert transformation.input_node_names == transformation.output_node_names + node_name = transformation.input_node_names[0] + + tags = ["before", "extracted", "after"] + i = 0 + for node in model.graph.nodes: + if node.name == node_name: + node.tag = tags[1] + weights = [node.all_input_nodes[1]] + while weights: + w_node = weights.pop() + assert w_node.tag in tags[0:2] + w_node.tag = tags[1] + weights.extend(w_node.all_input_nodes) + i = 2 + continue + node.tag = tags[i] + + # TODO(dlyakhov): reduce memory consumption by + # more optimal splitting implementation. + splitted_gm = split_by_tags(model, tags) + return splitted_gm.extracted + + @staticmethod + def _apply_transformation( + model: torch.fx.GraphModule, + transformations: List[FXApplyTransformationCommand], + ) -> torch.fx.GraphModule: + """ + Applies transformations to the given model. + + :param model: Target model. + :param transformations: Transformations to apply to the model. + :return: Target model after all transformations were applied. + """ + for transformation in transformations: + transformation.tranformation_fn(model) + return model diff --git a/nncf/experimental/torch/fx/nncf_graph_builder.py b/nncf/experimental/torch/fx/nncf_graph_builder.py new file mode 100644 index 00000000000..0863cab72ee --- /dev/null +++ b/nncf/experimental/torch/fx/nncf_graph_builder.py @@ -0,0 +1,142 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch.fx + +import nncf.torch.graph.operator_metatypes as om +from nncf.common.graph import NNCFNode +from nncf.common.graph.layer_attributes import Dtype +from nncf.common.graph.operator_metatypes import UnknownMetatype +from nncf.common.logging import nncf_logger +from nncf.torch.graph.graph import PTNNCFGraph +from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES + + +class GraphConverter: + """ + Builds the NNCFGraph from an torch.fx.GraphModule instance. + """ + + @staticmethod + def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMetatype]: + """ + Retrieves node's type and metatype. + + :param node: Given node. + :return: Node's type and metatype. + """ + if node.op == "placeholder": + node_type = "input" + node_metatype = om.PTInputNoopMetatype + elif node.op == "output": + node_type = "output" + node_metatype = om.PTOutputNoopMetatype + elif node.op == "get_attr": + node_type = "get_attr" + node_metatype = om.PTConstNoopMetatype + elif node.op in ("call_function",): + if hasattr(node.target, "overloadpacket"): + node_type = str(node.target.overloadpacket).split(".")[1] + elif node.target.__name__ == "getitem": + node_type = "__getitem__" + else: + # TODO(dlyakhov): get correct nodes types from this nodes as well + node_type = str(node.target) + node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type) + else: + node_type = node.op + node_metatype = UnknownMetatype + if node_metatype is UnknownMetatype: + nncf_logger.debug(f"Unknown metatype for node: {node}") + return node_type, node_metatype + + @staticmethod + def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph: + """ + Creates NNCFGraph from GraphModule. + All nodes from model which have valid metatype are added to NNCFGraph. + Then, corresponding edges are added to the NNCFGraph with shape, type, output and input port ids. + + :param model: torch fx GraphModule. + :return: NNCFGraph. + """ + + nncf_graph = PTNNCFGraph() + + for source_node in model.graph.nodes: + node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node) + + nncf_graph.add_nncf_node( + node_name=source_node.name, + node_type=node_type, + node_metatype=node_metatype, + ) + + for source_node in model.graph.nodes: + source_nncf_node = nncf_graph.get_node_by_name(source_node.name) + for idx, dist_node in enumerate(source_node.users): + dist_node_id = nncf_graph.get_node_by_name(dist_node.name).node_id + input_port_id, output_port_id, tensor_shape = GraphConverter.get_edge_params( + model, source_node, source_nncf_node, dist_node, idx + ) + + nncf_graph.add_edge_between_nncf_nodes( + source_nncf_node.node_id, + dist_node_id, + tensor_shape=tensor_shape, + input_port_id=input_port_id, + output_port_id=output_port_id, + dtype=Dtype.FLOAT, + ) + + return nncf_graph + + @staticmethod + def get_edge_params( + model: torch.fx.GraphModule, + source_node: torch.fx.Node, + source_nncf_node: NNCFNode, + dist_node: torch.fx.Node, + output_idx: int, + ) -> Tuple[int, int, Tuple[int, ...]]: + """ + Retrieves edge params from the given source_node and dist_node pair. + + :param model: A torch.fx.GraphModule instance. + :param source_node: Source node in format of torch.fx.Node. + :param source_nncf_node: Source node in format of NNCFNode. + :param dist_node: Distance node in format of torch.fx.Node. + :param output_idx: Output indes of the source_node. + :return: Tuple of edge parameters: edge input port id, edge output port id and + edge tensor shape. + """ + output_port_id = 0 + if source_node.op in ("get_attr",): + tensor_shape = tuple(getattr(model, source_node.target).shape) + elif "val" in source_node.meta: + if source_nncf_node.metatype is om.PTBatchNormMetatype: + tensor = source_node.meta["val"][0] + elif source_nncf_node.metatype is om.PTSplitMetatype: + tensor = source_node.meta["val"][output_idx] + # Assume every split outputs corresponds to an unique output_port_id + output_port_id = output_idx + else: + tensor = source_node.meta["val"] + tensor_shape = tuple(tensor.shape) + else: + # TODO(dlyakhov): Refactor algorithms to always have knowns edges shapes. + nncf_logger.debug(f"Edge shape between {source_node.name} and {dist_node.name} is unknown.") + tensor_shape = None + + input_port_id = dist_node.all_input_nodes.index(source_node) + return input_port_id, output_port_id, tensor_shape diff --git a/nncf/experimental/torch/fx/node_utils.py b/nncf/experimental/torch/fx/node_utils.py new file mode 100644 index 00000000000..5d03d5e355d --- /dev/null +++ b/nncf/experimental/torch/fx/node_utils.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +# TODO(dlyakhov): Use torch.fx.graph.find_nodes method instead after +# torch version update (>= 2.4) +def get_graph_node_by_name(graph: torch.fx.Graph, name: str) -> torch.fx.Node: + """ + Retrieves a node with the specified name from the grpah. + Raises a runtime error if graph does not contain node with + the given name. + + :param graph: Given torch fx graph. + :param name: Target node name. + :return: A graph node with the given name. + """ + for node in graph.nodes: + if node.name == name: + return node + raise RuntimeError(f"Node with name {name} is not found") diff --git a/nncf/experimental/torch/fx/quantization/__init__.py b/nncf/experimental/torch/fx/quantization/__init__.py new file mode 100644 index 00000000000..2e49d63977d --- /dev/null +++ b/nncf/experimental/torch/fx/quantization/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nncf/experimental/torch/fx/quantization/quantize_model.py b/nncf/experimental/torch/fx/quantization/quantize_model.py new file mode 100644 index 00000000000..01aebf68c1f --- /dev/null +++ b/nncf/experimental/torch/fx/quantization/quantize_model.py @@ -0,0 +1,107 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Optional + +import torch +import torch.fx +from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass +from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ +from torch.ao.quantization.pt2e.qat_utils import _fold_conv_bn_qat +from torch.ao.quantization.pt2e.utils import _disallow_eval_train +from torch.fx import GraphModule +from torch.fx.passes.infra.pass_manager import PassManager + +import nncf +from nncf.common.factory import NNCFGraphFactory +from nncf.common.logging import nncf_logger +from nncf.common.quantization.structs import QuantizationPreset +from nncf.data import Dataset +from nncf.experimental.torch.fx.transformations import apply_quantization_transformations +from nncf.experimental.torch.fx.transformations import revert_quantization_transformations +from nncf.parameters import ModelType +from nncf.parameters import QuantizationMode +from nncf.parameters import TargetDevice +from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters +from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization +from nncf.scopes import IgnoredScope + +DEFAULT_RANGE_TYPE = "mean_min_max" + + +def quantize_impl( + model: torch.fx.GraphModule, + calibration_dataset: Dataset, + mode: Optional[QuantizationMode] = None, + preset: Optional[QuantizationPreset] = None, + target_device: TargetDevice = TargetDevice.ANY, + subset_size: int = 300, + fast_bias_correction: bool = True, + model_type: Optional[ModelType] = None, + ignored_scope: Optional[IgnoredScope] = None, + advanced_parameters: Optional[AdvancedQuantizationParameters] = None, +) -> torch.nn.Module: + """ + Implementation of the `quantize()` method for the Torch FX backend. + """ + nncf_logger.warning( + "Experimental Torch FX quantization backend is being used for the given torch.fx.GraphModule model." + " Torch FX PTQ is an experimental feature, consider using Torch or OpenVino PTQ backends" + " in case of errors or a poor model performance." + ) + if fast_bias_correction is False: + raise ValueError(f"fast_bias_correction={fast_bias_correction} is not supported") + if target_device == TargetDevice.CPU_SPR: + raise nncf.InternalError("target_device == CPU_SPR is not supported") + if mode is not None: + raise ValueError(f"mode={mode} is not supported") + + original_graph_meta = model.meta + + copied_model = deepcopy(model) + + quantization_algorithm = PostTrainingQuantization( + preset=preset, + target_device=target_device, + subset_size=subset_size, + fast_bias_correction=fast_bias_correction, + model_type=model_type, + ignored_scope=ignored_scope, + advanced_parameters=advanced_parameters, + ) + + # To make it easier for bias correction algorithms, + # biases are being separated by the followng calls. + apply_quantization_transformations(copied_model) + + nncf_graph = NNCFGraphFactory.create(copied_model) + quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset) + + # Revert applied transformation to keep original model + # bias configuration. + revert_quantization_transformations(quantized_model) + + # Magic. Without this call compiled model + # is not preformant + quantized_model = GraphModule(quantized_model, quantized_model.graph) + + quantized_model = _fold_conv_bn_qat(quantized_model) + pm = PassManager([DuplicateDQPass()]) + + quantized_model = pm(quantized_model).graph_module + pm = PassManager([PortNodeMetaForQDQ()]) + quantized_model = pm(quantized_model).graph_module + + quantized_model.meta.update(original_graph_meta) + quantized_model = _disallow_eval_train(quantized_model) + + return quantized_model diff --git a/nncf/experimental/torch/fx/statistics/__init__.py b/nncf/experimental/torch/fx/statistics/__init__.py new file mode 100644 index 00000000000..2e49d63977d --- /dev/null +++ b/nncf/experimental/torch/fx/statistics/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nncf/experimental/torch/fx/statistics/aggregator.py b/nncf/experimental/torch/fx/statistics/aggregator.py new file mode 100644 index 00000000000..bf45c4cea0b --- /dev/null +++ b/nncf/experimental/torch/fx/statistics/aggregator.py @@ -0,0 +1,101 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import numpy as np +import torch + +from nncf.common.factory import TModel +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.common.tensor_statistics.aggregator import StatisticPointsContainer +from nncf.common.tensor_statistics.aggregator import StatisticsAggregator +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand +from nncf.experimental.torch.fx.transformations import leaf_module_insertion_transformation_builder +from nncf.tensor import Tensor +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.return_types import maybe_get_values_from_torch_return_type + + +class TensorCollectorModule(torch.nn.Module): + """ + torch.nn.Module which calls given collector in forward + """ + + def __init__(self, collector: TensorCollector): + super().__init__() + self._collector = collector + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Register inputs hook function. + + :parameter x: tensor to register in hook. + :return: tensor to register in hook. + """ + x_unwrapped = maybe_get_values_from_torch_return_type(x) + self._collector.register_input_for_all_reducers(Tensor(x_unwrapped)) + return x + + +class FXStatisticsAggregator(StatisticsAggregator): + HOOKS_GROUP_NAME = "statistics_hooks" + + def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None: + with torch.no_grad(): + super().collect_statistics(model, graph) + # All statistics are collected as a dead code, + # so eliminate dead core removed statistcs collector + # from the target model. No additional code required + # for that, horay! + model.graph.eliminate_dead_code() + model.recompile() + + def _register_statistics(self, outputs: Dict[str, Tensor], statistic_points: StatisticPointsContainer) -> None: + return + + def _get_transformation_layout_extra_outputs( + self, statistic_points: StatisticPointsContainer + ) -> TransformationLayout: + transformation_layout = TransformationLayout() + transformation_commands = [] + + for _statistic_points in statistic_points.values(): + for _statistic_point in _statistic_points: + for collectors in _statistic_point.algorithm_to_tensor_collectors.values(): + for collector in collectors: + transformation = leaf_module_insertion_transformation_builder( + TensorCollectorModule(collector), [_statistic_point.target_point] + ) + transformation_commands.append( + FXApplyTransformationCommand( + transformation, TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION + ) + ) + + for transformation_command in transformation_commands: + transformation_layout.register(transformation_command) + + return transformation_layout + + @staticmethod + def _get_merged_statistic_points( + statistic_points: StatisticPointsContainer, model: TModel, graph: NNCFGraph + ) -> StatisticPointsContainer: + # TODO(dlyakhov): mirgate to experimental statistic collector and use common merging algorithm + return statistic_points + + @staticmethod + def _process_outputs(outputs: Dict[str, np.ndarray]) -> Dict[str, Tensor]: + return outputs diff --git a/nncf/experimental/torch/fx/transformations.py b/nncf/experimental/torch/fx/transformations.py new file mode 100644 index 00000000000..47ae266ba1b --- /dev/null +++ b/nncf/experimental/torch/fx/transformations.py @@ -0,0 +1,422 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional + +import torch +import torch.fx +from torch.ao.quantization.fx.utils import create_getattr_from_value +from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ +from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node +from torch.quantization.fake_quantize import FakeQuantize + +import nncf +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.transformations.commands import TargetType +from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name +from nncf.torch.graph.transformations.commands import PTTargetPoint + +TransformationFNType = Callable[[torch.fx.GraphModule], None] + + +def leaf_module_insertion_transformation_builder( + module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint] +) -> TransformationFNType: + """ + Returns transformation which inserts given module to a target model + and calls given module after each target points. + + :param module_to_insert: Given torch.nn.Module to insert. + :param target_points: Target points to insert the target module. + :returns: Transformation which which inserts given module to a target model + and calls given module after each target points. + """ + + def leaf_module_insertion_transformation(model: torch.fx.GraphModule): + module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_points) + # Insert call_module nodes to the model + graph = model.graph + for target_point in target_points: + target_node = _get_target_node(graph, target_point) + _insert_call_module(graph, target_node, module_attr_name) + + return leaf_module_insertion_transformation + + +def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor) -> TransformationFNType: + """ + Return transformation which updates constant of the given bias node to the given value. + + :param node: Bias node which requires bias constant update. + :param value: New value to use as the bias constant. + :return: Transformation which updates constant of the given bias node to the given value. + """ + + def bias_update_transformation(model: torch.fx.GraphModule): + graph = model.graph + target_node_name = node.node_name + graph_node = get_graph_node_by_name(graph, target_node_name) + if len(graph_node.users) != 1: + raise nncf.InternalError(f"Node with bias have {len(graph_node.users)} users, 1 expected.") + + bias_node = next(iter(graph_node.users)) + with graph.inserting_before(bias_node): + new_constant = create_getattr_from_value(model, graph, target_node_name + "_shifted_bias", value) + + args = list(bias_node.args) + # A bias node suppose to have constant on the second input port. + args[1] = new_constant + bias_node.args = tuple(args) + graph.eliminate_dead_code() + + return bias_update_transformation + + +def qdq_insertion_tranformation_builder( + quantizer: FakeQuantize, target_points: List[PTTargetPoint] +) -> TransformationFNType: + """ + Returns transformation which inserts quantize-dequantize operations with parameters + inherited from the given quantizer to each given target point. + + :param quantizer: Quantizer module to inherit quantization parameters from. + :param target_points: List of target point used to insert quantize-dequantize pairs. + :return: Transformation which inserts quantize-dequantize operations with parameters + inherited from the given quantizer to each given target point. + """ + + def qdq_insertion_tranformation(model: torch.fx.GraphModule): + if any(tp.target_type != TargetType.OPERATION_WITH_WEIGHTS for tp in target_points) and len(target_points) > 1: + raise RuntimeError( + "Insertion of shared qdq pair for the weights is not supported." + " Please use non shared qdq pairs for the weights quantization." + ) + for target_point in target_points: + target_node = _get_target_node(model.graph, target_point) + insert_one_qdq_after_node(model, target_node, quantizer) + + return qdq_insertion_tranformation + + +def insert_one_qdq_after_node(model: torch.fx.GraphModule, target_node: torch.fx.Node, quantizer: FakeQuantize): + """ + Inserts quantize-dequantize after the target node to the target model. + + :param model: Target model. + :param target_node: Target node, quantizer-dequantizer pair is inserted just after the + target node. + :param quantizer: Quantizer module to inherit quantization parameters from. + """ + + # Copied from torch.ao.quantization.quantize_pt2e.convert_pt2e + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op: Optional[Callable] = None + + dtype = torch.int8 if quantizer.quant_min < 0 else torch.uint8 + if quantizer.is_per_channel: + qparams = { + "_scale_": quantizer.scale, + "_zero_point_": quantizer.zero_point, + "_axis_": quantizer.ch_axis, + "_quant_min_": quantizer.quant_min, + "_quant_max_": quantizer.quant_max, + "_dtype_": dtype, + } + quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default + else: + qparams = { + "_scale_": float(quantizer.scale), + "_zero_point_": int(quantizer.zero_point), + "_quant_min_": quantizer.quant_min, + "_quant_max_": quantizer.quant_max, + "_dtype_": dtype, + } + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default + + # 2. replace activation_post_process node with quantize and dequantize + graph = model.graph + # TODO(dlyakhov): use metatype to get correct input_port_id + # Do not quantize already quantized nodes + # inserting_before handle only order in the graph generated code. + # so, inserting quantize-dequantize and all constant nodes before the usage of the nodes + with graph.inserting_before(target_node): + quantize_op_inputs = [target_node] + for key, value_or_node in qparams.items(): + # TODO(dlyakhov): we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ["_scale_", "_zero_point_"] and (not isinstance(value_or_node, (float, int))): + # For scale and zero_point values we register them as buffers in the root module. + # However, note that when the values are not tensors, as in the case of + # per_tensor quantization, they will be treated as literals. + # However, registering them as a node seems to cause issue with dynamo + # tracing where it may consider tensor overload as opposed to default. + # With extra check of scale and zero_point being scalar, it makes + # sure that the default overload can be used. + # TODO(dlaykhov): maybe need more complex attr name here + qparam_node = create_getattr_from_value(model, graph, target_node.name + key, value_or_node) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store + # them as literals in the graph. + quantize_op_inputs.append(value_or_node) + with graph.inserting_after(target_node): + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + user_dq_nodes = [] + with graph.inserting_after(quantized_node): + for user in target_node.users: + if user is quantized_node: + continue + user_dq_nodes.append((user, graph.call_function(dequantize_op, tuple(dq_inputs), {}))) + + for user, dq_node in user_dq_nodes: + user.replace_input_with(target_node, dq_node) + + +def _insert_call_module(graph: torch.fx.Graph, target_node: torch.fx.Node, module_attr_name: str): + """ + Inserts module call node to the graph after the target node. + + :param graph: Graph to insert module call node. + :param target_node: Target node, module call node is being iserted just after the target node. + :param module_attr_name: The name of the graph attribute which keeps the target module. + """ + with graph.inserting_after(target_node): + return graph.create_node( + "call_module", module_attr_name, (target_node,), {}, name=module_attr_name + "_graph_node" + ) + + +def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint) -> torch.fx.Node: + """ + Returns TorchFX graph node correspondent to the target point. + + :param graph: Target torch.fx.Graph. + :param target_point: A target point to find the target node. + :return: TorchFX graph node correspondent to the target point. + """ + # TODO(dlyakhov): Support node insertion on a specific input port id. + target_type = target_point.target_type + target_node = get_graph_node_by_name(graph, target_point.target_node_name) + if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]: + target_node = target_node.all_input_nodes[target_point.input_port_id] + elif target_type != TargetType.OPERATOR_POST_HOOK: + raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}") + return target_node + + +def _set_module_to_the_graph_module( + model: torch.fx.GraphModule, module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint] +) -> str: + """ + Sets given module to the given torch.fx.GraphModule with unique name. + + :param graph: Target torch.fx.Graph. + :param module_to_insert: Module to insert to the target graph. + :param target_points: Target points which will be used to insert target module + to the graph. + :return: A graph module attribute name which keep given module. + """ + module_to_insert = module_to_insert + # TODO(dlyakhov) Make module name human readable. + module_name_in_model = ( + "__".join( + "_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) for tp in target_points + ) + + "_" + + str(id(module_to_insert)) + ) + assert not hasattr(model, module_name_in_model) + setattr(model, module_name_in_model, module_to_insert) + return module_name_in_model + + +def apply_quantization_transformations(model: torch.fx.GraphModule) -> None: + """ + Applies quantization transformations to the model. + :param model: Model to apply transformations to. + """ + # BatchNorm operations have 3 output ports, + # to make it easier for alorithms to work + # with the target graph BatchNorm operations + # are being fused + _fuse_conv_bn_(model) + separate_conv_and_bias(model) + separate_linear_and_bias(model) + + +def revert_quantization_transformations(model: torch.fx.GraphModule) -> None: + """ + Reverts quantization transformations from the model. + :param model: Model to revert transformations from. + """ + merge_conv_and_bias(model) + merge_linear_and_bias(model) + + +def _is_linear(n: torch.fx.Node) -> bool: + """ + Return whether the node refers to an aten linear op. + + :param n: The given node. + :return: True if given node is a linear node, else False. + """ + return n.op == "call_function" and n.target in (torch.ops.aten.linear.default,) + + +def _is_conv(n: torch.fx.Node): + """ + Return whether the node refers to an aten conv op. + """ + return n.op == "call_function" and n.target in ( + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ) + + +def separate_linear_and_bias(model: torch.fx.GraphModule): + """ + Separates one joined linear+bias node to two nodes: conv and bias. + Needed as nncf does not expect joined conv + + :param model: Target model. + """ + add_node_target = torch.ops.aten.add_.Tensor + for n in model.graph.nodes: + if not _is_linear(n): + continue + if len(n.args) < 3 or n.args[2] is None: + continue + linear_node = n + linear_bias_node = linear_node.args[2] + while linear_bias_node.op != "get_attr": + # Assume zero argument is on a path to the constant + linear_bias_node = linear_bias_node.args[0] + linear_bias_value = _get_tensor_constant_from_node(linear_bias_node, model) + args = list(n.args) + args[2] = None + linear_node.args = tuple(args) + with model.graph.inserting_after(linear_node): + new_linear_bias_node = create_getattr_from_value( + model, + model.graph, + linear_bias_node.name + "_", + linear_bias_value, + ) + with model.graph.inserting_after(new_linear_bias_node): + add_node = model.graph.create_node( + "call_function", add_node_target, (linear_node, new_linear_bias_node), {} + ) + for user in list(linear_node.users): + if user is add_node: + continue + user.replace_input_with(linear_node, add_node) + if "val" in linear_node.meta: + add_node.meta["val"] = linear_node.meta["val"] + model.graph.eliminate_dead_code() + model.recompile() + + +def separate_conv_and_bias(model: torch.fx.GraphModule): + """ + Separates one joined conv+bias node to two nodes: conv and bias. + Needed as nncf does not expect joined conv + + :param model: Target model. + """ + add_node_target = torch.ops.aten.add_.Tensor + for n in model.graph.nodes: + if not _is_conv(n): + continue + if len(n.args) < 3 or n.args[2] is None: + continue + conv_node = n + dims = len(_get_tensor_constant_from_node(conv_node.args[1], model).shape) + conv_bias_node = conv_node.args[2] + conv_bias_value = _get_tensor_constant_from_node(conv_bias_node, model) + args = list(n.args) + args[2] = None + conv_node.args = tuple(args) + with model.graph.inserting_after(conv_node): + new_conv_bias_node = create_getattr_from_value( + model, model.graph, conv_bias_node.name + "_", conv_bias_value.reshape((1, -1) + (1,) * (dims - 2)) + ) + with model.graph.inserting_after(new_conv_bias_node): + add_node = model.graph.create_node("call_function", add_node_target, (conv_node, new_conv_bias_node), {}) + for user in list(conv_node.users): + if user is add_node: + continue + user.replace_input_with(conv_node, add_node) + + if "val" in conv_node.meta: + add_node.meta["val"] = conv_node.meta["val"] + model.graph.eliminate_dead_code() + model.recompile() + + +def merge_conv_and_bias(model: torch.fx.GraphModule): + """ + Merges two separate conv and bias nodes to a one node: conv+bias. + Needed as nncf does not expect joined conv + + :param model: Target model. + """ + _merge_node_and_bias(model, _is_conv) + + +def merge_linear_and_bias(model: torch.fx.GraphModule): + """ + Merges two separate linear and bias nodes to a one node: linear+bias. + + :param model: Target model. + """ + _merge_node_and_bias(model, _is_linear) + + +def _merge_node_and_bias(model: torch.fx.GraphModule, is_target_node: Callable[[torch.fx.Node], bool]): + """ + Merges two separate node and bias node to a one node: node+bias. + Check which node should be merged by the given `is_target_node` predicate. + + :param model: Target model. + :param is_target_node: Predicate to specify nodes which shoudld be merged with the bias + """ + add_node_targets = (torch.ops.aten.add_.Tensor,) + for n in model.graph.nodes: + if not is_target_node(n): + continue + if len(n.args) > 2 and n.args[2] is not None: + continue + bias_node = next(iter(n.users)) + if len(n.users) > 1 or bias_node.target not in add_node_targets: + continue + conv_node = n + const_node = None + for node in bias_node.all_input_nodes: + if node is not conv_node: + const_node = node + break + assert const_node is not None + bias_value = _get_tensor_constant_from_node(const_node, model).squeeze() + with model.graph.inserting_before(conv_node): + new_bias_node = create_getattr_from_value(model, model.graph, const_node.name + "_", bias_value) + args = list(conv_node.args) + args[2] = new_bias_node + conv_node.args = tuple(args) + for user in list(bias_node.users): + user.replace_input_with(bias_node, conv_node) + + model.graph.eliminate_dead_code() + model.recompile() diff --git a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py index ff7836035c9..3d104cad3c9 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py @@ -93,7 +93,7 @@ def __init__( @property def available_backends(self) -> List[BackendType]: - return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH] + return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH, BackendType.TORCH_FX] def _set_backend_entity(self, model: TModel) -> None: """ @@ -116,6 +116,12 @@ def _set_backend_entity(self, model: TModel) -> None: from nncf.quantization.algorithms.fast_bias_correction.torch_backend import PTFastBiasCorrectionAlgoBackend self._backend_entity = PTFastBiasCorrectionAlgoBackend() + elif model_backend == BackendType.TORCH_FX: + from nncf.quantization.algorithms.fast_bias_correction.torch_fx_backend import ( + FXFastBiasCorrectionAlgoBackend, + ) + + self._backend_entity = FXFastBiasCorrectionAlgoBackend() else: raise nncf.UnsupportedBackendError( "Cannot return backend-specific entity because {} is not supported!".format(model_backend.value) diff --git a/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py new file mode 100644 index 00000000000..d808448307e --- /dev/null +++ b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py @@ -0,0 +1,119 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.fx +from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node + +import nncf.torch.graph.operator_metatypes as om +from nncf.common.graph import NNCFGraph +from nncf.common.graph import NNCFNode +from nncf.common.graph.definitions import NNCFGraphNodeType +from nncf.common.graph.transformations.commands import TargetType +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand +from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name +from nncf.experimental.torch.fx.transformations import bias_update_transformation_builder +from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend +from nncf.tensor import Tensor +from nncf.torch.graph.transformations.commands import PTModelExtractionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.tensor_statistics.collectors import get_mean_statistic_collector + + +class FXFastBiasCorrectionAlgoBackend(FastBiasCorrectionAlgoBackend): + TARGET_TYPE_TO_PT_INS_TYPE_MAP = { + TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK, + TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK, + } + + @staticmethod + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: + if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION: + port_id = None + if target_type in FXFastBiasCorrectionAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP: + target_type = FXFastBiasCorrectionAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[target_type] + return PTTargetPoint(target_type, target_node_name, input_port_id=port_id) + + @staticmethod + def create_bias_correction_command( + node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph + ) -> FXApplyTransformationCommand: + return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data)) + + @staticmethod + def model_extraction_command( + input_ids: List[Tuple[str, int]], output_ids: List[Tuple[str, int]] + ) -> PTModelExtractionCommand: + return PTModelExtractionCommand([input_ids[0][0]], [output_ids[0][0]]) + + @staticmethod + def mean_statistic_collector( + channel_axis: int, + inplace: bool, + num_samples: Optional[int] = None, + window_size: Optional[int] = None, + ) -> TensorCollector: + return get_mean_statistic_collector(num_samples, channel_axis, window_size) + + @staticmethod + def get_sub_input_output_names(subgraph: NNCFNetwork) -> Tuple[str, str]: + # Pytorch does not have name for extracted node + return None, None + + @staticmethod + def create_input_data(shape: Tuple[int], data: List[Tensor], input_name: str, channel_axis: int) -> torch.Tensor: + blob = torch.zeros(shape, dtype=data[0].data.dtype, device=data[0].data.device) + for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])): + index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim)) + blob[index] = data[j].data + return blob + + @staticmethod + def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphModule) -> Tensor: + bias_node = nncf_graph.get_next_nodes(node)[0] + # TODO(dlyakhov): make a node_name_vs_node map to speed up the process + graph_bias_node = get_graph_node_by_name(model.graph, bias_node.node_name) + return Tensor(_get_tensor_constant_from_node(graph_bias_node.all_input_nodes[1], model)) + + @staticmethod + def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]: + return 0, 0 + + @staticmethod + def process_model_output(raw_data: Dict, output_name: str) -> Tensor: + return Tensor(raw_data) + + @staticmethod + def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + weight_node = nncf_graph.get_previous_nodes(node)[1] + return "dequantize" in weight_node.node_type + + @staticmethod + def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + # Assumes that all biases were unfused + if node.metatype in (om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype, om.PTLinearMetatype): + next_nodes = nncf_graph.get_next_nodes(node) + if len(next_nodes) != 1: + return False + return next_nodes[0].metatype in (om.PTAddMetatype,) + + @staticmethod + def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]: + return node.node_name, node.node_name + + @staticmethod + def get_activation_channel_axis(node: NNCFNode, pord_id: int, input_shape: Tuple[int]) -> int: + return node.metatype.output_channel_axis diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 2fefb664a18..f8cdd316529 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -328,7 +328,7 @@ def _init_cache(self) -> None: @property def available_backends(self) -> List[BackendType]: - return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH] + return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH, BackendType.TORCH_FX] def _get_quantizer_constraints( self, @@ -381,6 +381,10 @@ def _set_backend_entity(self, model: TModel) -> None: from nncf.quantization.algorithms.min_max.openvino_backend import OVMinMaxAlgoBackend self._backend_entity = OVMinMaxAlgoBackend() + elif model_backend == BackendType.TORCH_FX: + from nncf.quantization.algorithms.min_max.torch_fx_backend import FXMinMaxAlgoBackend + + self._backend_entity = FXMinMaxAlgoBackend() elif model_backend == BackendType.TORCH: from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py new file mode 100644 index 00000000000..c095836e674 --- /dev/null +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -0,0 +1,353 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Set, Tuple + +import torch + +import nncf +import nncf.torch.graph.operator_metatypes as om +from nncf.common.graph.definitions import NNCFGraphNodeType +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.operator_metatypes import OperatorMetatype +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationCommand +from nncf.common.hardware.config import HWConfig +from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode +from nncf.common.quantization.structs import QuantizerConfig +from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic +from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand +from nncf.experimental.torch.fx.transformations import qdq_insertion_tranformation_builder +from nncf.parameters import ModelType +from nncf.parameters import TargetDevice +from nncf.quantization.advanced_parameters import StatisticsType +from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend +from nncf.quantization.fake_quantize import FakeConvertParameters +from nncf.quantization.fake_quantize import FakeQuantizeParameters +from nncf.quantization.range_estimator import AggregatorType +from nncf.quantization.range_estimator import RangeEstimatorParameters +from nncf.torch.graph.graph import PTNNCFGraph +from nncf.torch.graph.graph import PTTargetPoint +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.hardware.config import PTHWConfig +from nncf.torch.model_graph_manager import get_weight_tensor_port_ids +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT +from nncf.torch.quantization.layers import QUANTIZATION_MODULES +from nncf.torch.quantization.layers import AsymmetricQuantizer +from nncf.torch.quantization.layers import BaseQuantizer +from nncf.torch.quantization.layers import PTQuantizerSpec +from nncf.torch.quantization.layers import get_scale_shape +from nncf.torch.quantization.strip import convert_to_torch_fakequantizer +from nncf.torch.tensor_statistics.collectors import PT_REDUCERS_MAP + + +class FXMinMaxAlgoBackend(MinMaxAlgoBackend): + TARGET_TYPE_TO_PT_INS_TYPE_MAP = { + TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK, + TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK, + } + + @property + def mat_mul_metatypes(self) -> List[OperatorMetatype]: + return [om.PTLinearMetatype, om.PTMatMulMetatype] + + @property + def post_processing_metatypes(self) -> List[OperatorMetatype]: + return [] + + @property + def shapeof_metatypes(self) -> List[OperatorMetatype]: + return [] + + @property + def dropout_metatypes(self) -> List[OperatorMetatype]: + return [om.PTDropoutMetatype] + + @property + def read_variable_metatypes(self) -> List[OperatorMetatype]: + return [] + + @property + def conv_metatypes(self) -> List[OperatorMetatype]: + return [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype] + + @property + def overflow_fix_metatypes(self) -> List[OperatorMetatype]: + return [ + om.PTConv1dMetatype, + om.PTConv2dMetatype, + om.PTConv3dMetatype, + om.PTLinearMetatype, + om.PTConvTranspose1dMetatype, + om.PTConvTranspose2dMetatype, + om.PTConvTranspose3dMetatype, + ] + + @property + def add_metatypes(self) -> List[OperatorMetatype]: + return [om.PTAddMetatype] + + @property + def group_conv_metatypes(self) -> List[OperatorMetatype]: + return self.conv_metatypes + + @property + def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]: + return [om.PTScaledDotProductAttentionMetatype] + + @property + def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]: + return {om.PTCatMetatype: self.overflow_fix_metatypes} + + @property + def hw_config(self) -> HWConfig: + return PTHWConfig + + @property + def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]: + return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT + + @staticmethod + def get_start_nodes_for_activation_path_tracing(nncf_graph: PTNNCFGraph) -> List[NNCFNode]: + return nncf_graph.get_input_nodes() + + @staticmethod + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: + if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION: + port_id = None + if target_type in FXMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP: + target_type = FXMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[target_type] + return PTTargetPoint(target_type, target_node_name, input_port_id=port_id) + + @staticmethod + def create_convert_insertion_command( + target_point: PTTargetPoint, + parameters: FakeConvertParameters, + ) -> TransformationCommand: + raise nncf.InternalError("FakeConvert insertion not implemented in PyTorch backend!") + + @staticmethod + def get_target_point_shape(nncf_graph: PTNNCFGraph, node: NNCFNode, target_point: PTTargetPoint) -> Tuple[int, ...]: + return nncf_graph.get_input_shape_for_insertion_point(target_point) + + @staticmethod + def get_weight_quantization_axes(node: NNCFNode, target_point: PTTargetPoint, ndims: int) -> Tuple[int]: + # TODO(dlyakhov): support transpose conv and other cases + return (0,) + + @staticmethod + def get_statistic_collector( + range_estimator_params: RangeEstimatorParameters, + use_abs_max: bool, + reduction_axes: Optional[Tuple[int, ...]], + aggregation_axes: Optional[Tuple[int, ...]], + inplace: bool, + num_samples: Optional[int] = None, + ) -> TensorCollector: + collector = TensorCollector(MinMaxTensorStatistic) + for params, container_key in zip( + [range_estimator_params.min, range_estimator_params.max], + [MinMaxTensorStatistic.MIN_STAT, MinMaxTensorStatistic.MAX_STAT], + ): + if params.statistics_type not in PT_REDUCERS_MAP: + raise nncf.InternalError( + f"Statistic type: {params.statistics_type} is not supported for Torch PTQ backend yet." + ) + + if params.aggregator_type not in AGGREGATORS_MAP: + raise nncf.InternalError( + f"Aggregator type: {params.aggregator_type} is not supported for Torch PTQ backend yet." + ) + + statistic_type = params.statistics_type + if statistic_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]: + # TODO(dlyakhov): merge two quantile aggregators in one + if container_key == MinMaxTensorStatistic.MIN_STAT: + quantile = params.quantile_outlier_prob + else: + quantile = 1 - params.quantile_outlier_prob + reducer = PT_REDUCERS_MAP[statistic_type](reduction_axes=reduction_axes, quantile=[quantile]) + else: + if use_abs_max and statistic_type == StatisticsType.MAX: + statistic_type = StatisticsType.ABS_MAX + reducer = PT_REDUCERS_MAP[statistic_type](reduction_axes=reduction_axes) + + kwargs = { + "num_samples": num_samples, + "aggregation_axes": aggregation_axes, + } + if params.aggregator_type in [AggregatorType.MEAN_NO_OUTLIERS, AggregatorType.MEDIAN_NO_OUTLIERS]: + kwargs.update({"quantile": params.quantile_outlier_prob}) + aggregator = AGGREGATORS_MAP[params.aggregator_type](**kwargs) + + collector.register_statistic_branch(container_key, reducer, aggregator) + return collector + + @staticmethod + def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Optional[int]]: + return get_weight_tensor_port_ids(node, graph) + + @staticmethod + def get_weight_name(nncf_graph: NNCFGraph, target_point: PTTargetPoint) -> str: + weighted_node = nncf_graph.get_node_by_name(target_point.target_node_name) + weight = nncf_graph.get_previous_nodes(weighted_node)[target_point.input_port_id] + return weight.node_name + + @staticmethod + def should_quantize_weight(weight_name: str, quantized_weight_names: Set[str]) -> bool: + # If the nodes share one weight tensor, we should have only one quantizer on that + return weight_name not in quantized_weight_names + + @staticmethod + def get_weight_config(config: QuantizerConfig, model: NNCFNetwork) -> QuantizerConfig: + return config + + @staticmethod + def _get_input_scale_shape( + nncf_graph: NNCFGraph, target_point: PTTargetPoint, per_channel: bool + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: + is_weights = target_point.is_weight_target_point() + if is_weights: + # TODO(dlyakhov): support transpose conv/ make channel_idx common + channel_idx = 0 + else: + channel_idx = 1 # channel dim for activations + + input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point) + scale_shape = tuple( + get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx) + ) + + return input_shape, scale_shape, channel_idx + + @staticmethod + def _create_quantizer( + quantizer_config: QuantizerConfig, + scale_shape: Tuple, + parameters: FakeQuantizeParameters, + target_type: TargetType, + ) -> BaseQuantizer: + mode = quantizer_config.mode + quantizer_cls = QUANTIZATION_MODULES.get(mode) + narrow_range = target_type == TargetType.OPERATION_WITH_WEIGHTS and mode == QuantizationMode.SYMMETRIC + quantizer_spec = PTQuantizerSpec.from_config( + quantizer_config, + narrow_range=narrow_range, + scale_shape=scale_shape, + half_range=False, + logarithm_scale=False, + is_quantized_on_export=False, + compression_lr_multiplier=None, + ) + quantizer = quantizer_cls(quantizer_spec) + + # Fill it with minmax + # TODO(dlyakhov) Prevent creation of intermediate objects like nncf quantizer. + FXMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters, quantizer_spec.scale_shape) + # Convert to the torch fake quantizer + torch_fq = convert_to_torch_fakequantizer(quantizer) + return torch_fq + + @staticmethod + def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters, scale_shape) -> None: + if isinstance(quantizer, AsymmetricQuantizer): + quantizer.input_low = torch.nn.Parameter(parameters.input_low.data.reshape(scale_shape)) + input_range = parameters.input_high - parameters.input_low + # Subtract eps from the input_range to make quantizer parameters equal to + # original parameters on the forward call. + quantizer.input_range = torch.nn.Parameter((input_range.data - quantizer.eps).reshape(scale_shape)) + else: + quantizer.signed = bool(torch.any(parameters.input_low.data < 0)) + # Subtract eps from the scale to make quantizer parameters equal to + # original parameters on the forward call. + quantizer.scale = torch.nn.Parameter((parameters.input_high.data - quantizer.eps).reshape(scale_shape)) + + @staticmethod + def create_quantizer_insertion_command( + nncf_graph: NNCFGraph, + target_point: PTTargetPoint, + quantizer_config: QuantizerConfig, + parameters: FakeQuantizeParameters, + ) -> FXApplyTransformationCommand: + _, scale_shape, _ = FXMinMaxAlgoBackend._get_input_scale_shape( + nncf_graph, target_point, quantizer_config.per_channel + ) + + quantizer = FXMinMaxAlgoBackend._create_quantizer( + quantizer_config, scale_shape, parameters, target_point.target_type + ) + transformation = qdq_insertion_tranformation_builder(quantizer, [target_point]) + return FXApplyTransformationCommand(transformation) + + @staticmethod + def create_unified_scales_quantizers_insertion_commands( + nncf_graph: NNCFGraph, + target_points: List[PTTargetPoint], + quantizer_config: QuantizerConfig, + parameters: FakeQuantizeParameters, + ) -> List[PTSharedFnInsertionCommand]: + _, scale_shape, _ = FXMinMaxAlgoBackend._get_input_scale_shape( + nncf_graph, target_points[0], quantizer_config.per_channel + ) + + quantizer = FXMinMaxAlgoBackend._create_quantizer( + quantizer_config, scale_shape, parameters, target_points[0].target_type + ) + + transformations = [] + for tp in target_points: + transformation = qdq_insertion_tranformation_builder(quantizer, [tp]) + transformations.append(FXApplyTransformationCommand(transformation)) + return transformations + + @staticmethod + def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[OperatorMetatype]: + types = [] + if model_type == ModelType.TRANSFORMER: + types = [ + om.PTAddMetatype, + om.PTPowerMetatype, + om.PTSubMetatype, + om.PTAvgPool2dMetatype, + om.PTAvgPool3dMetatype, + om.PTMeanMetatype, + om.PTSumMetatype, + om.PTReduceL2, + om.PTDivMetatype, + om.PTMaxMetatype, + om.PTSqueezeMetatype, + om.PTLayerNormMetatype, + om.PTModuleLayerNormMetatype, + om.PTGroupNormMetatype, + om.PTModuleGroupNormMetatype, + # Batchnorm + om.PTBatchNormMetatype, + om.PTModuleBatchNormMetatype, + ] + if device != TargetDevice.CPU_SPR: + types.append(om.PTMulMetatype) + return types + + @staticmethod + def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]: + return set() + + @staticmethod + def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]: + retval = set() + for node in nncf_graph.get_all_nodes(): + if node.metatype in [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype, om.PTLinearMetatype]: + retval.add(node) + return list(retval) diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 65b804f10a8..dc56e6daede 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -228,7 +228,21 @@ def quantize( ignored_scope=ignored_scope, advanced_parameters=advanced_parameters, ) + if backend == BackendType.TORCH_FX: + from nncf.experimental.torch.fx.quantization.quantize_model import quantize_impl + return quantize_impl( + model=model, + calibration_dataset=calibration_dataset, + mode=mode, + preset=preset, + target_device=target_device, + subset_size=subset_size, + fast_bias_correction=fast_bias_correction, + model_type=model_type, + ignored_scope=ignored_scope, + advanced_parameters=advanced_parameters, + ) raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}") diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index 5148496fea6..5d20a0d7ba6 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -371,7 +371,7 @@ def patch_torch_operators(): functions_to_patch = {} for namespace in NamespaceTarget: - if namespace == NamespaceTarget.EXTERNAL: + if namespace in [NamespaceTarget.ATEN, NamespaceTarget.EXTERNAL]: continue functions_to_patch[namespace] = get_all_functions_from_namespace(namespace) diff --git a/nncf/torch/dynamic_graph/structs.py b/nncf/torch/dynamic_graph/structs.py index c767790a92c..d8cf563107f 100644 --- a/nncf/torch/dynamic_graph/structs.py +++ b/nncf/torch/dynamic_graph/structs.py @@ -22,6 +22,7 @@ class NamespaceTarget(Enum): TORCH_TENSOR = "torch.tensor" TORCH_NN_PARAMETER = "torch.nn.parameter" TORCH = "torch" + ATEN = "aten" EXTERNAL = "external_function" diff --git a/nncf/torch/engine.py b/nncf/torch/engine.py index c2a7c051132..2bc17db0416 100644 --- a/nncf/torch/engine.py +++ b/nncf/torch/engine.py @@ -15,6 +15,8 @@ from torch import nn from nncf.common.engine import Engine +from nncf.common.utils.backend import BackendType +from nncf.common.utils.backend import get_backend class PTEngine(Engine): @@ -30,7 +32,8 @@ def __init__(self, model: nn.Module): """ self._model = model - self._model.eval() + if get_backend(model) == BackendType.TORCH: + self._model.eval() def infer( self, input_data: Union[torch.Tensor, Tuple[torch.Tensor], Dict[str, torch.Tensor]] diff --git a/nncf/torch/graph/operator_metatypes.py b/nncf/torch/graph/operator_metatypes.py index ca038b77f24..3b998c40531 100644 --- a/nncf/torch/graph/operator_metatypes.py +++ b/nncf/torch/graph/operator_metatypes.py @@ -56,6 +56,7 @@ class PTOperatorMetatype(OperatorMetatype): NamespaceTarget.TORCH_NN_FUNCTIONAL: [], NamespaceTarget.TORCH_TENSOR: [], NamespaceTarget.TORCH: [], + NamespaceTarget.ATEN: [], } subtypes: List[Type["PTOperatorMetatype"]] = [] @@ -528,7 +529,7 @@ class PTGELUMetatype(PTOperatorMetatype): @PT_OPERATOR_METATYPES.register() class PTSILUMetatype(PTOperatorMetatype): name = "SiluOp" - module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["silu"]} + module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["silu"], NamespaceTarget.ATEN: ["silu_"]} @PT_OPERATOR_METATYPES.register() @@ -706,6 +707,7 @@ class PTModuleBatchNormMetatype(PTModuleOperatorSubtype): name = "BatchNormOp" module_to_function_names = { NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"], + NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training"], } @@ -714,6 +716,7 @@ class PTBatchNormMetatype(PTOperatorMetatype): name = "BatchNormOp" module_to_function_names = { NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"], + NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training"], } subtypes = [PTModuleBatchNormMetatype] weight_port_ids = [3] @@ -844,6 +847,7 @@ class PTGatherMetatype(PTOperatorMetatype): module_to_function_names = { NamespaceTarget.TORCH_TENSOR: ["index_select", "__getitem__"], NamespaceTarget.TORCH: ["gather", "index_select", "select", "where"], + NamespaceTarget.ATEN: ["slice"], } @@ -880,6 +884,7 @@ class PTSplitMetatype(PTOperatorMetatype): NamespaceTarget.TORCH_NN_FUNCTIONAL: [], NamespaceTarget.TORCH_TENSOR: ["split", "chunk", "unbind"], NamespaceTarget.TORCH: ["split", "chunk", "unbind"], + NamespaceTarget.ATEN: ["split_with_sizes"], } hw_config_names = [HWConfigOpName.SPLIT, HWConfigOpName.CHUNK] @@ -1047,6 +1052,7 @@ class PTInterpolateMetatype(PTOperatorMetatype): name = "InterpolateOp" module_to_function_names = { NamespaceTarget.TORCH_NN_FUNCTIONAL: ["interpolate"], + NamespaceTarget.ATEN: ["upsample_nearest2d", "upsample_nearest_exact2d"], } hw_config_names = [HWConfigOpName.INTERPOLATE] num_expected_input_edges = 1 diff --git a/tests/torch/fx/__init__.py b/tests/torch/fx/__init__.py new file mode 100644 index 00000000000..2e49d63977d --- /dev/null +++ b/tests/torch/fx/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/torch/fx/helpers.py b/tests/torch/fx/helpers.py new file mode 100644 index 00000000000..8bbc721e0fa --- /dev/null +++ b/tests/torch/fx/helpers.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import torch +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from fastdownload import FastDownload + + +class TinyImagenetDatasetManager: + DATASET_URL = "http://cs231n.stanford.edu/tiny-imagenet-200.zip" + DATASET_PATH = "~/.cache/nncf/tests/datasets" + + def __init__(self, image_size: int, batch_size: int) -> None: + self.image_size = image_size + self.batch_size = batch_size + + @staticmethod + def download_dataset() -> Path: + downloader = FastDownload(base=TinyImagenetDatasetManager.DATASET_PATH, archive="downloaded", data="extracted") + return downloader.get(TinyImagenetDatasetManager.DATASET_URL) + + @staticmethod + def prepare_tiny_imagenet_200(dataset_dir: Path): + # Format validation set the same way as train set is formatted. + val_data_dir = dataset_dir / "val" + val_images_dir = val_data_dir / "images" + if not val_images_dir.exists(): + return + + val_annotations_file = val_data_dir / "val_annotations.txt" + with open(val_annotations_file, "r") as f: + val_annotation_data = map(lambda line: line.split("\t")[:2], f.readlines()) + for image_filename, image_label in val_annotation_data: + from_image_filepath = val_images_dir / image_filename + to_image_dir = val_data_dir / image_label + if not to_image_dir.exists(): + to_image_dir.mkdir() + to_image_filepath = to_image_dir / image_filename + from_image_filepath.rename(to_image_filepath) + val_annotations_file.unlink() + val_images_dir.rmdir() + + def create_data_loaders(self): + dataset_path = TinyImagenetDatasetManager.download_dataset() + + TinyImagenetDatasetManager.prepare_tiny_imagenet_200(dataset_path) + print(f"Successfully downloaded and prepared dataset at: {dataset_path}") + + train_dir = dataset_path / "train" + val_dir = dataset_path / "val" + + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + train_dir, + transforms.Compose( + [ + transforms.Resize(self.image_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ] + ), + ) + val_dataset = datasets.ImageFolder( + val_dir, + transforms.Compose( + [ + transforms.Resize(self.image_size), + transforms.ToTensor(), + normalize, + ] + ), + ) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, pin_memory=True, sampler=None + ) + + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0, pin_memory=True + ) + + # Creating separate dataloader with batch size = 1 + # as dataloaders with batches > 1 are not supported yet. + calibration_dataset = torch.utils.data.DataLoader( + val_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True + ) + + return train_loader, val_loader, calibration_dataset diff --git a/tests/torch/fx/test_sanity.py b/tests/torch/fx/test_sanity.py new file mode 100644 index 00000000000..a3f0d828260 --- /dev/null +++ b/tests/torch/fx/test_sanity.py @@ -0,0 +1,145 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Tuple + +import numpy as np +import openvino.torch # noqa +import pytest +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.models as models +from torch._export import capture_pre_autograd_graph + +import nncf +from nncf.common.logging.track_progress import track +from nncf.torch.dynamic_graph.patch_pytorch import disable_patching +from tests.torch.fx.helpers import TinyImagenetDatasetManager + +IMAGE_SIZE = 64 +BATCH_SIZE = 128 + + +@pytest.fixture(name="tiny_imagenet_dataset", scope="module") +def tiny_imagenet_dataset_fixture(): + return TinyImagenetDatasetManager(IMAGE_SIZE, BATCH_SIZE).create_data_loaders() + + +@dataclass +class SanitySampleCase: + model_id: str + checkpoint_url: str + top1_int8_ref: float + ref_num_q: int + ref_num_dq: int + + +MODELS = ( + SanitySampleCase( + "resnet18", + "https://storage.openvinotoolkit.org/repositories/nncf/openvino_notebook_ckpts/302_resnet18_fp32_v1.pth", + 55.35, + 51, + 58, + ), +) + + +def get_model(model_id: str, checkpoint_url: str, device: torch.device) -> torch.nn.Module: + num_classes = 200 # 200 is for Tiny ImageNet, default is 1000 for ImageNet + model = getattr(models, model_id)(weights=None) + # Update the last FC layer for Tiny ImageNet number of classes. + model.fc = nn.Linear(in_features=512, out_features=num_classes, bias=True) + model.to(device) + checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device("cpu"), progress=False) + model.load_state_dict(checkpoint["state_dict"]) + return model + + +def validate(val_loader: torch.utils.data.DataLoader, model: torch.nn.Module, device: torch.device) -> float: + top1_sum = 0.0 + with torch.no_grad(): + for images, target in track(val_loader, total=len(val_loader), description="Validation:"): + images = images.to(device) + target = target.to(device) + + # Compute output. + output = model(images) + + # Measure accuracy and record loss. + [acc1] = accuracy(output, target, topk=(1,)) + top1_sum += acc1.item() + + num_samples = len(val_loader) + top1_avg = top1_sum / num_samples + return top1_avg + + +def accuracy(output: torch.Tensor, target: torch.tensor, topk: Tuple[int, ...] = (1,)): + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def count_q_dq(model: torch.fx.GraphModule): + q, dq = 0, 0 + for node in model.graph.nodes: + if node.op == "call_function" and hasattr(node.target, "overloadpacket"): + node_type = str(node.target.overloadpacket).split(".")[1] + if node_type in ["quantize_per_tensor", "quantize_per_channel"]: + q += 1 + elif node_type in ["dequantize_per_tensor", "dequantize_per_channel"]: + dq += 1 + return q, dq + + +@pytest.mark.parametrize("test_case", MODELS) +def test_sanity(test_case: SanitySampleCase, tiny_imagenet_dataset): + with disable_patching(): + torch.manual_seed(42) + device = torch.device("cpu") + model = get_model(test_case.model_id, test_case.checkpoint_url, device) + _, val_dataloader, calibration_dataset = tiny_imagenet_dataset + + def transform_fn(data_item): + return data_item[0].to(device) + + calibration_dataset = nncf.Dataset(calibration_dataset, transform_fn) + + with torch.no_grad(): + ex_input = next(iter(calibration_dataset.get_inference_data())) + model.eval() + exported_model = capture_pre_autograd_graph(model, args=(ex_input,)) + quantized_model = nncf.quantize(exported_model, calibration_dataset) + quantized_model = torch.compile(quantized_model, backend="openvino") + + top1_int8 = validate(val_dataloader, quantized_model, device) + assert np.isclose(top1_int8, test_case.top1_int8_ref, atol=0.1) + + num_q, num_dq = count_q_dq(quantized_model) + assert num_q == test_case.ref_num_q + assert num_dq == test_case.ref_num_dq diff --git a/tests/torch/requirements.txt b/tests/torch/requirements.txt index bbd3a45c57e..be82652d65f 100644 --- a/tests/torch/requirements.txt +++ b/tests/torch/requirements.txt @@ -19,3 +19,8 @@ datasets==2.14.7 evaluate==0.3.0 openvino timm==0.9.2 + + +# Required for torch/fx tests +torchvision +fastdownload==0.0.7