diff --git a/.github/workflows/post_pr_merge.yml b/.github/workflows/post_pr_merge.yml index ea29e72d6f2..9b34f92c46f 100644 --- a/.github/workflows/post_pr_merge.yml +++ b/.github/workflows/post_pr_merge.yml @@ -16,6 +16,9 @@ on: - develop types: - closed + paths-ignore: + - '**/*.md' + - 'docs/**/*' jobs: upload-coverage-common: diff --git a/.github/workflows/precommit.yml b/.github/workflows/precommit.yml index b18ea60137a..6f7ec2906d6 100644 --- a/.github/workflows/precommit.yml +++ b/.github/workflows/precommit.yml @@ -95,50 +95,6 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} name: coverage_openvino flags: OPENVINO - torchFX: - timeout-minutes: 40 - defaults: - run: - shell: bash - runs-on: ubuntu-20.04-8-cores - env: - DEBIAN_FRONTEND: noninteractive - steps: - - name: Install dependencies - run : | - sudo apt-get update - sudo apt-get --assume-yes install gcc g++ build-essential ninja-build libgl1-mesa-dev libglib2.0-0 - - uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4.1.6 - with: - lfs: true - - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 - with: - python-version: 3.8.18 - cache: pip - - name: Runner info - continue-on-error: true - run: | - cat /etc/*release - cat /proc/cpuinfo - - name: Install NNCF and test requirements - run: make install-torch-fx-test - - name: Run TorchFX precommit test scope - run: | - make test-torch-fx - env: - NNCF_COVERAGE: 1 - NUM_WORKERS: 4 - - name: Upload coverage report as artifact - uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # v4.3.3 - with: - name: coverage_fx_cpu - path: ./coverage.xml - - name: Upload coverage report to codecov - uses: codecov/codecov-action@125fc84a9a348dbcf27191600683ec096ec9021c # v4.4.1 - with: - token: ${{ secrets.CODECOV_TOKEN }} - name: coverage_fx_cpu - flags: TORCH pytorch-cpu: timeout-minutes: 40 defaults: diff --git a/Makefile b/Makefile index 7842dd68ffd..3086eea0b92 100644 --- a/Makefile +++ b/Makefile @@ -153,6 +153,7 @@ test-torch-cuda: test-torch-nightly: pytest ${COVERAGE_ARGS} tests/torch -m nightly --junitxml ${JUNITXML_PATH} $(DATA_ARG) + test-torch-fx test-torch-weekly: pytest ${COVERAGE_ARGS} tests/torch -m weekly \ diff --git a/nncf/common/factory.py b/nncf/common/factory.py index d5d13605a07..8cc7dd018e5 100644 --- a/nncf/common/factory.py +++ b/nncf/common/factory.py @@ -43,7 +43,7 @@ def create(model: TModel) -> NNCFGraph: return GraphConverter.create_nncf_graph(model) if model_backend == BackendType.TORCH_FX: - from nncf.experimental.torch_fx.nncf_graph_builder import GraphConverter + from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter return GraphConverter.create_nncf_graph(model) if model_backend == BackendType.TORCH: @@ -77,7 +77,7 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer: return PTModelTransformer(model) if model_backend == BackendType.TORCH_FX: - from nncf.experimental.torch_fx.model_transformer import FXModelTransformer + from nncf.experimental.torch.fx.model_transformer import FXModelTransformer return FXModelTransformer(model) raise nncf.UnsupportedBackendError( @@ -108,7 +108,7 @@ def create(model: TModel) -> Engine: return PTEngine(model) if model_backend == BackendType.TORCH_FX: - from nncf.experimental.torch_fx.engine import FXEngine + from nncf.experimental.torch.fx.engine import FXEngine return FXEngine(model) raise nncf.UnsupportedBackendError( @@ -164,7 +164,7 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator: return PTStatisticsAggregator(dataset) if model_backend == BackendType.TORCH_FX: - from nncf.experimental.torch_fx.statistics.aggregator import FXStatisticsAggregator + from nncf.experimental.torch.fx.statistics.aggregator import FXStatisticsAggregator return FXStatisticsAggregator(dataset) raise nncf.UnsupportedBackendError( diff --git a/nncf/experimental/torch_fx/__init__.py b/nncf/experimental/torch_fx/__init__.py deleted file mode 100644 index 2e49d63977d..00000000000 --- a/nncf/experimental/torch_fx/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/nncf/experimental/torch_fx/engine.py b/nncf/experimental/torch_fx/engine.py deleted file mode 100644 index 5f9dc2ac221..00000000000 --- a/nncf/experimental/torch_fx/engine.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, Tuple, Union - -import torch -from torch import nn - -from nncf.common.engine import Engine - - -class FXEngine(Engine): - """ - Engine for the Pytorch FX backend. - """ - - def __init__(self, model: nn.Module): - """ - Constructor. - - :param model: Pytorch module to infer. - """ - - self._model = model - - def infer( - self, input_data: Union[torch.Tensor, Tuple[torch.Tensor], Dict[str, torch.Tensor]] - ) -> Union[torch.Tensor, Dict[str, Any]]: - """ - Runs Torch model on the provided input. - - :param input_data: Inputs for the model. - :return: Model outputs. - """ - - if isinstance(input_data, dict): - return self._model(**input_data) - if isinstance(input_data, tuple): - return self._model(*input_data) - return self._model(input_data) diff --git a/nncf/experimental/torch_fx/model_transformer.py b/nncf/experimental/torch_fx/model_transformer.py deleted file mode 100644 index 48b3cf0c1f1..00000000000 --- a/nncf/experimental/torch_fx/model_transformer.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import defaultdict - -# from functools import partial -from typing import Callable, List, Union - -import torch -import torch.fx -from torch.fx.passes.split_utils import split_by_tags - -from nncf.common.graph.model_transformer import ModelTransformer -from nncf.common.graph.transformations.commands import Command -from nncf.common.graph.transformations.commands import TargetType -from nncf.common.graph.transformations.commands import TransformationPriority -from nncf.common.graph.transformations.commands import TransformationType -from nncf.torch.graph.transformations.commands import PTModelExtractionCommand -from nncf.torch.graph.transformations.commands import PTTargetPoint -from nncf.torch.graph.transformations.layout import PTTransformationLayout - - -class FXModuleInsertionCommand(Command): - def __init__( - self, - target_points: List[PTTargetPoint], - module_to_insert: torch.nn.Module, - priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, - ): - super().__init__(TransformationType.INSERT) - self.target_points = target_points - self.module_to_insert = module_to_insert - self.priority = priority - - -class FXApplyTransformationCommand(Command): - def __init__( - self, - transformation_fn: Callable[[torch.fx.GraphModule], None], - priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, - ): - super().__init__(TransformationType.INSERT) - self.tranformation_fn = transformation_fn - self.priority = priority - - -class FXModelTransformer(ModelTransformer): - """ - Applies transformations upon Torch FX model. - """ - - # TODO: manage priorities of transformations - - def __init__(self, model: torch.fx.GraphModule): - super().__init__(model) - - self._command_transformation_ordered_pairs = [ - # TODO: Move the module insertion command to a transformation - (FXApplyTransformationCommand, self._apply_transformation), - (FXModuleInsertionCommand, self._apply_module_insertion), - (PTModelExtractionCommand, self._apply_model_extraction), - ] - - def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.GraphModule: - transformations = transformation_layout.transformations - aggregated_transformations = defaultdict(list) - for transformation in transformations: - aggregated_transformations[transformation.__class__].append(transformation) - - model = self._model - for transformation_cls, transformation_fn in self._command_transformation_ordered_pairs: - transformations = aggregated_transformations[transformation_cls] - if transformations: - model = transformation_fn(model, transformations) - - # Do not eliminate dead code as - # the dead code is coputing statistics :) - # model.graph.eliminate_dead_code() - model.recompile() - return model - - @staticmethod - def _apply_model_extraction( - model: torch.fx.GraphModule, - transformations: List[PTModelExtractionCommand], - ) -> torch.fx.GraphModule: - transformation = transformations[-1] - assert len(transformation.input_node_names) == 1 - assert transformation.input_node_names == transformation.output_node_names - node_name = transformation.input_node_names[0] - - tags = ["before", "extracted", "after"] - i = 0 - for node in model.graph.nodes: - if node.name == node_name: - node.tag = tags[1] - weights = [node.all_input_nodes[1]] - while weights: - w_node = weights.pop() - assert w_node.tag in tags[0:2] - w_node.tag = tags[1] - weights.extend(w_node.all_input_nodes) - i = 2 - continue - node.tag = tags[i] - - splitted_gm = split_by_tags(model, tags) - return splitted_gm.extracted - - @staticmethod - def _apply_module_insertion( - model: torch.fx.GraphModule, - transformations: List[FXModuleInsertionCommand], - ) -> torch.fx.GraphModule: - """ - Applies insertion of PTSharedFnInsertionCommand commands. For each command method inserts - a torch module to the torch.fx.GraphModule and inserts call hooks for each command target points. - - :param model: Model to apply transformations. - :param transformations: List of the bias correction transformations. - :param device: Target device for the insertion functions. Applies only to - functions which are subclassed from torch.nn.Module. Do nothing in case device is None. - :return: A modified torch.fx.GraphModule. - """ - for transformation in transformations: - # Set fn to the model as an attribute - module_to_insert = transformation.module_to_insert - module_name_in_model = ( - ";".join( - "_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) - for tp in transformation.target_points - ) - + "_" - + str(id(module_to_insert)) - ) - assert not hasattr(model, module_name_in_model) - setattr(model, module_name_in_model, module_to_insert) - # Insert call_module nodes to the model - for target_point in transformation.target_points: - FXModelTransformer._create_call_module_node(model.graph, target_point, module_name_in_model) - return model - - @staticmethod - def get_graph_node_by_name(graph, name): - for node in graph.nodes: - if node.name == name: - return node - raise RuntimeError(f"Node with name {name} is not found") - - @staticmethod - def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint): - target_type = target_point.target_type - target_node = FXModelTransformer.get_graph_node_by_name(graph, target_point.target_node_name) - if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]: - target_node = target_node.all_input_nodes[target_point.input_port_id] - elif target_type == TargetType.OPERATOR_POST_HOOK: - pass - else: - raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}") - return target_node - - @staticmethod - def _create_call_module_node(graph: torch.fx.Graph, target_point: PTTargetPoint, module_name: str): - target_node = FXModelTransformer._get_target_node(graph, target_point) - with graph.inserting_after(target_node): - graph.create_node("call_module", module_name, (target_node,), {}, name=module_name + "_graph_node") - - @staticmethod - def _apply_transformation( - model: torch.fx.GraphModule, - transformations: List[FXApplyTransformationCommand], - ) -> torch.fx.GraphModule: - for transformation in transformations: - transformation.tranformation_fn(model) - return model diff --git a/nncf/experimental/torch_fx/nncf_graph_builder.py b/nncf/experimental/torch_fx/nncf_graph_builder.py deleted file mode 100644 index 9990ee3bf2f..00000000000 --- a/nncf/experimental/torch_fx/nncf_graph_builder.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from itertools import chain -from typing import Tuple - -import torch.fx -from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ - -import nncf.torch.graph.operator_metatypes as om -from nncf.common.graph import NNCFGraph -from nncf.common.graph import NNCFNode -from nncf.common.graph.layer_attributes import Dtype -from nncf.common.graph.operator_metatypes import UnknownMetatype -from nncf.common.logging import nncf_logger -from nncf.experimental.torch_fx.transformations import separate_conv_and_bias -from nncf.experimental.torch_fx.transformations import separate_linear_and_bias -from nncf.experimental.torch_fx.transformations import view_to_reshape -from nncf.torch.graph.graph import PTNNCFGraph -from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES - - -class GraphConverter: - """ - Builds the NNCFGraph from an OpenVINO model. - """ - - @staticmethod - def _get_leaf_node(module: torch.nn.Module, node: torch.fx.Node) -> torch.nn.Module: - py_obj = module - assert isinstance(node.target, str) - atoms = node.target.split(".") - for atom in atoms: - if not hasattr(py_obj, atom): - raise RuntimeError(str(py_obj) + " does not have attribute " + atom + "!") - py_obj = getattr(py_obj, atom) - return py_obj - - @staticmethod - def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMetatype]: - if node.op == "placeholder": - node_type = "input" - node_metatype = om.PTInputNoopMetatype - elif node.op == "output": - node_type = "output" - node_metatype = om.PTOutputNoopMetatype - elif node.op == "get_attr": - node_type = "get_attr" - node_metatype = om.PTConstNoopMetatype - elif node.op in ("call_function",): - if hasattr(node.target, "overloadpacket"): - node_type = str(node.target.overloadpacket).split(".")[1] - elif node.target.__name__ == "getitem": - node_type = "__getitem__" - else: - # TODO: get correct nodes types from this nodes as well - node_type = str(node.target) - node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type) - else: - node_type = node.op - node_metatype = UnknownMetatype - if node_metatype is UnknownMetatype: - nncf_logger.info(f"Unknown metatype for node: {node}") - return node_type, node_metatype - - @staticmethod - def create_nncf_graph(model: torch.fx.GraphModule) -> NNCFGraph: - """ - Creates NNCFGraph from GraphModule. - All nodes from model which have valid metatype are added to NNCFGraph. - Then, corresponding edges are added to the NNCFGraph with shape, type, output and input port ids. - - :param model: torch fx GraphModule. - :return: NNCFGraph. - """ - - _fuse_conv_bn_(model) - # BN fuses to conv bias, conv+bias joined op - # needs to be splited for nncf - separate_linear_and_bias(model) - separate_conv_and_bias(model) - view_to_reshape(model) - - nncf_graph = PTNNCFGraph() - - for source_node in model.graph.nodes: - - node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node) - - nncf_node = nncf_graph.add_nncf_node( - node_name=source_node.name, - node_type=node_type, - node_metatype=node_metatype, # layer_attributes, - ) - - def get_module_params_or_buffers(): - for pname, ptensor in chain(leaf_module.named_parameters(), leaf_module.named_buffers()): - pname1 = source_node.name + "." + pname - nncf_param_node = nncf_graph.add_nncf_node( - pname1, - "parameter" if isinstance(ptensor, torch.nn.Parameter) else "buffer", - om.PTConstNoopMetatype, - ) - # TODO: Use valid tensor_shape, input_port_id, output_port_id - nncf_graph.add_edge_between_nncf_nodes( - nncf_param_node, nncf_node, tensor_shape=[1, 1, 1, 1], input_port_id=0, output_port_id=0 - ) - - if source_node.op == "call_module": - leaf_module = GraphConverter._get_leaf_node(model, source_node) - - if not isinstance(leaf_module, torch.fx.GraphModule): - get_module_params_or_buffers() - - for source_node in model.graph.nodes: - - source_nncf_node = nncf_graph.get_node_by_name(source_node.name) - for idx, dist_node in enumerate(source_node.users): - dist_node_id = nncf_graph.get_node_by_name(dist_node.name).node_id - input_port_id, output_port_id, tensor_shape = GraphConverter.get_edge_params( - model, source_node, source_nncf_node, dist_node, idx - ) - - nncf_graph.add_edge_between_nncf_nodes( - source_nncf_node.node_id, - dist_node_id, - tensor_shape=tensor_shape, - input_port_id=input_port_id, - output_port_id=output_port_id, - dtype=Dtype.FLOAT, - ) - - return nncf_graph - - @staticmethod - def get_edge_params( - model, source_node: torch.fx.Node, source_nncf_node: NNCFNode, dist_node: torch.fx.Node, output_idx: int - ): - output_port_id = 0 - if source_node.op in ("get_attr",): - tensor_shape = tuple(getattr(model, source_node.target).shape) - elif "val" in source_node.meta: - if source_nncf_node.metatype is om.PTBatchNormMetatype: - tensor = source_node.meta["val"][0] - elif source_nncf_node.metatype is om.PTSplitMetatype: - tensor = source_node.meta["val"][output_idx] - # Assume every split outputs corresponds to an unique output_port_id - output_port_id = output_idx - else: - tensor = source_node.meta["val"] - tensor_shape = tuple(tensor.shape) - else: - nncf_logger.info( - f"Edge shape between {source_node.name} and {dist_node.name} is unknown. Using [1,1,1,1] instead." - ) - tensor_shape = [1, 1, 1, 1] - - input_port_id = dist_node.all_input_nodes.index(source_node) - return input_port_id, output_port_id, tensor_shape diff --git a/nncf/experimental/torch_fx/quantization/__init__.py b/nncf/experimental/torch_fx/quantization/__init__.py deleted file mode 100644 index 2e49d63977d..00000000000 --- a/nncf/experimental/torch_fx/quantization/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/nncf/experimental/torch_fx/quantization/quantize_model.py b/nncf/experimental/torch_fx/quantization/quantize_model.py deleted file mode 100644 index 0f40800fb49..00000000000 --- a/nncf/experimental/torch_fx/quantization/quantize_model.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from copy import deepcopy -from typing import Optional - -import torch -import torch.fx -from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass -from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ -from torch.ao.quantization.pt2e.qat_utils import _fold_conv_bn_qat -from torch.ao.quantization.pt2e.utils import _disallow_eval_train -from torch.fx import GraphModule -from torch.fx.passes.infra.pass_manager import PassManager - -import nncf -from nncf.common.factory import NNCFGraphFactory -from nncf.common.quantization.structs import QuantizationPreset -from nncf.common.quantization.structs import QuantizationScheme -from nncf.data import Dataset -from nncf.experimental.torch_fx.transformations import merge_conv_and_bias -from nncf.parameters import ModelType -from nncf.parameters import QuantizationMode -from nncf.parameters import TargetDevice -from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters -from nncf.quantization.advanced_parameters import QuantizationParameters -from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization -from nncf.scopes import IgnoredScope - -DEFAULT_RANGE_TYPE = "mean_min_max" - - -def quantize_impl( - model: torch.fx.GraphModule, - calibration_dataset: Dataset, - mode: Optional[QuantizationMode] = None, - preset: Optional[QuantizationPreset] = None, - target_device: TargetDevice = TargetDevice.ANY, - subset_size: int = 300, - fast_bias_correction: bool = True, - model_type: Optional[ModelType] = None, - ignored_scope: Optional[IgnoredScope] = None, - advanced_parameters: Optional[AdvancedQuantizationParameters] = None, -) -> torch.nn.Module: - """ - Implementation of the `quantize()` method for the Torch FX backend. - """ - if fast_bias_correction is False: - raise ValueError(f"fast_bias_correction={fast_bias_correction} is not supported") - if target_device == TargetDevice.CPU_SPR: - raise nncf.InternalError("target_device == CPU_SPR is not supported") - if mode is not None: - raise ValueError(f"mode={mode} is not supported") - - original_graph_meta = model.meta - - copied_model = deepcopy(model) - - if advanced_parameters is None: - advanced_parameters = AdvancedQuantizationParameters() - # torch.fx supports only assymetric activations quantization - # force to use only this type of quantization - activations_quantization_params = advanced_parameters.activations_quantization_params - if activations_quantization_params is None: - activations_quantization_params = QuantizationParameters() - - activations_quantization_params.mode = QuantizationScheme.ASYMMETRIC - advanced_parameters.activations_quantization_params = activations_quantization_params - - quantization_algorithm = PostTrainingQuantization( - preset=preset, - target_device=target_device, - subset_size=subset_size, - fast_bias_correction=fast_bias_correction, - model_type=model_type, - ignored_scope=ignored_scope, - advanced_parameters=advanced_parameters, - ) - nncf_graph = NNCFGraphFactory.create(copied_model) - quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset) - merge_conv_and_bias(quantized_model) - - # Magic. Without this call compiled model - # is not preformant - quantized_model = GraphModule(quantized_model, quantized_model.graph) - - quantized_model = _fold_conv_bn_qat(quantized_model) - pm = PassManager([DuplicateDQPass()]) - - quantized_model = pm(quantized_model).graph_module - pm = PassManager([PortNodeMetaForQDQ()]) - quantized_model = pm(quantized_model).graph_module - - quantized_model.meta.update(original_graph_meta) - quantized_model = _disallow_eval_train(quantized_model) - - return quantized_model diff --git a/nncf/experimental/torch_fx/statistics/__init__.py b/nncf/experimental/torch_fx/statistics/__init__.py deleted file mode 100644 index 2e49d63977d..00000000000 --- a/nncf/experimental/torch_fx/statistics/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/nncf/experimental/torch_fx/statistics/aggregator.py b/nncf/experimental/torch_fx/statistics/aggregator.py deleted file mode 100644 index f1ce6ff05ec..00000000000 --- a/nncf/experimental/torch_fx/statistics/aggregator.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict - -import numpy as np -import torch - -from nncf.common.factory import TModel -from nncf.common.graph.graph import NNCFGraph -from nncf.common.graph.transformations.commands import TransformationPriority -from nncf.common.graph.transformations.layout import TransformationLayout -from nncf.common.tensor_statistics.aggregator import StatisticPointsContainer -from nncf.common.tensor_statistics.aggregator import StatisticsAggregator -from nncf.experimental.common.tensor_statistics.collectors import TensorCollector -from nncf.experimental.torch_fx.model_transformer import FXModuleInsertionCommand -from nncf.tensor import Tensor -from nncf.torch.nncf_network import NNCFNetwork -from nncf.torch.return_types import maybe_get_values_from_torch_return_type - - -class TensorCollectorModule(torch.nn.Module): - """ - torch.nn.Module which calls given collector in forward - """ - - def __init__(self, collector: TensorCollector): - super().__init__() - self._collector = collector - - def forward(self, x: torch.Tensor): - """ - Register inputs hook function. - - :parameter x: tensor to register in hook. - :return: tensor to register in hook. - """ - x_unwrapped = maybe_get_values_from_torch_return_type(x) - self._collector.register_input_for_all_reducers(Tensor(x_unwrapped)) - return x - - -class FXStatisticsAggregator(StatisticsAggregator): - HOOKS_GROUP_NAME = "statistics_hooks" - - def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None: - with torch.no_grad(): - super().collect_statistics(model, graph) - # All statistics are collected as a dead code, - # so eliminate dead core removed statistcs collector - # from the target model. No additional code required - # for that, horay! - model.graph.eliminate_dead_code() - model.recompile() - - def _register_statistics(self, outputs: Dict[str, Tensor], statistic_points: StatisticPointsContainer) -> None: - return - - def _get_transformation_layout_extra_outputs( - self, statistic_points: StatisticPointsContainer - ) -> TransformationLayout: - transformation_layout = TransformationLayout() - transformation_commands = [] - - for _statistic_points in statistic_points.values(): - for _statistic_point in _statistic_points: - for collectors in _statistic_point.algorithm_to_tensor_collectors.values(): - for collector in collectors: - transformation_commands.append( - FXModuleInsertionCommand( - [_statistic_point.target_point], - TensorCollectorModule(collector), - TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION, - ) - ) - - for transformation_command in transformation_commands: - transformation_layout.register(transformation_command) - - return transformation_layout - - @staticmethod - def _get_merged_statistic_points( - statistic_points: StatisticPointsContainer, model: TModel, graph: NNCFGraph - ) -> StatisticPointsContainer: - # TODO: mirgate to experimental statistic collector and use common merging algorithm - return statistic_points - - @staticmethod - def _process_outputs(outputs: Dict[str, np.ndarray]) -> Dict[str, Tensor]: - return outputs diff --git a/nncf/experimental/torch_fx/transformations.py b/nncf/experimental/torch_fx/transformations.py deleted file mode 100644 index d572c06b120..00000000000 --- a/nncf/experimental/torch_fx/transformations.py +++ /dev/null @@ -1,350 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import Callable, List, Optional - -import torch -import torch.fx -from torch.ao.quantization.fx.utils import create_getattr_from_value -from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node -from torch.ao.quantization.pt2e.utils import _is_conv -from torch.quantization.fake_quantize import FakeQuantize - -from nncf.common.graph.graph import NNCFNode -from nncf.common.graph.transformations.commands import TargetType -from nncf.experimental.torch_fx.model_transformer import FXModelTransformer -from nncf.torch.graph.transformations.commands import PTTargetPoint - - -def fake_quantize_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]): - def fake_quantize_insertion_transformation(model: torch.fx.GraphModule): - module_attr_name = _set_module_to_the_graph_module(model, quantizer, target_points) - graph = model.graph - for target_point in target_points: - target_node = FXModelTransformer._get_target_node(model.graph, target_point) - with graph.inserting_after(target_node): - fq_node = graph.create_node( - "call_module", module_attr_name, (target_node,), {}, name=module_attr_name + "_quantizer" - ) - for user in list(target_node.users): - if user is fq_node: - continue - user.replace_input_with(target_node, fq_node) - - return fake_quantize_insertion_transformation - - -def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor): - def bias_update_transformation(model: torch.fx.GraphModule): - graph = model.graph - target_node_name = node.node_name - graph_node = FXModelTransformer.get_graph_node_by_name(graph, target_node_name) - bias_node = next(iter(graph_node.users)) - with graph.inserting_before(bias_node): - new_constant = create_getattr_from_value(model, graph, target_node_name + "_shifted_bias", value) - args = list(bias_node.args) - args[1] = new_constant - bias_node.args = tuple(args) - graph.eliminate_dead_code() - - return bias_update_transformation - - -def qdq_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]): - def qdq_insertion_tranformation(model: torch.fx.GraphModule): - if any(tp.target_type != TargetType.OPERATION_WITH_WEIGHTS for tp in target_points) and len(target_points) > 1: - raise RuntimeError - for target_point in target_points: - target_node = FXModelTransformer._get_target_node(model.graph, target_point) - insert_one_qdq(model, target_node, quantizer, target_point) - - return qdq_insertion_tranformation - - -def insert_one_qdq( - model: torch.fx.GraphModule, target_node: torch.fx.Node, quantizer: FakeQuantize, target_point: PTTargetPoint -): - # Copied from torch.ao.quantization.quantize_pt2e.convert_pt2e - # 1. extract information for inserting q/dq node from activation_post_process - node_type = "call_function" - quantize_op: Optional[Callable] = None - # scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] - dtype = torch.int8 if quantizer.quant_min < 0 else torch.uint8 - if quantizer.is_per_channel: - qparams = { - "_scale_": quantizer.scale, - "_zero_point_": quantizer.zero_point, - "_axis_": quantizer.ch_axis, - "_quant_min_": quantizer.quant_min, - "_quant_max_": quantizer.quant_max, - "_dtype_": dtype, - } - quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default - dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default - else: - qparams = { - "_scale_": float(quantizer.scale), - "_zero_point_": int(quantizer.zero_point), - "_quant_min_": quantizer.quant_min, - "_quant_max_": quantizer.quant_max, - "_dtype_": dtype, - } - quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default - dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default - - # 2. replace activation_post_process node with quantize and dequantize - graph = model.graph - # TODO: use metatype to get correct input_port_id - # Do not quantize already quantized nodes - # inserting_before handle only order in the graph generated code. - # so, inserting quantize-dequantize and all constant nodes before the usage of the nodes - with graph.inserting_before(target_node): - quantize_op_inputs = [target_node] - for key, value_or_node in qparams.items(): - # TODO: we can add the information of whether a value needs to - # be registered as an attribute in qparams dict itself - if key in ["_scale_", "_zero_point_"] and (not isinstance(value_or_node, (float, int))): - # For scale and zero_point values we register them as buffers in the root module. - # However, note that when the values are not tensors, as in the case of - # per_tensor quantization, they will be treated as literals. - # However, registering them as a node seems to cause issue with dynamo - # tracing where it may consider tensor overload as opposed to default. - # With extra check of scale and zero_point being scalar, it makes - # sure that the default overload can be used. - # TODO: maybe need more complex attr name here - qparam_node = create_getattr_from_value(model, graph, target_node.name + key, value_or_node) - quantize_op_inputs.append(qparam_node) - else: - # for qparams that are not scale/zero_point (like axis, dtype) we store - # them as literals in the graph. - quantize_op_inputs.append(value_or_node) - with graph.inserting_after(target_node): - quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) - # use the same qparams from quantize op - dq_inputs = [quantized_node] + quantize_op_inputs[1:] - user_dq_nodes = [] - with graph.inserting_after(quantized_node): - for user in target_node.users: - if user is quantized_node: - continue - user_dq_nodes.append((user, graph.call_function(dequantize_op, tuple(dq_inputs), {}))) - - for user, dq_node in user_dq_nodes: - user.replace_input_with(target_node, dq_node) - - -def _set_module_to_the_graph_module( - model: torch.fx.GraphModule, module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint] -) -> str: - """ - Sets given module to the given torch.fx.GraphModule with unique name. - """ - module_to_insert = module_to_insert - module_name_in_model = ( - ";".join( - "_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) for tp in target_points - ) - + "_" - + str(id(module_to_insert)) - ) - assert not hasattr(model, module_name_in_model) - setattr(model, module_name_in_model, module_to_insert) - return module_name_in_model - - -def _is_linear(n: torch.fx.Node): - return n.op == "call_function" and n.target in [torch.ops.aten.linear.default] - - -def separate_linear_and_bias(model: torch.fx.GraphModule): - """ - Separates one joined linear+bias node to two nodes: conv and bias. - Needed as nncf does not expect joined conv - """ - add_node_target = torch.ops.aten.add_.Tensor - for n in model.graph.nodes: - if not _is_linear(n): - continue - if len(n.args) < 3 or n.args[2] is None: - continue - linear_node = n - linear_bias_node = linear_node.args[2] - conv_bias_value = _get_tensor_constant_from_node(linear_bias_node, model) - args = list(n.args) - args[2] = None - linear_node.args = tuple(args) - with model.graph.inserting_after(linear_node): - new_linear_bias_node = create_getattr_from_value( - model, - model.graph, - linear_bias_node.name + "_", - conv_bias_value, - ) - with model.graph.inserting_after(new_linear_bias_node): - add_node = model.graph.create_node( - "call_function", add_node_target, (linear_node, new_linear_bias_node), {} - ) - for user in list(linear_node.users): - if user is add_node: - continue - user.replace_input_with(linear_node, add_node) - if "val" in linear_node.meta: - add_node.meta["val"] = linear_node.meta["val"] - model.graph.eliminate_dead_code() - model.recompile() - - -def view_to_reshape(model: torch.fx.GraphModule): - for n in model.graph.nodes: - if not (n.op == "call_function" and n.target in [torch.ops.aten.view.default]): - continue - with model.graph.inserting_after(n): - reshape = model.graph.create_node("call_function", torch.ops.aten.reshape.default, tuple(n.args), {}) - reshape.meta = n.meta - - for user in list(n.users): - user.replace_input_with(n, reshape) - - model.graph.eliminate_dead_code() - model.recompile() - - -def separate_conv_and_bias(model: torch.fx.GraphModule): - """ - Separates one joined conv+bias node to two nodes: conv and bias. - Needed as nncf does not expect joined conv - """ - add_node_target = torch.ops.aten.add_.Tensor - for n in model.graph.nodes: - if not _is_conv(n): - continue - if len(n.args) < 3 or n.args[2] is None: - continue - conv_node = n - dims = len(_get_tensor_constant_from_node(conv_node.args[1], model).shape) - conv_bias_node = conv_node.args[2] - conv_bias_value = _get_tensor_constant_from_node(conv_bias_node, model) - args = list(n.args) - args[2] = None - conv_node.args = tuple(args) - with model.graph.inserting_after(conv_node): - new_conv_bias_node = create_getattr_from_value( - model, - model.graph, - conv_bias_node.name + "_", - conv_bias_value.reshape( - ( - 1, - -1, - ) - + (1,) * (dims - 2) - ), - ) - with model.graph.inserting_after(new_conv_bias_node): - add_node = model.graph.create_node("call_function", add_node_target, (conv_node, new_conv_bias_node), {}) - for user in list(conv_node.users): - if user is add_node: - continue - user.replace_input_with(conv_node, add_node) - - if "val" in conv_node.meta: - add_node.meta["val"] = conv_node.meta["val"] - model.graph.eliminate_dead_code() - model.recompile() - - -def merge_conv_and_bias(model: torch.fx.GraphModule): - """ - Separates one joined conv+bias node to two nodes: conv and bias. - Needed as nncf does not expect joined conv - """ - add_node_targets = (torch.ops.aten.add_.Tensor,) - for n in model.graph.nodes: - if not _is_conv(n): - continue - if len(n.args) > 2 and n.args[2] is not None: - continue - bias_node = next(iter(n.users)) - if len(n.users) > 1 or bias_node.target not in add_node_targets: - continue - conv_node = n - const_node = None - for node in bias_node.all_input_nodes: - if node is not conv_node: - const_node = node - break - assert const_node is not None - bias_value = _get_tensor_constant_from_node(const_node, model).squeeze() - with model.graph.inserting_before(conv_node): - new_bias_node = create_getattr_from_value(model, model.graph, const_node.name + "_", bias_value) - args = list(conv_node.args) - args[2] = new_bias_node - conv_node.args = tuple(args) - for user in list(bias_node.users): - user.replace_input_with(bias_node, conv_node) - - model.graph.eliminate_dead_code() - model.recompile() - - -def _is_scaled_dot_product_attention(n: torch.fx.Node): - return n.op == "call_function" and n.target in [torch.ops.aten.scaled_dot_product_attention.default] - - -def _unfold_sdp(model: torch.fx.GraphModule, node: torch.fx.Node): - transpose_target = torch.ops.aten.transpose.int - matmul_target = torch.ops.aten.matmul.default - mul_target = torch.ops.aten.multiply.Scalar - softmax_target = torch.ops.aten.softmax.int - - query, key, value = node.args - q, k, v = (n.meta["val"] for n in node.args) - n = query.meta["val"].shape[-1] - scale_factor = 1 / math.sqrt(n) - - with model.graph.inserting_before(node): - k_transposed = model.graph.create_node("call_function", transpose_target, (key, -2, -1), {}) - k = k.transpose(-2, -1) - k_transposed.meta["val"] = torch.clone(k) - - sa = model.graph.create_node("call_function", matmul_target, (query, k_transposed), {}) - attn_value = q @ k - sa.meta["val"] = torch.clone(attn_value) - - sa_scaled = model.graph.create_node("call_function", mul_target, (sa, float(scale_factor)), {}) - sa_scaled.meta["val"] = torch.clone(attn_value) - - softmax = model.graph.create_node("call_function", softmax_target, (sa_scaled, -1), {}) - softmax.meta["val"] = torch.clone(attn_value) - - result = model.graph.create_node("call_function", matmul_target, (softmax, value), {}) - r = attn_value @ v - result.meta["val"] = torch.clone(r) - - for user in list(node.users): - user.replace_input_with(node, result) - model.graph.eliminate_dead_code() - - -@staticmethod -def unfold_scaled_dot_product_attention(model: torch.fx.GraphModule): - for n in model.graph.nodes: - if not _is_scaled_dot_product_attention(n): - continue - args = n.args - if len(args) > 3: - raise NotImplementedError( - f"Unfolding of scaled dot product attention node {n}" " with more than 3 inputs is not implemented yet" - ) - _unfold_sdp(model, n) - model.graph.eliminate_dead_code() - model.recompile() diff --git a/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py index 089afd4ab11..850df6f7d7b 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py @@ -22,8 +22,8 @@ from nncf.common.graph.definitions import NNCFGraphNodeType from nncf.common.graph.transformations.commands import TargetType from nncf.experimental.common.tensor_statistics.collectors import TensorCollector -from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand -from nncf.experimental.torch_fx.transformations import bias_update_transformation_builder +from nncf.experimental.torch.fx.model_transformer import FXApplyTransformationCommand +from nncf.experimental.torch.fx.transformations import bias_update_transformation_builder from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend from nncf.tensor import Tensor from nncf.torch.graph.transformations.commands import PTModelExtractionCommand @@ -83,7 +83,7 @@ def create_input_data(shape: Tuple[int], data: List[Tensor], input_name: str, ch @staticmethod def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphModule) -> Tensor: # TODO: make a node_name_vs_node map to speed up the process - from nncf.experimental.torch_fx.model_transformer import FXModelTransformer + from nncf.experimental.torch.fx.model_transformer import FXModelTransformer bias_node = nncf_graph.get_next_nodes(node)[0] graph_bias_node = FXModelTransformer.get_graph_node_by_name(model.graph, bias_node.node_name) diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py index bdeed5343c8..a21fcc883ab 100644 --- a/nncf/quantization/algorithms/min_max/torch_fx_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -27,8 +27,8 @@ from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic -from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand -from nncf.experimental.torch_fx.transformations import qdq_insertion_tranformation_builder +from nncf.experimental.torch.fx.model_transformer import FXApplyTransformationCommand +from nncf.experimental.torch.fx.transformations import qdq_insertion_tranformation_builder from nncf.parameters import ModelType from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import StatisticsType diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index bdc096a4982..1e717588bc4 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -229,7 +229,7 @@ def quantize( advanced_parameters=advanced_parameters, ) if backend == BackendType.TORCH_FX: - from nncf.experimental.torch_fx.quantization.quantize_model import quantize_impl + from nncf.experimental.torch.fx.quantization.quantize_model import quantize_impl return quantize_impl( model=model,