From e9c68b282135543fd76398d7f0d744cfc60993f9 Mon Sep 17 00:00:00 2001 From: shiwenloong <52487098+shiwenloong@users.noreply.github.com> Date: Fri, 7 Aug 2020 08:55:46 +0800 Subject: [PATCH] [PYTORCH]Std op without specified dimensions support (#6226) --- python/tvm/relay/frontend/pytorch.py | 11 ++++++++--- tests/python/frontend/pytorch/test_forward.py | 5 +++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 3dfdb2f70e7f8..bbc684ea8a4cf 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1253,9 +1253,14 @@ def _impl(inputs, input_types): def _std(): def _impl(inputs, input_types): data = inputs[0] - axis = list(_infer_shape(inputs[1])) - keepdims = bool(inputs[3]) - unbiased = bool(inputs[2]) + if len(inputs) == 2: + axis = None + keepdims = False + unbiased = bool(inputs[1]) + else: + axis = list(_infer_shape(inputs[1])) + keepdims = bool(inputs[3]) + unbiased = bool(inputs[2]) if unbiased: msg = "Currently only supports standard-deviation calculated via the biased "\ diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e370cd502b592..3c9dfb13fc4ce 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1869,12 +1869,17 @@ class Std5(Module): def forward(self, *args): return args[0].std(dim=(2,3), keepdim=False, unbiased=False) + class Std6(Module): + def forward(self, *args): + return args[0].std(unbiased=False) + input_data = torch.rand(input_shape).float() verify_model(Std1().float().eval(), input_data=input_data) verify_model(Std2().float().eval(), input_data=input_data) verify_model(Std3().float().eval(), input_data=input_data) verify_model(Std4().float().eval(), input_data=input_data) verify_model(Std5().float().eval(), input_data=input_data) + verify_model(Std6().float().eval(), input_data=input_data) def test_forward_variance():