Skip to content

Commit

Permalink
Model transformer minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jul 1, 2024
1 parent 2d5a02b commit 71b49f7
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 86 deletions.
70 changes: 1 addition & 69 deletions nncf/experimental/torch/fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,12 @@

from nncf.common.graph.model_transformer import ModelTransformer
from nncf.common.graph.transformations.commands import Command
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.commands import TransformationType
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.layout import PTTransformationLayout


class FXModuleInsertionCommand(Command):
def __init__(
self,
target_points: List[PTTargetPoint],
module_to_insert: torch.nn.Module,
priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY,
):
super().__init__(TransformationType.INSERT)
self.target_points = target_points
self.module_to_insert = module_to_insert
self.priority = priority


class FXApplyTransformationCommand(Command):
def __init__(
self,
Expand All @@ -63,9 +48,7 @@ def __init__(self, model: torch.fx.GraphModule):
super().__init__(model)

self._command_transformation_ordered_pairs = [
# TODO: Move the module insertion command to a transformation
(FXApplyTransformationCommand, self._apply_transformation),
(FXModuleInsertionCommand, self._apply_module_insertion),
(PTModelExtractionCommand, self._apply_model_extraction),
]

Expand All @@ -82,7 +65,7 @@ def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.G
model = transformation_fn(model, transformations)

# Do not eliminate dead code as
# the dead code is coputing statistics :)
# the dead code is computing statistics :)
# model.graph.eliminate_dead_code()
model.recompile()
return model
Expand Down Expand Up @@ -115,64 +98,13 @@ def _apply_model_extraction(
splitted_gm = split_by_tags(model, tags)
return splitted_gm.extracted

@staticmethod
def _apply_module_insertion(
model: torch.fx.GraphModule,
transformations: List[FXModuleInsertionCommand],
) -> torch.fx.GraphModule:
"""
Applies insertion of PTSharedFnInsertionCommand commands. For each command method inserts
a torch module to the torch.fx.GraphModule and inserts call hooks for each command target points.
:param model: Model to apply transformations.
:param transformations: List of the bias correction transformations.
:param device: Target device for the insertion functions. Applies only to
functions which are subclassed from torch.nn.Module. Do nothing in case device is None.
:return: A modified torch.fx.GraphModule.
"""
for transformation in transformations:
# Set fn to the model as an attribute
module_to_insert = transformation.module_to_insert
module_name_in_model = (
";".join(
"_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value)))
for tp in transformation.target_points
)
+ "_"
+ str(id(module_to_insert))
)
assert not hasattr(model, module_name_in_model)
setattr(model, module_name_in_model, module_to_insert)
# Insert call_module nodes to the model
for target_point in transformation.target_points:
FXModelTransformer._create_call_module_node(model.graph, target_point, module_name_in_model)
return model

@staticmethod
def get_graph_node_by_name(graph, name):
for node in graph.nodes:
if node.name == name:
return node
raise RuntimeError(f"Node with name {name} is not found")

@staticmethod
def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint):
target_type = target_point.target_type
target_node = FXModelTransformer.get_graph_node_by_name(graph, target_point.target_node_name)
if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]:
target_node = target_node.all_input_nodes[target_point.input_port_id]
elif target_type == TargetType.OPERATOR_POST_HOOK:
pass
else:
raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}")
return target_node

@staticmethod
def _create_call_module_node(graph: torch.fx.Graph, target_point: PTTargetPoint, module_name: str):
target_node = FXModelTransformer._get_target_node(graph, target_point)
with graph.inserting_after(target_node):
graph.create_node("call_module", module_name, (target_node,), {}, name=module_name + "_graph_node")

