Skip to content

Commit

Permalink
Tests are fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jul 10, 2024
1 parent ad383b2 commit 8d498db
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 23 deletions.
46 changes: 38 additions & 8 deletions tests/common/quantization/mock_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nncf.common.graph.layer_attributes import BaseLayerAttributes
from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes
from nncf.common.graph.layer_attributes import Dtype
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.operator_metatypes import UnknownMetatype
from nncf.common.insertion_point_graph import InsertionPointGraph
from nncf.common.insertion_point_graph import PostHookInsertionPoint
Expand Down Expand Up @@ -188,16 +189,26 @@ def get_mock_nncf_node_attrs(op_name=None, scope_str=None, metatype=None, type_=


def _add_nodes_with_layer_attrs(
nx_graph: nx.DiGraph, node_keys: List[str], layer_attrs: Dict[str, BaseLayerAttributes]
nx_graph: nx.DiGraph,
node_keys: List[str],
layer_attrs: Dict[str, BaseLayerAttributes],
metatypes: Dict[str, OperatorMetatype] = None,
) -> nx.DiGraph:
for node_key in node_keys:
nx_graph.add_node(node_key, **get_mock_nncf_node_attrs(op_name=node_key))
metatype = None
if metatypes is not None and node_key in metatypes:
metatype = metatypes[node_key]
nx_graph.add_node(node_key, **get_mock_nncf_node_attrs(op_name=node_key, metatype=metatype))

if node_key in layer_attrs:
nx_graph.nodes[node_key][NNCFNode.LAYER_ATTRIBUTES] = layer_attrs[node_key]

return nx_graph


def get_mock_model_graph_with_mergeable_pattern() -> NNCFGraph:
def get_mock_model_graph_with_mergeable_pattern(
conv2d_metatype=None, batchnorm_metatype=None, relu_metatype=None
) -> NNCFGraph:
mock_nx_graph = nx.DiGraph()

# (A)
Expand Down Expand Up @@ -225,7 +236,12 @@ def get_mock_model_graph_with_mergeable_pattern() -> NNCFGraph:
padding_values=[0, 0, 0, 0],
)
}
mock_nx_graph = _add_nodes_with_layer_attrs(mock_nx_graph, node_keys, layer_attrs)
metatypes = {
"conv2d": conv2d_metatype,
"batch_norm": batchnorm_metatype,
"relu": relu_metatype,
}
mock_nx_graph = _add_nodes_with_layer_attrs(mock_nx_graph, node_keys, layer_attrs, metatypes)

mock_nx_graph.add_edges_from(
[
Expand All @@ -238,7 +254,9 @@ def get_mock_model_graph_with_mergeable_pattern() -> NNCFGraph:
return get_nncf_graph_from_mock_nx_graph(mock_nx_graph)


def get_mock_model_graph_with_no_mergeable_pattern() -> NNCFGraph:
def get_mock_model_graph_with_no_mergeable_pattern(
conv2d_metatype=None, batchnorm_metatype=None, relu_metatype=None
) -> NNCFGraph:
mock_nx_graph = nx.DiGraph()

# (A)
Expand Down Expand Up @@ -270,7 +288,12 @@ def get_mock_model_graph_with_no_mergeable_pattern() -> NNCFGraph:
padding_values=[0, 0, 0, 0],
)
}
mock_nx_graph = _add_nodes_with_layer_attrs(mock_nx_graph, node_keys, layer_attrs)
metatypes = {
"conv2d": conv2d_metatype,
"batch_norm": batchnorm_metatype,
"relu": relu_metatype,
}
mock_nx_graph = _add_nodes_with_layer_attrs(mock_nx_graph, node_keys, layer_attrs, metatypes)

mock_nx_graph.add_edges_from(
[
Expand All @@ -285,7 +308,9 @@ def get_mock_model_graph_with_no_mergeable_pattern() -> NNCFGraph:
return get_nncf_graph_from_mock_nx_graph(mock_nx_graph)


def get_mock_model_graph_with_broken_output_edge_pattern() -> NNCFGraph:
def get_mock_model_graph_with_broken_output_edge_pattern(
conv2d_metatype=None, batchnorm_metatype=None, relu_metatype=None
) -> NNCFGraph:
mock_nx_graph = nx.DiGraph()

# (A)
Expand Down Expand Up @@ -314,7 +339,12 @@ def get_mock_model_graph_with_broken_output_edge_pattern() -> NNCFGraph:
padding_values=[0, 0, 0, 0],
)
}
mock_nx_graph = _add_nodes_with_layer_attrs(mock_nx_graph, node_keys, layer_attrs)
metatypes = {
"conv2d": conv2d_metatype,
"batch_norm": batchnorm_metatype,
"relu": relu_metatype,
}
mock_nx_graph = _add_nodes_with_layer_attrs(mock_nx_graph, node_keys, layer_attrs, metatypes)

