Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TENSORFLOW]StatefulPartitionedCall/PartitionedCall Ops support added #5617

Merged
merged 70 commits into from
Jun 4, 2020
Merged
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
241cab3
Implemented functionInvocation Unit Test for StatefulPartitionedCall…
deepakbabel23 May 6, 2020
6fe8c1b
Placeholder exercises with tvm
deepakbabel23 May 7, 2020
2b7c2c6
Merge pull request #1 from apache/master
deepakbabel23 May 8, 2020
7bd7711
Merge pull request #2 from deepakbabel23/master
deepakbabel23 May 11, 2020
aeb8073
placeholder interim
deepakbabel23 May 11, 2020
0119857
SPOP Test cases structure
deepakbabel23 May 11, 2020
077ca4d
New test cases for spop
deepakbabel23 May 12, 2020
bdff27c
miscellaneous test cases for spop
deepakbabel23 May 12, 2020
ec2eb89
Merge pull request #3 from apache/master
deepakbabel23 May 12, 2020
fb807bc
Merge pull request #4 from deepakbabel23/master
deepakbabel23 May 12, 2020
8ec980c
Placeholder samples..working with shapes explicitly passed
deepakbabel23 May 12, 2020
8c4a48f
Variables test case. Works with the same fix of shape_dict
deepakbabel23 May 13, 2020
9aaa650
SPOP Positive test cases first iteration
deepakbabel23 May 13, 2020
a1ee137
support output tensors as function args, multiple functions
prashantsail May 14, 2020
444f4f8
Corrected Indentation
prashantsail May 14, 2020
6ce53c4
filewritter is only for debug purpose
prashantsail May 14, 2020
e0eb166
support variables in function args
prashantsail May 14, 2020
cb1ec4d
First working iteration of positive spop test cases
deepakbabel23 May 14, 2020
a49369e
Removed commented code, simplified code
prashantsail May 14, 2020
898b79b
Code Reorganization- First working iteration of positive spop test cases
deepakbabel23 May 14, 2020
c6e0eaf
corrected variable name after refactor
prashantsail May 14, 2020
6ac4d7f
Code Reorganization- First working iteration of positive spop test cases
deepakbabel23 May 14, 2020
b046af1
move code inside mapped operator function
prashantsail May 14, 2020
4ae8dea
Removed extra line
prashantsail May 14, 2020
b3de79d
support variables in function args
prashantsail May 14, 2020
f767360
Removed commented code, simplified code
prashantsail May 14, 2020
e3517dd
move code inside mapped operator function
prashantsail May 14, 2020
16cdcbd
support output tensors as function args, multiple functions
prashantsail May 14, 2020
27fb98a
Code Reorganization- First working iteration of positive spop test cases
prashantsail May 14, 2020
8589fec
Code Reorganization- First working iteration of positive spop test cases
deepakbabel23 May 14, 2020
8feb00c
Merge pull request #5 from prashantsail/ps_spop
deepakbabel23 May 14, 2020
0f637b4
Function invocation more test cases
deepakbabel23 May 14, 2020
b77cece
Simplified & Merged different Function Invocation Test cases
deepakbabel23 May 15, 2020
6c71b72
support invocation of nested callables
prashantsail May 15, 2020
b253ac9
Simplified and Uniform testcases
prashantsail May 15, 2020
878313d
support invocation of nested callables
prashantsail May 15, 2020
ff11b39
Simplified and Uniform testcases
prashantsail May 15, 2020
224b93e
Merge branch 'ps_spop' of https://github.com/prashantsail/incubator-t…
prashantsail May 15, 2020
14982af
removed duplicate and renamed testcase
prashantsail May 15, 2020
4b396c0
Merge pull request #6 from prashantsail/ps_spop
deepakbabel23 May 15, 2020
332e2f0
Negative scenario added for testing operator statefulness. Only Excep…
deepakbabel23 May 15, 2020
b0b2fdf
Merge remote-tracking branch 'refs/remotes/origin/spop' into spop.
deepakbabel23 May 15, 2020
7ea2a66
Miscellaneous reorganization changes for spop scenarios
deepakbabel23 May 15, 2020
da79752
Miscellaneous reorganization changes for spop scenarios
deepakbabel23 May 15, 2020
b3666ce
Corrected import of tensorflow modules safely using try except and ot…
deepakbabel23 May 15, 2020
caaf8fd
Negative scenario for resource variables handled
deepakbabel23 May 15, 2020
a525666
Documentation update for code
deepakbabel23 May 15, 2020
790c026
SPOP change in function handling
maheshambule May 15, 2020
6b8bc37
tf global function support
maheshambule May 15, 2020
5e92663
handle nested subgraph
maheshambule May 18, 2020
becbad4
refactor
maheshambule May 18, 2020
9a45af9
get op def compatible with tf 1x & 2x
prashantsail May 18, 2020
fa122df
Fixed liniting issues
prashantsail May 18, 2020
af44879
Merge pull request #7 from prashantsail/ps_spop
deepakbabel23 May 18, 2020
9b8f24d
added doctsring and few nits
maheshambule May 18, 2020
0c06cfb
resolve conflicts
maheshambule May 18, 2020
5888317
Merged changes for positive test cases and negative test cases
deepakbabel23 May 18, 2020
001f3a0
Merge remote-tracking branch 'refs/remotes/origin/spop' into spop
deepakbabel23 May 18, 2020
7f3b52b
Moved StatefulPartitionedCall test case to the end of the TC list
deepakbabel23 May 18, 2020
92e1853
Fixed some typos and semantics
deepakbabel23 May 18, 2020
bd2952d
Merge pull request #8 from apache/master
deepakbabel23 May 18, 2020
87ceb84
Merge pull request #9 from deepakbabel23/master
deepakbabel23 May 18, 2020
7968859
dmlc-core
maheshambule May 18, 2020
cdc4af9
dmlc-core
maheshambule May 18, 2020
fe58431
fixes
maheshambule May 18, 2020
06a80b3
Addressing Review comments in the PR for SPOP support
deepakbabel23 May 28, 2020
cf8f374
Fixed pylint errors
deepakbabel23 May 28, 2020
bb09d3d
Corrected tensorflow import syntax
deepakbabel23 May 29, 2020
01244c9
Placed the op_def_registry module import outside of for loop
deepakbabel23 May 29, 2020
c0308ca
Removed new stateful operators list and combined these operators with…
deepakbabel23 May 31, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 131 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
# pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel, redefined-builtin
"""TF: Tensorflow frontend."""
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -1927,7 +1927,6 @@ def _impl(inputs, attr, params, mod):
return _res
return _impl


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -2717,8 +2716,10 @@ def __init__(self):
self._loop_var_order = {}
self._hash2tfnode = {}
self._while_loop_name_set = set()
self._main_graph_proto = self
self._stateful_ops_list = []

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef.

