diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 6163ee692f..7c8ecc2b62 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2695,6 +2695,29 @@ def aten_ops_pixel_unshuffle( ) +@dynamo_tensorrt_converter(torch.ops.aten.resize_.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_resize( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.shuffle.resize( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + sizes=args[1], + ) + + @enforce_tensor_types({0: (TRTTensor,)}) @dynamo_tensorrt_converter(torch.ops.aten.argmax.default) def aten_ops_argmax( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index b2d005b175..45927e7709 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -1,5 +1,7 @@ from typing import Optional, Sequence, Union +import numpy as np +import tensorrt as trt import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target from torch_tensorrt import _enums @@ -7,9 +9,11 @@ from torch_tensorrt.dynamo.conversion.converter_utils import ( SourceIR, cast_trt_tensor, + flatten_dims, get_trt_tensor, + set_layer_name, ) -from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.dynamo.utils import Frameworks, unified_dtype_converter from torch_tensorrt.fx.types import TRTTensor @@ -131,3 +135,61 @@ def pixel_unshuffle( permuted_tensor, shape[:-3] + (out_channels, out_height, out_width), ) + + +def resize( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + sizes: Sequence[int], +) -> TRTTensor: + input_np_dtype = unified_dtype_converter(input.dtype, Frameworks.NUMPY) + input_val = get_trt_tensor(ctx, input, f"{name}_input") + + # Calculate the total number of elements for new and current shape + new_num_elements = np.prod(sizes) + current_num_elements = np.prod(input_val.shape) + + if new_num_elements > current_num_elements: + # Create a padding tensor with the required size and initialize new elements with zeros + padding_size = new_num_elements - current_num_elements + padding_tensor = ctx.net.add_constant( + (padding_size,), trt.Weights(np.zeros(padding_size, dtype=input_np_dtype)) + ).get_output(0) + + # Flatten input tensor to 1D for concatenation + flatten_shape = flatten_dims(input_val, 0, -1) + flattened_input = reshape( + ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape + ) + + # Concatenate the flattened input tensor and padding tensor + reshaped_tensor = impl.cat.cat( + ctx, + target, + source_ir, + f"{name}_cat", + [flattened_input, padding_tensor], + dim=0, + ) + elif new_num_elements < current_num_elements: + # Flatten input tensor to 1D for slicing + flatten_shape = flatten_dims(input_val, 0, -1) + flattened_input = reshape( + ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape + ) + + # Slice the flattened input tensor to the desired number of elements + slice_layer = ctx.net.add_slice(flattened_input, [0], [new_num_elements], [1]) + reshaped_tensor = slice_layer.get_output(0) + else: + reshaped_tensor = input_val + + # Reshape the final output tensor to the target sizes + resized_output = reshape( + ctx, target, source_ir, f"{name}_final_reshape", reshaped_tensor, sizes + ) + + return resized_output diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 2c7071f7ad..5e1a18a0c2 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -2,8 +2,11 @@ import logging from dataclasses import fields, replace +from enum import Enum from typing import Any, Callable, Dict, Optional, Sequence, Union +import numpy as np +import tensorrt as trt import torch from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype @@ -13,12 +16,63 @@ from packaging import version +from .types import TRTDataType + logger = logging.getLogger(__name__) COSINE_THRESHOLD = 0.99 DYNAMIC_DIM = -1 +class Frameworks(Enum): + NUMPY = "numpy" + TORCH = "torch" + TRT = "trt" + + +DataTypeEquivalence: Dict[ + TRTDataType, Dict[Frameworks, Union[TRTDataType, np.dtype, torch.dtype]] +] = { + trt.int8: { + Frameworks.NUMPY: np.int8, + Frameworks.TORCH: torch.int8, + Frameworks.TRT: trt.int8, + }, + trt.int32: { + Frameworks.NUMPY: np.int32, + Frameworks.TORCH: torch.int32, + Frameworks.TRT: trt.int32, + }, + trt.int64: { + Frameworks.NUMPY: np.int64, + Frameworks.TORCH: torch.int64, + Frameworks.TRT: trt.int64, + }, + trt.float16: { + Frameworks.NUMPY: np.float16, + Frameworks.TORCH: torch.float16, + Frameworks.TRT: trt.float16, + }, + trt.float32: { + Frameworks.NUMPY: np.float32, + Frameworks.TORCH: torch.float32, + Frameworks.TRT: trt.float32, + }, + trt.bool: { + Frameworks.NUMPY: bool, + Frameworks.TORCH: torch.bool, + Frameworks.TRT: trt.bool, + }, +} + +if trt.__version__ >= "7.0": + DataTypeEquivalence[trt.bool] = { + Frameworks.NUMPY: np.bool_, + Frameworks.TORCH: torch.bool, + Frameworks.TRT: trt.bool, + } + + def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool: """Parses a user-provided input argument regarding Python runtime @@ -317,3 +371,34 @@ def function_wrapper(*args: Any, **kwargs: Any) -> Any: return function_wrapper return nested_decorator + + +def unified_dtype_converter( + dtype: Union[TRTDataType, torch.dtype, np.dtype], to: Frameworks +) -> Union[np.dtype, torch.dtype, TRTDataType]: + """ + Convert TensorRT, Numpy, or Torch data types to any other of those data types. + + Args: + dtype (TRTDataType, torch.dtype, np.dtype): A TensorRT, Numpy, or Torch data type. + to (Frameworks): The framework to convert the data type to. + + Returns: + The equivalent data type in the requested framework. + """ + assert to in Frameworks, f"Expected valid Framework for translation, got {to}" + trt_major_version = int(trt.__version__.split(".")[0]) + if dtype in (np.int8, torch.int8, trt.int8): + return DataTypeEquivalence[trt.int8][to] + elif trt_major_version >= 7 and dtype in (np.bool_, torch.bool, trt.bool): + return DataTypeEquivalence[trt.bool][to] + elif dtype in (np.int32, torch.int32, trt.int32): + return DataTypeEquivalence[trt.int32][to] + elif dtype in (np.int64, torch.int64, trt.int64): + return DataTypeEquivalence[trt.int64][to] + elif dtype in (np.float16, torch.float16, trt.float16): + return DataTypeEquivalence[trt.float16][to] + elif dtype in (np.float32, torch.float32, trt.float32): + return DataTypeEquivalence[trt.float32][to] + else: + raise TypeError("%s is not a supported dtype" % dtype) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 826d74ac96..e8c9882c1f 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -157,12 +157,20 @@ def run_test_custom_compare_results( res_trt = trt_mod(*cuda_inputs).cpu() res_cpu = mod(*cuda_inputs).cpu() assert len(res_trt) == len(res_cpu) - for output_trt, output_cpu, comparator in zip( - res_trt, res_cpu, comparators - ): - comp_func = comparator[0] - args = comparator[1] - self.assertTrue(comp_func(output_trt, output_cpu, *args)) + comparator = comparators + + if len(cuda_inputs) == 1: + for comparator in comparators: + comp_func = comparator[0] + args = comparator[1] + self.assertTrue(comp_func(res_trt, res_cpu, *args)) + else: + for output_trt, output_cpu, comparator in zip( + res_trt, res_cpu, comparators + ): + comp_func = comparator[0] + args = comparator[1] + self.assertTrue(comp_func(output_trt, output_cpu, *args)) def run_test_with_error(self, mod, inputs, interpreter, expect_error): with self.assertRaises(expect_error): diff --git a/tests/py/dynamo/conversion/test_resize_aten.py b/tests/py/dynamo/conversion/test_resize_aten.py new file mode 100644 index 0000000000..8318035d86 --- /dev/null +++ b/tests/py/dynamo/conversion/test_resize_aten.py @@ -0,0 +1,152 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestResizeConverter(DispatchTestCase): + + def compare_resized_tensors(self, tensor1, tensor2, input_shape, target_shape): + # Check if the sizes match + if tensor1.size() != tensor2.size(): + return False + + # Flatten the tensors to ensure we are comparing the valid elements + flat_tensor1 = tensor1.flatten() + flat_tensor2 = tensor2.flatten() + + # Calculate the number of valid elements to compare + input_numel = torch.Size(input_shape).numel() + target_numel = torch.Size(target_shape).numel() + min_size = min(input_numel, target_numel) + + # Compare only the valid elements + return torch.equal(flat_tensor1[:min_size], flat_tensor2[:min_size]) + + @parameterized.expand( + [ + ((3,),), + ((5,),), + ((10,),), + ((2, 2),), + ((3, 5),), + ((8, 3),), + ((7, 7),), + ((5, 5, 5),), + ((3, 3, 10),), + ((10, 15, 10),), + ] + ) + def test_resize_1d_input_float(self, target_shape): + class Resize(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.resize_.default(x, target_shape) + + input_shape = (5,) + inputs = [torch.randn(input_shape)] + + comparators = [(self.compare_resized_tensors, [input_shape, target_shape])] + + self.run_test_compare_tensor_attributes_only( + Resize(), + inputs, + expected_ops=[], + comparators=comparators, + ) + + @parameterized.expand( + [ + ((3,),), + ((5,),), + ((10,),), + ((3, 5),), + ((8, 3),), + ((7, 7),), + ((5, 5, 5),), + ((3, 3, 5),), + ((15, 10, 3),), + ((15, 10, 12),), + ] + ) + def test_resize_1d_input_int(self, target_shape): + class Resize(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.resize_.default(x, target_shape) + + input_shape = (5,) + inputs = [torch.randint(1, 5, input_shape)] + + comparators = [(self.compare_resized_tensors, [input_shape, target_shape])] + + self.run_test_compare_tensor_attributes_only( + Resize(), + inputs, + expected_ops=[], + comparators=comparators, + ) + + @parameterized.expand( + [ + ((3,),), + ((5,),), + ((10,),), + ((4, 4),), + ((3, 5),), + ((8, 3),), + ((7, 7),), + ((20, 12, 13),), + ((3, 3, 5),), + ((3, 10, 15),), + ] + ) + def test_resize_2d_input_float(self, target_shape): + class Resize(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.resize_.default(x, target_shape) + + input_shape = (4, 4) + inputs = [torch.randint(1, 10, input_shape)] + + comparators = [(self.compare_resized_tensors, [input_shape, target_shape])] + + self.run_test_compare_tensor_attributes_only( + Resize(), + inputs, + expected_ops=[], + comparators=comparators, + ) + + @parameterized.expand( + [ + ((3,),), + ((5,),), + ((20,),), + ((4, 4),), + ((3, 12),), + ((12, 3),), + ((15, 15),), + ((20, 20, 20),), + ((3, 3, 10),), + ] + ) + def test_resize_2d_input_int(self, target_shape): + class Resize(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.resize_.default(x, target_shape) + + input_shape = (4, 4) + inputs = [torch.randint(1, 10, input_shape)] + + comparators = [(self.compare_resized_tensors, [input_shape, target_shape])] + + self.run_test_compare_tensor_attributes_only( + Resize(), + inputs, + expected_ops=[], + comparators=comparators, + ) + + +if __name__ == "__main__": + run_tests()