From e58e3c2c30f7cf6cc0489b87ef8ad982b917fbff Mon Sep 17 00:00:00 2001 From: Dmitriy Smirnov Date: Tue, 4 Aug 2020 14:19:33 +0100 Subject: [PATCH] Reshape with dynamic shape arg Reshape operation updated to take shape from second operand. In case if shape is provided using second operand it can be a tensor now. --- python/tvm/relay/frontend/tflite.py | 35 +++++++++++++++----- tests/python/frontend/tflite/test_forward.py | 29 ++++++++++++---- 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 6e032b1efda8f..d36bda35238fd 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -459,26 +459,43 @@ def convert_reshape(self, op): raise ImportError("The tflite package must be installed") input_tensors = self.get_input_tensors(op) - assert input_tensors, "input tensors should not be empty" + assert len(input_tensors) in (1, 2), "input tensors should not be empty" + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "There should be only 1 output tensor" + input_tensor = input_tensors[0] input_tensor_idx = input_tensor.tensor_idx - assert op.BuiltinOptionsType() == BuiltinOptions.ReshapeOptions - op_options = op.BuiltinOptions() - reshape_options = ReshapeOptions() - reshape_options.Init(op_options.Bytes, op_options.Pos) - target_shape = reshape_options.NewShapeAsNumpy() + if len(input_tensors) == 2: + shape_tensor = input_tensors[1] + if self.has_expr(shape_tensor.tensor_idx): + target_shape = self.get_expr(shape_tensor.tensor_idx) + else: + target_shape = self.get_tensor_value(shape_tensor) + # convert to flattened list + from itertools import chain; + try: + target_shape = list(chain(*target_shape)) + except TypeError: + target_shape = list(chain(target_shape)) + + else: + assert op.BuiltinOptionsType() == BuiltinOptions.ReshapeOptions + op_options = op.BuiltinOptions() + reshape_options = ReshapeOptions() + reshape_options.Init(op_options.Bytes, op_options.Pos) + target_shape = tuple(reshape_options.NewShapeAsNumpy()) in_expr = self.get_expr(input_tensor_idx) # If the tensors are quantized, ensure that input/output qnn params are same. if input_tensor.qnn_params: - output_tensors = self.get_output_tensors(op) - assert len(output_tensors) == 1, "There should be only 1 output tensor" output_tensor = output_tensors[0] assert self.has_same_qnn_params(input_tensor, output_tensor), \ "TFLite reshape requires input and output scale and zero points to be equal" - out = _op.reshape(in_expr, newshape=tuple(target_shape)) + + out = _op.reshape(in_expr, newshape=target_shape) return out def _convert_resize(self, method, op): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 603eb11696241..30a663176a3f2 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -984,20 +984,35 @@ def test_forward_transpose_conv(): # Reshape # ------- -def _test_reshape(data, out_shape): +def _test_reshape(data, out_shape, wrap_shape): """ One iteration of reshape operation with given data and out shape """ with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - out = array_ops.reshape(in_data, out_shape) - compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + out_shape = out_shape if not wrap_shape\ + else np.array(out_shape, dtype=np.int32) + + in_shape = out_shape if not wrap_shape\ + else array_ops.placeholder(shape=out_shape.shape,\ + dtype=out_shape.dtype,\ + name="Newshape") + + out = array_ops.reshape(in_data, in_shape) + + compare_tflite_with_tvm( + [data, out_shape] if wrap_shape else [data],\ + ['Placeholder:0', 'Newshape:0'] if wrap_shape else ['Placeholder:0'],\ + [in_data, in_shape] if wrap_shape else [in_data],\ + [out], + mode='vm') def test_forward_reshape(): - _test_reshape(np.arange(6.0, dtype=np.float32), [2, 3]) - _test_reshape(np.arange(6), [-1, 2]) - _test_reshape(np.arange(6), [3, -1]) - _test_reshape(np.arange(6), [-1]) + for wrap in [True, False]: + _test_reshape(np.arange(6.0, dtype=np.float32), [2, 3], wrap) + _test_reshape(np.arange(6), [-1, 2], wrap) + _test_reshape(np.arange(6), [3, -1], wrap) + _test_reshape(np.arange(6), [-1], wrap) #######################################################################