From e37593338d16d0021c783ded4d9a37c4d6cdd8ca Mon Sep 17 00:00:00 2001 From: blackkker <823036806@qq.com> Date: Fri, 5 Aug 2022 03:52:07 +0000 Subject: [PATCH] Infer the value of shape expr to avoid dynamic --- python/tvm/relay/frontend/tflite.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 239d72055bff..c38191b389c9 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -33,7 +33,7 @@ from ..backend.name_transforms import sanitize_name from .common import ExprTable from .common import infer_shape as _infer_shape -from .common import lstm_cell, to_int_list, shape_of +from .common import lstm_cell, to_int_list, shape_of, try_infer_value from .tflite_flexbuffer import FlexBufferDecoder __all__ = ["from_tflite"] @@ -599,7 +599,21 @@ def convert_reshape(self, op): if len(input_tensors) == 2: shape_tensor = input_tensors[1] if self.has_expr(shape_tensor.tensor_idx): - target_shape = self.get_expr(shape_tensor.tensor_idx) + target_expr = self.get_expr(shape_tensor.tensor_idx) + target_value, success = try_infer_value( + target_expr, + parameters={k: _nd.array(np.array(v)) for k, v in self.exp_tab.params.items()}, + ) + if success: + # convert to flattened list + from itertools import chain + + try: + target_shape = list(chain(*target_value)) + except TypeError: + target_shape = list(chain(target_value)) + else: + target_shape = target_expr else: target_shape = self.get_tensor_value(shape_tensor) # convert to flattened list