Skip to content

Commit

Permalink
support parallel nested nnet for Kaldi (openvinotoolkit#1194)
Browse files Browse the repository at this point in the history
* supported nested nnet1 for Kaldi
  • Loading branch information
pavel-esir authored and Rom committed Aug 28, 2020
1 parent b77a604 commit 1d80ffc
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 117 deletions.
1 change: 0 additions & 1 deletion model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,6 @@ mo/front/kaldi/extractors/pnorm_component_ext.py
mo/front/kaldi/extractors/rectified_linear_component_ext.py
mo/front/kaldi/extractors/rescale_ext.py
mo/front/kaldi/extractors/scale_component_ext.py
mo/front/kaldi/extractors/slice_ext.py
mo/front/kaldi/extractors/softmax_ext.py
mo/front/kaldi/extractors/splice_component_ext.py
mo/front/kaldi/loader/__init__.py
Expand Down
42 changes: 0 additions & 42 deletions model-optimizer/mo/front/kaldi/extractors/slice_ext.py

This file was deleted.

35 changes: 0 additions & 35 deletions model-optimizer/mo/front/kaldi/extractors/slice_ext_test.py

This file was deleted.

78 changes: 39 additions & 39 deletions model-optimizer/mo/front/kaldi/loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import io
import logging as log
import struct
from io import IOBase

import networkx as nx
import numpy as np

from extensions.ops.split import AttributedVariadicSplit
from mo.front.kaldi.loader.utils import find_next_tag, read_placeholder, find_next_component, get_name_from_path, \
find_end_of_component, end_of_nnet_tag, read_binary_integer32_token, get_parameters, read_token_value, \
collect_until_token, collect_until_token_and_read, create_edge_attrs, get_args_for_specifier
Expand All @@ -33,7 +32,7 @@ def load_parallel_component(file_descr, graph: Graph, prev_layer_id):
"""
Load ParallelComponent of the Kaldi model.
ParallelComponent contains parallel nested networks.
Slice is inserted before nested networks.
VariadicSplit is inserted before nested networks.
Outputs of nested networks concatenate with layer Concat.
:param file_descr: descriptor of the model file
Expand All @@ -44,52 +43,52 @@ def load_parallel_component(file_descr, graph: Graph, prev_layer_id):
nnet_count = read_token_value(file_descr, b'<NestedNnetCount>')
log.debug('Model contains parallel component with {} nested networks'.format(nnet_count))

slice_id = graph.unique_id(prefix='Slice')
graph.add_node(slice_id, parameters=None, op='slice', kind='op')

slice_node = Node(graph, slice_id)
Node(graph, prev_layer_id).add_output_port(0)
slice_node.add_input_port(0)
graph.create_edge(Node(graph, prev_layer_id), slice_node, 0, 0)
slices_points = []

split_points = []
outputs = []
inputs = []

for i in range(nnet_count):
read_token_value(file_descr, b'<NestedNnet>')
collect_until_token(file_descr, b'<Nnet>')
g = Graph()
load_kalid_nnet1_model(g, file_descr, 'Nested_net_{}'.format(i))
input_nodes = [n for n in graph.nodes(data=True) if n[1]['op'] == 'Parameter']
shape = input_nodes[0][1]['shape']
if i != nnet_count - 1:
slices_points.append(shape[1])
g.remove_node(input_nodes[0][0])

# input to nnet1 models is of a rank 1 but we also insert batch_size to 0th axis
# 1st axis contains input_size of the nested subnetwork
# we split input from the main network to subnetworks
input_node = Node(g, 'Parameter')
split_points.append(input_node['shape'][1])
g.remove_node(input_node.id)

mapping = {node: graph.unique_id(node) for node in g.nodes(data=False) if node in graph}
g = nx.relabel_nodes(g, mapping)
for val in mapping.values():
g.node[val]['name'] = val
graph.add_nodes_from(g.nodes(data=True))
graph.add_edges_from(g.edges(data=True))
sorted_nodes = tuple(nx.topological_sort(g))
edge_attrs = create_edge_attrs(slice_id, sorted_nodes[0])
edge_attrs['out'] = i
Node(graph, slice_id).add_output_port(i)
Node(graph, sorted_nodes[0]).add_input_port(len(Node(graph, sorted_nodes[0]).in_ports()))
graph.create_edge(Node(graph, slice_id), Node(graph, sorted_nodes[0]), i, 0)
outputs.append(sorted_nodes[-1])
packed_sp = struct.pack("B", 4) + struct.pack("I", len(slices_points))
for i in slices_points:
packed_sp += struct.pack("I", i)
slice_node.parameters = io.BytesIO(packed_sp)

outputs.append(Node(graph, sorted_nodes[-1]))
inputs.append(Node(graph, sorted_nodes[0]))

split_id = graph.unique_id(prefix='NestedNets/VariadicSplit')
attrs = {'out_ports_count': nnet_count, 'size_splits': split_points, 'axis': 1, 'name': split_id}
variadic_split_node = AttributedVariadicSplit(graph, attrs).create_node()
prev_layer_node = Node(graph, prev_layer_id)
prev_layer_node.add_output_port(0)
graph.create_edge(prev_layer_node, variadic_split_node, 0, 0)

concat_id = graph.unique_id(prefix='Concat')
graph.add_node(concat_id, parameters=None, op='concat', kind='op')
for i, output in enumerate(outputs):
edge_attrs = create_edge_attrs(output, concat_id)
edge_attrs['in'] = i
Node(graph, output).add_output_port(0)
Node(graph, concat_id).add_input_port(i)
graph.create_edge(Node(graph, output), Node(graph, concat_id), 0, i)
concat_node = Node(graph, concat_id)

# Connect each output of variadic_split_node to each subnetwork's inputs in ParallelComponent
# and each subnetwork's output to concat_node
for i, (input_node, output_node) in enumerate(zip(inputs, outputs)):
output_node.add_output_port(0)
concat_node.add_input_port(i)
graph.create_edge(output_node, concat_node, 0, i)
graph.create_edge(variadic_split_node, input_node, i, 0)
return concat_id


Expand Down Expand Up @@ -145,6 +144,7 @@ def load_kalid_nnet1_model(graph, file_descr, name):

if component_type == 'parallelcomponent':
prev_layer_id = load_parallel_component(file_descr, graph, prev_layer_id)
find_end_of_component(file_descr, component_type)
continue

start_index = file_descr.tell()
Expand Down Expand Up @@ -231,7 +231,7 @@ def load_components(file_descr, graph, component_layer_map=None):
file_descr.seek(start_index)
dim = 0
try:
collect_until_token(file_descr, b'<Dim>', size_search_zone=end_index-start_index)
collect_until_token(file_descr, b'<Dim>', size_search_zone=end_index - start_index)
cur_index = file_descr.tell()
if start_index < cur_index < end_index:
dim = read_binary_integer32_token(file_descr)
Expand Down Expand Up @@ -284,9 +284,9 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
return False
tokens = s.split(b' ')
if tokens[0] == b'input-node':
in_name = s[s.find(b'name=')+len(b'name='):].split(b' ')[0]
in_name = s[s.find(b'name=') + len(b'name='):].split(b' ')[0]
in_name = str(in_name).strip('b').replace('\'', "")
in_shape = np.array([1, s[s.find(b'dim=')+len(b'dim='):].split(b' ')[0]], dtype=np.int)
in_shape = np.array([1, s[s.find(b'dim=') + len(b'dim='):].split(b' ')[0]], dtype=np.int)

if in_name not in layer_node_map:
graph.add_node(in_name, name=in_name, kind='op', op='Parameter', parameters=None, shape=in_shape)
Expand All @@ -295,7 +295,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
Node(graph, in_name)['op'] = 'Parameter'
Node(graph, in_name)['shape'] = in_shape
elif tokens[0] == b'component-node':
layer_name = s[s.find(b'name=')+len(b'name='):].split(b' ')[0]
layer_name = s[s.find(b'name=') + len(b'name='):].split(b' ')[0]
layer_name = str(layer_name).strip('b').replace('\'', "")

component_name = s[s.find(b'component=') + len(b'component='):].split(b' ')[0]
Expand All @@ -315,7 +315,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
component_layer_map[component_name] = [node_name]

# parse input
in_node_id = parse_input_for_node(s[s.find(b'input=')+6:], graph, layer_node_map)
in_node_id = parse_input_for_node(s[s.find(b'input=') + 6:], graph, layer_node_map)
out_port = len(Node(graph, in_node_id).out_nodes())
in_port = len(Node(graph, node_name).in_nodes())

Expand All @@ -331,7 +331,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
parameters=None,
op='Identity',
kind='op')
out_name = graph.unique_id(prefix=node_name+"_out")
out_name = graph.unique_id(prefix=node_name + "_out")
graph.add_node(out_name,
parameters=None,
op='Result',
Expand Down

0 comments on commit 1d80ffc

Please sign in to comment.