From 796be3cd20d0a3c16676cf1c2cd894f62654ccac Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 2 Aug 2024 19:28:24 +0200 Subject: [PATCH] [TorchFX][FBC] Constant linear layers support --- .../torch/fx/model_transformer.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/nncf/experimental/torch/fx/model_transformer.py b/nncf/experimental/torch/fx/model_transformer.py index 4be8f306051..b811e7275ca 100644 --- a/nncf/experimental/torch/fx/model_transformer.py +++ b/nncf/experimental/torch/fx/model_transformer.py @@ -18,6 +18,7 @@ from nncf.common.graph.model_transformer import ModelTransformer from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand +from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name from nncf.torch.graph.transformations.commands import PTModelExtractionCommand from nncf.torch.graph.transformations.layout import PTTransformationLayout @@ -97,7 +98,23 @@ def _apply_model_extraction( # TODO(dlyakhov): reduce memory consumption by # more optimal splitting implementation. splitted_gm = split_by_tags(model, tags) - return splitted_gm.extracted + + extracted_model = splitted_gm.extracted + graph: torch.fx.Graph = extracted_model.graph + # Check extracted model has inputs. + # It is possible to have two constant inputs + # for a linear layer, an placeholder is being + # placed to the input port. + target_node = get_graph_node_by_name(graph, node_name) + input_node = target_node.all_input_nodes[0] + if input_node.op != "placeholder": + with graph.inserting_before(target_node): + new_input_node = graph.create_node( + "placeholder", "placeholder_node", (), {}, name="placeholder_graph_node" + ) + target_node.replace_input_with(input_node, new_input_node) + extracted_model.graph.eliminate_dead_code() + return extracted_model @staticmethod def _apply_transformation(