From 34c95a89eae13043aa3a269ad7d27d28837ffb56 Mon Sep 17 00:00:00 2001 From: Dhruva Ray Date: Thu, 4 Jun 2020 23:25:38 +0530 Subject: [PATCH] [Frontend][TFLite] Add parser support for shape and range (#5329) * [Relay][Frontend][TFLite] Add parser support for shape and range Signed-off-by: Dhruva Ray * Incorporated review comments and used new functions Signed-off-by: Dhruva Ray * Few cosmetic changes Signed-off-by: Dhruva Ray * Removed an extra line added by rebase... Signed-off-by: Dhruva Ray --- python/tvm/relay/frontend/tflite.py | 35 ++++ tests/python/frontend/tflite/test_forward.py | 168 ++++++++++++++++--- 2 files changed, 179 insertions(+), 24 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 15d025399691..08ea715ece5e 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -114,6 +114,7 @@ def __init__(self, model, subgraph, exp_tab): 'PAD': self.convert_pad, 'POW': self.convert_pow, 'PRELU': self.convert_prelu, + 'RANGE': self.convert_range, 'QUANTIZE': self.convert_quantize, 'REDUCE_ANY': self.convert_reduce_any, 'REDUCE_MAX': self.convert_reduce_max, @@ -126,6 +127,7 @@ def __init__(self, model, subgraph, exp_tab): 'ROUND': self.convert_round, 'RSQRT': self.convert_rsqrt, 'SELECT': self.convert_select, + 'SHAPE': self.convert_shape, 'SIN': self.convert_sin, 'SLICE': self.convert_slice, 'SOFTMAX': self.convert_softmax, @@ -609,6 +611,39 @@ def convert_tanh(self, op): return out + def convert_range(self, op): + """Convert TFLite Range""" + try: + from tflite.TensorType import TensorType + except ImportError: + raise ImportError("The tflite package must be installed") + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 3, "input tensors length should be 3" + + start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2] + + expressions = [self.get_tensor_expr(t) for t in [start, limit, delta]] + + # out type inference + if delta.tensor.Type() == TensorType.FLOAT32: + out_type = self.get_tensor_type_str(delta.tensor.Type()) + else: + out_type = self.get_tensor_type_str(start.tensor.Type()) + + out = _op.arange(expressions[0], expressions[1], expressions[2], out_type) + + return out + + def convert_shape(self, op): + """Convert TFLite Shape""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + out = _op.shape_of(self.get_tensor_expr(input_tensors[0])) + + return out + def convert_relu(self, op): """Convert TFLite ReLU""" input_tensors = self.get_input_tensors(op) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index d5dafd87a513..29515403c6b6 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -83,8 +83,34 @@ def get_real_image_object_detection(im_height, im_width): data = np.reshape(x, (1, im_height, im_width, 3)) return data +def vmobj_to_list(o): + if isinstance(o, tvm.nd.NDArray): + return [o.asnumpy().tolist()] + elif isinstance(o, tvm.runtime.container.ADT): + result = [] + for f in o: + result.extend(vmobj_to_list(f)) + return result + elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): + if o.constructor.name_hint == 'Cons': + tl = vmobj_to_list(o.fields[1]) + hd = vmobj_to_list(o.fields[0]) + hd.extend(tl) + return hd + elif o.constructor.name_hint == 'Nil': + return [] + elif 'tensor_nil' in o.constructor.name_hint: + return [0] + elif 'tensor' in o.constructor.name_hint: + return [o.fields[0].asnumpy()] + else: + raise RuntimeError("Unknown object type: %s" % + o.constructor.name_hint) + else: + raise RuntimeError("Unknown object type: %s" % type(o)) + def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm', - out_names=None): + out_names=None, mode='graph_runtime'): """ Generic function to compile on relay and execute on tvm """ # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 try: @@ -109,27 +135,43 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target shape_dict=shape_dict, dtype_dict=dtype_dict) - with tvm.transform.PassContext(opt_level=3): - graph, lib, params = relay.build(mod, target, params=params) - - ctx = tvm.context(target, 0) - from tvm.contrib import graph_runtime - m = graph_runtime.create(graph, lib, ctx) - # set inputs - for i, e in enumerate(input_node): - m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype))) - - m.set_input(**params) - # execute - m.run() - # get outputs - assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format( - out_names, num_output) - tvm_output_list = [] - for i in range(0, num_output): - tvm_output = m.get_output(i) - tvm_output_list.append(tvm_output.asnumpy()) - return tvm_output_list + if mode in ['debug', 'vm']: + ex = relay.create_executor(mode, mod=mod, ctx=tvm.cpu(), target="llvm") + inputs = [] + for param in mod['main'].params: + found = False + for i, n in enumerate(input_node): + if n == param.name_hint: + found = True + inputs.append(tvm.nd.array(input_data[i])) + break + # Interpreter doesn't bind constants, so still need to find in params + if not found: + inputs.append(tvm.nd.array(params[param.name_hint])) + result = ex.evaluate()(*inputs) + return vmobj_to_list(result) + else: + with tvm.transform.PassContext(opt_level=3): + graph, lib, params = relay.build(mod, target, params=params) + + ctx = tvm.context(target, 0) + from tvm.contrib import graph_runtime + m = graph_runtime.create(graph, lib, ctx) + # set inputs + for i, e in enumerate(input_node): + m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype))) + + m.set_input(**params) + # execute + m.run() + # get outputs + assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format( + out_names, num_output) + tvm_output_list = [] + for i in range(0, num_output): + tvm_output = m.get_output(i) + tvm_output_list.append(tvm_output.asnumpy()) + return tvm_output_list def run_tflite_graph(tflite_model_buf, input_data): @@ -160,7 +202,7 @@ def run_tflite_graph(tflite_model_buf, input_data): def compare_tflite_with_tvm(in_data, in_name, input_tensors, output_tensors, init_global_variables=False, - out_names=None, quantized=False, input_range=None): + out_names=None, quantized=False, input_range=None, mode='graph_runtime'): """Generic function to generate and compare TFLite and TVM output""" in_data = convert_to_list(in_data) in_name = convert_to_list(in_name) @@ -202,7 +244,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors, continue tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device, - num_output=len(out_names), out_names=out_names) + num_output=len(out_names), out_names=out_names, mode=mode) # WARNING: the results could well be random values clipped to 0 or 255 because of badly tuned output # range for the specific operator. While adding test ensure that we aren't getting only clipped values @@ -859,6 +901,80 @@ def test_all_resize(): if 'RESIZE_NEAREST_NEIGHBOR' in dir(BuiltinOperator()): _test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False) +####################################################################### +# Range +# ----- +def _test_range(start, limit, delta): + # tflite 1.13 convert method does not accept empty shapes + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + tf.reset_default_graph() + with tf.Graph().as_default(): + start_scalar, limit_scalar, delta_scalar = \ + tf.placeholder(dtype=start.dtype, shape=(), name="start"), \ + tf.placeholder(dtype=limit.dtype, shape=(), name="limit"), \ + tf.placeholder(dtype=delta.dtype, shape=(), name="delta") + + out = tf.range(start_scalar, limit_scalar, delta_scalar, name="range") + + compare_tflite_with_tvm( + [start, limit, delta], + ["start", "limit", "delta"], + [start_scalar, limit_scalar, delta_scalar], + [out], + mode="vm", + quantized=False + ) + +def _test_range_default(): + # tflite 1.13 convert method does not accept empty shapes + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + tf.reset_default_graph() + with tf.Graph().as_default(): + inputs = [ + tf.placeholder(dtype=tf.int32, shape=(), name="p1"), + tf.placeholder(dtype=tf.int32, shape=(), name="p2") + ] + outputs = [ + tf.range(start = inputs[0], limit = inputs[1]), # use default delta + tf.range(start = inputs[1]) # use start as limit with 0 as the first item in the range + ] + + compare_tflite_with_tvm( + [np.int32(1), np.int32(18)], + ["p1", "p2"], + inputs, + outputs, + mode="vm" + ) + +def test_forward_range(): + _test_range(np.int32(1), np.int32(18), np.int32(3)) + _test_range(np.int32(1), np.int32(18), np.float32(3.1)) # increment is of type float + _test_range(np.float32(1.0), np.int32(18), np.int32(3.1)) # start is of type float + _test_range_default() + +####################################################################### +# Shape +# ----- +def test_forward_shape(): + # tflite 1.13 convert method does not accept empty shapes + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + tf.reset_default_graph() + with tf.Graph().as_default(): + data = np.array([1, 18, 3], dtype=np.int32) + start = tf.placeholder(dtype=tf.int32, shape=[], name="start") + limit = tf.placeholder(dtype=tf.int32, shape=[], name="limit") + delta = tf.placeholder(dtype=tf.int32, shape=[], name="delta") + r = tf.range(start, limit, delta, tf.int32, name="range") + out = tf.shape(r, out_type=tf.dtypes.int32) + compare_tflite_with_tvm( + [x for x in np.nditer(data)], + ["start", "limit", "delta"], + [start, limit, delta], + [out], + mode="vm" + ) + ####################################################################### # Concatenation # ------------- @@ -2363,6 +2479,9 @@ def test_forward_mediapipe_hand_landmark(): # Tile test_forward_tile() + # Query + test_forward_shape() + # Transforms test_forward_concatenation() test_forward_pad() @@ -2370,6 +2489,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_unpack() test_forward_reshape() test_all_resize() + test_forward_range() test_forward_squeeze() test_forward_slice() test_forward_topk()