mock_nx_graph.add_edges_from(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@ strict digraph {
"1 SymmetricQuantizer/symmetric_quantize_0" [id=1, type=symmetric_quantize];
"2 ShiftScaleParametrized/clone_0" [id=2, type=clone];
"3 ShiftScaleParametrized/sub__0" [id=3, type=sub_];
"4 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/sub__0|OUTPUT]/symmetric_quantize_0" [id=4, type=symmetric_quantize];
"5 ShiftScaleParametrized/div__0" [id=5, type=div_];
"6 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" [id=6, type=symmetric_quantize];
"7 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" [id=7, type=symmetric_quantize];
"8 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" [id=8, type=conv2d];
"9 /nncf_model_output_0" [id=9, type=nncf_model_output];
"4 ShiftScaleParametrized/div__0" [id=4, type=div_];
"5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" [id=5, type=symmetric_quantize];
"6 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" [id=6, type=symmetric_quantize];
"7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" [id=7, type=conv2d];
"8 /nncf_model_output_0" [id=8, type=nncf_model_output];
"0 /nncf_model_input_0" -> "1 SymmetricQuantizer/symmetric_quantize_0";
"1 SymmetricQuantizer/symmetric_quantize_0" -> "2 ShiftScaleParametrized/clone_0";
"2 ShiftScaleParametrized/clone_0" -> "3 ShiftScaleParametrized/sub__0";
"3 ShiftScaleParametrized/sub__0" -> "4 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/sub__0|OUTPUT]/symmetric_quantize_0";
"4 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/sub__0|OUTPUT]/symmetric_quantize_0" -> "5 ShiftScaleParametrized/div__0";
"5 ShiftScaleParametrized/div__0" -> "6 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0";
"6 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" -> "8 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0";
"7 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" -> "8 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0";
"8 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" -> "9 /nncf_model_output_0";
"3 ShiftScaleParametrized/sub__0" -> "4 ShiftScaleParametrized/div__0";
"4 ShiftScaleParametrized/div__0" -> "5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0";
"5 ShiftScaleParametrized/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[ShiftScaleParametrized/div__0|OUTPUT]/symmetric_quantize_0" -> "7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0";
"6 ShiftScaleParametrized/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" -> "7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0";
"7 ShiftScaleParametrized/NNCFConv2d[conv]/conv2d_0" -> "8 /nncf_model_output_0";
}
32 changes: 29 additions & 3 deletions tests/torch/test_model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections import Counter
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import List

Expand All @@ -21,6 +22,7 @@
import torch.nn.functional as F
from torch import nn

import nncf.torch.graph.operator_metatypes as om
from nncf.common.graph.definitions import MODEL_INPUT_OP_NAME
from nncf.common.graph.definitions import MODEL_OUTPUT_OP_NAME
from nncf.common.graph.patterns.manager import PatternsManager
Expand Down Expand Up @@ -367,9 +369,33 @@ def test_priority(self, target_type, trace_parameters, priority_type):


MERGE_PATTERN_TEST_CASES = (
[get_mock_model_graph_with_mergeable_pattern, "basic_pattern"],
[get_mock_model_graph_with_no_mergeable_pattern, "no_pattern"],
[get_mock_model_graph_with_broken_output_edge_pattern, "broken_output_edges_pattern"],
[
partial(
get_mock_model_graph_with_mergeable_pattern,
conv2d_metatype=om.PTConv2dMetatype,
batchnorm_metatype=om.PTBatchNormMetatype,
relu_metatype=om.PTRELUMetatype,
),
"basic_pattern",
],
[
partial(
get_mock_model_graph_with_no_mergeable_pattern,
conv2d_metatype=om.PTConv2dMetatype,
batchnorm_metatype=om.PTBatchNormMetatype,
relu_metatype=om.PTRELUMetatype,
),
"no_pattern",
],
[
partial(
get_mock_model_graph_with_broken_output_edge_pattern,
conv2d_metatype=om.PTConv2dMetatype,
batchnorm_metatype=om.PTBatchNormMetatype,
relu_metatype=om.PTRELUMetatype,
),
"broken_output_edges_pattern",
],
)


Expand Down

0 comments on commit 8d498db

Please sign in to comment.