From 16d0d56458c30a947bf32a87ac0bb220d66bcb26 Mon Sep 17 00:00:00 2001 From: Nikolay Nez <34389970+n-nez@users.noreply.github.com> Date: Mon, 27 Apr 2020 20:34:58 +0900 Subject: [PATCH] [Pytorch] fix translation of transpose when axis argument is as a list (#5451) --- python/tvm/relay/frontend/pytorch.py | 2 +- tests/python/frontend/pytorch/test_forward.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index e185f5817e87..64f30f35b376 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -923,7 +923,7 @@ def _impl(inputs, input_types): axes[src] = dst axes[dst] = src else: - axes = inputs[1] + axes = _infer_shape(inputs[1], prelude.mod) return _op.transform.transpose(data, axes) return _impl diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 20f1185c879c..a53f3540ef29 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -767,9 +767,14 @@ class Transpose2(Module): def forward(self, *args): return args[0].transpose(-2, -1) + class Transpose3(Module): + def forward(self, *args): + return args[0].permute(0,2,3,1) + input_data = torch.rand(input_shape).float() verify_model(Transpose1().float().eval(), input_data=input_data) verify_model(Transpose2().float().eval(), input_data=input_data) + verify_model(Transpose3().float().eval(), input_data=input_data) def test_forward_size(): torch.set_grad_enabled(False)