Skip to content

Commit

Permalink
Add dynamic shape support for cumsum/grid (#3051)
Browse files Browse the repository at this point in the history
  • Loading branch information
lanluo-nvidia authored Aug 2, 2024
1 parent ee16bad commit a5f4d5b
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 13 deletions.
14 changes: 9 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,14 @@ def aten_ops_fmod(
return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1])


@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.default)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.default)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(
torch.ops.aten.grid_sampler.default, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(
torch.ops.aten.grid_sampler_2d.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down Expand Up @@ -922,7 +926,7 @@ def aten_ops_chunk(
)


@dynamo_tensorrt_converter(torch.ops.aten.cumsum.default)
@dynamo_tensorrt_converter(torch.ops.aten.cumsum.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
39 changes: 34 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,17 +387,46 @@ def cumsum(
input: TRTTensor,
dim: int,
) -> TRTTensor:

input_shape = input.shape
dim = get_positive_dim(dim, len(input_shape))
if input_shape[dim] < 0:
trip_limit = impl.shape.shape(
ctx, target, source_ir, name + "_shape", input, dim
)
# the trip_limit has to be a 0D shape tensor, however this impl.shape.shape gives a 1D shape
# for example if the trip limit is 3, it wants a tensor(3), not a tensor([3])
# in order to reduce it from 1D to 0D, i have to use this impl.reduce.sum
trip_limit = impl.reduce.sum(
ctx, target, source_ir, name, trip_limit, 0, keepdim=False
)
else:
axis = np.array(input_shape[dim])
trip_limit = get_trt_tensor(ctx, axis, f"{name}_trip_limit")

loop = ctx.net.add_loop()
axis = np.array(input_shape[dim])
trip_limit = get_trt_tensor(ctx, axis, f"{name}_trip_limit")
loop.add_trip_limit(trip_limit, trt.TripLimit.COUNT)
iterator = loop.add_iterator(input, dim, reverse=False)
data = iterator.get_output(0)
new_dims = tuple(data.shape)
zeros = np.zeros(new_dims)
zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value")
if has_dynamic_shape(data.shape):
data_shape = []
for i in range(len(input_shape)):
if i != dim:
if input_shape[i] < 0:
data_shape.append(
impl.shape.shape(
ctx, target, source_ir, name + f"_{i}_shape", input, i
)
)
else:
data_shape.append(input_shape[i])
zero_trttensor = impl.full.full(
ctx, target, source_ir, name + "_full", data_shape, 0.0
)
else:
new_dims = tuple(data.shape)
zeros = np.zeros(new_dims)
zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value")

running_sum = loop.add_recurrence(zero_trttensor)
set_layer_name(running_sum, target, f"{name}_running_sum", source_ir)
Expand Down
32 changes: 31 additions & 1 deletion tests/py/dynamo/conversion/test_cumsum_aten.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import torch_tensorrt
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

Expand Down Expand Up @@ -46,7 +47,7 @@ def forward(self, x):

@parameterized.expand(
[
((4, 2, 3), 0),
((2, 3, 3), 0),
((4, 2, 3), 1),
((1, 2, 3), 2),
((1, 2, 3), -1),
Expand All @@ -64,6 +65,35 @@ def forward(self, x):
inputs,
)

@parameterized.expand(
[
((1,), (2,), (3,), 0),
((1,), (2,), (3,), -1),
((2, 3), (2, 4), (2, 5), 0),
((2, 3), (3, 4), (4, 5), -1),
((1, 2, 2), (2, 2, 3), (3, 3, 3), 0),
((1, 2, 2), (2, 2, 3), (3, 2, 3), -2),
((1, 2, 2, 3), (2, 2, 3, 4), (3, 3, 4, 5), -3),
((1, 2, 2, 3), (2, 2, 3, 4), (3, 3, 4, 5), -2),
]
)
def test_cumsum_dynamic_shape(self, min_shape, opt_shape, max_shape, dims):
class Cumsum(nn.Module):
def forward(self, x):
return torch.ops.aten.cumsum.default(x, dims)

inputs = [
torch_tensorrt.Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
Cumsum(),
inputs,
)


if __name__ == "__main__":
run_tests()
137 changes: 135 additions & 2 deletions tests/py/dynamo/conversion/test_grid_aten.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pytest
import torch
import torch.nn as nn
from .harness import DispatchTestCase
import torch_tensorrt
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

grid_sampler_aten_ops = {
"torch.ops.aten.grid_sampler": torch.ops.aten.grid_sampler,
Expand Down Expand Up @@ -185,6 +186,138 @@ def forward(self, x):
grid_model = TestModule(op)
self.run_test(grid_model, inputs)

@parameterized.expand(
[
(
(1, 1, 2, 2),
(2, 2, 3, 3),
(3, 3, 5, 5),
(1, 2, 2, 2),
(2, 3, 3, 2),
(3, 5, 5, 2),
0,
0,
True,
),
(
(1, 1, 2, 2),
(2, 2, 3, 3),
(3, 3, 5, 5),
(1, 2, 2, 2),
(2, 3, 3, 2),
(3, 5, 5, 2),
0,
2,
True,
),
(
(1, 1, 2, 2),
(1, 1, 3, 3),
(1, 1, 5, 5),
(1, 3, 3, 2),
(1, 4, 4, 2),
(1, 5, 5, 2),
0,
1,
True,
),
(
(1, 1, 2, 2),
(2, 2, 3, 3),
(3, 3, 5, 5),
(1, 4, 2, 2),
(2, 4, 3, 2),
(3, 4, 5, 2),
1,
0,
True,
),
(
(1, 1, 2, 2),
(2, 2, 3, 3),
(3, 3, 5, 5),
(1, 4, 2, 2),
(2, 5, 3, 2),
(3, 5, 5, 2),
1,
1,
False,
),
]
)
def test_grid_2d_default_dynamic_shape(
self,
input_min_shape,
input_opt_shape,
input_max_shape,
grid_min_shape,
grid_opt_shape,
grid_max_shape,
interpolation_mode,
padding_mode,
align_corners,
):
class Grid_SAMPLER_2D(nn.Module):
def forward(self, input, grid):
return torch.ops.aten.grid_sampler_2d(
input, grid, interpolation_mode, padding_mode, align_corners
)

class Grid_SAMPLER_2D_default(nn.Module):
def forward(self, input, grid):
return torch.ops.aten.grid_sampler_2d.default(
input, grid, interpolation_mode, padding_mode, align_corners
)

class Grid_SAMPLER(nn.Module):
def forward(self, input, grid):
return torch.ops.aten.grid_sampler(
input, grid, interpolation_mode, padding_mode, align_corners
)

class Grid_SAMPLER_default(nn.Module):
def forward(self, input, grid):
return torch.ops.aten.grid_sampler.default(
input, grid, interpolation_mode, padding_mode, align_corners
)

inputs = [
torch_tensorrt.Input(
min_shape=input_min_shape,
opt_shape=input_opt_shape,
max_shape=input_max_shape,
dtype=torch.float32,
torch_tensorrt=torch.randn(input_opt_shape, dtype=torch.float32),
),
torch_tensorrt.Input(
min_shape=grid_min_shape,
opt_shape=grid_opt_shape,
max_shape=grid_max_shape,
dtype=torch.float32,
torch_tensor=torch.randint(-1, 1, grid_opt_shape, dtype=torch.float32),
),
]
self.run_test_with_dynamic_shape(
Grid_SAMPLER_2D(),
inputs,
use_example_tensors=False,
)
self.run_test_with_dynamic_shape(
Grid_SAMPLER_2D_default(),
inputs,
use_example_tensors=False,
)
self.run_test_with_dynamic_shape(
Grid_SAMPLER(),
inputs,
use_example_tensors=False,
)
self.run_test_with_dynamic_shape(
Grid_SAMPLER_default(),
inputs,
use_example_tensors=False,
)


if __name__ == "__main__":
run_tests()

0 comments on commit a5f4d5b

Please sign in to comment.