forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MO] [Kaldi] Add TDNN Component (openvinotoolkit#1870)
* [MO] [Kaldi] Added TDNN Component * TdnnComponent replacer graphical comment updated * Added SpecAugmentTimeMaskComponent * some refactor of memoryoffset shape_infer * moved memoryoffset splitting to the middle stage * some corrections - set `need_shape_inferenc`=False in split_memoryoffset - use cycle instead of pattern in tdnn_replacer * separated splitting of MemoryOffsets in LSTM and TDNN blocks * set transpose_weights=True in TdnnComponent * Corrected Supported_Frameworks_Layers * corrected comments * separate naming for tdnn and lstm memoryoffset splits * corrected BOM file * corrected generaldropout_ext.py and removed 'has_default' for tdnn_component * corrections after PR review * renamed LSTM -> recurrent; added setting element_size for paired nodes of tdnn_memoffset and othe minor changes * Update split_tdnn_memoryoffset.py * corrected partial infer with new API in elemental.py and split_tdnn_memoryoffset.py
- Loading branch information
1 parent
9f1b4e0
commit 2110a29
Showing
19 changed files
with
592 additions
and
97 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
49 changes: 0 additions & 49 deletions
49
model-optimizer/extensions/front/kaldi/split_memoryoffsets.py
This file was deleted.
Oops, something went wrong.
62 changes: 62 additions & 0 deletions
62
model-optimizer/extensions/front/kaldi/split_recurrent_memoryoffset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
""" | ||
Copyright (C) 2018-2020 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import networkx as nx | ||
|
||
from mo.front.common.replacement import FrontReplacementSubgraph | ||
from mo.graph.graph import Graph | ||
from mo.ops.memoryoffset import MemoryOffset | ||
from mo.ops.result import Result | ||
from mo.utils.error import Error | ||
from mo.utils.graph import Node | ||
|
||
|
||
class SplitRecurrentMemoryOffset(FrontReplacementSubgraph): | ||
""" | ||
Splits MemoryOffsets in recurrent blocks (typically LSTM blocks) into 2 parts. | ||
These parts then will be converted to ReadValue and Assign. Splitting complicates shape inference but | ||
MemoryOffsets in recurrent blocks are cycled and, in order to make topological sort possible | ||
during shape inference, they are splitted earlier on the front phase. In contrast, | ||
MemoryOffsets in TDNN blocks are not cycled, so they will be splitted after shape infer on the middle. | ||
Now only LSTM blocks with MemoryOffset are present. | ||
""" | ||
enabled = True | ||
graph_condition = [lambda graph: graph.graph['fw'] == 'kaldi'] | ||
|
||
@staticmethod | ||
def split_offset(offset_node: Node): | ||
paired_node = MemoryOffset(offset_node.graph, {'name': offset_node.pair_name, 'splitted': True, | ||
'pair_name': offset_node.id, | ||
'element_size': offset_node['element_size'], | ||
't': offset_node.t, | ||
'has_default': offset_node.has_default}).create_node() | ||
offset_node['splitted'] = True | ||
offset_node.out_port(0).get_connection().set_source(paired_node.out_port(0)) | ||
res_node = Result(offset_node.graph, {'name': offset_node.id + '_output'}).create_node() | ||
offset_node.out_port(0).connect(res_node.in_port(0)) | ||
|
||
def find_and_replace_pattern(self, graph: Graph): | ||
for offset_node in graph.get_op_nodes(op='MemoryOffset', splitted=False): | ||
try: | ||
# if graph contains recurrent block -> split MemoryOffset to enable shape infer | ||
nx.find_cycle(graph, offset_node.id) | ||
except nx.NetworkXNoCycle as e: | ||
# MemoryOffset node is not in a recurrent block -- no splitting is needed | ||
return | ||
|
||
if not offset_node.has_valid('element_size'): | ||
raise Error("In a recurrent block 'element_size' for node {} is not set".format(offset_node.id)) | ||
SplitRecurrentMemoryOffset.split_offset(offset_node) |
95 changes: 95 additions & 0 deletions
95
model-optimizer/extensions/front/kaldi/tdnn_component_replacer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
""" | ||
Copyright (C) 2018-2020 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
from extensions.ops.MatMul import FullyConnected | ||
from mo.front.common.replacement import FrontReplacementPattern | ||
from mo.front.tf.graph_utils import create_op_with_const_inputs | ||
from mo.graph.graph import Graph, Node | ||
from mo.graph.graph import rename_nodes | ||
from mo.ops.concat import Concat | ||
from mo.ops.memoryoffset import MemoryOffset | ||
|
||
|
||
class TdnnComponentReplacer(FrontReplacementPattern): | ||
''' | ||
Expand TdnnComponent into MemoryOffsets, Concat and FullyConected nodes | ||
BEFORE: | ||
placeholder | ||
| | ||
TdnnComponent('time_offsets': t1, t2,... tk) | ||
| | ||
_______________________________________________________________ | ||
AFTER: | ||
placeholder | ||
__________________|___________________________ | ||
/ | \ \ | ||
MemoryOffset(t1) MemoryOffset(t2) ... MemoryOffset(tk) | ||
\_____________ _____|______________/____________/ | ||
Concat | ||
| | ||
FullyConnected | ||
| | ||
''' | ||
enabled = True | ||
run_not_recursively = True | ||
|
||
def run_before(self): | ||
from extensions.front.kaldi.memory_offset_adjustment import MemoryOffsetAdjustment | ||
return [MemoryOffsetAdjustment] | ||
|
||
def find_and_replace_pattern(self, graph: Graph): | ||
for node in graph.get_op_nodes(op='tdnncomponent'): | ||
self.replace_tdnn(graph, node) | ||
|
||
def replace_tdnn(self, graph: Graph, tdnn_node: Node): | ||
tdnn_name = tdnn_node.soft_get('name', tdnn_node.id) | ||
|
||
concat_node = Concat(graph, {'axis': 1}).create_node() | ||
rename_nodes([(tdnn_node, tdnn_name + '/to_be_removed'), (concat_node, tdnn_name)]) | ||
|
||
for offset_ind, t in enumerate(tdnn_node['time_offsets']): | ||
concat_node.add_input_port(offset_ind) | ||
if t != 0: | ||
memory_name = tdnn_name + '/MemoryOffset/' + str(abs(t)) | ||
memoryoffset_node = MemoryOffset(graph, {'name': memory_name, 't': t, | ||
'pair_name': memory_name + '_out', | ||
'has_default': False, 'splitted': False}).create_node() | ||
|
||
tdnn_node.in_port(0).get_source().connect(memoryoffset_node.in_port(0)) | ||
memoryoffset_node.out_port(0).connect(concat_node.in_port(offset_ind)) | ||
else: | ||
# 0 time delay is not allowed in IE, it's meaningless | ||
# if time offset is 0 then connect input of tdnncomponent directly to Concat without memoryoffset | ||
tdnn_node.in_port(0).get_source().connect(concat_node.in_port(offset_ind)) | ||
|
||
weights = tdnn_node['weights'] | ||
fc_inputs = {1: weights} | ||
|
||
bias_term = False | ||
if tdnn_node.has_valid('biases'): | ||
assert len(tdnn_node['biases']) == weights.shape[0] | ||
fc_inputs.update({2: tdnn_node['biases']}) | ||
bias_term = True | ||
|
||
fc_node = create_op_with_const_inputs(graph, FullyConnected, fc_inputs, | ||
{'name': tdnn_name + '/FC', 'out-size': weights.shape[0], | ||
'transpose_weights': True, 'bias_term': bias_term}) | ||
|
||
concat_node.out_port(0).connect(fc_node.in_port(0)) | ||
tdnn_node.in_port(0).disconnect() | ||
tdnn_node.out_port(0).get_connection().set_source(fc_node.out_port(0)) |
87 changes: 87 additions & 0 deletions
87
model-optimizer/extensions/front/kaldi/tdnn_component_replacer_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
""" | ||
Copyright (C) 2018-2020 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import unittest | ||
|
||
import numpy as np | ||
from generator import generator, generate | ||
|
||
from extensions.front.kaldi.tdnn_component_replacer import TdnnComponentReplacer | ||
from mo.utils.ir_engine.compare_graphs import compare_graphs | ||
from mo.utils.unittest.graph import build_graph, regular_op, result, connect_front, const | ||
|
||
|
||
@generator | ||
class TdnnComponentReplacerTest(unittest.TestCase): | ||
|
||
@generate(*[ | ||
([[1, 1, 1], [4, 4, 4]], [1, 2], [-1, 1],), | ||
([[1, 1, 1], [4, 4, 4]], [1, 2], [-1, 1, 2, 10, 1000],), | ||
([[1, 1, 1], [4, 4, 4]], [1, 2], [-1, 0]), | ||
]) | ||
def test_tdnnreplacer(self, weights, biases, time_offsets): | ||
def generate_offsets(): | ||
offset_edges = [] | ||
offset_nodes = {} | ||
|
||
for i, t in enumerate(time_offsets): | ||
offset_nodes.update(**regular_op('memoryoffset_' + str(i), {'type': None})) | ||
|
||
if t != 0: | ||
offset_edges.append(('placeholder', 'memoryoffset_' + str(i), {'out': 0, 'in': 0})) | ||
offset_edges.append(('memoryoffset_' + str(i), 'concat', {'out': 0, 'in': i})) | ||
else: | ||
offset_edges.append(('placeholder', 'concat', {'out': 0, 'in': i})) | ||
|
||
return offset_nodes, offset_edges | ||
|
||
offset_nodes, ref_offset_edges = generate_offsets() | ||
|
||
nodes = { | ||
**offset_nodes, | ||
**regular_op('placeholder', {'type': 'Parameter'}), | ||
**regular_op('tdnncomponent', {'op': 'tdnncomponent', | ||
'weights': np.array(weights), | ||
'biases': np.array(biases), | ||
'time_offsets': np.array(time_offsets)}), | ||
**const('weights', np.array(weights)), | ||
**const('biases', np.array(biases)), | ||
**regular_op('concat', {'type': 'Concat', 'axis': 1}), | ||
**regular_op('memoryoffset_0', {'type': None}), | ||
**regular_op('memoryoffset_1', {'type': None}), | ||
**regular_op('memoryoffset_2', {'type': None}), | ||
**regular_op('fully_connected', {'type': 'FullyConnected'}), | ||
**result('result'), | ||
} | ||
|
||
graph = build_graph(nodes, [ | ||
*connect_front('placeholder', 'tdnncomponent'), | ||
*connect_front('tdnncomponent', 'result') | ||
], nodes_with_edges_only=True) | ||
|
||
graph.stage = 'front' | ||
|
||
ref_graph = build_graph(nodes, [ | ||
*ref_offset_edges, | ||
*connect_front('concat', '0:fully_connected'), | ||
*connect_front('weights', '1:fully_connected'), | ||
*connect_front('biases', '2:fully_connected'), | ||
*connect_front('fully_connected', 'result') | ||
], nodes_with_edges_only=True) | ||
|
||
TdnnComponentReplacer().find_and_replace_pattern(graph) | ||
|
||
(flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True) | ||
self.assertTrue(flag, resp) |
Oops, something went wrong.