Skip to content

Commit

Permalink
[TENSORFLOW]StatefulPartitionedCall/PartitionedCall Ops support added (
Browse files Browse the repository at this point in the history
…#5617)

* Implemented functionInvocation Unit Test for StatefulPartitionedCall operator(working) and initial changes for placeholder(not working as of now)

* Placeholder exercises with tvm

* placeholder interim

* SPOP Test cases structure

* New test cases for spop

* miscellaneous test cases for spop

* Placeholder samples..working with shapes explicitly passed

* Variables test case. Works with the same fix of shape_dict

* SPOP Positive test cases first iteration

* support output tensors as function args, multiple functions

* Corrected Indentation

* filewritter is only for debug purpose

* support variables in function args

* First working iteration of positive spop test cases

* Removed commented code, simplified code

* Code Reorganization- First working iteration of positive spop test cases

* corrected variable name after refactor

* Code Reorganization- First working iteration of positive spop test cases

* move code inside mapped operator function

* Removed extra line

* support variables in function args

* Removed commented code, simplified code

* move code inside mapped operator function

* Code Reorganization- First working iteration of positive spop test cases

# Conflicts:
#	tests/python/frontend/tensorflow/test_forward.py

* Code Reorganization- First working iteration of positive spop test cases

* Function invocation more test cases

* Simplified & Merged different Function Invocation Test cases

* support invocation of nested callables

no need to explicitly handle paratitioned and
statefulPartitioned condition in convert_operator function

* Simplified and Uniform testcases

* support invocation of nested callables

no need to explicitly handle paratitioned and
statefulPartitioned condition in convert_operator function

* Simplified and Uniform testcases

* removed duplicate and renamed testcase

* Negative scenario added for testing operator statefulness. Only Exception to stateful operators are Partitioned & StatefulPartitionedOp which have capability to execute even stateless operators within them

* Miscellaneous reorganization changes for spop scenarios

* Miscellaneous reorganization changes for spop scenarios

* Corrected import of tensorflow modules safely using try except and other code reorganization

* Negative scenario for resource variables handled

* Documentation update for code

* SPOP change in function handling

* handle nested subgraph

* refactor

* get op def compatible with tf 1x & 2x

* Fixed liniting issues

* added doctsring and few nits

* Merged changes for positive test cases and negative test cases

* Moved StatefulPartitionedCall test case to the end of the TC list

* Fixed some typos and semantics

* dmlc-core

* dmlc-core

* fixes

* Addressing Review comments in the PR for SPOP support

* Fixed pylint errors

* Corrected tensorflow import syntax

* Placed the op_def_registry module import outside of for loop

* Removed new stateful operators list and combined these operators with missing operators to display as single list. Also removed throwing seperate exception for stateful ops

Co-authored-by: Prashant Sail <[email protected]>
Co-authored-by: maheshambule <[email protected]>
  • Loading branch information
3 people authored Jun 4, 2020
1 parent 3d61dc8 commit 43dcbc6
Show file tree
Hide file tree
Showing 2 changed files with 465 additions and 5 deletions.
126 changes: 123 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except
# pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel, redefined-builtin
"""TF: Tensorflow frontend."""
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -1927,7 +1927,6 @@ def _impl(inputs, attr, params, mod):
return _res
return _impl


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -2717,8 +2716,9 @@ def __init__(self):
self._loop_var_order = {}
self._hash2tfnode = {}
self._while_loop_name_set = set()
self._main_graph_proto = self

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef.
Follow the tensorflow graph definition to parse and convert it to Relay.
Expand Down Expand Up @@ -2885,6 +2885,13 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

out = out[0] if len(out) == 1 else _expr.Tuple(out)
func = _function.Function(analysis.free_vars(out), out)
return func

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
""" Wrapper to _get_relay_func which converts Tensorflow graph to Relay function
which is used as main function for the Relay module
"""
func = self._get_relay_func(graph, layout=layout, shape=shape, outputs=outputs)
self._mod["main"] = func
return self._mod, self._params

Expand All @@ -2895,16 +2902,24 @@ def _parse_import_prerequisites(self, graph):
which are not supported
"""
missing_operators = set()
from tensorflow.python.framework import op_def_registry
for node in graph.node:
getOpDef = op_def_registry._registered_ops.get if hasattr(op_def_registry,\
"_registered_ops") else op_def_registry.get
op_def = getOpDef(node.op)
if node.op == "Placeholder" or node.op == 'PlaceholderWithDefault':
pass
elif node.op == "Const":
pass
elif node.op in ["PartitionedCall", "StatefulPartitionedCall"]:
pass
else:
if any([node.op in t for t in [_identity_list, _convert_map,
_convert_map_rnn,
_control_flow_nodes]]):
pass
elif op_def is not None and op_def.is_stateful:
missing_operators.add(node.op)
else:
missing_operators.add(node.op)

Expand Down Expand Up @@ -3149,6 +3164,91 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_

return op

def _partition_call_operator(self, inputs, attr):
"""
Convert the Relay Partition call ops into Relay Function calls and
function definitions from Tensorflow graph library attribute to Relay global
functions
Parameters
----------
node: TensorFlow graph node object.
A TensorFlow graph node object.
inputs : List[tvm.relay.Expr]
List of input symbols.
attrs : Dict[tvm.Attrs]
Dict of operator attributes.
Returns
-------
op : tvm.relay.Expr
Converted relay expression.
"""

try:
from tensorflow.python.framework import function_def_to_graph
except ImportError as e:
raise ImportError(
"Unable to import tensorflow which is required {}".format(e))

main_graph_proto = self._main_graph_proto
outer_graph_def = main_graph_proto._graph

node_func_name = attr.get('f').name
func = next((f for f in outer_graph_def.library.function
if f.signature.name == node_func_name), None)
if func:
devices = set(node.device for node in func.node_def)
if len(devices) > 1:
raise Exception("Found inconsistent Device assignment in the "\
"Stateful Partitioned SubGraph. Rejecting "\
"the subgraph ")
# Convert function definition to graph
func_input_shapes = func.attr["_input_shapes"].list.shape
subgraph, _ = function_def_to_graph.\
function_def_to_graph_def(func, func_input_shapes)

# Computing subgraph's input shape dictionary
subgraph_shape_dict, input_expr_dict = {}, {}
for f_arg, input in zip(func.signature.input_arg, inputs):
input_expr_dict[f_arg.name] = input
subgraph_shape_dict[f_arg.name] = _infer_shape(input, main_graph_proto._mod)

func_name = 'func_{}'.format(func.signature.name)
try:
global_func = main_graph_proto._mod[func_name]
sub_func = global_func
sub_params = main_graph_proto._params
except ValueError:
# Construct relay nodes from the subgraph
g1 = SubGraphProto(main_graph_proto)
sub_func, sub_params = g1.from_tensorflow(subgraph, shape=subgraph_shape_dict)
main_graph_proto._params.update(sub_params)
func_expr = _function.Function(sub_func.params, sub_func.body)
global_func = tvm.relay.GlobalVar(func_name)
main_graph_proto._mod[global_func] = func_expr

param_exprs = []
for param_expr in sub_func.params:
# sub_params is subset of sub_func.params
param_name = param_expr.vid.name_hint
if param_name in input_expr_dict.keys():
param_exprs.append(input_expr_dict[param_name])
elif param_name in sub_params.keys():
param_exprs.append(param_expr)
else:
raise Exception("Input parameter {} not found".format(param_name))

sb = tvm.relay.scope_builder.ScopeBuilder()
loop_ret = global_func(*param_exprs)
sb.ret(loop_ret)
ret = sb.get()
else:
raise Exception("Function not found - {}".format(node_func_name))
return ret

def _convert_operator(self, op_name, inputs, attrs,
graph, identity_list=None, convert_map=None):
"""Convert from Tensorflow operator to relay operator.
Expand Down Expand Up @@ -3190,6 +3290,9 @@ def _convert_operator(self, op_name, inputs, attrs,
sym = self._convert_rnn_operator(op_name, inputs, attrs,
self._params, graph,
convert_map_rnn)

elif op_name in ["PartitionedCall", "StatefulPartitionedCall"]:
sym = self._partition_call_operator(inputs, attrs)
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))
return sym
Expand Down Expand Up @@ -3253,6 +3356,22 @@ def _backtrack_construct(self, node_name):

return out[0]


class SubGraphProto(GraphProto):
""" A helper class for handling relay subgraph copying from Tensorflow GraphDef.
"""
def __init__(self, main_graph_proto):
super().__init__()
self._main_graph_proto = main_graph_proto # holds main graph proto object

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
""" Wrapper to _get_relay_func which converts Tensorflow graph to Relay function.
Return Relay function and params
"""
func = self._get_relay_func(graph, layout=layout, shape=shape, outputs=outputs)
return func, self._params


def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
"""Load tensorflow graph which is a python tensorflow graph object into relay.
The companion parameters will be handled automatically.
Expand All @@ -3279,6 +3398,7 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
params : dict of str to tvm.nd.NDArray
Dict of converted parameters stored in tvm.nd.NDArray format
"""

g = GraphProto()
mod, params = g.from_tensorflow(graph, layout, shape, outputs)
return mod, params
Loading

0 comments on commit 43dcbc6

Please sign in to comment.