diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 013efae18c0a..18868cf8491c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -936,13 +936,16 @@ def _impl(inputs, input_types): data = inputs[0] axis = None keepdims = False + if len(inputs) > 2: # default, torch have only data, axis=None, keepdims=False if isinstance(inputs[1], int): axis = int(inputs[1]) else: axis = list(_infer_shape(inputs[1])) keepdims = bool(inputs[2]) + return get_relay_op(name)(data, axis=axis, keepdims=keepdims) + return _impl def _std(): @@ -951,11 +954,14 @@ def _impl(inputs, input_types): axis = list(_infer_shape(inputs[1])) keepdims = bool(inputs[3]) unbiased = bool(inputs[2]) + if unbiased: msg = "Currently only supports standard-deviation calculated via the biased "\ "estimator. Pytorch's Bessel's correction is not supported." raise NotImplementedError(msg) + return _op.reduce.std(data, axis=axis, keepdims=keepdims) + return _impl def _variance(): @@ -964,11 +970,14 @@ def _impl(inputs, input_types): axis = list(_infer_shape(inputs[1])) keepdims = bool(inputs[3]) unbiased = bool(inputs[2]) + if unbiased: msg = "Currently only supports standard-deviation calculated via the biased "\ "estimator. Pytorch's Bessel's correction is not supported." raise NotImplementedError(msg) + return _op.reduce.variance(data, axis=axis, keepdims=keepdims) + return _impl def _mean(): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 0c8c11188b0c..91e14c697f35 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1278,6 +1278,7 @@ def forward(self, xs): verify_script_model(RNNLoop().eval(), [(10, 10, 4)]) + def test_forward_reduce_sum(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1285,18 +1286,23 @@ def test_forward_reduce_sum(): class ReduceSum1(Module): def forward(self, *args): return args[0].sum(1) + class ReduceSum2(Module): def forward(self, *args): return args[0].sum(dim=1, keepdim=False) + class ReduceSum3(Module): def forward(self, *args): return args[0].sum(dim=2, keepdim=True) + class ReduceSum4(Module): def forward(self, *args): return args[0].sum(dim=(2,3), keepdim=True) + class ReduceSum5(Module): def forward(self, *args): return args[0].sum(dim=(2,3), keepdim=False) + input_data = torch.rand(input_shape).float() verify_model(ReduceSum1().float().eval(), input_data=input_data) verify_model(ReduceSum2().float().eval(), input_data=input_data) @@ -1304,6 +1310,7 @@ def forward(self, *args): verify_model(ReduceSum4().float().eval(), input_data=input_data) verify_model(ReduceSum5().float().eval(), input_data=input_data) + def test_forward_reduce_prod(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1311,17 +1318,21 @@ def test_forward_reduce_prod(): class ReduceProd1(Module): def forward(self, *args): return args[0].prod(1) + class ReduceProd2(Module): def forward(self, *args): return args[0].prod(dim=1, keepdim=False) + class ReduceProd3(Module): def forward(self, *args): return args[0].prod(dim=2, keepdim=True) + input_data = torch.rand(input_shape).float() verify_model(ReduceProd1().float().eval(), input_data=input_data) verify_model(ReduceProd2().float().eval(), input_data=input_data) verify_model(ReduceProd3().float().eval(), input_data=input_data) + def test_forward_argmin(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1329,17 +1340,21 @@ def test_forward_argmin(): class ArgMin1(Module): def forward(self, *args): return args[0].argmin(1) + class ArgMin2(Module): def forward(self, *args): return args[0].argmin(dim=1, keepdim=False) + class ArgMin3(Module): def forward(self, *args): return args[0].argmin(dim=2, keepdim=True) + input_data = torch.rand(input_shape).float() verify_model(ArgMin1().float().eval(), input_data=input_data) verify_model(ArgMin2().float().eval(), input_data=input_data) verify_model(ArgMin3().float().eval(), input_data=input_data) + def test_forward_argmax(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1347,17 +1362,21 @@ def test_forward_argmax(): class ArgMax1(Module): def forward(self, *args): return args[0].argmax(1) + class ArgMax2(Module): def forward(self, *args): return args[0].argmax(dim=1, keepdim=False) + class ArgMax3(Module): def forward(self, *args): return args[0].argmax(dim=2, keepdim=True) + input_data = torch.rand(input_shape).float() verify_model(ArgMax1().float().eval(), input_data=input_data) verify_model(ArgMax2().float().eval(), input_data=input_data) verify_model(ArgMax3().float().eval(), input_data=input_data) + def test_forward_std(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1365,18 +1384,23 @@ def test_forward_std(): class Std1(Module): def forward(self, *args): return args[0].std(1, unbiased=False) + class Std2(Module): def forward(self, *args): return args[0].std(dim=1, keepdim=False, unbiased=False) + class Std3(Module): def forward(self, *args): return args[0].std(dim=2, keepdim=True, unbiased=False) + class Std4(Module): def forward(self, *args): return args[0].std(dim=(2,3), keepdim=True, unbiased=False) + class Std5(Module): def forward(self, *args): return args[0].std(dim=(2,3), keepdim=False, unbiased=False) + input_data = torch.rand(input_shape).float() verify_model(Std1().float().eval(), input_data=input_data) verify_model(Std2().float().eval(), input_data=input_data) @@ -1384,31 +1408,37 @@ def forward(self, *args): verify_model(Std4().float().eval(), input_data=input_data) verify_model(Std5().float().eval(), input_data=input_data) -def test_forward_var(): + +def test_forward_variance(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] - class Var1(Module): + class Variance1(Module): def forward(self, *args): return args[0].var(1, unbiased=False) - class Var2(Module): + + class Variance2(Module): def forward(self, *args): return args[0].var(dim=1, keepdim=False, unbiased=False) - class Var3(Module): + + class Variance3(Module): def forward(self, *args): return args[0].var(dim=2, keepdim=True, unbiased=False) - class Var4(Module): + + class Variance4(Module): def forward(self, *args): return args[0].var(dim=(2,3), keepdim=True, unbiased=False) - class Var5(Module): + + class Variance5(Module): def forward(self, *args): return args[0].var(dim=(2,3), keepdim=False, unbiased=False) + input_data = torch.rand(input_shape).float() - verify_model(Var1().float().eval(), input_data=input_data) - verify_model(Var2().float().eval(), input_data=input_data) - verify_model(Var3().float().eval(), input_data=input_data) - verify_model(Var4().float().eval(), input_data=input_data) - verify_model(Var5().float().eval(), input_data=input_data) + verify_model(Variance1().float().eval(), input_data=input_data) + verify_model(Variance2().float().eval(), input_data=input_data) + verify_model(Variance3().float().eval(), input_data=input_data) + verify_model(Variance4().float().eval(), input_data=input_data) + verify_model(Variance5().float().eval(), input_data=input_data) if __name__ == "__main__": @@ -1428,7 +1458,7 @@ def forward(self, *args): test_forward_argmin() test_forward_argmax() test_forward_std() - test_forward_var() + test_forward_variance() test_forward_relu() test_forward_prelu() test_forward_leakyrelu()