From 81f8b6e7be62eca12c48a690fb3b43cfe9069fd4 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sun, 29 Sep 2024 16:41:08 -0700 Subject: [PATCH 1/3] fix the globalpartitioner bug https://github.com/pytorch/TensorRT/issues/3157 --- py/torch_tensorrt/dynamo/_compiler.py | 4 +++- .../dynamo/partitioning/_global_partitioner.py | 18 ++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 97aa2ec443..d213cca638 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -366,7 +366,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: # Partition module into components that can be TRT-accelerated fast_partitioner_failed = False - # If specified, try using the fast partitioner and fall back to the global one on failure if settings.use_fast_partitioner: try: @@ -408,6 +407,9 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: # Generate the corresponding TRT Module for those for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) + # filter on the GraphModule + if not isinstance(submodule, torch.fx.graph_module.GraphModule): + continue # Criteria for a module to be convertible to TRT if settings.use_fast_partitioner and "_run_on_acc" not in name: dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule)) diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index 823a43beb8..6086d7c707 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -228,8 +228,22 @@ def partition( # Determine partitions based on user specifications and operator support # Then, fuse partitions and display overview of supported/unsupported operators partitions = partitioner.propose_partitions() - fused_graph = partitioner.fuse_partitions(partitions) - + # TODO: confirm with Naren whether this change is required or not + # tested both with and without this change, it both works + # the only difference is the graph node name, an example is as below: + # graph(): + # %x : [num_users=1] = placeholder[target=x] + # %_run_on_acc_0 : [num_users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {}) + # return (_run_on_acc_0,) + + # or + + # graph(): + # %x : [num_users=1] = placeholder[target=x] + # %fused_0 : [num_users=1] = call_module[target=fused_0](args = (%x,), kwargs = {}) + # return (fused_0,) + + fused_graph = partitioner.fuse_partitions(partitions, prefix="_run_on_acc_") if verbose: supported_ops.print_support_overview(len(partitions)) From 0402ba76debb8cd6b2fd32cd9be15b268178c50a Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 2 Oct 2024 16:12:15 -0700 Subject: [PATCH 2/3] add end2end test cases for global partition --- .../partitioning/_global_partitioner.py | 15 -------- .../partitioning/test_global_partitioning.py | 38 +++++++++++++++++++ 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index 6086d7c707..bdca0e1e1d 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -228,21 +228,6 @@ def partition( # Determine partitions based on user specifications and operator support # Then, fuse partitions and display overview of supported/unsupported operators partitions = partitioner.propose_partitions() - # TODO: confirm with Naren whether this change is required or not - # tested both with and without this change, it both works - # the only difference is the graph node name, an example is as below: - # graph(): - # %x : [num_users=1] = placeholder[target=x] - # %_run_on_acc_0 : [num_users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {}) - # return (_run_on_acc_0,) - - # or - - # graph(): - # %x : [num_users=1] = placeholder[target=x] - # %fused_0 : [num_users=1] = call_module[target=fused_0](args = (%x,), kwargs = {}) - # return (fused_0,) - fused_graph = partitioner.fuse_partitions(partitions, prefix="_run_on_acc_") if verbose: supported_ops.print_support_overview(len(partitions)) diff --git a/tests/py/dynamo/partitioning/test_global_partitioning.py b/tests/py/dynamo/partitioning/test_global_partitioning.py index cd9da7521c..5426caeb07 100644 --- a/tests/py/dynamo/partitioning/test_global_partitioning.py +++ b/tests/py/dynamo/partitioning/test_global_partitioning.py @@ -1,7 +1,10 @@ from copy import deepcopy import numpy as np +import pytest import torch +import torch.nn.functional as F +import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt.dynamo import partitioning @@ -9,6 +12,41 @@ class TestGlobalPartitioning(TestCase): + def test_end2end_global_partition(self): + class SimpleCNN(torch.nn.Module): + def __init__(self): + super(SimpleCNN, self).__init__() + self.conv1 = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) + self.conv2 = torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) + self.fc1 = torch.nn.Linear(32 * 134 * 134, 10) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.max_pool2d(x, kernel_size=2, stride=2) + x = F.relu(self.conv2(x)) + x = F.max_pool2d(x, kernel_size=2, stride=2) + x = x.view(x.size(0), -1) + x = self.fc1(x) + return x + + mod = SimpleCNN().to(dtype=torch.float16, device=torch.device("cuda")) + mod.eval() + batch_size, tile_size = 1, 538 + with torch.no_grad(): + inputs = torch.randn( + batch_size, 3, tile_size, tile_size, device="cuda", dtype=torch.float16 + ) + try: + torch_tensorrt.compile( + mod, + ir="dynamo", + inputs=[inputs], + enabled_precisions={torch.float16}, + use_fast_partitioner=False, + ) + except Exception as e: + pytest.fail(f"unexpected exception raised: {e}") + def test_partition_fully_supported_one_op(self): class FullySupportedOneOp(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: From e6f35087ac4728d54c7b132f15469c063314ce7a Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 3 Oct 2024 17:54:08 -0700 Subject: [PATCH 3/3] add test case --- .../partitioning/test_global_partitioning.py | 56 ++++++++++++------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/tests/py/dynamo/partitioning/test_global_partitioning.py b/tests/py/dynamo/partitioning/test_global_partitioning.py index 5426caeb07..80b6716d20 100644 --- a/tests/py/dynamo/partitioning/test_global_partitioning.py +++ b/tests/py/dynamo/partitioning/test_global_partitioning.py @@ -5,6 +5,7 @@ import torch import torch.nn.functional as F import torch_tensorrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt.dynamo import partitioning @@ -12,38 +13,51 @@ class TestGlobalPartitioning(TestCase): - def test_end2end_global_partition(self): + @parameterized.expand( + [ + ({}, 1), + ({"torch.ops.aten.relu.default"}, 3), + ] + ) + def test_end2end_global_partition(self, torch_executed_ops, trt_mod_cnt): class SimpleCNN(torch.nn.Module): def __init__(self): super(SimpleCNN, self).__init__() - self.conv1 = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) - self.conv2 = torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) - self.fc1 = torch.nn.Linear(32 * 134 * 134, 10) - - def forward(self, x): - x = F.relu(self.conv1(x)) - x = F.max_pool2d(x, kernel_size=2, stride=2) - x = F.relu(self.conv2(x)) - x = F.max_pool2d(x, kernel_size=2, stride=2) - x = x.view(x.size(0), -1) - x = self.fc1(x) - return x - - mod = SimpleCNN().to(dtype=torch.float16, device=torch.device("cuda")) + self.conv1 = torch.nn.Conv2d(3, 12, 3, padding=1) + self.bn = torch.nn.BatchNorm2d(12) + self.conv2 = torch.nn.Conv2d(12, 12, 3, padding=1) + self.fc1 = torch.nn.Linear(12 * 56 * 56, 10) + + def forward(self, x, b=5): + x = self.conv1(x) + x = F.relu(x) + x = self.bn(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + x = x + b + return self.fc1(x) + + mod = SimpleCNN().to("cuda") mod.eval() - batch_size, tile_size = 1, 538 with torch.no_grad(): - inputs = torch.randn( - batch_size, 3, tile_size, tile_size, device="cuda", dtype=torch.float16 - ) + inputs = torch.rand((1, 3, 224, 224)).to("cuda") try: - torch_tensorrt.compile( + trt_mod = torch_tensorrt.compile( mod, ir="dynamo", inputs=[inputs], - enabled_precisions={torch.float16}, + min_block_size=1, + torch_executed_ops=torch_executed_ops, use_fast_partitioner=False, ) + cnt = 0 + for name, _ in trt_mod.named_children(): + if "_run_on_acc" in name: + cnt += 1 + self.assertEqual(cnt, trt_mod_cnt) except Exception as e: pytest.fail(f"unexpected exception raised: {e}")