Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support aten.resize_ converter #2874

Merged
merged 5 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2695,6 +2695,29 @@ def aten_ops_pixel_unshuffle(
)


@dynamo_tensorrt_converter(torch.ops.aten.resize_.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_resize(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.shuffle.resize(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
sizes=args[1],
)


@enforce_tensor_types({0: (TRTTensor,)})
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default)
def aten_ops_argmax(
Expand Down
64 changes: 63 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from typing import Optional, Sequence, Union

import numpy as np
import tensorrt as trt
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.node import Target
from torch_tensorrt import _enums
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
SourceIR,
cast_trt_tensor,
flatten_dims,
get_trt_tensor,
set_layer_name,
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.dynamo.utils import Frameworks, unified_dtype_converter
from torch_tensorrt.fx.types import TRTTensor


Expand Down Expand Up @@ -131,3 +135,61 @@ def pixel_unshuffle(
permuted_tensor,
shape[:-3] + (out_channels, out_height, out_width),
)


def resize(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
sizes: Sequence[int],
) -> TRTTensor:
input_np_dtype = unified_dtype_converter(input.dtype, Frameworks.NUMPY)
input_val = get_trt_tensor(ctx, input, f"{name}_input")

# Calculate the total number of elements for new and current shape
new_num_elements = np.prod(sizes)
current_num_elements = np.prod(input_val.shape)

if new_num_elements > current_num_elements:
# Create a padding tensor with the required size and initialize new elements with zeros
padding_size = new_num_elements - current_num_elements
padding_tensor = ctx.net.add_constant(
(padding_size,), trt.Weights(np.zeros(padding_size, dtype=input_np_dtype))
).get_output(0)

# Flatten input tensor to 1D for concatenation
flatten_shape = flatten_dims(input_val, 0, -1)
flattened_input = reshape(
ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape
)

# Concatenate the flattened input tensor and padding tensor
reshaped_tensor = impl.cat.cat(
ctx,
target,
source_ir,
f"{name}_cat",
[flattened_input, padding_tensor],
dim=0,
)
elif new_num_elements < current_num_elements:
# Flatten input tensor to 1D for slicing
flatten_shape = flatten_dims(input_val, 0, -1)
flattened_input = reshape(
ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape
)

# Slice the flattened input tensor to the desired number of elements
slice_layer = ctx.net.add_slice(flattened_input, [0], [new_num_elements], [1])
reshaped_tensor = slice_layer.get_output(0)
else:
reshaped_tensor = input_val

# Reshape the final output tensor to the target sizes
resized_output = reshape(
ctx, target, source_ir, f"{name}_final_reshape", reshaped_tensor, sizes
)

return resized_output
85 changes: 85 additions & 0 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import logging
from dataclasses import fields, replace
from enum import Enum
from typing import Any, Callable, Dict, Optional, Sequence, Union

import numpy as np
import tensorrt as trt
import torch
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
Expand All @@ -13,12 +16,63 @@

from packaging import version

from .types import TRTDataType

logger = logging.getLogger(__name__)

COSINE_THRESHOLD = 0.99
DYNAMIC_DIM = -1


class Frameworks(Enum):
NUMPY = "numpy"
TORCH = "torch"
TRT = "trt"


DataTypeEquivalence: Dict[
TRTDataType, Dict[Frameworks, Union[TRTDataType, np.dtype, torch.dtype]]
] = {
trt.int8: {
Frameworks.NUMPY: np.int8,
Frameworks.TORCH: torch.int8,
Frameworks.TRT: trt.int8,
},
trt.int32: {
Frameworks.NUMPY: np.int32,
Frameworks.TORCH: torch.int32,
Frameworks.TRT: trt.int32,
},
trt.int64: {
Frameworks.NUMPY: np.int64,
Frameworks.TORCH: torch.int64,
Frameworks.TRT: trt.int64,
},
trt.float16: {
Frameworks.NUMPY: np.float16,
Frameworks.TORCH: torch.float16,
Frameworks.TRT: trt.float16,
},
trt.float32: {
Frameworks.NUMPY: np.float32,
Frameworks.TORCH: torch.float32,
Frameworks.TRT: trt.float32,
},
trt.bool: {
Frameworks.NUMPY: bool,
Frameworks.TORCH: torch.bool,
Frameworks.TRT: trt.bool,
},
}

if trt.__version__ >= "7.0":
DataTypeEquivalence[trt.bool] = {
Frameworks.NUMPY: np.bool_,
Frameworks.TORCH: torch.bool,
Frameworks.TRT: trt.bool,
}


def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool:
"""Parses a user-provided input argument regarding Python runtime

Expand Down Expand Up @@ -317,3 +371,34 @@ def function_wrapper(*args: Any, **kwargs: Any) -> Any:
return function_wrapper

return nested_decorator


def unified_dtype_converter(
dtype: Union[TRTDataType, torch.dtype, np.dtype], to: Frameworks
) -> Union[np.dtype, torch.dtype, TRTDataType]:
"""
Convert TensorRT, Numpy, or Torch data types to any other of those data types.

Args:
dtype (TRTDataType, torch.dtype, np.dtype): A TensorRT, Numpy, or Torch data type.
to (Frameworks): The framework to convert the data type to.

Returns:
The equivalent data type in the requested framework.
"""
assert to in Frameworks, f"Expected valid Framework for translation, got {to}"
trt_major_version = int(trt.__version__.split(".")[0])
if dtype in (np.int8, torch.int8, trt.int8):
return DataTypeEquivalence[trt.int8][to]
elif trt_major_version >= 7 and dtype in (np.bool_, torch.bool, trt.bool):
return DataTypeEquivalence[trt.bool][to]
elif dtype in (np.int32, torch.int32, trt.int32):
return DataTypeEquivalence[trt.int32][to]
elif dtype in (np.int64, torch.int64, trt.int64):
return DataTypeEquivalence[trt.int64][to]
elif dtype in (np.float16, torch.float16, trt.float16):
return DataTypeEquivalence[trt.float16][to]
elif dtype in (np.float32, torch.float32, trt.float32):
return DataTypeEquivalence[trt.float32][to]
else:
raise TypeError("%s is not a supported dtype" % dtype)
20 changes: 14 additions & 6 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,20 @@ def run_test_custom_compare_results(
res_trt = trt_mod(*cuda_inputs).cpu()
res_cpu = mod(*cuda_inputs).cpu()
assert len(res_trt) == len(res_cpu)
for output_trt, output_cpu, comparator in zip(
res_trt, res_cpu, comparators
):
comp_func = comparator[0]
args = comparator[1]
self.assertTrue(comp_func(output_trt, output_cpu, *args))
comparator = comparators

if len(cuda_inputs) == 1:
for comparator in comparators:
comp_func = comparator[0]
args = comparator[1]
self.assertTrue(comp_func(res_trt, res_cpu, *args))
else:
for output_trt, output_cpu, comparator in zip(
res_trt, res_cpu, comparators
):
comp_func = comparator[0]
args = comparator[1]
self.assertTrue(comp_func(output_trt, output_cpu, *args))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific case where the len(cuda_inputs) == 1 is required? In general I assume the len(cuda_inputs) would be 1 in most cases. And since it conditioned on res_trt could you highlight the cases where it would be required?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your comment!

As you mentioned, in most cases, the length of cuda_inputs is 1. Below are the outputs when I set a breakpoint in the original code for two cases where torch.ops.aten.resize_.default(x, target_shape) returns tensors with shapes (3,) and (10, 15, 10).

image
image

In both cases, the length of cuda_inputs is also 1. When cuda_inputs has a length of 1, res_trt and res_cpu are not lists of length 1 but are torch.Tensors with shapes (3,) and (10, 15, 10) for each case. Therefore, when we use zip in for output_trt, output_cpu, comparator in zip(res_trt, res_cpu, comparators), and the comparators list has a length of 1, it results in comparing fewer elements because one dimension is lost from res_trt and res_cpu.

def run_test_with_error(self, mod, inputs, interpreter, expect_error):
with self.assertRaises(expect_error):
Expand Down
Loading
Loading