From 71b49f707e3a56866261a4db13eb14e4b8d03df0 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Mon, 1 Jul 2024 18:02:18 +0200 Subject: [PATCH] Model transformer minor refactoring --- .../torch/fx/model_transformer.py | 70 +----------------- .../torch/fx/statistics/aggregator.py | 12 ++-- nncf/experimental/torch/fx/transformations.py | 71 +++++++++++++++---- 3 files changed, 67 insertions(+), 86 deletions(-) diff --git a/nncf/experimental/torch/fx/model_transformer.py b/nncf/experimental/torch/fx/model_transformer.py index 48b3cf0c1f1..9c5d230887c 100644 --- a/nncf/experimental/torch/fx/model_transformer.py +++ b/nncf/experimental/torch/fx/model_transformer.py @@ -20,27 +20,12 @@ from nncf.common.graph.model_transformer import ModelTransformer from nncf.common.graph.transformations.commands import Command -from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.commands import TransformationPriority from nncf.common.graph.transformations.commands import TransformationType from nncf.torch.graph.transformations.commands import PTModelExtractionCommand -from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.graph.transformations.layout import PTTransformationLayout -class FXModuleInsertionCommand(Command): - def __init__( - self, - target_points: List[PTTargetPoint], - module_to_insert: torch.nn.Module, - priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, - ): - super().__init__(TransformationType.INSERT) - self.target_points = target_points - self.module_to_insert = module_to_insert - self.priority = priority - - class FXApplyTransformationCommand(Command): def __init__( self, @@ -63,9 +48,7 @@ def __init__(self, model: torch.fx.GraphModule): super().__init__(model) self._command_transformation_ordered_pairs = [ - # TODO: Move the module insertion command to a transformation (FXApplyTransformationCommand, self._apply_transformation), - (FXModuleInsertionCommand, self._apply_module_insertion), (PTModelExtractionCommand, self._apply_model_extraction), ] @@ -82,7 +65,7 @@ def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.G model = transformation_fn(model, transformations) # Do not eliminate dead code as - # the dead code is coputing statistics :) + # the dead code is computing statistics :) # model.graph.eliminate_dead_code() model.recompile() return model @@ -115,39 +98,6 @@ def _apply_model_extraction( splitted_gm = split_by_tags(model, tags) return splitted_gm.extracted - @staticmethod - def _apply_module_insertion( - model: torch.fx.GraphModule, - transformations: List[FXModuleInsertionCommand], - ) -> torch.fx.GraphModule: - """ - Applies insertion of PTSharedFnInsertionCommand commands. For each command method inserts - a torch module to the torch.fx.GraphModule and inserts call hooks for each command target points. - - :param model: Model to apply transformations. - :param transformations: List of the bias correction transformations. - :param device: Target device for the insertion functions. Applies only to - functions which are subclassed from torch.nn.Module. Do nothing in case device is None. - :return: A modified torch.fx.GraphModule. - """ - for transformation in transformations: - # Set fn to the model as an attribute - module_to_insert = transformation.module_to_insert - module_name_in_model = ( - ";".join( - "_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) - for tp in transformation.target_points - ) - + "_" - + str(id(module_to_insert)) - ) - assert not hasattr(model, module_name_in_model) - setattr(model, module_name_in_model, module_to_insert) - # Insert call_module nodes to the model - for target_point in transformation.target_points: - FXModelTransformer._create_call_module_node(model.graph, target_point, module_name_in_model) - return model - @staticmethod def get_graph_node_by_name(graph, name): for node in graph.nodes: @@ -155,24 +105,6 @@ def get_graph_node_by_name(graph, name): return node raise RuntimeError(f"Node with name {name} is not found") - @staticmethod - def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint): - target_type = target_point.target_type - target_node = FXModelTransformer.get_graph_node_by_name(graph, target_point.target_node_name) - if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]: - target_node = target_node.all_input_nodes[target_point.input_port_id] - elif target_type == TargetType.OPERATOR_POST_HOOK: - pass - else: - raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}") - return target_node - - @staticmethod - def _create_call_module_node(graph: torch.fx.Graph, target_point: PTTargetPoint, module_name: str): - target_node = FXModelTransformer._get_target_node(graph, target_point) - with graph.inserting_after(target_node): - graph.create_node("call_module", module_name, (target_node,), {}, name=module_name + "_graph_node") - @staticmethod def _apply_transformation( model: torch.fx.GraphModule, diff --git a/nncf/experimental/torch/fx/statistics/aggregator.py b/nncf/experimental/torch/fx/statistics/aggregator.py index 774430b1834..efca401c175 100644 --- a/nncf/experimental/torch/fx/statistics/aggregator.py +++ b/nncf/experimental/torch/fx/statistics/aggregator.py @@ -21,7 +21,8 @@ from nncf.common.tensor_statistics.aggregator import StatisticPointsContainer from nncf.common.tensor_statistics.aggregator import StatisticsAggregator from nncf.experimental.common.tensor_statistics.collectors import TensorCollector -from nncf.experimental.torch.fx.model_transformer import FXModuleInsertionCommand +from nncf.experimental.torch.fx.model_transformer import FXApplyTransformationCommand +from nncf.experimental.torch.fx.transformations import leaf_module_insertion_transformation_builder from nncf.tensor import Tensor from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.return_types import maybe_get_values_from_torch_return_type @@ -74,11 +75,12 @@ def _get_transformation_layout_extra_outputs( for _statistic_point in _statistic_points: for collectors in _statistic_point.algorithm_to_tensor_collectors.values(): for collector in collectors: + transformation = leaf_module_insertion_transformation_builder( + TensorCollectorModule(collector), [_statistic_point.target_point] + ) transformation_commands.append( - FXModuleInsertionCommand( - [_statistic_point.target_point], - TensorCollectorModule(collector), - TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION, + FXApplyTransformationCommand( + transformation, TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION ) ) diff --git a/nncf/experimental/torch/fx/transformations.py b/nncf/experimental/torch/fx/transformations.py index 9c608c90290..09306fb53bc 100644 --- a/nncf/experimental/torch/fx/transformations.py +++ b/nncf/experimental/torch/fx/transformations.py @@ -25,25 +25,48 @@ from nncf.torch.graph.transformations.commands import PTTargetPoint -def fake_quantize_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]): - def fake_quantize_insertion_transformation(model: torch.fx.GraphModule): - module_attr_name = _set_module_to_the_graph_module(model, quantizer, target_points) +def module_insertion_tranformation_builder(module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint]): + """ + Inserts given module to a target model and calls given module after each target points. + For each target node all original ouputs are being replaced by outputs of corresponded + module call. + """ + + def module_insertion_transformation(model: torch.fx.GraphModule): + module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_points) graph = model.graph for target_point in target_points: - target_node = FXModelTransformer._get_target_node(model.graph, target_point) - with graph.inserting_after(target_node): - fq_node = graph.create_node( - "call_module", module_attr_name, (target_node,), {}, name=module_attr_name + "_quantizer" - ) + target_node = _get_target_node(graph, target_point) + new_node = _insert_call_module(graph, target_node, module_attr_name) for user in list(target_node.users): - if user is fq_node: + if user is new_node: continue - user.replace_input_with(target_node, fq_node) + user.replace_input_with(target_node, new_node) + + return module_insertion_transformation + + +def leaf_module_insertion_transformation_builder(module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint]): + """ + Inserts given module to a target model and calls given module after each target points. + """ + + def leaf_module_insertion_transformation(model: torch.fx.GraphModule): + module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_points) + # Insert call_module nodes to the model + graph = model.graph + for target_point in target_points: + target_node = _get_target_node(graph, target_point) + _insert_call_module(graph, target_node, module_attr_name) - return fake_quantize_insertion_transformation + return leaf_module_insertion_transformation def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor): + """ + Updates constant of the given bias node to the given value. + """ + def bias_update_transformation(model: torch.fx.GraphModule): graph = model.graph target_node_name = node.node_name @@ -60,11 +83,16 @@ def bias_update_transformation(model: torch.fx.GraphModule): def qdq_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]): + """ + Inserts quantize-dequantize operations with parameters inherited from the given quantizer to each + given target point. + """ + def qdq_insertion_tranformation(model: torch.fx.GraphModule): if any(tp.target_type != TargetType.OPERATION_WITH_WEIGHTS for tp in target_points) and len(target_points) > 1: raise RuntimeError for target_point in target_points: - target_node = FXModelTransformer._get_target_node(model.graph, target_point) + target_node = _get_target_node(model.graph, target_point) insert_one_qdq(model, target_node, quantizer, target_point) return qdq_insertion_tranformation @@ -142,6 +170,25 @@ def insert_one_qdq( user.replace_input_with(target_node, dq_node) +def _insert_call_module(graph: torch.fx.Graph, target_node: torch.fx.Node, module_attr_name: str): + with graph.inserting_after(target_node): + return graph.create_node( + "call_module", module_attr_name, (target_node,), {}, name=module_attr_name + "_graph_node" + ) + + +def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint): + target_type = target_point.target_type + target_node = FXModelTransformer.get_graph_node_by_name(graph, target_point.target_node_name) + if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]: + target_node = target_node.all_input_nodes[target_point.input_port_id] + elif target_type == TargetType.OPERATOR_POST_HOOK: + pass + else: + raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}") + return target_node + + def _set_module_to_the_graph_module( model: torch.fx.GraphModule, module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint] ) -> str: