Skip to content

Commit

Permalink
feat: dynamic shape support for tan, sinh, cosh, asin and acos (#2941)
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 authored Jun 24, 2024
1 parent 7b825f5 commit 1f10647
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 5 deletions.
10 changes: 5 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,7 +1511,7 @@ def aten_ops_cos(
)


@dynamo_tensorrt_converter(torch.ops.aten.tan.default)
@dynamo_tensorrt_converter(torch.ops.aten.tan.default, supports_dynamic_shapes=True)
def aten_ops_tan(
ctx: ConversionContext,
target: Target,
Expand All @@ -1528,7 +1528,7 @@ def aten_ops_tan(
)


@dynamo_tensorrt_converter(torch.ops.aten.sinh.default)
@dynamo_tensorrt_converter(torch.ops.aten.sinh.default, supports_dynamic_shapes=True)
def aten_ops_sinh(
ctx: ConversionContext,
target: Target,
Expand All @@ -1545,7 +1545,7 @@ def aten_ops_sinh(
)


@dynamo_tensorrt_converter(torch.ops.aten.cosh.default)
@dynamo_tensorrt_converter(torch.ops.aten.cosh.default, supports_dynamic_shapes=True)
def aten_ops_cosh(
ctx: ConversionContext,
target: Target,
Expand All @@ -1562,7 +1562,7 @@ def aten_ops_cosh(
)


@dynamo_tensorrt_converter(torch.ops.aten.asin.default)
@dynamo_tensorrt_converter(torch.ops.aten.asin.default, supports_dynamic_shapes=True)
def aten_ops_asin(
ctx: ConversionContext,
target: Target,
Expand All @@ -1579,7 +1579,7 @@ def aten_ops_asin(
)


@dynamo_tensorrt_converter(torch.ops.aten.acos.default)
@dynamo_tensorrt_converter(torch.ops.aten.acos.default, supports_dynamic_shapes=True)
def aten_ops_acos(
ctx: ConversionContext,
target: Target,
Expand Down
51 changes: 51 additions & 0 deletions tests/py/dynamo/conversion/test_acos_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -44,6 +45,56 @@ def forward(self, input):
inputs,
)

@parameterized.expand(
[
(
"3d_dim_dtype_int32",
(3, 2, 1),
(3, 2, 3),
(3, 3, 4),
torch.int32,
torch.float32,
),
(
"2d_dim_dtype_float16",
(1, 1),
(2, 2),
(4, 4),
torch.float16,
torch.float16,
),
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
]
)
def test_acos_dynamic_shape(
self, _, min_shape, opt_shape, max_shape, type, output_type
):
class acos(nn.Module):
def forward(self, input):
return torch.ops.aten.acos.default(input)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]

self.run_test_with_dynamic_shape(
acos(),
input_specs,
output_dtypes=[output_type],
)


if __name__ == "__main__":
run_tests()
50 changes: 50 additions & 0 deletions tests/py/dynamo/conversion/test_asin_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -44,6 +45,55 @@ def forward(self, input):
inputs,
)

@parameterized.expand(
[
(
"3d_dim_dtype_int32",
(3, 2, 1),
(3, 2, 3),
(3, 3, 4),
torch.int32,
torch.float32,
),
(
"2d_dim_dtype_float16",
(1, 1),
(2, 2),
(4, 4),
torch.float16,
torch.float16,
),
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
]
)
def test_asin_dynamic_shape(
self, _, min_shape, opt_shape, max_shape, type, output_type
):
class asin(nn.Module):
def forward(self, input):
return torch.ops.aten.asin.default(input)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
asin(),
input_specs,
output_dtypes=[output_type],
)


if __name__ == "__main__":
run_tests()
50 changes: 50 additions & 0 deletions tests/py/dynamo/conversion/test_cosh_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -44,6 +45,55 @@ def forward(self, input):
inputs,
)

@parameterized.expand(
[
(
"3d_dim_dtype_int32",
(3, 2, 1),
(3, 2, 3),
(3, 3, 4),
torch.int32,
torch.float32,
),
(
"2d_dim_dtype_float16",
(1, 1),
(2, 2),
(4, 4),
torch.float16,
torch.float16,
),
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
]
)
def test_cosh_dynamic_shape(
self, _, min_shape, opt_shape, max_shape, type, output_type
):
class cosh(nn.Module):
def forward(self, input):
return torch.ops.aten.cosh.default(input)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
cosh(),
input_specs,
output_dtypes=[output_type],
)


if __name__ == "__main__":
run_tests()
50 changes: 50 additions & 0 deletions tests/py/dynamo/conversion/test_sinh_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -44,6 +45,55 @@ def forward(self, input):
inputs,
)

@parameterized.expand(
[
(
"3d_dim_dtype_int32",
(3, 2, 1),
(3, 2, 3),
(3, 3, 4),
torch.int32,
torch.float32,
),
(
"2d_dim_dtype_float16",
(1, 1),
(2, 2),
(4, 4),
torch.float16,
torch.float16,
),
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
]
)
def test_sinh_dynamic_shape(
self, _, min_shape, opt_shape, max_shape, type, output_type
):
class sinh(nn.Module):
def forward(self, input):
return torch.ops.aten.sinh.default(input)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
sinh(),
input_specs,
output_dtypes=[output_type],
)


if __name__ == "__main__":
run_tests()
51 changes: 51 additions & 0 deletions tests/py/dynamo/conversion/test_tan_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -44,6 +45,56 @@ def forward(self, input):
inputs,
)

@parameterized.expand(
[
(
"3d_dim_dtype_int32",
(3, 2, 1),
(3, 2, 3),
(3, 3, 4),
torch.int32,
torch.float32,
),
(
"2d_dim_dtype_float16",
(1, 1),
(2, 2),
(4, 4),
torch.float16,
torch.float16,
),
(
"3d_dim_dtype_float",
(1, 1, 1),
(1, 2, 3),
(3, 3, 3),
torch.float,
torch.float,
),
]
)
def test_tan_dynamic_shape(
self, _, min_shape, opt_shape, max_shape, type, output_type
):
class tan(nn.Module):
def forward(self, input):
return torch.ops.aten.tan.default(input)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]

self.run_test_with_dynamic_shape(
tan(),
input_specs,
output_dtypes=[output_type],
)


if __name__ == "__main__":
run_tests()

0 comments on commit 1f10647

Please sign in to comment.