From 430f3b3a183de76ec141ba0591d359c0f512ed02 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Sat, 11 Apr 2020 20:28:10 +0530 Subject: [PATCH] [PYTORCH]Reduce_ops support added --- python/tvm/relay/frontend/pytorch.py | 40 ++++- tests/python/frontend/pytorch/test_forward.py | 138 ++++++++++++++++++ 2 files changed, 177 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 506f6ba3ceb7..013efae18c0a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -934,7 +934,41 @@ def _impl(inputs, input_types): def _reduce(name): def _impl(inputs, input_types): data = inputs[0] - return get_relay_op(name)(data) + axis = None + keepdims = False + if len(inputs) > 2: # default, torch have only data, axis=None, keepdims=False + if isinstance(inputs[1], int): + axis = int(inputs[1]) + else: + axis = list(_infer_shape(inputs[1])) + keepdims = bool(inputs[2]) + return get_relay_op(name)(data, axis=axis, keepdims=keepdims) + return _impl + +def _std(): + def _impl(inputs, input_types): + data = inputs[0] + axis = list(_infer_shape(inputs[1])) + keepdims = bool(inputs[3]) + unbiased = bool(inputs[2]) + if unbiased: + msg = "Currently only supports standard-deviation calculated via the biased "\ + "estimator. Pytorch's Bessel's correction is not supported." + raise NotImplementedError(msg) + return _op.reduce.std(data, axis=axis, keepdims=keepdims) + return _impl + +def _variance(): + def _impl(inputs, input_types): + data = inputs[0] + axis = list(_infer_shape(inputs[1])) + keepdims = bool(inputs[3]) + unbiased = bool(inputs[2]) + if unbiased: + msg = "Currently only supports standard-deviation calculated via the biased "\ + "estimator. Pytorch's Bessel's correction is not supported." + raise NotImplementedError(msg) + return _op.reduce.variance(data, axis=axis, keepdims=keepdims) return _impl def _mean(): @@ -1381,6 +1415,10 @@ def _get_convert_map(prelude): "aten::permute" : _transpose(prelude), "aten::sum" : _reduce("sum"), "aten::prod" : _reduce("prod"), + "aten::argmin" : _reduce("argmin"), + "aten::argmax" : _reduce("argmax"), + "aten::std" : _std(), + "aten::var" : _variance(), "aten::sqrt" : _sqrt(), 'aten::floor' : _floor(), "aten::detach" : _identity(), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 8e9928510220..0c8c11188b0c 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1278,6 +1278,138 @@ def forward(self, xs): verify_script_model(RNNLoop().eval(), [(10, 10, 4)]) +def test_forward_reduce_sum(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class ReduceSum1(Module): + def forward(self, *args): + return args[0].sum(1) + class ReduceSum2(Module): + def forward(self, *args): + return args[0].sum(dim=1, keepdim=False) + class ReduceSum3(Module): + def forward(self, *args): + return args[0].sum(dim=2, keepdim=True) + class ReduceSum4(Module): + def forward(self, *args): + return args[0].sum(dim=(2,3), keepdim=True) + class ReduceSum5(Module): + def forward(self, *args): + return args[0].sum(dim=(2,3), keepdim=False) + input_data = torch.rand(input_shape).float() + verify_model(ReduceSum1().float().eval(), input_data=input_data) + verify_model(ReduceSum2().float().eval(), input_data=input_data) + verify_model(ReduceSum3().float().eval(), input_data=input_data) + verify_model(ReduceSum4().float().eval(), input_data=input_data) + verify_model(ReduceSum5().float().eval(), input_data=input_data) + +def test_forward_reduce_prod(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class ReduceProd1(Module): + def forward(self, *args): + return args[0].prod(1) + class ReduceProd2(Module): + def forward(self, *args): + return args[0].prod(dim=1, keepdim=False) + class ReduceProd3(Module): + def forward(self, *args): + return args[0].prod(dim=2, keepdim=True) + input_data = torch.rand(input_shape).float() + verify_model(ReduceProd1().float().eval(), input_data=input_data) + verify_model(ReduceProd2().float().eval(), input_data=input_data) + verify_model(ReduceProd3().float().eval(), input_data=input_data) + +def test_forward_argmin(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class ArgMin1(Module): + def forward(self, *args): + return args[0].argmin(1) + class ArgMin2(Module): + def forward(self, *args): + return args[0].argmin(dim=1, keepdim=False) + class ArgMin3(Module): + def forward(self, *args): + return args[0].argmin(dim=2, keepdim=True) + input_data = torch.rand(input_shape).float() + verify_model(ArgMin1().float().eval(), input_data=input_data) + verify_model(ArgMin2().float().eval(), input_data=input_data) + verify_model(ArgMin3().float().eval(), input_data=input_data) + +def test_forward_argmax(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class ArgMax1(Module): + def forward(self, *args): + return args[0].argmax(1) + class ArgMax2(Module): + def forward(self, *args): + return args[0].argmax(dim=1, keepdim=False) + class ArgMax3(Module): + def forward(self, *args): + return args[0].argmax(dim=2, keepdim=True) + input_data = torch.rand(input_shape).float() + verify_model(ArgMax1().float().eval(), input_data=input_data) + verify_model(ArgMax2().float().eval(), input_data=input_data) + verify_model(ArgMax3().float().eval(), input_data=input_data) + +def test_forward_std(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class Std1(Module): + def forward(self, *args): + return args[0].std(1, unbiased=False) + class Std2(Module): + def forward(self, *args): + return args[0].std(dim=1, keepdim=False, unbiased=False) + class Std3(Module): + def forward(self, *args): + return args[0].std(dim=2, keepdim=True, unbiased=False) + class Std4(Module): + def forward(self, *args): + return args[0].std(dim=(2,3), keepdim=True, unbiased=False) + class Std5(Module): + def forward(self, *args): + return args[0].std(dim=(2,3), keepdim=False, unbiased=False) + input_data = torch.rand(input_shape).float() + verify_model(Std1().float().eval(), input_data=input_data) + verify_model(Std2().float().eval(), input_data=input_data) + verify_model(Std3().float().eval(), input_data=input_data) + verify_model(Std4().float().eval(), input_data=input_data) + verify_model(Std5().float().eval(), input_data=input_data) + +def test_forward_var(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class Var1(Module): + def forward(self, *args): + return args[0].var(1, unbiased=False) + class Var2(Module): + def forward(self, *args): + return args[0].var(dim=1, keepdim=False, unbiased=False) + class Var3(Module): + def forward(self, *args): + return args[0].var(dim=2, keepdim=True, unbiased=False) + class Var4(Module): + def forward(self, *args): + return args[0].var(dim=(2,3), keepdim=True, unbiased=False) + class Var5(Module): + def forward(self, *args): + return args[0].var(dim=(2,3), keepdim=False, unbiased=False) + input_data = torch.rand(input_shape).float() + verify_model(Var1().float().eval(), input_data=input_data) + verify_model(Var2().float().eval(), input_data=input_data) + verify_model(Var3().float().eval(), input_data=input_data) + verify_model(Var4().float().eval(), input_data=input_data) + verify_model(Var5().float().eval(), input_data=input_data) + if __name__ == "__main__": # Single operator tests @@ -1291,6 +1423,12 @@ def forward(self, xs): test_forward_squeeze() test_forward_unsqueeze() test_forward_concatenate() + test_forward_reduce_sum() + test_forward_reduce_prod() + test_forward_argmin() + test_forward_argmax() + test_forward_std() + test_forward_var() test_forward_relu() test_forward_prelu() test_forward_leakyrelu()