diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 239d72055bff..c38191b389c9 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -33,7 +33,7 @@ from ..backend.name_transforms import sanitize_name from .common import ExprTable from .common import infer_shape as _infer_shape -from .common import lstm_cell, to_int_list, shape_of +from .common import lstm_cell, to_int_list, shape_of, try_infer_value from .tflite_flexbuffer import FlexBufferDecoder __all__ = ["from_tflite"] @@ -599,7 +599,21 @@ def convert_reshape(self, op): if len(input_tensors) == 2: shape_tensor = input_tensors[1] if self.has_expr(shape_tensor.tensor_idx): - target_shape = self.get_expr(shape_tensor.tensor_idx) + target_expr = self.get_expr(shape_tensor.tensor_idx) + target_value, success = try_infer_value( + target_expr, + parameters={k: _nd.array(np.array(v)) for k, v in self.exp_tab.params.items()}, + ) + if success: + # convert to flattened list + from itertools import chain + + try: + target_shape = list(chain(*target_value)) + except TypeError: + target_shape = list(chain(target_value)) + else: + target_shape = target_expr else: target_shape = self.get_tensor_value(shape_tensor) # convert to flattened list