Skip to content

Commit

Permalink
feat: dynamic shape support for adaptive_avg_poolNd (partially) (#3021)
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 authored Jul 31, 2024
1 parent 8536289 commit 5bd948f
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 35 deletions.
20 changes: 15 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2603,7 +2603,9 @@ def aten_ops_avg_pool(
)


@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool1d.default)
@dynamo_tensorrt_converter(
torch.ops.aten.adaptive_avg_pool1d.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand All @@ -2626,10 +2628,18 @@ def aten_ops_adaptive_avg_pool1d(
)


@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool2d.default)
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default)
@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool3d.default)
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default)
@dynamo_tensorrt_converter(
torch.ops.aten.adaptive_avg_pool2d.default, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(
torch.ops.aten._adaptive_avg_pool2d.default, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(
torch.ops.aten.adaptive_avg_pool3d.default, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(
torch.ops.aten._adaptive_avg_pool3d.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
17 changes: 17 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int:
"""Calculate the end index of each pooling window"""
return math.ceil((float(idx + 1) * float(in_dim)) / out_dim)

if has_dynamic_shape(input.shape):
assert (
input.shape[-1] != -1 and input.shape[-2] != -1
), "Last 2 dimensions can't be dynamic for adaptive_avg_pool1d."

in_dim = input.shape[-1]
out_dim = output_size if isinstance(output_size, int) else output_size[0]
output_list = []
Expand Down Expand Up @@ -179,6 +184,18 @@ def adaptive_avg_poolNd(
input: TRTTensor,
output_size: Sequence[int],
) -> TRTTensor:
if has_dynamic_shape(input.shape):
if len(output_size) == 2: # adaptive_avg_pool2d
assert (
input.shape[-1] != -1 and input.shape[-2] != -1
), "Last 2 dimensions can't be dynamic for adaptive_avg_pool2d."
elif len(output_size) == 3: # adaptive_avg_pool3d
assert (
input.shape[-1] != -1
and input.shape[-2] != -1
and input.shape[-3] != -1
), "Last 3 dimensions can't be dynamic for adaptive_avg_pool3d."

input_shape = input.shape
input_rank = len(input_shape)
output_rank = len(output_size)
Expand Down
110 changes: 80 additions & 30 deletions tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,40 @@ def forward(self, x):
enable_passes=True,
)

@parameterized.expand(
[
(
(1, 3, 3),
(2, 3, 3),
(3, 3, 3),
torch.float,
(2,),
),
]
)
def test_dynamic_shape_adaptive_pool1d(
self,
min_shape,
opt_shape,
max_shape,
type,
output_size,
):
class adaptive_pool1d(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.adaptive_avg_pool1d.default(x, output_size)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]

self.run_test_with_dynamic_shape(adaptive_pool1d(), input_specs)

@parameterized.expand(
[
# 3d input
Expand Down Expand Up @@ -159,29 +193,37 @@ def forward(self, x):

@parameterized.expand(
[
((1, 2),),
(
(1, 1, 3, 3),
(2, 2, 3, 3),
(3, 3, 3, 3),
torch.float,
(2, 2),
),
]
)
def test_adaptive_avg_pool2d_dynamic(self, output_size):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def test_dynamic_shape_adaptive_pool2d(
self,
min_shape,
opt_shape,
max_shape,
type,
output_size,
):
class adaptive_pool2d(torch.nn.Module):
def forward(self, x):
out = torch.ops.aten.adaptive_avg_pool2d.default(x, output_size)
return out
return torch.ops.aten.adaptive_avg_pool2d.default(x, output_size)

input_specs = [
Input(
shape=(-1, 2, 3, 2),
dtype=torch.float32,
shape_ranges=[((1, 2, 3, 2), (3, 2, 3, 2), (10, 2, 3, 2))],
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
TestModule(),
input_specs,
)

self.run_test_with_dynamic_shape(adaptive_pool2d(), input_specs)

@parameterized.expand(
[
Expand Down Expand Up @@ -271,29 +313,37 @@ def forward(self, x):

@parameterized.expand(
[
((1, 2, 3),),
(
(1, 1, 3, 3, 3),
(2, 2, 3, 3, 3),
(3, 3, 3, 3, 3),
torch.float,
(2, 2, 2),
),
]
)
def test_adaptive_avg_pool3d_dynamic(self, output_size):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def test_dynamic_shape_adaptive_pool3d(
self,
min_shape,
opt_shape,
max_shape,
type,
output_size,
):
class adaptive_pool3d(torch.nn.Module):
def forward(self, x):
out = torch.ops.aten.adaptive_avg_pool3d.default(x, output_size)
return out
return torch.ops.aten.adaptive_avg_pool3d.default(x, output_size)

input_specs = [
Input(
shape=(-1, 2, 3, 1, 4),
dtype=torch.float32,
shape_ranges=[((1, 2, 3, 1, 4), (3, 2, 3, 1, 4), (10, 2, 3, 1, 4))],
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
TestModule(),
input_specs,
)

self.run_test_with_dynamic_shape(adaptive_pool3d(), input_specs)


if __name__ == "__main__":
Expand Down

0 comments on commit 5bd948f

Please sign in to comment.