diff --git a/nncf/experimental/torch/fx/nncf_graph_builder.py b/nncf/experimental/torch/fx/nncf_graph_builder.py index c005073b941..613474fb48a 100644 --- a/nncf/experimental/torch/fx/nncf_graph_builder.py +++ b/nncf/experimental/torch/fx/nncf_graph_builder.py @@ -68,7 +68,6 @@ def _get_layer_attributes( def _map_fx_unique_metatypes(node: torch.fx.Node, metatype: om.OperatorMetatype) -> om.OperatorMetatype: """ Attempts to retrieve correct subtype for the given node. - :param node: Given node. :param metatype: Given node metatype. :param model: Target GraphModule instance. @@ -138,6 +137,7 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph: for source_node in model.graph.nodes: node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node, model) node_metatype = GraphConverter._map_fx_unique_metatypes(source_node, node_metatype) + is_shared_node = source_node.op in ("get_attr",) and ( const_targets_counter[source_node.target] > 1 or len(source_node.users) > 1 ) diff --git a/nncf/torch/graph/operator_metatypes.py b/nncf/torch/graph/operator_metatypes.py index fd40399e22e..273f98415f9 100644 --- a/nncf/torch/graph/operator_metatypes.py +++ b/nncf/torch/graph/operator_metatypes.py @@ -927,6 +927,7 @@ class PTEmbeddingMetatype(PTOperatorMetatype): @FX_OPERATOR_METATYPES.register() + class PTAtenEmbeddingMetatype(OperatorMetatype): name = "EmbeddingOp" module_to_function_names = {NamespaceTarget.ATEN: ["embedding"]}