Skip to content

Commit

Permalink
Fix remove_shapeof_subgraphs (#2536)
Browse files Browse the repository at this point in the history
### Changes


![img](https://github.com/openvinotoolkit/nncf/assets/77268007/a2d33e78-2fdf-4aff-8826-258d538e2f63)


### Reason for changes

The `remove_shapeof_subgraphs()` function removes some operations whose
weights are quantized.

### Related tickets

Ref: 131843

### Tests

Current scope
  • Loading branch information
andrey-churkin authored Mar 1, 2024
1 parent 5daa391 commit 0fdc629
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
7 changes: 6 additions & 1 deletion nncf/quantization/algorithms/accuracy_control/ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,15 @@ def find_groups_of_quantizers_to_rank(self, quantized_model_graph: NNCFGraph) ->
processed = {}
quantizers = [x for x in quantized_model_graph.topological_sort() if x.metatype in quantizer_metatypes]

input_nodes = [
*quantized_model_graph.get_nodes_by_metatypes(self._algo_backend.get_op_with_weights_metatypes()),
*self._algo_backend.get_start_nodes_for_activation_path_tracing(quantized_model_graph),
]

quantized_model_graph_without_shapeof = remove_shapeof_subgraphs(
deepcopy(quantized_model_graph),
self._algo_backend.get_shapeof_metatypes(),
self._algo_backend.get_start_nodes_for_activation_path_tracing(quantized_model_graph),
input_nodes,
)

for quantizer_node in reversed(quantizers):
Expand Down
12 changes: 6 additions & 6 deletions nncf/quantization/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,26 +64,26 @@ def remove_shapeof_subgraphs(
if node.metatype in shapeof_metatypes:
shape_of_nodes.append(node)
continue
if node in infer_nodes:
if node.node_name in infer_nodes:
continue
infer_nodes.append(node)
infer_nodes.append(node.node_name)
nodes_queue.extend(nncf_graph.get_next_nodes(node))

for shape_of_node in shape_of_nodes:
nodes_to_drop.add(shape_of_node)
nodes_to_drop.add(shape_of_node.node_name)

shape_of_queue = collections.deque()
shape_of_queue.extend(nncf_graph.get_next_nodes(shape_of_node))
while shape_of_queue:
node = shape_of_queue.pop()
if node in nodes_to_drop or node in infer_nodes:
if node.node_name in nodes_to_drop or node.node_name in infer_nodes:
continue
nodes_to_drop.add(node)
nodes_to_drop.add(node.node_name)
# traverse forward and backward to exclude full shape of subgraph
# recursion excluded due to infer_nodes list around subgraph shape
shape_of_queue.extend(nncf_graph.get_next_nodes(node) + nncf_graph.get_previous_nodes(node))

nncf_graph.remove_nodes_from(nodes_to_drop)
nncf_graph.remove_nodes_from([nncf_graph.get_node_by_name(name) for name in nodes_to_drop])
return nncf_graph


Expand Down
5 changes: 5 additions & 0 deletions tests/common/accuracy_control/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@
from tests.common.quantization.metatypes import QUANTIZABLE_METATYPES
from tests.common.quantization.metatypes import QUANTIZE_AGNOSTIC_METATYPES
from tests.common.quantization.metatypes import QUANTIZER_METATYPES
from tests.common.quantization.metatypes import WEIGHT_LAYER_METATYPES
from tests.common.quantization.metatypes import ShapeOfTestMetatype


class AABackendForTests(AccuracyControlAlgoBackend):
@staticmethod
def get_op_with_weights_metatypes() -> List[OperatorMetatype]:
return WEIGHT_LAYER_METATYPES

@staticmethod
def get_quantizer_metatypes() -> List[OperatorMetatype]:
return QUANTIZER_METATYPES
Expand Down

0 comments on commit 0fdc629

Please sign in to comment.