Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
kshpv committed Mar 5, 2024
1 parent b0795a5 commit cd9e11a
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 0 deletions.
1 change: 1 addition & 0 deletions nncf/onnx/graph/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def _add_nncf_output_nodes(
output_port_id=output_port_id,
dtype=nncf_dtype,
)
input_port_id += 1

@staticmethod
def convert_onnx_dtype_to_nncf_dtype(onnx_dtype: int) -> Dtype:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
strict digraph {
"0 Conv1" [id=0, type=Conv];
"1 nncf_model_input_0" [id=1, type=nncf_model_input];
"2 nncf_model_output_0" [id=2, type=nncf_model_output];
"3 nncf_model_output_1" [id=3, type=nncf_model_output];
"0 Conv1" -> "2 nncf_model_output_0" [label="[]", style=solid];
"1 nncf_model_input_0" -> "0 Conv1" [label="[1, 3, 10, 10]", style=solid];
}
12 changes: 12 additions & 0 deletions tests/onnx/test_nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import torch
from torchvision import models

from nncf.onnx.graph.model_transformer import ONNXModelTransformer
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
from tests.onnx.conftest import ONNX_TEST_ROOT
from tests.onnx.models import ALL_SYNTHETIC_MODELS
from tests.onnx.models import OneConvolutionalModel
from tests.onnx.opset_converter import convert_opset_version
from tests.onnx.quantization.common import ModelToTest
from tests.onnx.weightless_model import load_model_topology_with_zeros_weights
Expand Down Expand Up @@ -99,3 +101,13 @@ def test_compare_nncf_graph_detection_real_models(tmp_path, model_to_test):
nx_graph = nncf_graph.get_graph_for_structure_analysis(extended=True)

compare_nx_graph_with_reference(nx_graph, path_to_dot, check_edge_attrs=True)


def test_add_output_nodes_with_no_parents_node():
model_to_test = OneConvolutionalModel().onnx_model
model_outputs = (value_info.name for value_info in model_to_test.graph.output)
model_with_output = ONNXModelTransformer._insert_outputs(model_to_test, (*model_outputs, "Conv1_W"))
nncf_graph = GraphConverter.create_nncf_graph(model_with_output)
nx_graph = nncf_graph.get_graph_for_structure_analysis(extended=True)
path_to_dot = REFERENCE_GRAPHS_DIR / "synthetic" / "output_with_no_parents_model.dot"
compare_nx_graph_with_reference(nx_graph, path_to_dot, check_edge_attrs=True)

0 comments on commit cd9e11a

Please sign in to comment.