Skip to content

Commit

Permalink
TorchFX quantization init
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jun 25, 2024
1 parent 4924b7e commit 0360129
Show file tree
Hide file tree
Showing 35 changed files with 2,569 additions and 20 deletions.
456 changes: 456 additions & 0 deletions aa_torch_fx.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/llm_compression/openvino/tiny_llama/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import time
from functools import partial

import datasets
import numpy as np
import openvino as ov
from optimum.intel.openvino import OVModelForCausalLM
from transformers import AutoTokenizer

import datasets
import nncf


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

import numpy as np
import openvino as ov
from datasets import load_dataset
from optimum.intel import OVModelForCausalLM
from transformers import AutoTokenizer
from whowhatbench import Evaluator

import nncf
from datasets import load_dataset
from nncf.common.logging import nncf_logger

DataItem = TypeVar("DataItem")
Expand Down
16 changes: 16 additions & 0 deletions nncf/common/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def create(model: TModel) -> NNCFGraph:
if model_backend == BackendType.OPENVINO:
from nncf.openvino.graph.nncf_graph_builder import GraphConverter

return GraphConverter.create_nncf_graph(model)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch_fx.nncf_graph_builder import GraphConverter

return GraphConverter.create_nncf_graph(model)
if model_backend == BackendType.TORCH:
return model.nncf.get_graph()
Expand Down Expand Up @@ -72,6 +76,10 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer:
from nncf.torch.model_transformer import PTModelTransformer

return PTModelTransformer(model)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch_fx.model_transformer import FXModelTransformer

return FXModelTransformer(model)
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific model transformer because {} is not supported!".format(model_backend.value)
)
Expand Down Expand Up @@ -99,6 +107,10 @@ def create(model: TModel) -> Engine:
from nncf.torch.engine import PTEngine

return PTEngine(model)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch_fx.engine import FXEngine

return FXEngine(model)
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific engine because {} is not supported!".format(model_backend.value)
)
Expand Down Expand Up @@ -151,6 +163,10 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator:
from nncf.torch.statistics.aggregator import PTStatisticsAggregator

return PTStatisticsAggregator(dataset)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch_fx.statistics.aggregator import FXStatisticsAggregator

