Skip to content

Commit

Permalink
Infer the value of shape expr to avoid dynamic
Browse files Browse the repository at this point in the history
  • Loading branch information
blackkker committed Aug 5, 2022
1 parent 485bfaf commit e375933
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ..backend.name_transforms import sanitize_name
from .common import ExprTable
from .common import infer_shape as _infer_shape
from .common import lstm_cell, to_int_list, shape_of
from .common import lstm_cell, to_int_list, shape_of, try_infer_value
from .tflite_flexbuffer import FlexBufferDecoder

__all__ = ["from_tflite"]
Expand Down Expand Up @@ -599,7 +599,21 @@ def convert_reshape(self, op):
if len(input_tensors) == 2:
shape_tensor = input_tensors[1]
if self.has_expr(shape_tensor.tensor_idx):
target_shape = self.get_expr(shape_tensor.tensor_idx)
target_expr = self.get_expr(shape_tensor.tensor_idx)
target_value, success = try_infer_value(
target_expr,
parameters={k: _nd.array(np.array(v)) for k, v in self.exp_tab.params.items()},
)
if success:
# convert to flattened list
from itertools import chain

try:
target_shape = list(chain(*target_value))
except TypeError:
target_shape = list(chain(target_value))
else:
target_shape = target_expr
else:
target_shape = self.get_tensor_value(shape_tensor)
# convert to flattened list
Expand Down

0 comments on commit e375933

Please sign in to comment.