diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 3dfdb2f70e7f..bbc684ea8a4c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1253,9 +1253,14 @@ def _impl(inputs, input_types): def _std(): def _impl(inputs, input_types): data = inputs[0] - axis = list(_infer_shape(inputs[1])) - keepdims = bool(inputs[3]) - unbiased = bool(inputs[2]) + if len(inputs) == 2: + axis = None + keepdims = False + unbiased = bool(inputs[1]) + else: + axis = list(_infer_shape(inputs[1])) + keepdims = bool(inputs[3]) + unbiased = bool(inputs[2]) if unbiased: msg = "Currently only supports standard-deviation calculated via the biased "\ diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e370cd502b59..3c9dfb13fc4c 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1869,12 +1869,17 @@ class Std5(Module): def forward(self, *args): return args[0].std(dim=(2,3), keepdim=False, unbiased=False) + class Std6(Module): + def forward(self, *args): + return args[0].std(unbiased=False) + input_data = torch.rand(input_shape).float() verify_model(Std1().float().eval(), input_data=input_data) verify_model(Std2().float().eval(), input_data=input_data) verify_model(Std3().float().eval(), input_data=input_data) verify_model(Std4().float().eval(), input_data=input_data) verify_model(Std5().float().eval(), input_data=input_data) + verify_model(Std6().float().eval(), input_data=input_data) def test_forward_variance():