Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support parallel nested nnet for Kaldi #1194

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,6 @@ mo/front/kaldi/extractors/pnorm_component_ext.py
mo/front/kaldi/extractors/rectified_linear_component_ext.py
mo/front/kaldi/extractors/rescale_ext.py
mo/front/kaldi/extractors/scale_component_ext.py
mo/front/kaldi/extractors/slice_ext.py
mo/front/kaldi/extractors/softmax_ext.py
mo/front/kaldi/extractors/splice_component_ext.py
mo/front/kaldi/loader/__init__.py
Expand Down
42 changes: 0 additions & 42 deletions model-optimizer/mo/front/kaldi/extractors/slice_ext.py

This file was deleted.

35 changes: 0 additions & 35 deletions model-optimizer/mo/front/kaldi/extractors/slice_ext_test.py

This file was deleted.

78 changes: 39 additions & 39 deletions model-optimizer/mo/front/kaldi/loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import io
import logging as log
import struct
from io import IOBase

import networkx as nx
import numpy as np

from extensions.ops.split import AttributedVariadicSplit
from mo.front.kaldi.loader.utils import find_next_tag, read_placeholder, find_next_component, get_name_from_path, \
find_end_of_component, end_of_nnet_tag, read_binary_integer32_token, get_parameters, read_token_value, \
collect_until_token, collect_until_token_and_read, create_edge_attrs, get_args_for_specifier
Expand All @@ -33,7 +32,7 @@ def load_parallel_component(file_descr, graph: Graph, prev_layer_id):
"""
Load ParallelComponent of the Kaldi model.
ParallelComponent contains parallel nested networks.
Slice is inserted before nested networks.
VariadicSplit is inserted before nested networks.
Outputs of nested networks concatenate with layer Concat.

:param file_descr: descriptor of the model file
Expand All @@ -44,52 +43,52 @@ def load_parallel_component(file_descr, graph: Graph, prev_layer_id):
nnet_count = read_token_value(file_descr, b'<NestedNnetCount>')
log.debug('Model contains parallel component with {} nested networks'.format(nnet_count))

slice_id = graph.unique_id(prefix='Slice')
graph.add_node(slice_id, parameters=None, op='slice', kind='op')

slice_node = Node(graph, slice_id)
Node(graph, prev_layer_id).add_output_port(0)
slice_node.add_input_port(0)
graph.create_edge(Node(graph, prev_layer_id), slice_node, 0, 0)
slices_points = []

split_points = []
outputs = []
inputs = []

for i in range(nnet_count):
read_token_value(file_descr, b'<NestedNnet>')
collect_until_token(file_descr, b'<Nnet>')
g = Graph()
load_kalid_nnet1_model(g, file_descr, 'Nested_net_{}'.format(i))
input_nodes = [n for n in graph.nodes(data=True) if n[1]['op'] == 'Parameter']
shape = input_nodes[0][1]['shape']
if i != nnet_count - 1:
slices_points.append(shape[1])
g.remove_node(input_nodes[0][0])

# input to nnet1 models is of a rank 1 but we also insert batch_size to 0th axis
# 1st axis contains input_size of the nested subnetwork
# we split input from the main network to subnetworks
input_node = Node(g, 'Parameter')
split_points.append(input_node['shape'][1])
g.remove_node(input_node.id)

mapping = {node: graph.unique_id(node) for node in g.nodes(data=False) if node in graph}
g = nx.relabel_nodes(g, mapping)
for val in mapping.values():
g.node[val]['name'] = val
graph.add_nodes_from(g.nodes(data=True))
graph.add_edges_from(g.edges(data=True))
sorted_nodes = tuple(nx.topological_sort(g))
edge_attrs = create_edge_attrs(slice_id, sorted_nodes[0])
edge_attrs['out'] = i
Node(graph, slice_id).add_output_port(i)
Node(graph, sorted_nodes[0]).add_input_port(len(Node(graph, sorted_nodes[0]).in_ports()))
graph.create_edge(Node(graph, slice_id), Node(graph, sorted_nodes[0]), i, 0)
outputs.append(sorted_nodes[-1])
packed_sp = struct.pack("B", 4) + struct.pack("I", len(slices_points))
for i in slices_points:
packed_sp += struct.pack("I", i)
slice_node.parameters = io.BytesIO(packed_sp)

outputs.append(Node(graph, sorted_nodes[-1]))
inputs.append(Node(graph, sorted_nodes[0]))

split_id = graph.unique_id(prefix='NestedNets/VariadicSplit')
attrs = {'out_ports_count': nnet_count, 'size_splits': split_points, 'axis': 1, 'name': split_id}
variadic_split_node = AttributedVariadicSplit(graph, attrs).create_node()
prev_layer_node = Node(graph, prev_layer_id)
prev_layer_node.add_output_port(0)
graph.create_edge(prev_layer_node, variadic_split_node, 0, 0)

concat_id = graph.unique_id(prefix='Concat')
graph.add_node(concat_id, parameters=None, op='concat', kind='op')
for i, output in enumerate(outputs):
edge_attrs = create_edge_attrs(output, concat_id)
edge_attrs['in'] = i
Node(graph, output).add_output_port(0)
Node(graph, concat_id).add_input_port(i)
graph.create_edge(Node(graph, output), Node(graph, concat_id), 0, i)
concat_node = Node(graph, concat_id)

# Connect each output of variadic_split_node to each subnetwork's inputs in ParallelComponent
# and each subnetwork's output to concat_node
for i, (input_node, output_node) in enumerate(zip(inputs, outputs)):
output_node.add_output_port(0)
concat_node.add_input_port(i)
graph.create_edge(output_node, concat_node, 0, i)
graph.create_edge(variadic_split_node, input_node, i, 0)
return concat_id


Expand Down Expand Up @@ -145,6 +144,7 @@ def load_kalid_nnet1_model(graph, file_descr, name):

if component_type == 'parallelcomponent':
prev_layer_id = load_parallel_component(file_descr, graph, prev_layer_id)
find_end_of_component(file_descr, component_type)
continue

start_index = file_descr.tell()
Expand Down Expand Up @@ -231,7 +231,7 @@ def load_components(file_descr, graph, component_layer_map=None):
file_descr.seek(start_index)
dim = 0
try:
collect_until_token(file_descr, b'<Dim>', size_search_zone=end_index-start_index)
collect_until_token(file_descr, b'<Dim>', size_search_zone=end_index - start_index)
cur_index = file_descr.tell()
if start_index < cur_index < end_index:
dim = read_binary_integer32_token(file_descr)
Expand Down Expand Up @@ -284,9 +284,9 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
return False
tokens = s.split(b' ')
if tokens[0] == b'input-node':
in_name = s[s.find(b'name=')+len(b'name='):].split(b' ')[0]
in_name = s[s.find(b'name=') + len(b'name='):].split(b' ')[0]
in_name = str(in_name).strip('b').replace('\'', "")
in_shape = np.array([1, s[s.find(b'dim=')+len(b'dim='):].split(b' ')[0]], dtype=np.int)
in_shape = np.array([1, s[s.find(b'dim=') + len(b'dim='):].split(b' ')[0]], dtype=np.int)

if in_name not in layer_node_map:
graph.add_node(in_name, name=in_name, kind='op', op='Parameter', parameters=None, shape=in_shape)
Expand All @@ -295,7 +295,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
Node(graph, in_name)['op'] = 'Parameter'
Node(graph, in_name)['shape'] = in_shape
elif tokens[0] == b'component-node':
layer_name = s[s.find(b'name=')+len(b'name='):].split(b' ')[0]
layer_name = s[s.find(b'name=') + len(b'name='):].split(b' ')[0]
layer_name = str(layer_name).strip('b').replace('\'', "")

component_name = s[s.find(b'component=') + len(b'component='):].split(b' ')[0]
Expand All @@ -315,7 +315,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
component_layer_map[component_name] = [node_name]

# parse input
in_node_id = parse_input_for_node(s[s.find(b'input=')+6:], graph, layer_node_map)
in_node_id = parse_input_for_node(s[s.find(b'input=') + 6:], graph, layer_node_map)
out_port = len(Node(graph, in_node_id).out_nodes())
in_port = len(Node(graph, node_name).in_nodes())

Expand All @@ -331,7 +331,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
parameters=None,
op='Identity',
kind='op')
out_name = graph.unique_id(prefix=node_name+"_out")
out_name = graph.unique_id(prefix=node_name + "_out")
graph.add_node(out_name,
parameters=None,
op='Result',
Expand Down