Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for prelu dynamo converter #2972

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3370,3 +3370,29 @@ def aten_ops_native_dropout(
args[1],
args_bounds_check(args, 2, None),
)


@dynamo_tensorrt_converter(
torch.ops.aten._prelu_kernel.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
1: (TRTTensor,),
}
)
def aten_ops_prelu(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.prelu.prelu(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
pad,
permutation,
pool,
prelu,
quantize,
reduce,
select,
Expand Down
20 changes: 20 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/prelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Optional

from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import set_layer_name
from torch_tensorrt.dynamo.types import TRTTensor


def prelu(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
weight: TRTTensor,
) -> TRTTensor:
layer = ctx.net.add_parametric_relu(input, weight)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@
aten.norm,
aten.ones,
aten.ones_like,
aten._prelu_kernel,
aten._prelu_kernel_backward,
aten._reshape_alias,
aten.rad2deg,
Expand Down
66 changes: 66 additions & 0 deletions tests/py/dynamo/conversion/test_prelu_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestPReLUConverter(DispatchTestCase):
def test_prelu(self):
class TestModule(nn.Module):
def forward(self, x, weight):
return torch.ops.aten._prelu_kernel.default(x, weight)

inputs = [torch.randn(1, 10), torch.randn(1, 1)]
self.run_test(TestModule(), inputs)

def test_prelu_with_dynamic_shape(self):
class TestModule(nn.Module):
def forward(self, x, weight):
return torch.ops.aten._prelu_kernel.default(x, weight)

input_specs = [
Input(
min_shape=(1, 1, 1),
opt_shape=(1, 2, 3),
max_shape=(3, 3, 3),
dtype=torch.float32,
name="x",
),
Input(
min_shape=(1, 1, 1),
opt_shape=(1, 1, 1),
max_shape=(1, 1, 1),
dtype=torch.float32,
name="weight",
),
]
self.run_test_with_dynamic_shape(TestModule(), input_specs)

def test_prelu_with_dynamic_shape_four_dimensions(self):
class TestModule(nn.Module):
def forward(self, x, weight):
return torch.ops.aten._prelu_kernel.default(x, weight)

input_specs = [
Input(
min_shape=(1, 1, 1, 5),
opt_shape=(1, 2, 3, 5),
max_shape=(3, 3, 3, 5),
dtype=torch.float32,
name="x",
),
Input(
min_shape=(1, 1, 1, 1),
opt_shape=(1, 2, 1, 1),
max_shape=(1, 3, 1, 1),
dtype=torch.float32,
name="weight",
),
]
self.run_test_with_dynamic_shape(TestModule(), input_specs)


if __name__ == "__main__":
run_tests()
Loading