Skip to content

Commit

Permalink
[Experimental] TorchFX PTQ backend (openvinotoolkit#2764)
Browse files Browse the repository at this point in the history
### Changes

* Torch FX experimental PTQ backend is presented (MinMax, FBC)
* Torch metatypes are updated with new namespace: ATEN
* Some Torch metatypes are updated by new operations names 

### Reason for changes

To begin the Torch FX backend development

### Related tickets

141640

### Tests

* Resnet18 imagnette sanity test
  • Loading branch information
daniil-lyakhov authored Jul 26, 2024
1 parent 33bbf6e commit d94b93b
Show file tree
Hide file tree
Showing 26 changed files with 1,797 additions and 11 deletions.
14 changes: 13 additions & 1 deletion nncf/common/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def create(model: TModel) -> NNCFGraph:
if model_backend == BackendType.OPENVINO:
from nncf.openvino.graph.nncf_graph_builder import GraphConverter

return GraphConverter.create_nncf_graph(model)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter

return GraphConverter.create_nncf_graph(model)
if model_backend == BackendType.TORCH:
return model.nncf.get_graph()
Expand Down Expand Up @@ -72,6 +76,10 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer:
from nncf.torch.model_transformer import PTModelTransformer

return PTModelTransformer(model)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.model_transformer import FXModelTransformer

return FXModelTransformer(model)
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific model transformer because {} is not supported!".format(model_backend.value)
)
Expand All @@ -95,7 +103,7 @@ def create(model: TModel) -> Engine:
from nncf.openvino.engine import OVNativeEngine

return OVNativeEngine(model)
if model_backend == BackendType.TORCH:
if model_backend in (BackendType.TORCH, BackendType.TORCH_FX):
from nncf.torch.engine import PTEngine

return PTEngine(model)
Expand Down Expand Up @@ -151,6 +159,10 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator:
from nncf.torch.statistics.aggregator import PTStatisticsAggregator

return PTStatisticsAggregator(dataset)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.statistics.aggregator import FXStatisticsAggregator

return FXStatisticsAggregator(dataset)
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific statistics aggregator because {} is not supported!".format(
model_backend.value
Expand Down
4 changes: 2 additions & 2 deletions nncf/common/graph/patterns/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _get_backend_hw_patterns_map(backend: BackendType) -> Dict[HWFusedPatternNam
Dict[HWFusedPatternNames, Callable[[], GraphPattern]], OPENVINO_HW_FUSED_PATTERNS.registry_dict
)
return registry
if backend == BackendType.TORCH:
if backend in (BackendType.TORCH, BackendType.TORCH_FX):
from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS

registry = cast(Dict[HWFusedPatternNames, Callable[[], GraphPattern]], PT_HW_FUSED_PATTERNS.registry_dict)
Expand Down Expand Up @@ -77,7 +77,7 @@ def _get_backend_ignored_patterns_map(
Dict[IgnoredPatternNames, Callable[[], GraphPattern]], OPENVINO_IGNORED_PATTERNS.registry_dict
)
return registry
if backend == BackendType.TORCH:
if backend in (BackendType.TORCH, BackendType.TORCH_FX):
from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS

registry = cast(Dict[IgnoredPatternNames, Callable[[], GraphPattern]], PT_IGNORED_PATTERNS.registry_dict)
Expand Down
24 changes: 21 additions & 3 deletions nncf/common/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

class BackendType(Enum):
TORCH = "Torch"
TORCH_FX = "TorchFX"
TENSORFLOW = "Tensorflow"
ONNX = "ONNX"
OPENVINO = "OpenVINO"
Expand All @@ -33,6 +34,7 @@ def get_available_backends() -> List[BackendType]:
"""
frameworks = [
("torch", BackendType.TORCH),
("torch.fx", BackendType.TORCH_FX),
("tensorflow", BackendType.TENSORFLOW),
("onnx", BackendType.ONNX),
("openvino.runtime", BackendType.OPENVINO),
Expand All @@ -51,14 +53,27 @@ def get_available_backends() -> List[BackendType]:

def is_torch_model(model: TModel) -> bool:
"""
Returns True if the model is an instance of torch.nn.Module, otherwise False.
Returns True if the model is an instance of torch.nn.Module and not a torch.fx.GraphModule, otherwise False.
:param model: A target model.
:return: True if the model is an instance of torch.nn.Module, otherwise False.
:return: True if the model is an instance of torch.nn.Module and not torch.fx.GraphModule, otherwise False.
"""
import torch # type: ignore
import torch.fx # type: ignore

return isinstance(model, torch.nn.Module)
return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module)


def is_torch_fx_model(model: TModel) -> bool:
"""
Returns True if the model is an instance of torch.fx.GraphModule, otherwise False.
:param model: A target model.
:return: True if the model is an instance of torch.fx.GraphModule, otherwise False.
"""
import torch.fx

return isinstance(model, torch.fx.GraphModule)


def is_tensorflow_model(model: TModel) -> bool:
Expand Down Expand Up @@ -118,6 +133,9 @@ def get_backend(model: TModel) -> BackendType:
"""
available_backends = get_available_backends()