Follow the tensorflow graph definition to parse and convert it to Relay.
Expand Down Expand Up @@ -2773,6 +2774,12 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
if freezed_ops:
raise Exception("Graph is not frozen. Provide a frozen graph. "
"Found operators {}".format(freezed_ops))
stateful_ops = [op for op in missing_operators
if op in self._main_graph_proto._stateful_ops_list]
if stateful_ops:
raise Exception("Found stateful operators in this graph {}. " \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again this will display exception with only stateful missing ops (Doesn't show normal missing ops.).
I don't think we need separate list for stateful missing ops. Just add them to missing_operators list and display as one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. merged into missing op list itself and removed seperate exception code for stateful ops.

"Rejecting the graph as TVM does not support stateful operations " \
.format(stateful_ops))

raise NotImplementedError(
"The following operators are not implemented: {}".format(missing_operators))
Expand Down Expand Up @@ -2885,6 +2892,13 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

out = out[0] if len(out) == 1 else _expr.Tuple(out)
func = _function.Function(analysis.free_vars(out), out)
return func

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
""" Wrapper to _get_relay_func which converts Tensorflow graph to Relay function
which is used as main function for the Relay module
"""
func = self._get_relay_func(graph, layout=layout, shape=shape, outputs=outputs)
self._mod["main"] = func
return self._mod, self._params

Expand All @@ -2895,16 +2909,25 @@ def _parse_import_prerequisites(self, graph):
which are not supported
"""
missing_operators = set()
from tensorflow.python.framework import op_def_registry
for node in graph.node:
getOpDef = op_def_registry._registered_ops.get if hasattr(op_def_registry,\
"_registered_ops") else op_def_registry.get
op_def = getOpDef(node.op)
if node.op == "Placeholder" or node.op == 'PlaceholderWithDefault':
pass
elif node.op == "Const":
pass
elif node.op in ["PartitionedCall", "StatefulPartitionedCall"]:
deepakbabel23 marked this conversation as resolved.
Show resolved Hide resolved
pass
else:
if any([node.op in t for t in [_identity_list, _convert_map,
_convert_map_rnn,
_control_flow_nodes]]):
pass
elif op_def is not None and op_def.is_stateful:
self._main_graph_proto._stateful_ops_list.append(node.op)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need of another list. If needed you may append (StatufulOperator) to node.op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the new list. only adding these stateful ops to missing operator list.

missing_operators.add(node.op)
else:
missing_operators.add(node.op)

Expand Down Expand Up @@ -3155,6 +3178,91 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_

return op

def _partition_call_operator(self, inputs, attr):
"""
Convert the Relay Partition call ops into Relay Function calls and
function definitions from Tensorflow graph library attribute to Relay global
functions

Parameters
----------
node: TensorFlow graph node object.
A TensorFlow graph node object.

inputs : List[tvm.relay.Expr]
List of input symbols.

attrs : Dict[tvm.Attrs]
Dict of operator attributes.

Returns
-------
op : tvm.relay.Expr
Converted relay expression.
"""

try:
from tensorflow.python.framework import function_def_to_graph
except ImportError as e:
raise ImportError(
"Unable to import tensorflow which is required {}".format(e))

main_graph_proto = self._main_graph_proto
outer_graph_def = main_graph_proto._graph

node_func_name = attr.get('f').name
func = next((f for f in outer_graph_def.library.function
if f.signature.name == node_func_name), None)
if func:
devices = set(node.device for node in func.node_def)
if len(devices) > 1:
raise Exception("Found inconsistent Device assignment in the "\
"Stateful Partitioned SubGraph. Rejecting "\
"the subgraph ")
# Convert function definition to graph
func_input_shapes = func.attr["_input_shapes"].list.shape
subgraph, _ = function_def_to_graph.\
function_def_to_graph_def(func, func_input_shapes)

# Computing subgraph's input shape dictionary
subgraph_shape_dict, input_expr_dict = {}, {}
for f_arg, input in zip(func.signature.input_arg, inputs):
input_expr_dict[f_arg.name] = input
subgraph_shape_dict[f_arg.name] = _infer_shape(input, main_graph_proto._mod)

func_name = 'func_{}'.format(func.signature.name)
try:
global_func = main_graph_proto._mod[func_name]
deepakbabel23 marked this conversation as resolved.
Show resolved Hide resolved
sub_func = global_func
sub_params = main_graph_proto._params
except ValueError:
# Construct relay nodes from the subgraph
g1 = SubGraphProto(main_graph_proto)
sub_func, sub_params = g1.from_tensorflow(subgraph, shape=subgraph_shape_dict)
main_graph_proto._params.update(sub_params)
func_expr = _function.Function(sub_func.params, sub_func.body)
global_func = tvm.relay.GlobalVar(func_name)
main_graph_proto._mod[global_func] = func_expr

param_exprs = []
for param_expr in sub_func.params:
# sub_params is subset of sub_func.params
param_name = param_expr.vid.name_hint
if param_name in input_expr_dict.keys():
param_exprs.append(input_expr_dict[param_name])
elif param_name in sub_params.keys():
param_exprs.append(param_expr)
else:
raise Exception("Input parameter {} not found".format(param_name))

sb = tvm.relay.scope_builder.ScopeBuilder()
loop_ret = global_func(*param_exprs)
sb.ret(loop_ret)
ret = sb.get()
else:
raise Exception("Function not found - {}".format(node_func_name))
return ret

def _convert_operator(self, op_name, inputs, attrs,
graph, identity_list=None, convert_map=None):
"""Convert from Tensorflow operator to relay operator.
Expand Down Expand Up @@ -3196,6 +3304,9 @@ def _convert_operator(self, op_name, inputs, attrs,
sym = self._convert_rnn_operator(op_name, inputs, attrs,
self._params, graph,
convert_map_rnn)

elif op_name in ["PartitionedCall", "StatefulPartitionedCall"]:
sym = self._partition_call_operator(inputs, attrs)
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))
return sym
Expand Down Expand Up @@ -3262,6 +3373,22 @@ def _backtrack_construct(self, node_name):

return self._nodes[node_name]


class SubGraphProto(GraphProto):
""" A helper class for handling relay subgraph copying from Tensorflow GraphDef.
"""
def __init__(self, main_graph_proto):
super().__init__()
self._main_graph_proto = main_graph_proto # holds main graph proto object

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
""" Wrapper to _get_relay_func which converts Tensorflow graph to Relay function.
Return Relay function and params
"""
func = self._get_relay_func(graph, layout=layout, shape=shape, outputs=outputs)
return func, self._params


def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
"""Load tensorflow graph which is a python tensorflow graph object into relay.
The companion parameters will be handled automatically.
Expand All @@ -3288,6 +3415,7 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
params : dict of str to tvm.nd.NDArray
Dict of converted parameters stored in tvm.nd.NDArray format
"""

g = GraphProto()
mod, params = g.from_tensorflow(graph, layout, shape, outputs)
return mod, params
Loading