Skip to content

Commit

Permalink
support torch.permute
Browse files Browse the repository at this point in the history
  • Loading branch information
mshr-h committed Jul 22, 2024
1 parent f22958e commit aa2e91a
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,11 @@ def _flatten(self, node: fx.node.Node) -> relax.Var:
return self.block_builder.emit(relax.op.reshape(x, new_shape))

def _permute(self, node: fx.node.Node) -> relax.Var:
import torch # type: ignore

args = self.retrieve_args(node)
if isinstance(args[1], (torch.Size, tuple, list)):
return self.block_builder.emit(relax.op.permute_dims(args[0], tuple(args[1])))
return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:]))

def _reshape(self, node: fx.node.Node) -> relax.Var:
Expand Down

0 comments on commit aa2e91a

Please sign in to comment.