if BackendType.TORCH_FX in available_backends and is_torch_fx_model(model):
return BackendType.TORCH_FX

if BackendType.TORCH in available_backends and is_torch_model(model):
return BackendType.TORCH

Expand Down
10 changes: 10 additions & 0 deletions nncf/experimental/torch/fx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
37 changes: 37 additions & 0 deletions nncf/experimental/torch/fx/commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Union

import torch.fx

from nncf.common.graph.transformations.commands import Command
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.commands import TransformationType


class FXApplyTransformationCommand(Command):
"""
Command to apply given transformation to a model.
"""

def __init__(
self,
transformation_fn: Callable[[torch.fx.GraphModule], None],
priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY,
):
"""
:param transformation_fn: Target transformation function.
:param priority: Transformation priority.
"""
super().__init__(TransformationType.INSERT)
self.tranformation_fn = transformation_fn
self.priority = priority
116 changes: 116 additions & 0 deletions nncf/experimental/torch/fx/model_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from typing import List

import torch
import torch.fx
from torch.fx.passes.split_utils import split_by_tags

from nncf.common.graph.model_transformer import ModelTransformer
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
from nncf.torch.graph.transformations.layout import PTTransformationLayout


class FXModelTransformer(ModelTransformer):
"""
Applies transformations upon Torch FX model.
"""

def __init__(self, model: torch.fx.GraphModule):
super().__init__(model)

self._command_transformation_ordered_pairs = [
(FXApplyTransformationCommand, self._apply_transformation),
(PTModelExtractionCommand, self._apply_model_extraction),
]

def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.GraphModule:
"""
Transforms the target model according to given transformation layout.
:param transformation_layout: Given transformation layout.
:return: Target model transformered according to the given transformation layout.
"""
# TODO(dlyakhov): Manage priorities of transformations.
transformations = transformation_layout.transformations
aggregated_transformations = defaultdict(list)
for transformation in transformations:
aggregated_transformations[transformation.__class__].append(transformation)

model = self._model
for transformation_cls, transformation_fn in self._command_transformation_ordered_pairs:
transformations = aggregated_transformations[transformation_cls]
if transformations:
model = transformation_fn(model, transformations)

# Do not use model.graph.eliminate_dead_code()
# because the computational statistics code
# is interpolated as dead code.
model.recompile()
return model

@staticmethod
def _apply_model_extraction(
model: torch.fx.GraphModule,
transformations: List[PTModelExtractionCommand],
) -> torch.fx.GraphModule:
"""
Returns a submodel extracted from the given model by the given transformation.
:param model: Given model.
:param transformations: List of one transformation which specifies
how to retrieve a submodule from the model. In case list contains
more than one element this function raises an assert.
:return: Returns a submodel extracted from the given model by the given transformation.
"""
transformation = transformations[-1]
assert len(transformation.input_node_names) == 1
assert transformation.input_node_names == transformation.output_node_names
node_name = transformation.input_node_names[0]

tags = ["before", "extracted", "after"]
i = 0
for node in model.graph.nodes:
if node.name == node_name:
node.tag = tags[1]
weights = [node.all_input_nodes[1]]
while weights:
w_node = weights.pop()
assert w_node.tag in tags[0:2]
w_node.tag = tags[1]
weights.extend(w_node.all_input_nodes)
i = 2
continue
node.tag = tags[i]

# TODO(dlyakhov): reduce memory consumption by
# more optimal splitting implementation.
splitted_gm = split_by_tags(model, tags)
return splitted_gm.extracted

@staticmethod
def _apply_transformation(
model: torch.fx.GraphModule,
transformations: List[FXApplyTransformationCommand],
) -> torch.fx.GraphModule:
"""
Applies transformations to the given model.
:param model: Target model.
:param transformations: Transformations to apply to the model.
:return: Target model after all transformations were applied.
"""
for transformation in transformations:
transformation.tranformation_fn(model)
return model
Loading

0 comments on commit d94b93b

Please sign in to comment.