Skip to content

Commit

Permalink
[TorchFX][FBC] Constant linear layers support
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Aug 2, 2024
1 parent f5aa90f commit 796be3c
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion nncf/experimental/torch/fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from nncf.common.graph.model_transformer import ModelTransformer
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
from nncf.torch.graph.transformations.layout import PTTransformationLayout

Expand Down Expand Up @@ -97,7 +98,23 @@ def _apply_model_extraction(
# TODO(dlyakhov): reduce memory consumption by
# more optimal splitting implementation.
splitted_gm = split_by_tags(model, tags)
return splitted_gm.extracted

extracted_model = splitted_gm.extracted
graph: torch.fx.Graph = extracted_model.graph
# Check extracted model has inputs.
# It is possible to have two constant inputs
# for a linear layer, an placeholder is being
# placed to the input port.
target_node = get_graph_node_by_name(graph, node_name)
input_node = target_node.all_input_nodes[0]
if input_node.op != "placeholder":
with graph.inserting_before(target_node):
new_input_node = graph.create_node(
"placeholder", "placeholder_node", (), {}, name="placeholder_graph_node"
)
target_node.replace_input_with(input_node, new_input_node)
extracted_model.graph.eliminate_dead_code()
return extracted_model

@staticmethod
def _apply_transformation(
Expand Down

0 comments on commit 796be3c

Please sign in to comment.