-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TENSORFLOW]StatefulPartitionedCall/PartitionedCall Ops support added #5617
Changes from 69 commits
241cab3
6fe8c1b
2b7c2c6
7bd7711
aeb8073
0119857
077ca4d
bdff27c
ec2eb89
fb807bc
8ec980c
8c4a48f
9aaa650
a1ee137
444f4f8
6ce53c4
e0eb166
cb1ec4d
a49369e
898b79b
c6e0eaf
6ac4d7f
b046af1
4ae8dea
b3de79d
f767360
e3517dd
16cdcbd
27fb98a
8589fec
8feb00c
0f637b4
b77cece
6c71b72
b253ac9
878313d
ff11b39
224b93e
14982af
4b396c0
332e2f0
b0b2fdf
7ea2a66
da79752
b3666ce
caaf8fd
a525666
790c026
6b8bc37
5e92663
becbad4
9a45af9
fa122df
af44879
9b8f24d
0c06cfb
5888317
001f3a0
7f3b52b
92e1853
bd2952d
87ceb84
7968859
cdc4af9
fe58431
06a80b3
cf8f374
bb09d3d
01244c9
c0308ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except | ||
# pylint: disable=import-outside-toplevel | ||
# pylint: disable=import-outside-toplevel, redefined-builtin | ||
"""TF: Tensorflow frontend.""" | ||
import warnings | ||
from collections import defaultdict | ||
|
@@ -1927,7 +1927,6 @@ def _impl(inputs, attr, params, mod): | |
return _res | ||
return _impl | ||
|
||
|
||
# compatible operators that do NOT require any conversion. | ||
_identity_list = [] | ||
|
||
|
@@ -2717,8 +2716,10 @@ def __init__(self): | |
self._loop_var_order = {} | ||
self._hash2tfnode = {} | ||
self._while_loop_name_set = set() | ||
self._main_graph_proto = self | ||
self._stateful_ops_list = [] | ||
|
||
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): | ||
def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None): | ||
"""Construct relay nodes from tensorflow graph definition - GraphDef. | ||
|
||
Follow the tensorflow graph definition to parse and convert it to Relay. | ||
|
@@ -2773,6 +2774,12 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): | |
if freezed_ops: | ||
raise Exception("Graph is not frozen. Provide a frozen graph. " | ||
"Found operators {}".format(freezed_ops)) | ||
stateful_ops = [op for op in missing_operators | ||
if op in self._main_graph_proto._stateful_ops_list] | ||
if stateful_ops: | ||
raise Exception("Found stateful operators in this graph {}. " \ | ||
"Rejecting the graph as TVM does not support stateful operations " \ | ||
.format(stateful_ops)) | ||
|
||
raise NotImplementedError( | ||
"The following operators are not implemented: {}".format(missing_operators)) | ||
|
@@ -2885,6 +2892,13 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): | |
|
||
out = out[0] if len(out) == 1 else _expr.Tuple(out) | ||
func = _function.Function(analysis.free_vars(out), out) | ||
return func | ||
|
||
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): | ||
""" Wrapper to _get_relay_func which converts Tensorflow graph to Relay function | ||
which is used as main function for the Relay module | ||
""" | ||
func = self._get_relay_func(graph, layout=layout, shape=shape, outputs=outputs) | ||
self._mod["main"] = func | ||
return self._mod, self._params | ||
|
||
|
@@ -2895,16 +2909,25 @@ def _parse_import_prerequisites(self, graph): | |
which are not supported | ||
""" | ||
missing_operators = set() | ||
from tensorflow.python.framework import op_def_registry | ||
for node in graph.node: | ||
getOpDef = op_def_registry._registered_ops.get if hasattr(op_def_registry,\ | ||
"_registered_ops") else op_def_registry.get | ||
op_def = getOpDef(node.op) | ||
if node.op == "Placeholder" or node.op == 'PlaceholderWithDefault': | ||
pass | ||
elif node.op == "Const": | ||
pass | ||
elif node.op in ["PartitionedCall", "StatefulPartitionedCall"]: | ||
deepakbabel23 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pass | ||
else: | ||
if any([node.op in t for t in [_identity_list, _convert_map, | ||
_convert_map_rnn, | ||
_control_flow_nodes]]): | ||
pass | ||
elif op_def is not None and op_def.is_stateful: | ||
self._main_graph_proto._stateful_ops_list.append(node.op) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need of another list. If needed you may append (StatufulOperator) to node.op There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed the new list. only adding these stateful ops to missing operator list. |
||
missing_operators.add(node.op) | ||
else: | ||
missing_operators.add(node.op) | ||
|
||
|
@@ -3155,6 +3178,91 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_ | |
|
||
return op | ||
|
||
def _partition_call_operator(self, inputs, attr): | ||
""" | ||
Convert the Relay Partition call ops into Relay Function calls and | ||
function definitions from Tensorflow graph library attribute to Relay global | ||
functions | ||
|
||
Parameters | ||
---------- | ||
node: TensorFlow graph node object. | ||
A TensorFlow graph node object. | ||
|
||
inputs : List[tvm.relay.Expr] | ||
List of input symbols. | ||
|
||
attrs : Dict[tvm.Attrs] | ||
Dict of operator attributes. | ||
|
||
Returns | ||
------- | ||
op : tvm.relay.Expr | ||
Converted relay expression. | ||
""" | ||
|
||
try: | ||
from tensorflow.python.framework import function_def_to_graph | ||
except ImportError as e: | ||
raise ImportError( | ||
"Unable to import tensorflow which is required {}".format(e)) | ||
|
||
main_graph_proto = self._main_graph_proto | ||
outer_graph_def = main_graph_proto._graph | ||
|
||
node_func_name = attr.get('f').name | ||
func = next((f for f in outer_graph_def.library.function | ||
if f.signature.name == node_func_name), None) | ||
if func: | ||
devices = set(node.device for node in func.node_def) | ||
if len(devices) > 1: | ||
raise Exception("Found inconsistent Device assignment in the "\ | ||
"Stateful Partitioned SubGraph. Rejecting "\ | ||
"the subgraph ") | ||
# Convert function definition to graph | ||
func_input_shapes = func.attr["_input_shapes"].list.shape | ||
subgraph, _ = function_def_to_graph.\ | ||
function_def_to_graph_def(func, func_input_shapes) | ||
|
||
# Computing subgraph's input shape dictionary | ||
subgraph_shape_dict, input_expr_dict = {}, {} | ||
for f_arg, input in zip(func.signature.input_arg, inputs): | ||
input_expr_dict[f_arg.name] = input | ||
subgraph_shape_dict[f_arg.name] = _infer_shape(input, main_graph_proto._mod) | ||
|
||
func_name = 'func_{}'.format(func.signature.name) | ||
try: | ||
global_func = main_graph_proto._mod[func_name] | ||
deepakbabel23 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sub_func = global_func | ||
sub_params = main_graph_proto._params | ||
except ValueError: | ||
# Construct relay nodes from the subgraph | ||
g1 = SubGraphProto(main_graph_proto) | ||
sub_func, sub_params = g1.from_tensorflow(subgraph, shape=subgraph_shape_dict) | ||
main_graph_proto._params.update(sub_params) | ||
func_expr = _function.Function(sub_func.params, sub_func.body) | ||
global_func = tvm.relay.GlobalVar(func_name) | ||
main_graph_proto._mod[global_func] = func_expr | ||
|
||
param_exprs = [] | ||
for param_expr in sub_func.params: | ||
# sub_params is subset of sub_func.params | ||
param_name = param_expr.vid.name_hint | ||
if param_name in input_expr_dict.keys(): | ||
param_exprs.append(input_expr_dict[param_name]) | ||
elif param_name in sub_params.keys(): | ||
param_exprs.append(param_expr) | ||
else: | ||
raise Exception("Input parameter {} not found".format(param_name)) | ||
|
||
sb = tvm.relay.scope_builder.ScopeBuilder() | ||
loop_ret = global_func(*param_exprs) | ||
sb.ret(loop_ret) | ||
ret = sb.get() | ||
else: | ||
raise Exception("Function not found - {}".format(node_func_name)) | ||
return ret | ||
|
||
def _convert_operator(self, op_name, inputs, attrs, | ||
graph, identity_list=None, convert_map=None): | ||
"""Convert from Tensorflow operator to relay operator. | ||
|
@@ -3196,6 +3304,9 @@ def _convert_operator(self, op_name, inputs, attrs, | |
sym = self._convert_rnn_operator(op_name, inputs, attrs, | ||
self._params, graph, | ||
convert_map_rnn) | ||
|
||
elif op_name in ["PartitionedCall", "StatefulPartitionedCall"]: | ||
sym = self._partition_call_operator(inputs, attrs) | ||
else: | ||
raise NotImplementedError("Operator {} not implemented.".format(op_name)) | ||
return sym | ||
|
@@ -3262,6 +3373,22 @@ def _backtrack_construct(self, node_name): | |
|
||
return self._nodes[node_name] | ||
|
||
|
||
class SubGraphProto(GraphProto): | ||
""" A helper class for handling relay subgraph copying from Tensorflow GraphDef. | ||
""" | ||
def __init__(self, main_graph_proto): | ||
super().__init__() | ||
self._main_graph_proto = main_graph_proto # holds main graph proto object | ||
|
||
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): | ||
""" Wrapper to _get_relay_func which converts Tensorflow graph to Relay function. | ||
Return Relay function and params | ||
""" | ||
func = self._get_relay_func(graph, layout=layout, shape=shape, outputs=outputs) | ||
return func, self._params | ||
|
||
|
||
def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): | ||
"""Load tensorflow graph which is a python tensorflow graph object into relay. | ||
The companion parameters will be handled automatically. | ||
|
@@ -3288,6 +3415,7 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): | |
params : dict of str to tvm.nd.NDArray | ||
Dict of converted parameters stored in tvm.nd.NDArray format | ||
""" | ||
|
||
g = GraphProto() | ||
mod, params = g.from_tensorflow(graph, layout, shape, outputs) | ||
return mod, params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again this will display exception with only stateful missing ops (Doesn't show normal missing ops.).
I don't think we need separate list for stateful missing ops. Just add them to missing_operators list and display as one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay. merged into missing op list itself and removed seperate exception code for stateful ops.