diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index c50fa57400..92f4b7e18e 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -794,7 +794,7 @@ def aten_ops_scatter( ) -@dynamo_tensorrt_converter(torch.ops.aten.select.int) +@dynamo_tensorrt_converter(torch.ops.aten.select.int, supports_dynamic_shapes=True) def aten_ops_select( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 6d9a86f89b..6653e9e1a5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Sequence, Union, cast +from typing import Optional, Sequence, Union import numpy as np import tensorrt as trt @@ -21,7 +21,7 @@ has_dynamic_shape, set_layer_name, ) -from torch_tensorrt.fx.types import Shape, TRTTensor +from torch_tensorrt.fx.types import TRTTensor _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -32,8 +32,8 @@ def select( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - dim: Shape, - index: Shape, + dim: int, + index: int, ) -> TRTTensor: if not isinstance(input, TRTTensor): raise RuntimeError( @@ -42,31 +42,13 @@ def select( ) ranks = len(input.shape) - dim = get_positive_dim(cast(int, dim), ranks) - dynamic_shape = has_dynamic_shape(input.shape) - if dynamic_shape: - # Check whether slice target dim is dynamic shape dim - assert input.shape[dim] != -1, "Can't select on negative shape dimension!" - index = index + dim = get_positive_dim(dim, ranks) - if index >= input.shape[dim]: - raise RuntimeError( - f"cannot have index greater than the dimension length! {input.shape[dim]}" - ) - output_shape = list(input.shape) - output_shape[dim] = 1 - if dynamic_shape > 0: - output_shape = get_shape_with_dynamic_shape( - ctx, target, source_ir, name, output_shape, input - ) - index_value = np.array(index, dtype=np.int32) - indices_tensor = ctx.net.add_constant( - index_value.shape, to_numpy(index_value) - ).get_output(0) + indices_tensor = get_trt_tensor( + ctx, np.array(index, dtype=np.int32), f"{name}_indices_tensor" + ) layer = ctx.net.add_gather(input, indices_tensor, dim) - out = layer.get_output(0) - if len(out.shape) != 1: - layer = ctx.net.add_shuffle(out) + return layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_select_aten.py b/tests/py/dynamo/conversion/test_select_aten.py index 4a9f0666a9..cce3fa0a30 100644 --- a/tests/py/dynamo/conversion/test_select_aten.py +++ b/tests/py/dynamo/conversion/test_select_aten.py @@ -1,4 +1,5 @@ import torch +import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input @@ -9,11 +10,11 @@ class TestSelectConverterOne(DispatchTestCase): @parameterized.expand( [ - ("select_dim_index", 1, 0), + ("dim_index", 1, 0), ] ) - def test_select(self, _, dim, index): - class TestModule(torch.nn.Module): + def test_select_2d(self, _, dim, index): + class select(nn.Module): def __init__(self): super().__init__() @@ -22,19 +23,17 @@ def forward(self, input): input = [torch.randn(1, 2)] self.run_test( - TestModule(), + select(), input, ) - -class TestSelectConverterTwo(DispatchTestCase): @parameterized.expand( [ - ("select_dim_index", 1, 0), + ("dim_index", 1, 0), ] ) - def test_select(self, _, dim, index): - class TestModule(torch.nn.Module): + def test_select_4d(self, _, dim, index): + class select(nn.Module): def __init__(self): super().__init__() @@ -43,33 +42,70 @@ def forward(self, input): input = [torch.randn(4, 4, 4, 4)] self.run_test( - TestModule(), + select(), input, ) - -class TestSelectConverterWithDynamicShape(DispatchTestCase): @parameterized.expand( [ - ("select_dim_index", 1, 0), + ( + "partial_dynamic_static_dim", + (1, 1, 3), + (2, 2, 3), + (3, 3, 3), + torch.float, + 2, + 0, + ), + ( + "partial_dynamic_dynamic_dim", + (1, 1, 3), + (2, 2, 3), + (3, 3, 3), + torch.float, + 1, + 1, + ), + ( + "fully_dynamic", + (1, 1, 1), + (2, 2, 2), + (3, 3, 3), + torch.float, + 1, + 1, + ), + ( + "fully_dynamic_neg_dim", + (1, 1, 1), + (2, 2, 2), + (3, 3, 3), + torch.float, + -1, + 1, + ), ] ) - def test_select_with_dynamic_shape(self, _, dim, index): - class TestModule(torch.nn.Module): + def test_dynamic_shape_select( + self, _, min_shape, opt_shape, max_shape, type, dim, index + ): + class select(nn.Module): def __init__(self): super().__init__() def forward(self, input): return torch.ops.aten.select.int(input, dim, index) - input_spec = [ + input_specs = [ Input( - shape=(-1, 3, 3), - dtype=torch.float32, - shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))], + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=type, ), ] - self.run_test_with_dynamic_shape(TestModule(), input_spec) + + self.run_test_with_dynamic_shape(select(), input_specs) if __name__ == "__main__":