Skip to content

Commit

Permalink
chore: minor naming issues
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 committed Jun 5, 2024
1 parent b606306 commit eaecfb2
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 25 deletions.
3 changes: 1 addition & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2610,7 +2610,6 @@ def aten_ops_pixel_unshuffle(
)


@dynamo_tensorrt_converter(torch.ops.aten.resize.default)
@dynamo_tensorrt_converter(torch.ops.aten.resize_.default)
@enforce_tensor_types(
{
Expand All @@ -2624,7 +2623,7 @@ def aten_ops_resize(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.shuffle.resize_(
return impl.shuffle.resize(
ctx,
target,
SourceIR.ATEN,
Expand Down
23 changes: 12 additions & 11 deletions py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,15 @@ def pixel_unshuffle(
)


def resize_(
def resize(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
sizes: Sequence[int],
) -> TRTTensor:

input_np_dtype = unified_dtype_converter(input.dtype, Frameworks.NUMPY)

input_val = get_trt_tensor(ctx, input, f"{name}_input")

# Calculate the total number of elements for new and current shape
Expand All @@ -158,31 +156,34 @@ def resize_(

# Flatten input tensor to 1D for concatenation
flatten_shape = flatten_dims(input_val, 0, -1)
flattened_input = impl.shuffle.reshape(
flattened_input = reshape(
ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape
)

# Concatenate the flattened input tensor and padding tensor
concat_layer = ctx.net.add_concatenation([flattened_input, padding_tensor])
concat_layer.axis = 0
reshaped_tensor = concat_layer.get_output(0)

reshaped_tensor = impl.cat.cat(
ctx,
target,
source_ir,
f"{name}_cat",
[flattened_input, padding_tensor],
dim=0,
)
elif new_num_elements < current_num_elements:
# Flatten input tensor to 1D for slicing
flatten_shape = flatten_dims(input_val, 0, -1)
flattened_input = impl.shuffle.reshape(
flattened_input = reshape(
ctx, target, source_ir, f"{name}_flatten_input", input_val, flatten_shape
)

# Slice the flattened input tensor to the desired number of elements
slice_layer = ctx.net.add_slice(flattened_input, [0], [new_num_elements], [1])
reshaped_tensor = slice_layer.get_output(0)

else:
reshaped_tensor = input_val

# Reshape the final output tensor to the target sizes
resized_output = impl.shuffle.reshape(
resized_output = reshape(
ctx, target, source_ir, f"{name}_final_reshape", reshaped_tensor, sizes
)

Expand Down
12 changes: 0 additions & 12 deletions tests/py/dynamo/conversion/test_resize_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ class TestResizeConverter(DispatchTestCase):
)
def test_resize_1d_input_float(self, target_shape):
class Resize(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ops.aten.resize_.default(x, target_shape)

Expand All @@ -46,9 +43,6 @@ def forward(self, x):
)
def test_resize_1d_input_int(self, target_shape):
class Resize(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ops.aten.resize_.default(x, target_shape)

Expand All @@ -73,9 +67,6 @@ def forward(self, x):
)
def test_resize_2d_input_float(self, target_shape):
class Resize(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ops.aten.resize_.default(x, target_shape)

Expand All @@ -100,9 +91,6 @@ def forward(self, x):
)
def test_resize_2d_input_int(self, target_shape):
class Resize(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ops.aten.resize_.default(x, target_shape)

Expand Down

0 comments on commit eaecfb2

Please sign in to comment.