return FXStatisticsAggregator(dataset)
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific statistics aggregator because {} is not supported!".format(
model_backend.value
Expand Down
10 changes: 10 additions & 0 deletions nncf/common/graph/patterns/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def _get_backend_hw_patterns_map(backend: BackendType) -> Dict[HWFusedPatternNam
if backend == BackendType.TORCH:
from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS

registry = PT_HW_FUSED_PATTERNS.registry_dict
return registry
if backend == BackendType.TORCH_FX:
from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS

registry = PT_HW_FUSED_PATTERNS.registry_dict
return registry
raise ValueError(f"Hardware-fused patterns not implemented for {backend} backend.")
Expand Down Expand Up @@ -76,6 +81,11 @@ def _get_backend_ignored_patterns_map(
if backend == BackendType.TORCH:
from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS

registry = PT_IGNORED_PATTERNS.registry_dict
return registry
if backend == BackendType.TORCH_FX:
from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS

registry = PT_IGNORED_PATTERNS.registry_dict
return registry
raise ValueError(f"Ignored patterns not implemented for {backend} backend.")
Expand Down
24 changes: 21 additions & 3 deletions nncf/common/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

class BackendType(Enum):
TORCH = "Torch"
TORCH_FX = "TorchFX"
TENSORFLOW = "Tensorflow"
ONNX = "ONNX"
OPENVINO = "OpenVINO"
Expand All @@ -33,6 +34,7 @@ def get_available_backends() -> List[BackendType]:
"""
frameworks = [
("torch", BackendType.TORCH),
("torch.fx", BackendType.TORCH_FX),
("tensorflow", BackendType.TENSORFLOW),
("onnx", BackendType.ONNX),
("openvino.runtime", BackendType.OPENVINO),
Expand All @@ -51,14 +53,27 @@ def get_available_backends() -> List[BackendType]:

def is_torch_model(model: TModel) -> bool:
"""
Returns True if the model is an instance of torch.nn.Module, otherwise False.
Returns True if the model is an instance of torch.nn.Module and not a torch.fx.GraphModule, otherwise False.
:param model: A target model.
:return: True if the model is an instance of torch.nn.Module, otherwise False.
:return: True if the model is an instance of torch.nn.Module and not torch.fx.GraphModule, otherwise False.
"""
import torch
import torch.fx

return isinstance(model, torch.nn.Module)
return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module)


def is_torch_fx_model(model: TModel) -> bool:
"""
Returns True if the model is an instance of torch.fx.GraphModule, otherwise False.
:param model: A target model.
:return: True if the model is an instance of torch.fx.GraphModule, otherwise False.
"""
import torch.fx

return isinstance(model, torch.fx.GraphModule)


def is_tensorflow_model(model: TModel) -> bool:
Expand Down Expand Up @@ -118,6 +133,9 @@ def get_backend(model: TModel) -> BackendType:
"""
available_backends = get_available_backends()

if BackendType.TORCH_FX in available_backends and is_torch_fx_model(model):
return BackendType.TORCH_FX

if BackendType.TORCH in available_backends and is_torch_model(model):
return BackendType.TORCH

Expand Down
10 changes: 10 additions & 0 deletions nncf/experimental/torch_fx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
48 changes: 48 additions & 0 deletions nncf/experimental/torch_fx/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Tuple, Union

import torch
from torch import nn

from nncf.common.engine import Engine


class FXEngine(Engine):
"""
Engine for the Pytorch FX backend.
"""

def __init__(self, model: nn.Module):
"""
Constructor.
:param model: Pytorch module to infer.
"""

self._model = model

def infer(
self, input_data: Union[torch.Tensor, Tuple[torch.Tensor], Dict[str, torch.Tensor]]
) -> Union[torch.Tensor, Dict[str, Any]]:
"""
Runs Torch model on the provided input.
:param input_data: Inputs for the model.
:return: Model outputs.
"""

if isinstance(input_data, dict):
return self._model(**input_data)
if isinstance(input_data, tuple):
return self._model(*input_data)
return self._model(input_data)
183 changes: 183 additions & 0 deletions nncf/experimental/torch_fx/model_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict

# from functools import partial
from typing import Callable, List, Union

import torch
import torch.fx
from torch.fx.passes.split_utils import split_by_tags

from nncf.common.graph.model_transformer import ModelTransformer
from nncf.common.graph.transformations.commands import Command
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.commands import TransformationType
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.layout import PTTransformationLayout


class FXModuleInsertionCommand(Command):
def __init__(
self,
target_points: List[PTTargetPoint],
module_to_insert: torch.nn.Module,
priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY,
):
super().__init__(TransformationType.INSERT)
self.target_points = target_points
self.module_to_insert = module_to_insert
self.priority = priority


class FXApplyTransformationCommand(Command):
def __init__(
self,
transformation_fn: Callable[[torch.fx.GraphModule], None],
priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY,
):
super().__init__(TransformationType.INSERT)
self.tranformation_fn = transformation_fn
self.priority = priority


class FXModelTransformer(ModelTransformer):
"""
Applies transformations upon Torch FX model.
"""

# TODO: manage priorities of transformations

def __init__(self, model: torch.fx.GraphModule):
super().__init__(model)

self._command_transformation_ordered_pairs = [
# TODO: Move the module insertion command to a transformation
(FXApplyTransformationCommand, self._apply_transformation),
(FXModuleInsertionCommand, self._apply_module_insertion),
(PTModelExtractionCommand, self._apply_model_extraction),
]

def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.GraphModule:
transformations = transformation_layout.transformations
aggregated_transformations = defaultdict(list)
for transformation in transformations:
aggregated_transformations[transformation.__class__].append(transformation)

model = self._model
for transformation_cls, transformation_fn in self._command_transformation_ordered_pairs:
transformations = aggregated_transformations[transformation_cls]
if transformations:
model = transformation_fn(model, transformations)

# Do not eliminate dead code as
# the dead code is coputing statistics :)
# model.graph.eliminate_dead_code()
model.recompile()
return model

@staticmethod
def _apply_model_extraction(
model: torch.fx.GraphModule,
transformations: List[PTModelExtractionCommand],
) -> torch.fx.GraphModule:
transformation = transformations[-1]
assert len(transformation.input_node_names) == 1
assert transformation.input_node_names == transformation.output_node_names
node_name = transformation.input_node_names[0]

tags = ["before", "extracted", "after"]
i = 0
for node in model.graph.nodes:
if node.name == node_name:
node.tag = tags[1]
weights = [node.all_input_nodes[1]]
while weights:
w_node = weights.pop()
assert w_node.tag in tags[0:2]
w_node.tag = tags[1]
weights.extend(w_node.all_input_nodes)
i = 2
continue
node.tag = tags[i]

splitted_gm = split_by_tags(model, tags)
return splitted_gm.extracted

@staticmethod
def _apply_module_insertion(
model: torch.fx.GraphModule,
transformations: List[FXModuleInsertionCommand],
) -> torch.fx.GraphModule:
"""
Applies insertion of PTSharedFnInsertionCommand commands. For each command method inserts
a torch module to the torch.fx.GraphModule and inserts call hooks for each command target points.
:param model: Model to apply transformations.
:param transformations: List of the bias correction transformations.
:param device: Target device for the insertion functions. Applies only to
functions which are subclassed from torch.nn.Module. Do nothing in case device is None.
:return: A modified torch.fx.GraphModule.
"""
for transformation in transformations:
# Set fn to the model as an attribute
module_to_insert = transformation.module_to_insert
module_name_in_model = (
";".join(
"_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value)))
for tp in transformation.target_points
)
+ "_"
+ str(id(module_to_insert))
)
assert not hasattr(model, module_name_in_model)
setattr(model, module_name_in_model, module_to_insert)
# Insert call_module nodes to the model
for target_point in transformation.target_points:
FXModelTransformer._create_call_module_node(model.graph, target_point, module_name_in_model)
return model

@staticmethod
def get_graph_node_by_name(graph, name):
for node in graph.nodes:
if node.name == name:
return node
raise RuntimeError(f"Node with name {name} is not found")

@staticmethod
def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint):
target_type = target_point.target_type
target_node = FXModelTransformer.get_graph_node_by_name(graph, target_point.target_node_name)
if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]:
target_node = target_node.all_input_nodes[target_point.input_port_id]
elif target_type == TargetType.OPERATOR_POST_HOOK:
pass
else:
raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}")
return target_node

@staticmethod
def _create_call_module_node(graph: torch.fx.Graph, target_point: PTTargetPoint, module_name: str):
target_node = FXModelTransformer._get_target_node(graph, target_point)
with graph.inserting_after(target_node):
graph.create_node("call_module", module_name, (target_node,), {}, name=module_name + "_graph_node")

@staticmethod
def _apply_transformation(
model: torch.fx.GraphModule,
transformations: List[FXApplyTransformationCommand],
) -> torch.fx.GraphModule:
for transformation in transformations:
transformation.tranformation_fn(model)
return model
Loading

0 comments on commit 0360129

Please sign in to comment.