diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index efb95f97d2b4..46b5cecbae7f 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1577,6 +1577,22 @@ def _impl(inputs, input_types): return _impl +def _reflection_pad2d(): + def _impl(inputs, input_types): + if isinstance(inputs[1], list): + pad_list = inputs[1] + else: + pad_list = list(_infer_shape(inputs[1])) + padding_left = pad_list[0] + padding_right = pad_list[1] + padding_top = pad_list[2] + padding_bottom = pad_list[3] + paddings = [[0, 0], [0, 0], [padding_top, padding_bottom], [padding_left, padding_right]] + + return _op.nn.mirror_pad(inputs[0], paddings, mode='REFLECT') + return _impl + + # Helper functions for operator implementation def _convert_dtype_value(val): convert_torch_dtype_map = {7:"torch.float64", @@ -1695,6 +1711,7 @@ def _get_convert_map(prelude): "aten::prelu" : _prelu(), "aten::leaky_relu" : _leaky_relu(), "aten::elu" : _elu(), + "aten::elu_" : _elu(), "aten::celu" : _celu(), "aten::gelu" : _gelu(), "aten::selu" : _selu(), @@ -1798,6 +1815,7 @@ def _get_convert_map(prelude): "aten::embedding" : _embedding(), "aten::one_hot" : _one_hot(), "aten::mm" : _matmul(prelude), + "aten::reflection_pad2d" : _reflection_pad2d(), "relay::tensor_array_stack" : _tensor_array_stack(prelude), "aten::add" : _add(prelude), "aten::add_" : _add(prelude), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 30036db7a77c..50c3ede94ad9 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1020,6 +1020,15 @@ def test_adaptive_pool3d(): verify_model(torch.nn.AdaptiveMaxPool3d((7, 8, 9)).eval(), inp) +def test_forward_reflection_pad2d(): + inp = torch.rand((1, 1, 3, 3)) + verify_model(torch.nn.ReflectionPad2d(2).eval(), inp) + verify_model(torch.nn.ReflectionPad2d((1, 1, 2, 0)).eval(), inp) + + inp = torch.rand((2, 4, 5, 6)) + verify_model(torch.nn.ReflectionPad2d((1, 3, 2, 4)).eval(), inp) + + def test_conv3d(): for ishape in [(1, 32, 16, 16, 16), (1, 32, 9, 15, 15), @@ -2183,6 +2192,7 @@ def forward(self, *args): test_forward_split() test_upsample() test_to() + test_forward_reflection_pad2d() test_adaptive_pool3d() test_conv3d()