Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add dynamic support for eq/ne/lt/le #2979

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2253,8 +2253,8 @@ def aten_ops_bitwise_not(
)


@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar)
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand All @@ -2277,8 +2277,8 @@ def aten_ops_eq(
)


@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar)
@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down Expand Up @@ -2349,8 +2349,8 @@ def aten_ops_ge(
)


@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar)
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand All @@ -2373,8 +2373,8 @@ def aten_ops_lt(
)


@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar)
@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
85 changes: 85 additions & 0 deletions tests/py/dynamo/conversion/test_eq_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -61,6 +62,90 @@ def forward(self, lhs_val):
inputs,
)

@parameterized.expand(
[
((1,), (3,), (5,)),
((1, 20), (2, 20), (3, 20)),
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
]
)
def test_eq_tensor_dynamic_shape(self, min_shape, opt_shape, max_shape):
class eq(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.eq.Tensor(lhs_val, rhs_val)

input_specs = [
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
eq(),
input_specs,
)

@parameterized.expand(
[
((1,), (3,), (5,)),
((1, 20), (2, 20), (3, 20)),
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
]
)
def test_eq_tensor_scalar_dynamic_shape(self, min_shape, opt_shape, max_shape):
class eq(nn.Module):
def forward(self, lhs_val):
return torch.ops.aten.eq.Tensor(lhs_val, torch.tensor(1))

input_specs = [
Input(
dtype=torch.int32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
eq(),
input_specs,
)

@parameterized.expand(
[
((1,), (3,), (5,)),
((1, 20), (2, 20), (3, 20)),
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
]
)
def test_eq_scalar_dynamic_shape(self, min_shape, opt_shape, max_shape):
class eq(nn.Module):
def forward(self, lhs_val):
return torch.ops.aten.eq.Scalar(lhs_val, 1.0)

input_specs = [
Input(
dtype=torch.int32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
eq(),
input_specs,
)


if __name__ == "__main__":
run_tests()
85 changes: 85 additions & 0 deletions tests/py/dynamo/conversion/test_le_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -61,6 +62,90 @@ def forward(self, lhs_val):
inputs,
)

@parameterized.expand(
[
((1,), (3,), (5,)),
((1, 20), (2, 20), (3, 20)),
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
]
)
def test_le_tensor_dynamic_shape(self, min_shape, opt_shape, max_shape):
class le(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.le.Tensor(lhs_val, rhs_val)

input_specs = [
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
le(),
input_specs,
)

@parameterized.expand(
[
((1,), (3,), (5,)),
((1, 20), (2, 20), (3, 20)),
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
]
)
def test_le_tensor_scalar_dynamic_shape(self, min_shape, opt_shape, max_shape):
class le(nn.Module):
def forward(self, lhs_val):
return torch.ops.aten.le.Tensor(lhs_val, torch.tensor(1))

input_specs = [
Input(
dtype=torch.int32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
le(),
input_specs,
)

@parameterized.expand(
[
((1,), (3,), (5,)),
((1, 20), (2, 20), (3, 20)),
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
]
)
def test_le_scalar_dynamic_shape(self, min_shape, opt_shape, max_shape):
class le(nn.Module):
def forward(self, lhs_val):
return torch.ops.aten.le.Scalar(lhs_val, 1.0)

input_specs = [
Input(
dtype=torch.int32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
le(),
input_specs,
)


if __name__ == "__main__":
run_tests()
85 changes: 85 additions & 0 deletions tests/py/dynamo/conversion/test_lt_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -58,6 +59,90 @@ def forward(self, lhs_val):
inputs,
)

@parameterized.expand(
[
((1,), (3,), (5,)),
((1, 20), (2, 20), (3, 20)),
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
]
)
def test_lt_tensor_dynamic_shape(self, min_shape, opt_shape, max_shape):
class lt(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.lt.Tensor(lhs_val, rhs_val)

input_specs = [
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
Input(
dtype=torch.float32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
lt(),
input_specs,
)

@parameterized.expand(
[
((1,), (3,), (5,)),
((1, 20), (2, 20), (3, 20)),
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
]
)
def test_lt_tensor_scalar_dynamic_shape(self, min_shape, opt_shape, max_shape):
class lt(nn.Module):
def forward(self, lhs_val):
return torch.ops.aten.lt.Tensor(lhs_val, torch.tensor(1))

input_specs = [
Input(
dtype=torch.int32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
lt(),
input_specs,
)

@parameterized.expand(
[
((1,), (3,), (5,)),
((1, 20), (2, 20), (3, 20)),
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
]
)
def test_lt_scalar_dynamic_shape(self, min_shape, opt_shape, max_shape):
class lt(nn.Module):
def forward(self, lhs_val):
return torch.ops.aten.lt.Scalar(lhs_val, 1.0)

input_specs = [
Input(
dtype=torch.int32,
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
lt(),
input_specs,
)


if __name__ == "__main__":
run_tests()
Loading
Loading