From 0398f484a0be66185e50e118f118a70d9fa58da0 Mon Sep 17 00:00:00 2001 From: HolyWu Date: Sun, 17 Mar 2024 17:41:26 +0800 Subject: [PATCH] Add support for aten.pixel_unshuffle dynamo converter --- .../dynamo/conversion/aten_ops_converters.py | 23 ++++++++++ .../dynamo/conversion/impl/shuffle.py | 44 +++++++++++++++++++ .../conversion/test_pixel_unshuffle_aten.py | 29 ++++++++++++ 3 files changed, 96 insertions(+) create mode 100644 tests/py/dynamo/conversion/test_pixel_unshuffle_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 0dd153d0aa..2f37b0d84b 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2319,6 +2319,29 @@ def aten_ops_pixel_shuffle( ) +@dynamo_tensorrt_converter(torch.ops.aten.pixel_unshuffle.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_pixel_unshuffle( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.shuffle.pixel_unshuffle( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + @enforce_tensor_types({0: (TRTTensor,)}) @dynamo_tensorrt_converter(torch.ops.aten.argmax.default) def aten_ops_argmax( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index 49ddb76e2c..1d6dd7396f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -60,3 +60,47 @@ def pixel_shuffle( permuted_tensor, shape[:-3] + (out_channels, out_height, out_width), ) + + +def pixel_unshuffle( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + downscale_factor: int, +) -> TRTTensor: + shape = input.shape + in_channels, in_height, in_width = shape[-3:] + out_channels = in_channels * (downscale_factor**2) + out_height = in_height // downscale_factor + out_width = in_width // downscale_factor + new_shape = shape[:-3] + ( + in_channels, + out_height, + downscale_factor, + out_width, + downscale_factor, + ) + reshaped_tensor = reshape( + ctx, target, source_ir, f"{name}_reshape1", input, new_shape + ) + rank = len(new_shape) + permute_shape = tuple(range(rank - 5)) + ( + rank - 5, # in_channels + rank - 3, # downscale_factor + rank - 1, # downscale_factor + rank - 4, # out_height + rank - 2, # out_width + ) + permuted_tensor = impl.permutation.permute( + ctx, target, source_ir, f"{name}_permute", reshaped_tensor, permute_shape + ) + return reshape( + ctx, + target, + source_ir, + f"{name}_reshape2", + permuted_tensor, + shape[:-3] + (out_channels, out_height, out_width), + ) diff --git a/tests/py/dynamo/conversion/test_pixel_unshuffle_aten.py b/tests/py/dynamo/conversion/test_pixel_unshuffle_aten.py new file mode 100644 index 0000000000..fb93e68499 --- /dev/null +++ b/tests/py/dynamo/conversion/test_pixel_unshuffle_aten.py @@ -0,0 +1,29 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestPixelUnshuffleConverter(DispatchTestCase): + @parameterized.expand( + [ + ((1, 1, 1), 1), + ((1, 1, 12, 12), 3), + ((2, 3, 4, 25, 30), 5), + ] + ) + def test_pixel_unshuffle(self, shape, downscale_factor): + class PixelUnshuffle(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.pixel_unshuffle.default(x, downscale_factor) + + inputs = [torch.randn(shape)] + self.run_test( + PixelUnshuffle(), + inputs, + ) + + +if __name__ == "__main__": + run_tests()