From 7f873b5fad027fc55d16871e1de759a9e1daa34a Mon Sep 17 00:00:00 2001 From: anzr299 Date: Fri, 26 Jul 2024 14:57:38 +0400 Subject: [PATCH] Add shape to edge for quantize dequantize nodes --- nncf/experimental/torch/fx/nncf_graph_builder.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nncf/experimental/torch/fx/nncf_graph_builder.py b/nncf/experimental/torch/fx/nncf_graph_builder.py index 0863cab72ee..c4716680e47 100644 --- a/nncf/experimental/torch/fx/nncf_graph_builder.py +++ b/nncf/experimental/torch/fx/nncf_graph_builder.py @@ -122,7 +122,8 @@ def get_edge_params( """ output_port_id = 0 if source_node.op in ("get_attr",): - tensor_shape = tuple(getattr(model, source_node.target).shape) + tensor = getattr(model, source_node.target) + tensor_shape = tuple(tensor.shape) elif "val" in source_node.meta: if source_nncf_node.metatype is om.PTBatchNormMetatype: tensor = source_node.meta["val"][0] @@ -137,6 +138,7 @@ def get_edge_params( # TODO(dlyakhov): Refactor algorithms to always have knowns edges shapes. nncf_logger.debug(f"Edge shape between {source_node.name} and {dist_node.name} is unknown.") tensor_shape = None - + if "quantize" in dist_node.name: + dist_node.meta["val"] = tensor input_port_id = dist_node.all_input_nodes.index(source_node) return input_port_id, output_port_id, tensor_shape