From be85f5de037868e7a75d8260718a80b0311af76e Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Wed, 19 Jun 2024 10:14:12 +0900 Subject: [PATCH 1/5] feat: support aten.resize_ converter --- .../dynamo/conversion/aten_ops_converters.py | 24 ++++ .../dynamo/conversion/impl/shuffle.py | 66 +++++++++- .../py/dynamo/conversion/test_resize_aten.py | 117 ++++++++++++++++++ 3 files changed, 206 insertions(+), 1 deletion(-) create mode 100644 tests/py/dynamo/conversion/test_resize_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 6163ee692f..6b5f4681f0 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2695,6 +2695,30 @@ def aten_ops_pixel_unshuffle( ) +@dynamo_tensorrt_converter(torch.ops.aten.resize.default) +@dynamo_tensorrt_converter(torch.ops.aten.resize_.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_resize( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.shuffle.resize_( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + sizes=args[1], + ) + + @enforce_tensor_types({0: (TRTTensor,)}) @dynamo_tensorrt_converter(torch.ops.aten.argmax.default) def aten_ops_argmax( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index b2d005b175..8afc871bf4 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -1,5 +1,7 @@ from typing import Optional, Sequence, Union +import numpy as np +import tensorrt as trt import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target from torch_tensorrt import _enums @@ -7,9 +9,14 @@ from torch_tensorrt.dynamo.conversion.converter_utils import ( SourceIR, cast_trt_tensor, + flatten_dims, get_trt_tensor, ) -from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.converters.converter_utils import ( + Frameworks, + set_layer_name, + unified_dtype_converter, +) from torch_tensorrt.fx.types import TRTTensor @@ -131,3 +138,60 @@ def pixel_unshuffle( permuted_tensor, shape[:-3] + (out_channels, out_height, out_width), ) + + +def resize_( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + sizes: Sequence[int], +) -> TRTTensor: + + input_np_dtype = unified_dtype_converter(input.dtype, Frameworks.NUMPY) + + input_val = get_trt_tensor(ctx, input, f"{name}_input") + + # Calculate the total number of elements for new and current shape + new_num_elements = np.prod(sizes) + current_num_elements = np.prod(input_val.shape) + + if new_num_elements > current_num_elements: + # Create a padding tensor with the required size and initialize new elements with zeros + padding_size = new_num_elements - current_num_elements + padding_tensor = ctx.net.add_constant( + (padding_size,), trt.Weights(np.zeros(padding_size, dtype=input_np_dtype)) + ).get_output(0) + + # Flatten input tensor to 1D for concatenation + flatten_shape = flatten_dims(input_val, 0, -1) + flattened_input = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape + ) + + # Concatenate the flattened input tensor and padding tensor + concat_layer = ctx.net.add_concatenation([flattened_input, padding_tensor]) + concat_layer.axis = 0 + reshaped_tensor = concat_layer.get_output(0) + + elif new_num_elements < current_num_elements: + # Flatten input tensor to 1D for slicing + flatten_shape = flatten_dims(input_val, 0, -1) + flattened_input = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape + ) + + # Slice the flattened input tensor to the desired number of elements + slice_layer = ctx.net.add_slice(flattened_input, [0], [new_num_elements], [1]) + reshaped_tensor = slice_layer.get_output(0) + + else: + reshaped_tensor = input_val + + # Reshape the final output tensor to the target sizes + resized_output = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_final_reshape", reshaped_tensor, sizes + ) + + return resized_output diff --git a/tests/py/dynamo/conversion/test_resize_aten.py b/tests/py/dynamo/conversion/test_resize_aten.py new file mode 100644 index 0000000000..25f1f31fac --- /dev/null +++ b/tests/py/dynamo/conversion/test_resize_aten.py @@ -0,0 +1,117 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestResizeConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3,),), + ((5,),), + ((10,),), + ((3, 5),), + ((8, 3),), + ((7, 7),), + ((5, 5, 5),), + ((3, 3, 5),), + ] + ) + def test_resize_1d_input_float(self, target_shape): + class Resize(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.ops.aten.resize_.default(x, target_shape) + + inputs = [torch.randn(5)] + self.run_test( + Resize(), + inputs, + ) + + @parameterized.expand( + [ + ((3,),), + ((5,),), + ((10,),), + ((3, 5),), + ((8, 3),), + ((7, 7),), + ((5, 5, 5),), + ((3, 3, 5),), + ] + ) + def test_resize_1d_input_int(self, target_shape): + class Resize(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.ops.aten.resize_.default(x, target_shape) + + inputs = [torch.randint(1, 5, (5,))] + self.run_test( + Resize(), + inputs, + ) + + @parameterized.expand( + [ + ((3,),), + ((5,),), + ((10,),), + ((4, 4),), + ((3, 5),), + ((8, 3),), + ((7, 7),), + ((5, 5, 5),), + ((3, 3, 5),), + ] + ) + def test_resize_2d_input_float(self, target_shape): + class Resize(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.ops.aten.resize_.default(x, target_shape) + + inputs = [torch.randn(4, 4)] + self.run_test( + Resize(), + inputs, + ) + + @parameterized.expand( + [ + ((3,),), + ((5,),), + ((10,),), + ((4, 4),), + ((3, 5),), + ((8, 3),), + ((7, 7),), + ((5, 5, 5),), + ((3, 3, 5),), + ] + ) + def test_resize_2d_input_int(self, target_shape): + class Resize(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.ops.aten.resize_.default(x, target_shape) + + inputs = [torch.randint(1, 10, (4, 4))] + self.run_test( + Resize(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() From c3a2ed256e4a763165cff438cd1a7abfd7182dd1 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Wed, 19 Jun 2024 10:14:12 +0900 Subject: [PATCH 2/5] chore: minor naming issues --- .../dynamo/conversion/aten_ops_converters.py | 3 +-- .../dynamo/conversion/impl/shuffle.py | 23 ++++++++++--------- .../py/dynamo/conversion/test_resize_aten.py | 12 ---------- 3 files changed, 13 insertions(+), 25 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 6b5f4681f0..7c8ecc2b62 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2695,7 +2695,6 @@ def aten_ops_pixel_unshuffle( ) -@dynamo_tensorrt_converter(torch.ops.aten.resize.default) @dynamo_tensorrt_converter(torch.ops.aten.resize_.default) @enforce_tensor_types( { @@ -2709,7 +2708,7 @@ def aten_ops_resize( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.shuffle.resize_( + return impl.shuffle.resize( ctx, target, SourceIR.ATEN, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index 8afc871bf4..6d2a1ba7e0 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -140,7 +140,7 @@ def pixel_unshuffle( ) -def resize_( +def resize( ctx: ConversionContext, target: Union[Target, str], source_ir: Optional[SourceIR], @@ -148,9 +148,7 @@ def resize_( input: TRTTensor, sizes: Sequence[int], ) -> TRTTensor: - input_np_dtype = unified_dtype_converter(input.dtype, Frameworks.NUMPY) - input_val = get_trt_tensor(ctx, input, f"{name}_input") # Calculate the total number of elements for new and current shape @@ -166,31 +164,34 @@ def resize_( # Flatten input tensor to 1D for concatenation flatten_shape = flatten_dims(input_val, 0, -1) - flattened_input = impl.shuffle.reshape( + flattened_input = reshape( ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape ) # Concatenate the flattened input tensor and padding tensor - concat_layer = ctx.net.add_concatenation([flattened_input, padding_tensor]) - concat_layer.axis = 0 - reshaped_tensor = concat_layer.get_output(0) - + reshaped_tensor = impl.cat.cat( + ctx, + target, + source_ir, + f"{name}_cat", + [flattened_input, padding_tensor], + dim=0, + ) elif new_num_elements < current_num_elements: # Flatten input tensor to 1D for slicing flatten_shape = flatten_dims(input_val, 0, -1) - flattened_input = impl.shuffle.reshape( + flattened_input = reshape( ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape ) # Slice the flattened input tensor to the desired number of elements slice_layer = ctx.net.add_slice(flattened_input, [0], [new_num_elements], [1]) reshaped_tensor = slice_layer.get_output(0) - else: reshaped_tensor = input_val # Reshape the final output tensor to the target sizes - resized_output = impl.shuffle.reshape( + resized_output = reshape( ctx, target, source_ir, f"{name}_final_reshape", reshaped_tensor, sizes ) diff --git a/tests/py/dynamo/conversion/test_resize_aten.py b/tests/py/dynamo/conversion/test_resize_aten.py index 25f1f31fac..12e6cb66a1 100644 --- a/tests/py/dynamo/conversion/test_resize_aten.py +++ b/tests/py/dynamo/conversion/test_resize_aten.py @@ -20,9 +20,6 @@ class TestResizeConverter(DispatchTestCase): ) def test_resize_1d_input_float(self, target_shape): class Resize(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x): return torch.ops.aten.resize_.default(x, target_shape) @@ -46,9 +43,6 @@ def forward(self, x): ) def test_resize_1d_input_int(self, target_shape): class Resize(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x): return torch.ops.aten.resize_.default(x, target_shape) @@ -73,9 +67,6 @@ def forward(self, x): ) def test_resize_2d_input_float(self, target_shape): class Resize(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x): return torch.ops.aten.resize_.default(x, target_shape) @@ -100,9 +91,6 @@ def forward(self, x): ) def test_resize_2d_input_int(self, target_shape): class Resize(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x): return torch.ops.aten.resize_.default(x, target_shape) From dfa5d162a80400a4834c5f428425d7ae358a8d96 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Wed, 19 Jun 2024 10:14:12 +0900 Subject: [PATCH 3/5] feat: revise run_test function to compare only valid elements --- tests/py/dynamo/conversion/harness.py | 20 ++- .../py/dynamo/conversion/test_resize_aten.py | 129 +++++++++++++++--- 2 files changed, 127 insertions(+), 22 deletions(-) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 826d74ac96..e8c9882c1f 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -157,12 +157,20 @@ def run_test_custom_compare_results( res_trt = trt_mod(*cuda_inputs).cpu() res_cpu = mod(*cuda_inputs).cpu() assert len(res_trt) == len(res_cpu) - for output_trt, output_cpu, comparator in zip( - res_trt, res_cpu, comparators - ): - comp_func = comparator[0] - args = comparator[1] - self.assertTrue(comp_func(output_trt, output_cpu, *args)) + comparator = comparators + + if len(cuda_inputs) == 1: + for comparator in comparators: + comp_func = comparator[0] + args = comparator[1] + self.assertTrue(comp_func(res_trt, res_cpu, *args)) + else: + for output_trt, output_cpu, comparator in zip( + res_trt, res_cpu, comparators + ): + comp_func = comparator[0] + args = comparator[1] + self.assertTrue(comp_func(output_trt, output_cpu, *args)) def run_test_with_error(self, mod, inputs, interpreter, expect_error): with self.assertRaises(expect_error): diff --git a/tests/py/dynamo/conversion/test_resize_aten.py b/tests/py/dynamo/conversion/test_resize_aten.py index 12e6cb66a1..3a6a30dae5 100644 --- a/tests/py/dynamo/conversion/test_resize_aten.py +++ b/tests/py/dynamo/conversion/test_resize_aten.py @@ -11,11 +11,13 @@ class TestResizeConverter(DispatchTestCase): ((3,),), ((5,),), ((10,),), + ((2, 2),), ((3, 5),), ((8, 3),), ((7, 7),), ((5, 5, 5),), - ((3, 3, 5),), + ((3, 3, 10),), + ((10, 15, 10),), ] ) def test_resize_1d_input_float(self, target_shape): @@ -23,10 +25,33 @@ class Resize(torch.nn.Module): def forward(self, x): return torch.ops.aten.resize_.default(x, target_shape) - inputs = [torch.randn(5)] - self.run_test( + input_shape = (5,) + inputs = [torch.randn(input_shape)] + + def compare_resized_tensors(tensor1, tensor2, input_shape, target_shape): + # Check if the sizes match + if tensor1.size() != tensor2.size(): + return False + + # Flatten the tensors to ensure we are comparing the valid elements + flat_tensor1 = tensor1.flatten() + flat_tensor2 = tensor2.flatten() + + # Calculate the number of valid elements to compare + input_numel = torch.Size(input_shape).numel() + target_numel = torch.Size(target_shape).numel() + min_size = min(input_numel, target_numel) + + # Compare only the valid elements + return torch.equal(flat_tensor1[:min_size], flat_tensor2[:min_size]) + + comparators = [(compare_resized_tensors, [input_shape, target_shape])] + + self.run_test_compare_tensor_attributes_only( Resize(), inputs, + expected_ops=[], + comparators=comparators, ) @parameterized.expand( @@ -39,6 +64,8 @@ def forward(self, x): ((7, 7),), ((5, 5, 5),), ((3, 3, 5),), + ((15, 10, 3),), + ((15, 10, 12),), ] ) def test_resize_1d_input_int(self, target_shape): @@ -46,10 +73,33 @@ class Resize(torch.nn.Module): def forward(self, x): return torch.ops.aten.resize_.default(x, target_shape) - inputs = [torch.randint(1, 5, (5,))] - self.run_test( + input_shape = (5,) + inputs = [torch.randint(1, 5, input_shape)] + + def compare_resized_tensors(tensor1, tensor2, input_shape, target_shape): + # Check if the sizes match + if tensor1.size() != tensor2.size(): + return False + + # Flatten the tensors to ensure we are comparing the valid elements + flat_tensor1 = tensor1.flatten() + flat_tensor2 = tensor2.flatten() + + # Calculate the number of valid elements to compare + input_numel = torch.Size(input_shape).numel() + target_numel = torch.Size(target_shape).numel() + min_size = min(input_numel, target_numel) + + # Compare only the valid elements + return torch.equal(flat_tensor1[:min_size], flat_tensor2[:min_size]) + + comparators = [(compare_resized_tensors, [input_shape, target_shape])] + + self.run_test_compare_tensor_attributes_only( Resize(), inputs, + expected_ops=[], + comparators=comparators, ) @parameterized.expand( @@ -61,8 +111,9 @@ def forward(self, x): ((3, 5),), ((8, 3),), ((7, 7),), - ((5, 5, 5),), + ((20, 12, 13),), ((3, 3, 5),), + ((3, 10, 15),), ] ) def test_resize_2d_input_float(self, target_shape): @@ -70,23 +121,46 @@ class Resize(torch.nn.Module): def forward(self, x): return torch.ops.aten.resize_.default(x, target_shape) - inputs = [torch.randn(4, 4)] - self.run_test( + input_shape = (4, 4) + inputs = [torch.randint(1, 10, input_shape)] + + def compare_resized_tensors(tensor1, tensor2, input_shape, target_shape): + # Check if the sizes match + if tensor1.size() != tensor2.size(): + return False + + # Flatten the tensors to ensure we are comparing the valid elements + flat_tensor1 = tensor1.flatten() + flat_tensor2 = tensor2.flatten() + + # Calculate the number of valid elements to compare + input_numel = torch.Size(input_shape).numel() + target_numel = torch.Size(target_shape).numel() + min_size = min(input_numel, target_numel) + + # Compare only the valid elements + return torch.equal(flat_tensor1[:min_size], flat_tensor2[:min_size]) + + comparators = [(compare_resized_tensors, [input_shape, target_shape])] + + self.run_test_compare_tensor_attributes_only( Resize(), inputs, + expected_ops=[], + comparators=comparators, ) @parameterized.expand( [ ((3,),), ((5,),), - ((10,),), + ((20,),), ((4, 4),), - ((3, 5),), - ((8, 3),), - ((7, 7),), - ((5, 5, 5),), - ((3, 3, 5),), + ((3, 12),), + ((12, 3),), + ((15, 15),), + ((20, 20, 20),), + ((3, 3, 10),), ] ) def test_resize_2d_input_int(self, target_shape): @@ -94,10 +168,33 @@ class Resize(torch.nn.Module): def forward(self, x): return torch.ops.aten.resize_.default(x, target_shape) - inputs = [torch.randint(1, 10, (4, 4))] - self.run_test( + input_shape = (4, 4) + inputs = [torch.randint(1, 10, input_shape)] + + def compare_resized_tensors(tensor1, tensor2, input_shape, target_shape): + # Check if the sizes match + if tensor1.size() != tensor2.size(): + return False + + # Flatten the tensors to ensure we are comparing the valid elements + flat_tensor1 = tensor1.flatten() + flat_tensor2 = tensor2.flatten() + + # Calculate the number of valid elements to compare + input_numel = torch.Size(input_shape).numel() + target_numel = torch.Size(target_shape).numel() + min_size = min(input_numel, target_numel) + + # Compare only the valid elements + return torch.equal(flat_tensor1[:min_size], flat_tensor2[:min_size]) + + comparators = [(compare_resized_tensors, [input_shape, target_shape])] + + self.run_test_compare_tensor_attributes_only( Resize(), inputs, + expected_ops=[], + comparators=comparators, ) From 777473852943a31a025be1c45cf810bb40307be0 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Wed, 19 Jun 2024 10:14:12 +0900 Subject: [PATCH 4/5] chore: Moved comparators function to a class method to avoid redundancy --- .../py/dynamo/conversion/test_resize_aten.py | 94 +++++-------------- 1 file changed, 22 insertions(+), 72 deletions(-) diff --git a/tests/py/dynamo/conversion/test_resize_aten.py b/tests/py/dynamo/conversion/test_resize_aten.py index 3a6a30dae5..8318035d86 100644 --- a/tests/py/dynamo/conversion/test_resize_aten.py +++ b/tests/py/dynamo/conversion/test_resize_aten.py @@ -6,6 +6,24 @@ class TestResizeConverter(DispatchTestCase): + + def compare_resized_tensors(self, tensor1, tensor2, input_shape, target_shape): + # Check if the sizes match + if tensor1.size() != tensor2.size(): + return False + + # Flatten the tensors to ensure we are comparing the valid elements + flat_tensor1 = tensor1.flatten() + flat_tensor2 = tensor2.flatten() + + # Calculate the number of valid elements to compare + input_numel = torch.Size(input_shape).numel() + target_numel = torch.Size(target_shape).numel() + min_size = min(input_numel, target_numel) + + # Compare only the valid elements + return torch.equal(flat_tensor1[:min_size], flat_tensor2[:min_size]) + @parameterized.expand( [ ((3,),), @@ -28,24 +46,7 @@ def forward(self, x): input_shape = (5,) inputs = [torch.randn(input_shape)] - def compare_resized_tensors(tensor1, tensor2, input_shape, target_shape): - # Check if the sizes match - if tensor1.size() != tensor2.size(): - return False - - # Flatten the tensors to ensure we are comparing the valid elements - flat_tensor1 = tensor1.flatten() - flat_tensor2 = tensor2.flatten() - - # Calculate the number of valid elements to compare - input_numel = torch.Size(input_shape).numel() - target_numel = torch.Size(target_shape).numel() - min_size = min(input_numel, target_numel) - - # Compare only the valid elements - return torch.equal(flat_tensor1[:min_size], flat_tensor2[:min_size]) - - comparators = [(compare_resized_tensors, [input_shape, target_shape])] + comparators = [(self.compare_resized_tensors, [input_shape, target_shape])] self.run_test_compare_tensor_attributes_only( Resize(), @@ -76,24 +77,7 @@ def forward(self, x): input_shape = (5,) inputs = [torch.randint(1, 5, input_shape)] - def compare_resized_tensors(tensor1, tensor2, input_shape, target_shape): - # Check if the sizes match - if tensor1.size() != tensor2.size(): - return False - - # Flatten the tensors to ensure we are comparing the valid elements - flat_tensor1 = tensor1.flatten() - flat_tensor2 = tensor2.flatten() - - # Calculate the number of valid elements to compare - input_numel = torch.Size(input_shape).numel() - target_numel = torch.Size(target_shape).numel() - min_size = min(input_numel, target_numel) - - # Compare only the valid elements - return torch.equal(flat_tensor1[:min_size], flat_tensor2[:min_size]) - - comparators = [(compare_resized_tensors, [input_shape, target_shape])] + comparators = [(self.compare_resized_tensors, [input_shape, target_shape])] self.run_test_compare_tensor_attributes_only( Resize(), @@ -124,24 +108,7 @@ def forward(self, x): input_shape = (4, 4) inputs = [torch.randint(1, 10, input_shape)] - def compare_resized_tensors(tensor1, tensor2, input_shape, target_shape): - # Check if the sizes match - if tensor1.size() != tensor2.size(): - return False - - # Flatten the tensors to ensure we are comparing the valid elements - flat_tensor1 = tensor1.flatten() - flat_tensor2 = tensor2.flatten() - - # Calculate the number of valid elements to compare - input_numel = torch.Size(input_shape).numel() - target_numel = torch.Size(target_shape).numel() - min_size = min(input_numel, target_numel) - - # Compare only the valid elements - return torch.equal(flat_tensor1[:min_size], flat_tensor2[:min_size]) - - comparators = [(compare_resized_tensors, [input_shape, target_shape])] + comparators = [(self.compare_resized_tensors, [input_shape, target_shape])] self.run_test_compare_tensor_attributes_only( Resize(), @@ -171,24 +138,7 @@ def forward(self, x): input_shape = (4, 4) inputs = [torch.randint(1, 10, input_shape)] - def compare_resized_tensors(tensor1, tensor2, input_shape, target_shape): - # Check if the sizes match - if tensor1.size() != tensor2.size(): - return False - - # Flatten the tensors to ensure we are comparing the valid elements - flat_tensor1 = tensor1.flatten() - flat_tensor2 = tensor2.flatten() - - # Calculate the number of valid elements to compare - input_numel = torch.Size(input_shape).numel() - target_numel = torch.Size(target_shape).numel() - min_size = min(input_numel, target_numel) - - # Compare only the valid elements - return torch.equal(flat_tensor1[:min_size], flat_tensor2[:min_size]) - - comparators = [(compare_resized_tensors, [input_shape, target_shape])] + comparators = [(self.compare_resized_tensors, [input_shape, target_shape])] self.run_test_compare_tensor_attributes_only( Resize(), From 6bcadc5ad00c607fa0c5c305394fdb4f5b5a17c4 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Wed, 19 Jun 2024 10:45:14 +0900 Subject: [PATCH 5/5] chore: remove dependency on fx util, move to dynamo --- .../dynamo/conversion/impl/shuffle.py | 5 +- py/torch_tensorrt/dynamo/utils.py | 85 +++++++++++++++++++ 2 files changed, 86 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index 6d2a1ba7e0..45927e7709 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -11,12 +11,9 @@ cast_trt_tensor, flatten_dims, get_trt_tensor, -) -from torch_tensorrt.fx.converters.converter_utils import ( - Frameworks, set_layer_name, - unified_dtype_converter, ) +from torch_tensorrt.dynamo.utils import Frameworks, unified_dtype_converter from torch_tensorrt.fx.types import TRTTensor diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 2c7071f7ad..5e1a18a0c2 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -2,8 +2,11 @@ import logging from dataclasses import fields, replace +from enum import Enum from typing import Any, Callable, Dict, Optional, Sequence, Union +import numpy as np +import tensorrt as trt import torch from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype @@ -13,12 +16,63 @@ from packaging import version +from .types import TRTDataType + logger = logging.getLogger(__name__) COSINE_THRESHOLD = 0.99 DYNAMIC_DIM = -1 +class Frameworks(Enum): + NUMPY = "numpy" + TORCH = "torch" + TRT = "trt" + + +DataTypeEquivalence: Dict[ + TRTDataType, Dict[Frameworks, Union[TRTDataType, np.dtype, torch.dtype]] +] = { + trt.int8: { + Frameworks.NUMPY: np.int8, + Frameworks.TORCH: torch.int8, + Frameworks.TRT: trt.int8, + }, + trt.int32: { + Frameworks.NUMPY: np.int32, + Frameworks.TORCH: torch.int32, + Frameworks.TRT: trt.int32, + }, + trt.int64: { + Frameworks.NUMPY: np.int64, + Frameworks.TORCH: torch.int64, + Frameworks.TRT: trt.int64, + }, + trt.float16: { + Frameworks.NUMPY: np.float16, + Frameworks.TORCH: torch.float16, + Frameworks.TRT: trt.float16, + }, + trt.float32: { + Frameworks.NUMPY: np.float32, + Frameworks.TORCH: torch.float32, + Frameworks.TRT: trt.float32, + }, + trt.bool: { + Frameworks.NUMPY: bool, + Frameworks.TORCH: torch.bool, + Frameworks.TRT: trt.bool, + }, +} + +if trt.__version__ >= "7.0": + DataTypeEquivalence[trt.bool] = { + Frameworks.NUMPY: np.bool_, + Frameworks.TORCH: torch.bool, + Frameworks.TRT: trt.bool, + } + + def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool: """Parses a user-provided input argument regarding Python runtime @@ -317,3 +371,34 @@ def function_wrapper(*args: Any, **kwargs: Any) -> Any: return function_wrapper return nested_decorator + + +def unified_dtype_converter( + dtype: Union[TRTDataType, torch.dtype, np.dtype], to: Frameworks +) -> Union[np.dtype, torch.dtype, TRTDataType]: + """ + Convert TensorRT, Numpy, or Torch data types to any other of those data types. + + Args: + dtype (TRTDataType, torch.dtype, np.dtype): A TensorRT, Numpy, or Torch data type. + to (Frameworks): The framework to convert the data type to. + + Returns: + The equivalent data type in the requested framework. + """ + assert to in Frameworks, f"Expected valid Framework for translation, got {to}" + trt_major_version = int(trt.__version__.split(".")[0]) + if dtype in (np.int8, torch.int8, trt.int8): + return DataTypeEquivalence[trt.int8][to] + elif trt_major_version >= 7 and dtype in (np.bool_, torch.bool, trt.bool): + return DataTypeEquivalence[trt.bool][to] + elif dtype in (np.int32, torch.int32, trt.int32): + return DataTypeEquivalence[trt.int32][to] + elif dtype in (np.int64, torch.int64, trt.int64): + return DataTypeEquivalence[trt.int64][to] + elif dtype in (np.float16, torch.float16, trt.float16): + return DataTypeEquivalence[trt.float16][to] + elif dtype in (np.float32, torch.float32, trt.float32): + return DataTypeEquivalence[trt.float32][to] + else: + raise TypeError("%s is not a supported dtype" % dtype)