diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 5ed0f18deb9e..f9a5d9c33f02 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -550,7 +550,11 @@ def _flatten(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.reshape(x, new_shape)) def _permute(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.permute_dims(args[0], tuple(args[1]))) return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:])) def _reshape(self, node: fx.node.Node) -> relax.Var: