diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index ab9e9e656516..6b4a534e51bd 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except -# pylint: disable=import-outside-toplevel +# pylint: disable=import-outside-toplevel, redefined-builtin """TF: Tensorflow frontend.""" import warnings from collections import defaultdict @@ -1927,7 +1927,6 @@ def _impl(inputs, attr, params, mod): return _res return _impl - # compatible operators that do NOT require any conversion. _identity_list = [] @@ -2717,8 +2716,9 @@ def __init__(self): self._loop_var_order = {} self._hash2tfnode = {} self._while_loop_name_set = set() + self._main_graph_proto = self - def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): + def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None): """Construct relay nodes from tensorflow graph definition - GraphDef. Follow the tensorflow graph definition to parse and convert it to Relay. @@ -2885,6 +2885,13 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): out = out[0] if len(out) == 1 else _expr.Tuple(out) func = _function.Function(analysis.free_vars(out), out) + return func + + def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): + """ Wrapper to _get_relay_func which converts Tensorflow graph to Relay function + which is used as main function for the Relay module + """ + func = self._get_relay_func(graph, layout=layout, shape=shape, outputs=outputs) self._mod["main"] = func return self._mod, self._params @@ -2895,16 +2902,24 @@ def _parse_import_prerequisites(self, graph): which are not supported """ missing_operators = set() + from tensorflow.python.framework import op_def_registry for node in graph.node: + getOpDef = op_def_registry._registered_ops.get if hasattr(op_def_registry,\ + "_registered_ops") else op_def_registry.get + op_def = getOpDef(node.op) if node.op == "Placeholder" or node.op == 'PlaceholderWithDefault': pass elif node.op == "Const": pass + elif node.op in ["PartitionedCall", "StatefulPartitionedCall"]: + pass else: if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn, _control_flow_nodes]]): pass + elif op_def is not None and op_def.is_stateful: + missing_operators.add(node.op) else: missing_operators.add(node.op) @@ -3155,6 +3170,91 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_ return op + def _partition_call_operator(self, inputs, attr): + """ + Convert the Relay Partition call ops into Relay Function calls and + function definitions from Tensorflow graph library attribute to Relay global + functions + + Parameters + ---------- + node: TensorFlow graph node object. + A TensorFlow graph node object. + + inputs : List[tvm.relay.Expr] + List of input symbols. + + attrs : Dict[tvm.Attrs] + Dict of operator attributes. + + Returns + ------- + op : tvm.relay.Expr + Converted relay expression. + """ + + try: + from tensorflow.python.framework import function_def_to_graph + except ImportError as e: + raise ImportError( + "Unable to import tensorflow which is required {}".format(e)) + + main_graph_proto = self._main_graph_proto + outer_graph_def = main_graph_proto._graph + + node_func_name = attr.get('f').name + func = next((f for f in outer_graph_def.library.function + if f.signature.name == node_func_name), None) + if func: + devices = set(node.device for node in func.node_def) + if len(devices) > 1: + raise Exception("Found inconsistent Device assignment in the "\ + "Stateful Partitioned SubGraph. Rejecting "\ + "the subgraph ") + # Convert function definition to graph + func_input_shapes = func.attr["_input_shapes"].list.shape + subgraph, _ = function_def_to_graph.\ + function_def_to_graph_def(func, func_input_shapes) + + # Computing subgraph's input shape dictionary + subgraph_shape_dict, input_expr_dict = {}, {} + for f_arg, input in zip(func.signature.input_arg, inputs): + input_expr_dict[f_arg.name] = input + subgraph_shape_dict[f_arg.name] = _infer_shape(input, main_graph_proto._mod) + + func_name = 'func_{}'.format(func.signature.name) + try: + global_func = main_graph_proto._mod[func_name] + sub_func = global_func + sub_params = main_graph_proto._params + except ValueError: + # Construct relay nodes from the subgraph + g1 = SubGraphProto(main_graph_proto) + sub_func, sub_params = g1.from_tensorflow(subgraph, shape=subgraph_shape_dict) + main_graph_proto._params.update(sub_params) + func_expr = _function.Function(sub_func.params, sub_func.body) + global_func = tvm.relay.GlobalVar(func_name) + main_graph_proto._mod[global_func] = func_expr + + param_exprs = [] + for param_expr in sub_func.params: + # sub_params is subset of sub_func.params + param_name = param_expr.vid.name_hint + if param_name in input_expr_dict.keys(): + param_exprs.append(input_expr_dict[param_name]) + elif param_name in sub_params.keys(): + param_exprs.append(param_expr) + else: + raise Exception("Input parameter {} not found".format(param_name)) + + sb = tvm.relay.scope_builder.ScopeBuilder() + loop_ret = global_func(*param_exprs) + sb.ret(loop_ret) + ret = sb.get() + else: + raise Exception("Function not found - {}".format(node_func_name)) + return ret + def _convert_operator(self, op_name, inputs, attrs, graph, identity_list=None, convert_map=None): """Convert from Tensorflow operator to relay operator. @@ -3196,6 +3296,9 @@ def _convert_operator(self, op_name, inputs, attrs, sym = self._convert_rnn_operator(op_name, inputs, attrs, self._params, graph, convert_map_rnn) + + elif op_name in ["PartitionedCall", "StatefulPartitionedCall"]: + sym = self._partition_call_operator(inputs, attrs) else: raise NotImplementedError("Operator {} not implemented.".format(op_name)) return sym @@ -3262,6 +3365,22 @@ def _backtrack_construct(self, node_name): return self._nodes[node_name] + +class SubGraphProto(GraphProto): + """ A helper class for handling relay subgraph copying from Tensorflow GraphDef. + """ + def __init__(self, main_graph_proto): + super().__init__() + self._main_graph_proto = main_graph_proto # holds main graph proto object + + def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): + """ Wrapper to _get_relay_func which converts Tensorflow graph to Relay function. + Return Relay function and params + """ + func = self._get_relay_func(graph, layout=layout, shape=shape, outputs=outputs) + return func, self._params + + def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): """Load tensorflow graph which is a python tensorflow graph object into relay. The companion parameters will be handled automatically. @@ -3288,6 +3407,7 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): params : dict of str to tvm.nd.NDArray Dict of converted parameters stored in tvm.nd.NDArray format """ + g = GraphProto() mod, params = g.from_tensorflow(graph, layout, shape, outputs) return mod, params diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index c3313b69a0bd..93bf7394c80c 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -36,6 +36,10 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.ops import init_ops +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import gen_functional_ops from distutils.version import LooseVersion import tvm from tvm import te @@ -176,6 +180,7 @@ def name_without_num(name): if init_global_variables: sess.run(variables.global_variables_initializer()) final_graph_def = tf_testing.AddShapesToGraphDef(sess, out_node) + tf_output = run_tf_graph(sess, in_data, in_name, out_name) for device in ["llvm", "cuda"]: @@ -1138,13 +1143,13 @@ def test_read_variable_op(): tf_output = run_tf_graph(sess, in_data, in_name, out_name) shape_dict = {e: i.shape for e, i in zip(in_name, in_data)} - with pytest.raises(Exception) as exexcinfo: + with pytest.raises(Exception) as execinfo: mod, params = relay.frontend.from_tensorflow(final_graph_def, layout=None, shape=shape_dict, outputs=None) - assert exexcinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph.") + assert execinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph") # Now convert the variables to constant and run inference on the converted graph final_graph_def = tf.graph_util.convert_variables_to_constants( @@ -3179,10 +3184,342 @@ def test_forward_isfinite(): _verify_infiniteness_ops(tf.is_finite, "isfinite") +def _test_spop_placeholder_without_shape_info(): + with tf.Graph().as_default(): + + @function.Defun(*[tf.int32]*2) + def Forward(x,y): + print(x.name) + print(y.name) + b = tf.add(x, y) + return b + pl1 = tf.placeholder(tf.int32,name="pl1") + pl2 = tf.placeholder(tf.int32,name="pl2") + pl3 = tf.placeholder(tf.int32, name="pl3") + data = np.array([[-1, 1], [2, -2]], dtype=np.int32) + data2 = np.array([[-2, 3], [4, -6]], dtype=np.int32) + data3 = np.array([[-2, 3], [4, -6]], dtype=np.int32) + z1 = gen_functional_ops.StatefulPartitionedCall(args=[pl1,pl2], Tout=[tf.int32],f=Forward) + z2 = z1 + pl3 + compare_tf_with_tvm([data, data2, data3], ['pl1:0', 'pl2:0', 'pl3:0'], + ['StatefulPartitionedCall:0',z2.name], mode='vm', init_global_variables=True) + + +def _test_spop_placeholder_with_shape_and_default_value(): + with tf.Graph().as_default(): + data = np.ones([1], dtype=int).astype(np.int32) + dataVar = tf.Variable(data, shape=data.shape) + pl1 = array_ops.placeholder_with_default(dataVar,shape=data.shape,name="pl1") + tpl = tf.convert_to_tensor(pl1, dtype=tf.int32) + + @function.Defun(*[tf.int32]) + def pl_with_default(pl): + return tf.expand_dims(tf.multiply(pl, pl), 0) + + z = gen_functional_ops.StatefulPartitionedCall(args=[tpl], Tout=[tf.int32], f=pl_with_default) + compare_tf_with_tvm(data, ['pl1:0'], 'StatefulPartitionedCall:0', mode='vm', init_global_variables=True) + + +def _test_spop_placeholder_numpy_arange_feed(): + with tf.Graph().as_default(): + t1 = tf.placeholder(tf.int32, (3, 3, 3), "t1") + t1_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3)) + t2 = tf.placeholder(tf.int32, (3, 3, 3), "t2") + t2_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3)) + + @tf.function + def add(x, y): + return tf.add(x, y, "add_t1_t2") + + t3 = add(t1, t2) + compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [t3.name], mode='vm', init_global_variables=True) + + +def _test_spop_placeholder_numpy_array_feed(): + with tf.Graph().as_default(): + t1_data = np.array([[-1, 1, 3], [2, -2, 4], [2, -3, 14]], dtype=np.int32) + t2_data = np.array([[-2, 1, 2], [12, -2, 14], [12, -3, 4]], dtype=np.int32) + t1 = tf.placeholder(tf.int32, name="t1") + t2 = tf.placeholder(tf.int32, name="t2") + + @tf.function + def add(x, y): + return tf.add(x, y, "add_t1_t2") + + t3 = add(t1, t2) + compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [t3.name], mode='vm', init_global_variables=True) + + +def _test_spop_function_invocation_basic(): + with tf.Graph().as_default(): + + def fun1(a): + return tf.multiply(a,a) + + def fun2(b): + return tf.multiply(b,10) + + @tf.function + def fun3(x,y): + x = fun2(x) + y = fun1(y) + z = tf.add(x,y) + return z + + t3 = fun3(tf.constant(10.5), tf.constant(20.4)) + + compare_tf_with_tvm([], [], [t3.name], mode='vm', init_global_variables=True) + + +def _test_spop_function_invocation_nested(): + with tf.Graph().as_default(): + t1 = tf.placeholder(tf.int32, (3, 3, 3), name="t1") + t1_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3)) + t2 = tf.placeholder(tf.int32, name="t2") + t2_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3)) + + @tf.function + def myfunc(x, y): + return tf.add(x, y, "myfunc") + + @tf.function + def myfunc2(x, y): + z = myfunc(x, y) + l = myfunc(z, y) + m = myfunc(l,z) + return tf.add(l, m, "myfunc2") + + res1 = myfunc(t1, t2) + res2 = myfunc2(res1, t1) + + compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [res2.name], mode='vm', init_global_variables=True) + + +def _test_spop_function_invocation_no_autograph(): + with tf.Graph().as_default(): + + @tf.function(autograph=False) + def fun1(a): + return tf.multiply(a,a) + + @tf.function(autograph=False) + def fun2(b): + return tf.multiply(b,10) + + @tf.function + def fun3(x,y): + x = fun2(x) + y = fun1(y) + z = tf.add(x,y) + return z + + t3 = fun3(tf.constant(10.5), tf.constant(20.4)) + + compare_tf_with_tvm([], [], [t3.name], mode='vm', init_global_variables=True) + + +def _test_spop_function_invocation_defun(): + with tf.Graph().as_default(): + + def fun1(a): + return tf.multiply(a,a) + + def fun2(b): + return tf.multiply(b,b) + + @function.Defun(dtypes.float32, dtypes.float32, func_name="Fun3") + def fun3(x,y): + x = fun2(x) + y = fun1(y) + z = tf.add(x,y) + return z + + op = gen_functional_ops.StatefulPartitionedCall(args=[tf.constant(10.5),tf.constant(20.4)], + Tout=[dtypes.float32], f=fun3, name="SpopFnInvocation") + compare_tf_with_tvm([],[], 'SpopFnInvocation:0', mode='vm', init_global_variables=True) + + +def _test_spop_arithmetic(): + with tf.Graph().as_default(): + @function.Defun(*[dtypes.int32]*3) + def arithmetic(m,x,c): + z = tf.add(tf.multiply(m, x), c) + return z + + m = tf.constant(10) + x = tf.constant(20) + c = tf.constant(2) + spopFn = gen_functional_ops.StatefulPartitionedCall(args=[m,x,c],Tout=[tf.int32], f=arithmetic) + + compare_tf_with_tvm([],[],'StatefulPartitionedCall:0', mode='vm', init_global_variables=True) + + +def _test_spop_control_flow(): + with tf.Graph().as_default(): + + @function.Defun(*[dtypes.float32] * 2) + def Body1(x, y): + with ops.device("/job:localhost/replica:0/task:0/device:CPU:0"): + z = math_ops.multiply(x, y) + i = 0 + while i<10 : + i +=1 + if i == 5: + continue + z = math_ops.multiply(x, y*i) + return z + + op = gen_functional_ops.StatefulPartitionedCall( + args=[constant_op.constant(32.), constant_op.constant(100.)], + Tout=[dtypes.float32], f=Body1) + compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', mode='vm', init_global_variables=True) + + +def _test_spop_variables(): + with tf.Graph().as_default(): + const1 = tf.constant(10) + const2 = tf.constant(20) + var1 = tf.Variable(const1, dtype=tf.int32) + var2 = tf.Variable(const2, dtype=tf.int32) + + @function.Defun(tf.int32,tf.int32) + def Forward(x,y): + return tf.multiply(x,y) + + z = gen_functional_ops.StatefulPartitionedCall(args=[var1,var2],Tout=[tf.int32], f=Forward) + compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', init_global_variables=True, mode="vm") + + +def _test_spop_constants(): + with tf.Graph().as_default(): + @function.Defun(*[dtypes.int32] * 2) + def constantsFn(x, y): + vv = tf.constant([2, 3, 4], name="vv") + z = tf.add(vv + x, y) + return z + + a = tf.constant(20000, name = "a") + b = tf.constant(40000, name = "b") + spopFn = gen_functional_ops.StatefulPartitionedCall(args=[a, b], Tout=[tf.int32], f=constantsFn) + + compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', mode='vm', init_global_variables=True) + + +def _test_spop_stateful(): + # This test case is to test that TVM rejects any TF stateful operations + # (including Resource Variables) except StatefulPartitionedCall/PartitionedCall + # (as these two operators can still be used as container graphs to execute + # "stateless" operations internally. + tf.reset_default_graph() + with tf.Graph().as_default(): + + @tf.function + def FunctionWithStatefulOp_One(i): + b = tf.random.uniform(shape=[2, 4], maxval=10, dtype=tf.float32, seed=10) + y = tf.multiply(b, i) + return y + + @tf.function + def FunctionWithStatefulOp(m, n): + a = tf.random.uniform(shape=[2, 4], maxval=10, dtype=tf.float32, seed = 10) + x = tf.multiply(a,m) + y = FunctionWithStatefulOp_One(n) + z = tf.multiply(x,y) + return z + + op = FunctionWithStatefulOp(constant_op.constant(1.), constant_op.constant(2.)) + with pytest.raises(Exception) as execinfo: + compare_tf_with_tvm([], [], [op.name], init_global_variables=True, mode="vm") + assert execinfo.value.args[0].startswith( + "The following operators are not implemented") + + +def _test_spop_device_assignment(): + # This test case is to test that TVM rejects inconsistent device assignment + # while using StatefulPartitionedCall/PartitionedCall operators which in case of TVM will + # be used as container graphs to internally execute "stateless" operations. + + tf.reset_default_graph() + with tf.Graph().as_default(): + + def fun1(a): + with ops.device("/GPU:0"): + return tf.multiply(a,a) + + def fun2(b): + with ops.device("/job:localhost/replica:0/task:0/device:CPU:1"): + return tf.multiply(b,b) + + @function.Defun(dtypes.float32, dtypes.float32, func_name="Fun3") + def fun3(x,y): + with ops.device("/CPU:0"): + x = fun2(x) + with ops.device("/job:localhost/replica:0/task:0/device:CPU:2"): + y = fun1(y) + with ops.device("/job:localhost/replica:0/task:0/device:CPU:3"): + z = tf.add(x,y) + return z + + op = gen_functional_ops.StatefulPartitionedCall(args=[tf.constant(10.5),tf.constant(20.4)], + Tout=[dtypes.float32], f=fun3) + with pytest.raises(Exception) as execinfo: + compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', + mode='vm', init_global_variables=True) + assert execinfo.value.args[0].startswith("Found inconsistent Device assignment") + + +def _test_spop_resource_variables(): + # This test case is to test that TVM rejects any graph containing + # resource variables with StatefulPartitionedOp. + + tf.reset_default_graph() + with tf.Graph().as_default(): + + const1 = tf.constant(10) + const2 = tf.constant(20) + var1 = tf.Variable(const1, dtype=tf.int32, use_resource=True) + var2 = tf.Variable(const2, dtype=tf.int32, use_resource=True) + + @tf.function + def resourceVariablesTest(x, y): + return tf.multiply(x, y) + + op = resourceVariablesTest(var1,var2) + with pytest.raises(Exception) as execinfo: + compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', + mode='vm', init_global_variables=True) + assert execinfo.value.args[0].startswith("Graph is not frozen." + " Provide a frozen graph") + +def test_forward_spop(): + _test_spop_stateful() + _test_spop_device_assignment() + _test_spop_resource_variables() + + #Placeholder test cases + _test_spop_placeholder_without_shape_info() + _test_spop_placeholder_with_shape_and_default_value() + _test_spop_placeholder_numpy_arange_feed() + _test_spop_placeholder_numpy_array_feed() + + #Function Invocation test cases + _test_spop_function_invocation_basic() + _test_spop_function_invocation_nested() + _test_spop_function_invocation_no_autograph() + _test_spop_function_invocation_defun() + + #Test cases for various other TF constructs + _test_spop_arithmetic() + _test_spop_control_flow() + _test_spop_variables() + _test_spop_constants() + + ####################################################################### # Main # ---- if __name__ == '__main__': + # Transforms test_forward_slice() test_forward_transpose() @@ -3307,3 +3644,6 @@ def test_forward_isfinite(): # Sharing params case using Mean ops test_sharing_node() + + # StatefulPartitionedCall + test_forward_spop()