From 73e9d5abd9741c213876157843a7a4ba8703d697 Mon Sep 17 00:00:00 2001 From: shiwenloong Date: Wed, 5 Aug 2020 17:02:21 +0800 Subject: [PATCH] [PYTORCH]Std op without specified dimensions support --- python/tvm/relay/frontend/pytorch.py | 11 ++++++++--- tests/python/frontend/pytorch/test_forward.py | 5 +++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 3dfdb2f70e7f..bbc684ea8a4c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1253,9 +1253,14 @@ def _impl(inputs, input_types): def _std(): def _impl(inputs, input_types): data = inputs[0] - axis = list(_infer_shape(inputs[1])) - keepdims = bool(inputs[3]) - unbiased = bool(inputs[2]) + if len(inputs) == 2: + axis = None + keepdims = False + unbiased = bool(inputs[1]) + else: + axis = list(_infer_shape(inputs[1])) + keepdims = bool(inputs[3]) + unbiased = bool(inputs[2]) if unbiased: msg = "Currently only supports standard-deviation calculated via the biased "\ diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e370cd502b59..3c9dfb13fc4c 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1869,12 +1869,17 @@ class Std5(Module): def forward(self, *args): return args[0].std(dim=(2,3), keepdim=False, unbiased=False) + class Std6(Module): + def forward(self, *args): + return args[0].std(unbiased=False) + input_data = torch.rand(input_shape).float() verify_model(Std1().float().eval(), input_data=input_data) verify_model(Std2().float().eval(), input_data=input_data) verify_model(Std3().float().eval(), input_data=input_data) verify_model(Std4().float().eval(), input_data=input_data) verify_model(Std5().float().eval(), input_data=input_data) + verify_model(Std6().float().eval(), input_data=input_data) def test_forward_variance():