Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MO] Add support to moc_frontend of ":" as delimiter for --input and --output #6543

Merged
merged 12 commits into from
Jul 29, 2021
18 changes: 13 additions & 5 deletions model-optimizer/mo/front/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,17 @@ def extract_node_attrs(graph: Graph, extractor: callable):
return graph


def raise_no_node(node_name: str):
raise Error('No node with name {}'.format(node_name))


def raise_node_name_collision(node_name: str, found_nodes: list):
raise Error('Name collision was found, there are several nodes for mask "{}": {}. '
'If your intention was to specify port for node, please instead specify node names connected to '
'this port. If your intention was to specify the node name, please add port to the node '
'name'.format(node_name, found_nodes))


def get_node_id_with_ports(graph: Graph, node_name: str, skip_if_no_port=True):
"""
Extracts port and node ID out of user provided name
Expand Down Expand Up @@ -483,12 +494,9 @@ def get_node_id_with_ports(graph: Graph, node_name: str, skip_if_no_port=True):

found_names.append((in_port, out_port, name))
if len(found_names) == 0:
raise Error('No node with name {}'.format(node_name))
raise_no_node(node_name)
if len(found_names) > 1:
raise Error('Name collision was found, there are several nodes for mask "{}": {}. '
'If your intention was to specify port for node, please instead specify node names connected to '
'this port. If your intention was to specify the node name, please add port to the node '
'name'.format(node_name, [name for _, _, name in found_names]))
raise_node_name_collision(node_name, [name for _, _, name in found_names])
in_port, out_port, name = found_names[0]
node_id = graph.get_node_id_by_name(name)
if in_port is not None:
Expand Down
62 changes: 51 additions & 11 deletions model-optimizer/mo/moc_frontend/extractor.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,75 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging as log
import re
from collections import defaultdict
from copy import copy

import numpy as np

from mo.front.extractor import raise_no_node, raise_node_name_collision
from mo.utils.error import Error

from ngraph.frontend import InputModel # pylint: disable=no-name-in-module,import-error

import numpy as np


def decode_name_with_port(input_model: InputModel, node_name: str):
"""
Decode name with optional port specification w/o traversing all the nodes in the graph
TODO: in future node_name can specify input/output port groups and indices (58562)
TODO: in future node_name can specify input/output port groups as well as indices (58562)
:param input_model: Input Model
:param node_name: user provided node name
:return: decoded place in the graph
"""
# Check exact match with one of the names in the graph first
found_nodes = []
found_node_names = []

node = input_model.get_place_by_tensor_name(node_name)
if node:
return node
found_node_names.append('Tensor:' + node_name)
found_nodes.append(node)

node = input_model.get_place_by_operation_name(node_name)
if node:
found_node_names.append('Operation:' + node_name)
found_nodes.append(node)

regexp_post = r'(.+):(\d+)'
match_post = re.search(regexp_post, node_name)
if match_post:
node_post = input_model.get_place_by_operation_name(match_post.group(1))
if node_post:
node_post = node_post.get_output_port(
outputPortIndex=int(match_post.group(2)))
if node_post:
found_node_names.append(match_post.group(1))
found_nodes.append(node_post)

regexp_pre = r'(\d+):(.+)'
match_pre = re.search(regexp_pre, node_name)
if match_pre:
node_pre = input_model.get_place_by_operation_name(match_pre.group(2))
if node_pre:
node_pre = node_pre.get_input_port(
inputPortIndex=int(match_pre.group(1)))
if node_pre:
found_node_names.append(match_pre.group(2))
found_nodes.append(node_pre)

if len(found_nodes) == 0:
raise_no_node(node_name)

# Check that there is no collision, all found places shall point to same data
if not all([n.is_equal_data(found_nodes[0]) for n in found_nodes]):
raise_node_name_collision(node_name, found_node_names)

# TODO: ONNX specific (59408)
# To comply with legacy behavior, for ONNX-only there shall be considered additional 2 possibilities
# 1) "abc:1" - get_place_by_tensor_name("abc").get_producing_operation().get_output_port(1)
# 2) "1:abc" - get_place_by_tensor_name("abc").get_producing_operation().get_input_port(1)
# This logic is not going to work with other frontends

# TODO: Add support for input/output group name and port index here (58562)
# Legacy frontends use format "number:name:number" to specify input and output port indices
# For new frontends this logic shall be extended to additionally support input and output group names
raise Error('There is no node with name {}'.format(node_name))
# For new frontends logic shall be extended to additionally support input and output group names
return found_nodes[0]


def fe_input_user_data_repack(input_model: InputModel, input_user_shapes: [None, list, dict, np.ndarray],
Expand Down
9 changes: 9 additions & 0 deletions model-optimizer/unit_tests/mo/frontend_ngraph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def test_frontends():
assert not status.returncode


def test_moc_extractor():
setup_env()
args = [sys.executable, '-m', 'pytest',
os.path.join(os.path.dirname(__file__), 'moc_frontend/moc_extractor_test_actual.py'), '-s']

status = subprocess.run(args, env=os.environ)
assert not status.returncode


def test_main_test():
setup_env()
args = [sys.executable, '-m', 'pytest',
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import unittest

from mo.moc_frontend.extractor import decode_name_with_port
from mo.utils.error import Error

import pytest


mock_available = True

try:
# pylint: disable=no-name-in-module,import-error
from mock_mo_python_api import get_model_statistic, get_place_statistic, \
clear_frontend_statistic, clear_model_statistic, clear_place_statistic, \
clear_setup, set_equal_data, set_max_port_counts

# pylint: disable=no-name-in-module,import-error
from ngraph.frontend import FrontEndManager

except Exception:
print("No mock frontend API available,"
"ensure to use -DENABLE_TESTS=ON option when running these tests")
mock_available = False

# FrontEndManager shall be initialized and destroyed after all tests finished
# This is because destroy of FrontEndManager will unload all plugins,
# no objects shall exist after this
if mock_available:
fem = FrontEndManager()

mock_needed = pytest.mark.skipif(not mock_available,
reason="mock MO fe is not available")


class TestMainFrontend(unittest.TestCase):
def setUp(self):
clear_frontend_statistic()
clear_model_statistic()
clear_place_statistic()
clear_setup()
set_max_port_counts(10, 10)
self.fe = fem.load_by_framework('mock_mo_ngraph_frontend')
self.model = self.fe.load('abc.bin')

# Mock model has 'tensor' tensor place
@mock_needed
def test_decode_name_with_port_tensor(self):
node = decode_name_with_port(self.model, "tensor")
model_stat = get_model_statistic()

assert model_stat.get_place_by_tensor_name == 1
assert model_stat.get_place_by_operation_name == 1
assert node

# Mock model has 'operation' operation place
@mock_needed
def test_decode_name_with_port_op(self):
node = decode_name_with_port(self.model, "operation")
model_stat = get_model_statistic()

assert model_stat.get_place_by_tensor_name == 1
assert model_stat.get_place_by_operation_name == 1
assert node

# pylint: disable=wrong-spelling-in-comment
# Mock model doesn't have 'mocknoname' place
@mock_needed
def test_decode_name_with_port_noname(self):
with self.assertRaisesRegex(Error, 'No\\ node\\ with\\ name.*mocknoname*'):
decode_name_with_port(self.model, 'mocknoname')
model_stat = get_model_statistic()
assert model_stat.get_place_by_tensor_name == 1
assert model_stat.get_place_by_operation_name == 1

# Mock model has both tensor and operation with same name and non-equal data
# Collision is expected
@mock_needed
def test_decode_name_with_port_collision_op_tensor(self):
with self.assertRaisesRegex(Error, 'Name\\ collision.*tensorAndOp*'):
decode_name_with_port(self.model, 'tensorAndOp')
model_stat = get_model_statistic()
place_stat = get_place_statistic()

assert model_stat.get_place_by_tensor_name == 1
assert model_stat.get_place_by_operation_name == 1
assert place_stat.is_equal_data > 0

# Mock model has 'operation' and output port up to 10
@mock_needed
def test_decode_name_with_port_delim_op_out(self):
node = decode_name_with_port(self.model, 'operation:7')
model_stat = get_model_statistic()
place_stat = get_place_statistic()

assert model_stat.get_place_by_tensor_name == 1
assert model_stat.get_place_by_operation_name == 2
assert place_stat.get_output_port == 1
assert place_stat.lastArgInt == 7
assert node

# Mock model has 'operation' and input port up to 10
@mock_needed
def test_decode_name_with_port_delim_op_in(self):
node = decode_name_with_port(self.model, '7:operation')
model_stat = get_model_statistic()
place_stat = get_place_statistic()

assert model_stat.get_place_by_tensor_name == 1
assert model_stat.get_place_by_operation_name == 2
assert place_stat.get_input_port == 1
assert place_stat.lastArgInt == 7
assert node

# Mock model has 'operation' and 'operation:0' op places, collision is expected
@mock_needed
def test_decode_name_with_port_delim_op_collision_out(self):
with self.assertRaisesRegex(Error, 'Name\\ collision(?!.*Tensor.*).*operation\\:0*'):
decode_name_with_port(self.model, 'operation:0')
model_stat = get_model_statistic()
place_stat = get_place_statistic()

assert model_stat.get_place_by_tensor_name == 1
assert model_stat.get_place_by_operation_name == 2
assert place_stat.is_equal_data > 0
assert place_stat.get_output_port == 1
assert place_stat.lastArgInt == 0

# Mock model has 'operation' and '0:operation' op places, collision is expected
@mock_needed
def test_decode_name_with_port_delim_op_collision_in(self):
with self.assertRaisesRegex(Error, 'Name\\ collision(?!.*Tensor.*).*0\\:operation*'):
decode_name_with_port(self.model, '0:operation')
model_stat = get_model_statistic()
place_stat = get_place_statistic()

assert model_stat.get_place_by_tensor_name == 1
assert model_stat.get_place_by_operation_name == 2
assert place_stat.is_equal_data > 0
assert place_stat.get_input_port == 1
assert place_stat.lastArgInt == 0

# Mock model has 'tensor' and 'tensor:0' tensor places, no collision is expected
@mock_needed
def test_decode_name_with_port_delim_tensor_no_collision_out(self):
node = decode_name_with_port(self.model, 'tensor:0')
model_stat = get_model_statistic()
place_stat = get_place_statistic()

assert model_stat.get_place_by_tensor_name == 1
assert model_stat.get_place_by_operation_name == 2
assert place_stat.get_output_port == 0
assert node

# Mock model has 'tensor' and '0:tensor' tensor places, no collision is expected
@mock_needed
def test_decode_name_with_port_delim_tensor_no_collision_in(self):
node = decode_name_with_port(self.model, '0:tensor')
model_stat = get_model_statistic()
place_stat = get_place_statistic()

assert model_stat.get_place_by_tensor_name == 1
assert model_stat.get_place_by_operation_name == 2
assert place_stat.get_input_port == 0
assert node

# Mock model doesn't have such '1234:operation' or output port=1234 for 'operation'
@mock_needed
def test_decode_name_with_port_delim_no_port_out(self):
with self.assertRaisesRegex(Error, 'No\\ node\\ with\\ name.*operation\\:1234*'):
decode_name_with_port(self.model, 'operation:1234')
model_stat = get_model_statistic()
place_stat = get_place_statistic()

assert model_stat.get_place_by_tensor_name == 1
assert model_stat.get_place_by_operation_name == 2
assert place_stat.get_output_port == 1
assert place_stat.lastArgInt == 1234

# Mock model doesn't have such '1234:operation' or input port=1234 for 'operation'
@mock_needed
def test_decode_name_with_port_delim_no_port_in(self):
with self.assertRaisesRegex(Error, 'No\\ node\\ with\\ name.*1234\\:operation*'):
decode_name_with_port(self.model, '1234:operation')
model_stat = get_model_statistic()
place_stat = get_place_statistic()

assert model_stat.get_place_by_tensor_name == 1
assert model_stat.get_place_by_operation_name == 2
assert place_stat.get_input_port == 1
assert place_stat.lastArgInt == 1234

# Mock model has tensor with name 'conv2d:0' and operation 'conv2d' with output port = 1
# It is setup to return 'is_equal_data=True' for these tensor and port
# So no collision is expected
@mock_needed
def test_decode_name_with_port_delim_equal_data_out(self):
set_equal_data('conv2d', 'conv2d')
node = decode_name_with_port(self.model, 'conv2d:0')
model_stat = get_model_statistic()
place_stat = get_place_statistic()

assert model_stat.get_place_by_tensor_name == 1
assert model_stat.get_place_by_operation_name == 2
assert place_stat.get_output_port == 1
assert place_stat.is_equal_data > 0
assert node

# Mock model has tensor with name '0:conv2d' and operation 'conv2d' with input port = 1
# It is setup to return 'is_equal_data=True' for these tensor and port
# So no collision is expected
@mock_needed
def test_decode_name_with_port_delim_equal_data_in(self):
set_equal_data('conv2d', 'conv2d')
node = decode_name_with_port(self.model, '0:conv2d')
model_stat = get_model_statistic()
place_stat = get_place_statistic()

assert model_stat.get_place_by_tensor_name == 1
assert model_stat.get_place_by_operation_name == 2
assert place_stat.get_input_port == 1
assert place_stat.is_equal_data > 0
assert node

# Stress case: Mock model has:
# Tensor '8:9'
# Operation '8:9'
# Operation '8' with output port = 9
# Operation '9' with input port = 8
# All places point to same data - no collision is expected
@mock_needed
def test_decode_name_with_port_delim_all_same_data(self):
set_equal_data('8', '9')
node = decode_name_with_port(self.model, '8:9')
model_stat = get_model_statistic()
place_stat = get_place_statistic()

assert model_stat.get_place_by_tensor_name == 1
assert model_stat.get_place_by_operation_name == 3
assert place_stat.get_input_port == 1
assert place_stat.get_output_port == 1
# At least 3 comparisons of places are expected
assert place_stat.is_equal_data > 2
assert node
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ FeStat FrontEndMockPy::m_stat = {};
ModelStat InputModelMockPy::m_stat = {};
PlaceStat PlaceMockPy::m_stat = {};

std::string MockSetup::m_equal_data_node1 = {};
std::string MockSetup::m_equal_data_node2 = {};
int MockSetup::m_max_input_port_index = 0;
int MockSetup::m_max_output_port_index = 0;

PartialShape InputModelMockPy::m_returnShape = {};

extern "C" MOCK_API FrontEndVersion GetAPIVersion()
Expand Down
Loading