Skip to content

Commit

Permalink
Add shape to edge for quantize dequantize nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
anzr299 committed Jul 26, 2024
1 parent d94b93b commit 7f873b5
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def get_edge_params(
"""
output_port_id = 0
if source_node.op in ("get_attr",):
tensor_shape = tuple(getattr(model, source_node.target).shape)
tensor = getattr(model, source_node.target)
tensor_shape = tuple(tensor.shape)
elif "val" in source_node.meta:
if source_nncf_node.metatype is om.PTBatchNormMetatype:
tensor = source_node.meta["val"][0]
Expand All @@ -137,6 +138,7 @@ def get_edge_params(
# TODO(dlyakhov): Refactor algorithms to always have knowns edges shapes.
nncf_logger.debug(f"Edge shape between {source_node.name} and {dist_node.name} is unknown.")
tensor_shape = None

if "quantize" in dist_node.name:
dist_node.meta["val"] = tensor
input_port_id = dist_node.all_input_nodes.index(source_node)
return input_port_id, output_port_id, tensor_shape

0 comments on commit 7f873b5

Please sign in to comment.