Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] TorchFX PTQ backend #2764

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion nncf/common/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def create(model: TModel) -> NNCFGraph:
if model_backend == BackendType.OPENVINO:
from nncf.openvino.graph.nncf_graph_builder import GraphConverter

return GraphConverter.create_nncf_graph(model)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter

return GraphConverter.create_nncf_graph(model)
if model_backend == BackendType.TORCH:
return model.nncf.get_graph()
Expand Down Expand Up @@ -72,6 +76,10 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer:
from nncf.torch.model_transformer import PTModelTransformer

return PTModelTransformer(model)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.model_transformer import FXModelTransformer

return FXModelTransformer(model)
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific model transformer because {} is not supported!".format(model_backend.value)
)
Expand All @@ -95,7 +103,7 @@ def create(model: TModel) -> Engine:
from nncf.openvino.engine import OVNativeEngine

return OVNativeEngine(model)
if model_backend == BackendType.TORCH:
if model_backend in (BackendType.TORCH, BackendType.TORCH_FX):
from nncf.torch.engine import PTEngine

return PTEngine(model)
Expand Down Expand Up @@ -151,6 +159,10 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator:
from nncf.torch.statistics.aggregator import PTStatisticsAggregator

return PTStatisticsAggregator(dataset)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.statistics.aggregator import FXStatisticsAggregator

return FXStatisticsAggregator(dataset)
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific statistics aggregator because {} is not supported!".format(
model_backend.value
Expand Down
4 changes: 2 additions & 2 deletions nncf/common/graph/patterns/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _get_backend_hw_patterns_map(backend: BackendType) -> Dict[HWFusedPatternNam
Dict[HWFusedPatternNames, Callable[[], GraphPattern]], OPENVINO_HW_FUSED_PATTERNS.registry_dict
)
return registry
if backend == BackendType.TORCH:
if backend in (BackendType.TORCH, BackendType.TORCH_FX):
from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS

registry = cast(Dict[HWFusedPatternNames, Callable[[], GraphPattern]], PT_HW_FUSED_PATTERNS.registry_dict)
Expand Down Expand Up @@ -77,7 +77,7 @@ def _get_backend_ignored_patterns_map(
Dict[IgnoredPatternNames, Callable[[], GraphPattern]], OPENVINO_IGNORED_PATTERNS.registry_dict
)
return registry
if backend == BackendType.TORCH:
if backend in (BackendType.TORCH, BackendType.TORCH_FX):
from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS

registry = cast(Dict[IgnoredPatternNames, Callable[[], GraphPattern]], PT_IGNORED_PATTERNS.registry_dict)
Expand Down
24 changes: 21 additions & 3 deletions nncf/common/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

class BackendType(Enum):
TORCH = "Torch"
TORCH_FX = "TorchFX"
TENSORFLOW = "Tensorflow"
ONNX = "ONNX"
OPENVINO = "OpenVINO"
Expand All @@ -33,6 +34,7 @@ def get_available_backends() -> List[BackendType]:
"""
frameworks = [
("torch", BackendType.TORCH),
("torch.fx", BackendType.TORCH_FX),
("tensorflow", BackendType.TENSORFLOW),
("onnx", BackendType.ONNX),
("openvino.runtime", BackendType.OPENVINO),
Expand All @@ -51,14 +53,27 @@ def get_available_backends() -> List[BackendType]:

def is_torch_model(model: TModel) -> bool:
"""
Returns True if the model is an instance of torch.nn.Module, otherwise False.
Returns True if the model is an instance of torch.nn.Module and not a torch.fx.GraphModule, otherwise False.
:param model: A target model.
:return: True if the model is an instance of torch.nn.Module, otherwise False.
:return: True if the model is an instance of torch.nn.Module and not torch.fx.GraphModule, otherwise False.
"""
import torch # type: ignore
import torch.fx # type: ignore

return isinstance(model, torch.nn.Module)
return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module)


def is_torch_fx_model(model: TModel) -> bool:
"""
Returns True if the model is an instance of torch.fx.GraphModule, otherwise False.
:param model: A target model.
:return: True if the model is an instance of torch.fx.GraphModule, otherwise False.
"""
import torch.fx

return isinstance(model, torch.fx.GraphModule)
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved


def is_tensorflow_model(model: TModel) -> bool:
Expand Down Expand Up @@ -118,6 +133,9 @@ def get_backend(model: TModel) -> BackendType:
"""
available_backends = get_available_backends()

