From 5ff73ee40d533f89bf78b8415f3d681fc66a1f06 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Wed, 8 Apr 2020 21:00:16 +0530 Subject: [PATCH] [PYTORCH]Repeat, Reciprocal & Reshape Op support --- python/tvm/relay/frontend/pytorch.py | 42 +++++++++++ tests/python/frontend/pytorch/test_forward.py | 75 +++++++++++++++++++ 2 files changed, 117 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 46068a4e24ed..b8b32e7d8925 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -154,6 +154,34 @@ def _impl(inputs, input_types): return _op.transform.take(data, index, axis=dim) return _impl +def _reciprocal(): + def _impl(inputs, input_types): + data = inputs[0] + return _expr.const(1.0) / data + return _impl + +def _repeat(): + def _impl(inputs, input_types): + data = inputs[0] + reps = _get_dims(inputs[1]) + return _op.transform.tile(data, reps=reps) + return _impl + +def _repeat_interleave(): + def _impl(inputs, input_types): + data = inputs[0] + if isinstance(inputs[1], int): + repeats = inputs[1] + axis = inputs[2] + else: + msg = "Only repeat with one value as repeat is currently supported." + raise AssertionError(msg) + if axis is None: # Flatten the data if no axis is given from torch + data = _op.transform.reshape(data, [-1]) + axis = 0 + return _op.transform.repeat(data, repeats=repeats, axis=axis) + return _impl + def _ones(): def _impl(inputs, input_types): data = inputs[0] @@ -675,6 +703,16 @@ def _impl(inputs, input_types): return _op.transform.reshape(data, new_shape) return _impl +def _reshape(): + def _impl(inputs, input_types): + data = inputs[0] + if isinstance(inputs[1], list): + new_shape = inputs[1] + else: + new_shape = _infer_shape(inputs[1]) + return _op.transform.reshape(data, new_shape) + return _impl + def _clone(): def _impl(inputs, input_types): data = inputs[0] @@ -1082,6 +1120,9 @@ def _wrap_const(c): "aten::div_" : _elemwise("divide"), "aten::ones" : _ones(), "aten::zeros" : _zeros(), + "aten::reciprocal" : _reciprocal(), + "aten::repeat" : _repeat(), + "aten::repeat_interleave" : _repeat_interleave(), "aten::to" : _to(), "aten::squeeze" : _squeeze(), "aten::unsqueeze" : _unsqueeze(), @@ -1122,6 +1163,7 @@ def _wrap_const(c): "aten::addmm" : _dense(), "aten::size" : _size(), "aten::view" : _view(), + "aten::reshape" : _reshape(), "aten::clone" : _clone(), "aten::log_softmax" : _log_softmax(), "aten::sigmoid" : _sigmoid(), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 05bf7e460890..4226463e9527 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -293,6 +293,61 @@ def forward(self, *args): verify_model(Multiply3().float().eval(), input_data=input_data) verify_model(Multiply4().float().eval(), input_data=input_data) +def test_forward_reciprocal(): + torch.set_grad_enabled(False) + input_shape = [2, 1, 10, 1, 10] + class Reciprocal1(Module): + def forward(self, *args): + return args[0].reciprocal() + + input_data = torch.rand(input_shape).float() + verify_model(Reciprocal1().float().eval(), input_data=input_data) + +def test_forward_repeat(): + torch.set_grad_enabled(False) + input_shape = [1, 3] + class Repeat1(Module): + def forward(self, *args): + return args[0].repeat(1, 1) + + class Repeat2(Module): + def forward(self, *args): + return args[0].repeat(4, 2) + + class Repeat3(Module): + def forward(self, *args): + return args[0].repeat(4, 2, 1) + + input_data = torch.rand(input_shape).float() + verify_model(Repeat1().float().eval(), input_data=input_data) + verify_model(Repeat2().float().eval(), input_data=input_data) + verify_model(Repeat3().float().eval(), input_data=input_data) + +def test_forward_repeat_interleave(): + torch.set_grad_enabled(False) + input_shape = [2, 2, 3] + class RepeatInterleave1(Module): + def forward(self, *args): + return args[0].repeat_interleave(2) + + class RepeatInterleave2(Module): + def forward(self, *args): + return args[0].repeat_interleave(3, dim=0) + + class RepeatInterleave3(Module): + def forward(self, *args): + return args[0].repeat_interleave(2, dim=1) + + class RepeatInterleave4(Module): + def forward(self, *args): + return args[0].repeat_interleave(4, dim=2) + + input_data = torch.rand(input_shape).float() + verify_model(RepeatInterleave1().float().eval(), input_data=input_data) + verify_model(RepeatInterleave2().float().eval(), input_data=input_data) + verify_model(RepeatInterleave3().float().eval(), input_data=input_data) + verify_model(RepeatInterleave4().float().eval(), input_data=input_data) + def test_forward_unsqueeze(): torch.set_grad_enabled(False) input_shape = [10, 10] @@ -600,6 +655,22 @@ def init_weight(m): init_weight(ln.eval()) verify_model(ln.eval(), input_data=inp) +def test_forward_reshape(): + torch.set_grad_enabled(False) + input_shape = [2, 1, 10, 1, 10] + new_shape = [2, 1, 10, 10] + class Reshape1(Module): + def forward(self, *args): + return args[0].reshape(new_shape) + + class Reshape2(Module): + def forward(self, *args): + return args[0].reshape([-1]) + + input_data = torch.rand(input_shape).float() + verify_model(Reshape1().float().eval(), input_data=input_data) + verify_model(Reshape2().float().eval(), input_data=input_data) + def test_forward_transpose(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1151,6 +1222,10 @@ def forward(self, xs): test_forward_add() test_forward_subtract() test_forward_multiply() + test_forward_reshape() + test_forward_reciprocal() + test_forward_repeat() + test_forward_repeat_interleave() test_forward_squeeze() test_forward_unsqueeze() test_forward_concatenate()