@staticmethod
def _apply_transformation(
model: torch.fx.GraphModule,
Expand Down
12 changes: 7 additions & 5 deletions nncf/experimental/torch/fx/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from nncf.common.tensor_statistics.aggregator import StatisticPointsContainer
from nncf.common.tensor_statistics.aggregator import StatisticsAggregator
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.torch.fx.model_transformer import FXModuleInsertionCommand
from nncf.experimental.torch.fx.model_transformer import FXApplyTransformationCommand
from nncf.experimental.torch.fx.transformations import leaf_module_insertion_transformation_builder
from nncf.tensor import Tensor
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.return_types import maybe_get_values_from_torch_return_type
Expand Down Expand Up @@ -74,11 +75,12 @@ def _get_transformation_layout_extra_outputs(
for _statistic_point in _statistic_points:
for collectors in _statistic_point.algorithm_to_tensor_collectors.values():
for collector in collectors:
transformation = leaf_module_insertion_transformation_builder(
TensorCollectorModule(collector), [_statistic_point.target_point]
)
transformation_commands.append(
FXModuleInsertionCommand(
[_statistic_point.target_point],
TensorCollectorModule(collector),
TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION,
FXApplyTransformationCommand(
transformation, TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION
)
)

Expand Down
71 changes: 59 additions & 12 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,48 @@
from nncf.torch.graph.transformations.commands import PTTargetPoint


def fake_quantize_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]):
def fake_quantize_insertion_transformation(model: torch.fx.GraphModule):
module_attr_name = _set_module_to_the_graph_module(model, quantizer, target_points)
def module_insertion_tranformation_builder(module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint]):
"""
Inserts given module to a target model and calls given module after each target points.
For each target node all original ouputs are being replaced by outputs of corresponded
module call.
"""

def module_insertion_transformation(model: torch.fx.GraphModule):
module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_points)
graph = model.graph
for target_point in target_points:
target_node = FXModelTransformer._get_target_node(model.graph, target_point)
with graph.inserting_after(target_node):
fq_node = graph.create_node(
"call_module", module_attr_name, (target_node,), {}, name=module_attr_name + "_quantizer"
)
target_node = _get_target_node(graph, target_point)
new_node = _insert_call_module(graph, target_node, module_attr_name)
for user in list(target_node.users):
if user is fq_node:
if user is new_node:
continue
user.replace_input_with(target_node, fq_node)
user.replace_input_with(target_node, new_node)

return module_insertion_transformation


def leaf_module_insertion_transformation_builder(module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint]):
"""
Inserts given module to a target model and calls given module after each target points.
"""

def leaf_module_insertion_transformation(model: torch.fx.GraphModule):
module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_points)
# Insert call_module nodes to the model
graph = model.graph
for target_point in target_points:
target_node = _get_target_node(graph, target_point)
_insert_call_module(graph, target_node, module_attr_name)

return fake_quantize_insertion_transformation
return leaf_module_insertion_transformation


def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor):
"""
Updates constant of the given bias node to the given value.
"""

def bias_update_transformation(model: torch.fx.GraphModule):
graph = model.graph
target_node_name = node.node_name
Expand All @@ -60,11 +83,16 @@ def bias_update_transformation(model: torch.fx.GraphModule):


def qdq_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]):
"""
Inserts quantize-dequantize operations with parameters inherited from the given quantizer to each
given target point.
"""

def qdq_insertion_tranformation(model: torch.fx.GraphModule):
if any(tp.target_type != TargetType.OPERATION_WITH_WEIGHTS for tp in target_points) and len(target_points) > 1:
raise RuntimeError
for target_point in target_points:
target_node = FXModelTransformer._get_target_node(model.graph, target_point)
target_node = _get_target_node(model.graph, target_point)
insert_one_qdq(model, target_node, quantizer, target_point)

return qdq_insertion_tranformation
Expand Down Expand Up @@ -142,6 +170,25 @@ def insert_one_qdq(
user.replace_input_with(target_node, dq_node)


def _insert_call_module(graph: torch.fx.Graph, target_node: torch.fx.Node, module_attr_name: str):
with graph.inserting_after(target_node):
return graph.create_node(
"call_module", module_attr_name, (target_node,), {}, name=module_attr_name + "_graph_node"
)


def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint):
target_type = target_point.target_type
target_node = FXModelTransformer.get_graph_node_by_name(graph, target_point.target_node_name)
if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]:
target_node = target_node.all_input_nodes[target_point.input_port_id]
elif target_type == TargetType.OPERATOR_POST_HOOK:
pass
else:
raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}")
return target_node


def _set_module_to_the_graph_module(
model: torch.fx.GraphModule, module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint]
) -> str:
Expand Down

0 comments on commit 71b49f7

Please sign in to comment.