if BackendType.TORCH_FX in available_backends and is_torch_fx_model(model):
return BackendType.TORCH_FX

if BackendType.TORCH in available_backends and is_torch_model(model):
return BackendType.TORCH

Expand Down
10 changes: 10 additions & 0 deletions nncf/experimental/torch/fx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
37 changes: 37 additions & 0 deletions nncf/experimental/torch/fx/commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Union

import torch.fx

from nncf.common.graph.transformations.commands import Command
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.commands import TransformationType


class FXApplyTransformationCommand(Command):
"""
Command to apply given transformation to a model.
"""

def __init__(
self,
transformation_fn: Callable[[torch.fx.GraphModule], None],
priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY,
):
"""
:param transformation_fn: Target transformation function.
:param priority: Transformation priority.
"""
super().__init__(TransformationType.INSERT)
self.tranformation_fn = transformation_fn
self.priority = priority
116 changes: 116 additions & 0 deletions nncf/experimental/torch/fx/model_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from typing import List

import torch
import torch.fx
from torch.fx.passes.split_utils import split_by_tags

from nncf.common.graph.model_transformer import ModelTransformer
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
from nncf.torch.graph.transformations.layout import PTTransformationLayout


class FXModelTransformer(ModelTransformer):
"""
Applies transformations upon Torch FX model.
"""

def __init__(self, model: torch.fx.GraphModule):
super().__init__(model)

self._command_transformation_ordered_pairs = [
(FXApplyTransformationCommand, self._apply_transformation),
(PTModelExtractionCommand, self._apply_model_extraction),
]

def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.GraphModule:
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
"""
Transforms the target model according to given transformation layout.
:param transformation_layout: Given transformation layout.
:return: Target model transformered according to the given transformation layout.
"""
# TODO(dlyakhov): Manage priorities of transformations.
transformations = transformation_layout.transformations
aggregated_transformations = defaultdict(list)
for transformation in transformations:
aggregated_transformations[transformation.__class__].append(transformation)

model = self._model
for transformation_cls, transformation_fn in self._command_transformation_ordered_pairs:
transformations = aggregated_transformations[transformation_cls]
if transformations:
model = transformation_fn(model, transformations)

# Do not use model.graph.eliminate_dead_code()
# because the computational statistics code
# is interpolated as dead code.
model.recompile()
return model

@staticmethod
def _apply_model_extraction(
model: torch.fx.GraphModule,
transformations: List[PTModelExtractionCommand],
) -> torch.fx.GraphModule:
"""
Returns a submodel extracted from the given model by the given transformation.
:param model: Given model.
:param transformations: List of one transformation which specifies
how to retrieve a submodule from the model. In case list contains
more than one element this function raises an assert.
:return: Returns a submodel extracted from the given model by the given transformation.
"""
transformation = transformations[-1]
assert len(transformation.input_node_names) == 1
assert transformation.input_node_names == transformation.output_node_names
node_name = transformation.input_node_names[0]

tags = ["before", "extracted", "after"]
i = 0
for node in model.graph.nodes:
if node.name == node_name:
node.tag = tags[1]
weights = [node.all_input_nodes[1]]
while weights:
w_node = weights.pop()
assert w_node.tag in tags[0:2]
w_node.tag = tags[1]
weights.extend(w_node.all_input_nodes)
i = 2
continue
node.tag = tags[i]

# TODO(dlyakhov): reduce memory consumption by
# more optimal splitting implementation.
splitted_gm = split_by_tags(model, tags)
AlexanderDokuchaev marked this conversation as resolved.
Show resolved Hide resolved
return splitted_gm.extracted

@staticmethod
def _apply_transformation(
model: torch.fx.GraphModule,
transformations: List[FXApplyTransformationCommand],
) -> torch.fx.GraphModule:
"""
Applies transformations to the given model.
:param model: Target model.
:param transformations: Transformations to apply to the model.
:return: Target model after all transformations were applied.
"""
for transformation in transformations:
transformation.tranformation_fn(model)
return model
Loading
Loading