Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for aten.pixel_unshuffle dynamo converter #2696

Merged
merged 1 commit into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2319,6 +2319,29 @@ def aten_ops_pixel_shuffle(
)


@dynamo_tensorrt_converter(torch.ops.aten.pixel_unshuffle.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_pixel_unshuffle(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.shuffle.pixel_unshuffle(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@enforce_tensor_types({0: (TRTTensor,)})
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default)
def aten_ops_argmax(
Expand Down
44 changes: 44 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,47 @@ def pixel_shuffle(
permuted_tensor,
shape[:-3] + (out_channels, out_height, out_width),
)


def pixel_unshuffle(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
downscale_factor: int,
) -> TRTTensor:
shape = input.shape
in_channels, in_height, in_width = shape[-3:]
out_channels = in_channels * (downscale_factor**2)
out_height = in_height // downscale_factor
out_width = in_width // downscale_factor
new_shape = shape[:-3] + (
in_channels,
out_height,
downscale_factor,
out_width,
downscale_factor,
)
reshaped_tensor = reshape(
ctx, target, source_ir, f"{name}_reshape1", input, new_shape
)
rank = len(new_shape)
permute_shape = tuple(range(rank - 5)) + (
rank - 5, # in_channels
rank - 3, # downscale_factor
rank - 1, # downscale_factor
rank - 4, # out_height
rank - 2, # out_width
)
permuted_tensor = impl.permutation.permute(
ctx, target, source_ir, f"{name}_permute", reshaped_tensor, permute_shape
)
return reshape(
ctx,
target,
source_ir,
f"{name}_reshape2",
permuted_tensor,
shape[:-3] + (out_channels, out_height, out_width),
)
29 changes: 29 additions & 0 deletions tests/py/dynamo/conversion/test_pixel_unshuffle_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestPixelUnshuffleConverter(DispatchTestCase):
@parameterized.expand(
[
((1, 1, 1), 1),
((1, 1, 12, 12), 3),
((2, 3, 4, 25, 30), 5),
]
)
def test_pixel_unshuffle(self, shape, downscale_factor):
class PixelUnshuffle(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.pixel_unshuffle.default(x, downscale_factor)

inputs = [torch.randn(shape)]
self.run_test(
PixelUnshuffle(),
inputs,
)


if __name__ == "__main__":
run_tests()
Loading