diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index b4a08df99da0a7..2288401c1f065a 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -784,7 +784,6 @@ mo/front/kaldi/extractors/pnorm_component_ext.py mo/front/kaldi/extractors/rectified_linear_component_ext.py mo/front/kaldi/extractors/rescale_ext.py mo/front/kaldi/extractors/scale_component_ext.py -mo/front/kaldi/extractors/slice_ext.py mo/front/kaldi/extractors/softmax_ext.py mo/front/kaldi/extractors/splice_component_ext.py mo/front/kaldi/loader/__init__.py diff --git a/model-optimizer/mo/front/kaldi/extractors/slice_ext.py b/model-optimizer/mo/front/kaldi/extractors/slice_ext.py deleted file mode 100644 index 45af7d7c6ee4f8..00000000000000 --- a/model-optimizer/mo/front/kaldi/extractors/slice_ext.py +++ /dev/null @@ -1,42 +0,0 @@ -""" - Copyright (C) 2018-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" -import numpy as np - -from mo.front.common.partial_infer.slice import caffe_slice_infer -from mo.front.extractor import FrontExtractorOp -from mo.front.kaldi.loader.utils import read_binary_integer32_token, read_blob -from mo.ops.slice import Slice - - -class SliceFrontExtractor(FrontExtractorOp): - op = 'slice' - enabled = True - - @classmethod - def extract(cls, node): - pb = node.parameters - num_slice_points = read_binary_integer32_token(pb) - mapping_rule = { - 'axis': 1, - 'slice_point': read_blob(pb, num_slice_points, np.int32), - 'batch_dims': 0, - 'spatial_dims': 1, - 'out_ports_count': num_slice_points + 1, - 'infer': caffe_slice_infer - } - node.parameters.close() - Slice.update_node_stat(node, mapping_rule) - return cls.enabled diff --git a/model-optimizer/mo/front/kaldi/extractors/slice_ext_test.py b/model-optimizer/mo/front/kaldi/extractors/slice_ext_test.py deleted file mode 100644 index 4be6c875f7768a..00000000000000 --- a/model-optimizer/mo/front/kaldi/extractors/slice_ext_test.py +++ /dev/null @@ -1,35 +0,0 @@ -""" - Copyright (C) 2018-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -from mo.front.kaldi.extractors.common_ext_test import KaldiFrontExtractorTest -from mo.front.kaldi.extractors.slice_ext import SliceFrontExtractor -from mo.ops.op import Op -from mo.ops.slice import Slice -from mo.utils.unittest.extractors import FakeMultiParam - - -class SliceFrontExtractorTest(KaldiFrontExtractorTest): - @classmethod - def register_op(cls): - Op.registered_ops['Slice'] = Slice - cls.slice_params = { - 'slice_point': [99, 1320], - 'axis': 1 - } - cls.test_node['pb'] = FakeMultiParam(cls.slice_params) - - def test_assertion_no_pb(self): - self.assertRaises(AttributeError, SliceFrontExtractor.extract, None) diff --git a/model-optimizer/mo/front/kaldi/loader/loader.py b/model-optimizer/mo/front/kaldi/loader/loader.py index b471c3c90ba5e1..dcbc01bc25a511 100644 --- a/model-optimizer/mo/front/kaldi/loader/loader.py +++ b/model-optimizer/mo/front/kaldi/loader/loader.py @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. """ -import io import logging as log -import struct from io import IOBase import networkx as nx import numpy as np +from extensions.ops.split import AttributedVariadicSplit from mo.front.kaldi.loader.utils import find_next_tag, read_placeholder, find_next_component, get_name_from_path, \ find_end_of_component, end_of_nnet_tag, read_binary_integer32_token, get_parameters, read_token_value, \ collect_until_token, collect_until_token_and_read, create_edge_attrs, get_args_for_specifier @@ -33,7 +32,7 @@ def load_parallel_component(file_descr, graph: Graph, prev_layer_id): """ Load ParallelComponent of the Kaldi model. ParallelComponent contains parallel nested networks. - Slice is inserted before nested networks. + VariadicSplit is inserted before nested networks. Outputs of nested networks concatenate with layer Concat. :param file_descr: descriptor of the model file @@ -44,27 +43,23 @@ def load_parallel_component(file_descr, graph: Graph, prev_layer_id): nnet_count = read_token_value(file_descr, b'') log.debug('Model contains parallel component with {} nested networks'.format(nnet_count)) - slice_id = graph.unique_id(prefix='Slice') - graph.add_node(slice_id, parameters=None, op='slice', kind='op') - - slice_node = Node(graph, slice_id) - Node(graph, prev_layer_id).add_output_port(0) - slice_node.add_input_port(0) - graph.create_edge(Node(graph, prev_layer_id), slice_node, 0, 0) - slices_points = [] - + split_points = [] outputs = [] + inputs = [] for i in range(nnet_count): read_token_value(file_descr, b'') collect_until_token(file_descr, b'') g = Graph() load_kalid_nnet1_model(g, file_descr, 'Nested_net_{}'.format(i)) - input_nodes = [n for n in graph.nodes(data=True) if n[1]['op'] == 'Parameter'] - shape = input_nodes[0][1]['shape'] - if i != nnet_count - 1: - slices_points.append(shape[1]) - g.remove_node(input_nodes[0][0]) + + # input to nnet1 models is of a rank 1 but we also insert batch_size to 0th axis + # 1st axis contains input_size of the nested subnetwork + # we split input from the main network to subnetworks + input_node = Node(g, 'Parameter') + split_points.append(input_node['shape'][1]) + g.remove_node(input_node.id) + mapping = {node: graph.unique_id(node) for node in g.nodes(data=False) if node in graph} g = nx.relabel_nodes(g, mapping) for val in mapping.values(): @@ -72,24 +67,28 @@ def load_parallel_component(file_descr, graph: Graph, prev_layer_id): graph.add_nodes_from(g.nodes(data=True)) graph.add_edges_from(g.edges(data=True)) sorted_nodes = tuple(nx.topological_sort(g)) - edge_attrs = create_edge_attrs(slice_id, sorted_nodes[0]) - edge_attrs['out'] = i - Node(graph, slice_id).add_output_port(i) - Node(graph, sorted_nodes[0]).add_input_port(len(Node(graph, sorted_nodes[0]).in_ports())) - graph.create_edge(Node(graph, slice_id), Node(graph, sorted_nodes[0]), i, 0) - outputs.append(sorted_nodes[-1]) - packed_sp = struct.pack("B", 4) + struct.pack("I", len(slices_points)) - for i in slices_points: - packed_sp += struct.pack("I", i) - slice_node.parameters = io.BytesIO(packed_sp) + + outputs.append(Node(graph, sorted_nodes[-1])) + inputs.append(Node(graph, sorted_nodes[0])) + + split_id = graph.unique_id(prefix='NestedNets/VariadicSplit') + attrs = {'out_ports_count': nnet_count, 'size_splits': split_points, 'axis': 1, 'name': split_id} + variadic_split_node = AttributedVariadicSplit(graph, attrs).create_node() + prev_layer_node = Node(graph, prev_layer_id) + prev_layer_node.add_output_port(0) + graph.create_edge(prev_layer_node, variadic_split_node, 0, 0) + concat_id = graph.unique_id(prefix='Concat') graph.add_node(concat_id, parameters=None, op='concat', kind='op') - for i, output in enumerate(outputs): - edge_attrs = create_edge_attrs(output, concat_id) - edge_attrs['in'] = i - Node(graph, output).add_output_port(0) - Node(graph, concat_id).add_input_port(i) - graph.create_edge(Node(graph, output), Node(graph, concat_id), 0, i) + concat_node = Node(graph, concat_id) + + # Connect each output of variadic_split_node to each subnetwork's inputs in ParallelComponent + # and each subnetwork's output to concat_node + for i, (input_node, output_node) in enumerate(zip(inputs, outputs)): + output_node.add_output_port(0) + concat_node.add_input_port(i) + graph.create_edge(output_node, concat_node, 0, i) + graph.create_edge(variadic_split_node, input_node, i, 0) return concat_id @@ -145,6 +144,7 @@ def load_kalid_nnet1_model(graph, file_descr, name): if component_type == 'parallelcomponent': prev_layer_id = load_parallel_component(file_descr, graph, prev_layer_id) + find_end_of_component(file_descr, component_type) continue start_index = file_descr.tell() @@ -231,7 +231,7 @@ def load_components(file_descr, graph, component_layer_map=None): file_descr.seek(start_index) dim = 0 try: - collect_until_token(file_descr, b'', size_search_zone=end_index-start_index) + collect_until_token(file_descr, b'', size_search_zone=end_index - start_index) cur_index = file_descr.tell() if start_index < cur_index < end_index: dim = read_binary_integer32_token(file_descr) @@ -284,9 +284,9 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map): return False tokens = s.split(b' ') if tokens[0] == b'input-node': - in_name = s[s.find(b'name=')+len(b'name='):].split(b' ')[0] + in_name = s[s.find(b'name=') + len(b'name='):].split(b' ')[0] in_name = str(in_name).strip('b').replace('\'', "") - in_shape = np.array([1, s[s.find(b'dim=')+len(b'dim='):].split(b' ')[0]], dtype=np.int) + in_shape = np.array([1, s[s.find(b'dim=') + len(b'dim='):].split(b' ')[0]], dtype=np.int) if in_name not in layer_node_map: graph.add_node(in_name, name=in_name, kind='op', op='Parameter', parameters=None, shape=in_shape) @@ -295,7 +295,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map): Node(graph, in_name)['op'] = 'Parameter' Node(graph, in_name)['shape'] = in_shape elif tokens[0] == b'component-node': - layer_name = s[s.find(b'name=')+len(b'name='):].split(b' ')[0] + layer_name = s[s.find(b'name=') + len(b'name='):].split(b' ')[0] layer_name = str(layer_name).strip('b').replace('\'', "") component_name = s[s.find(b'component=') + len(b'component='):].split(b' ')[0] @@ -315,7 +315,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map): component_layer_map[component_name] = [node_name] # parse input - in_node_id = parse_input_for_node(s[s.find(b'input=')+6:], graph, layer_node_map) + in_node_id = parse_input_for_node(s[s.find(b'input=') + 6:], graph, layer_node_map) out_port = len(Node(graph, in_node_id).out_nodes()) in_port = len(Node(graph, node_name).in_nodes()) @@ -331,7 +331,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map): parameters=None, op='Identity', kind='op') - out_name = graph.unique_id(prefix=node_name+"_out") + out_name = graph.unique_id(prefix=node_name + "_out") graph.add_node(out_name, parameters=None, op='Result',