From ce38387b6c49d127bddfff9b6c335b261f461bdd Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 7 Oct 2024 10:15:39 -0700 Subject: [PATCH] fix the global partitioner bug (#3195) --- py/torch_tensorrt/dynamo/_compiler.py | 4 +- .../partitioning/_global_partitioner.py | 3 +- .../partitioning/test_global_partitioning.py | 52 +++++++++++++++++++ 3 files changed, 56 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 97aa2ec443..d213cca638 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -366,7 +366,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: # Partition module into components that can be TRT-accelerated fast_partitioner_failed = False - # If specified, try using the fast partitioner and fall back to the global one on failure if settings.use_fast_partitioner: try: @@ -408,6 +407,9 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: # Generate the corresponding TRT Module for those for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) + # filter on the GraphModule + if not isinstance(submodule, torch.fx.graph_module.GraphModule): + continue # Criteria for a module to be convertible to TRT if settings.use_fast_partitioner and "_run_on_acc" not in name: dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule)) diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index 823a43beb8..bdca0e1e1d 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -228,8 +228,7 @@ def partition( # Determine partitions based on user specifications and operator support # Then, fuse partitions and display overview of supported/unsupported operators partitions = partitioner.propose_partitions() - fused_graph = partitioner.fuse_partitions(partitions) - + fused_graph = partitioner.fuse_partitions(partitions, prefix="_run_on_acc_") if verbose: supported_ops.print_support_overview(len(partitions)) diff --git a/tests/py/dynamo/partitioning/test_global_partitioning.py b/tests/py/dynamo/partitioning/test_global_partitioning.py index cd9da7521c..80b6716d20 100644 --- a/tests/py/dynamo/partitioning/test_global_partitioning.py +++ b/tests/py/dynamo/partitioning/test_global_partitioning.py @@ -1,7 +1,11 @@ from copy import deepcopy import numpy as np +import pytest import torch +import torch.nn.functional as F +import torch_tensorrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt.dynamo import partitioning @@ -9,6 +13,54 @@ class TestGlobalPartitioning(TestCase): + @parameterized.expand( + [ + ({}, 1), + ({"torch.ops.aten.relu.default"}, 3), + ] + ) + def test_end2end_global_partition(self, torch_executed_ops, trt_mod_cnt): + class SimpleCNN(torch.nn.Module): + def __init__(self): + super(SimpleCNN, self).__init__() + self.conv1 = torch.nn.Conv2d(3, 12, 3, padding=1) + self.bn = torch.nn.BatchNorm2d(12) + self.conv2 = torch.nn.Conv2d(12, 12, 3, padding=1) + self.fc1 = torch.nn.Linear(12 * 56 * 56, 10) + + def forward(self, x, b=5): + x = self.conv1(x) + x = F.relu(x) + x = self.bn(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + x = x + b + return self.fc1(x) + + mod = SimpleCNN().to("cuda") + mod.eval() + with torch.no_grad(): + inputs = torch.rand((1, 3, 224, 224)).to("cuda") + try: + trt_mod = torch_tensorrt.compile( + mod, + ir="dynamo", + inputs=[inputs], + min_block_size=1, + torch_executed_ops=torch_executed_ops, + use_fast_partitioner=False, + ) + cnt = 0 + for name, _ in trt_mod.named_children(): + if "_run_on_acc" in name: + cnt += 1 + self.assertEqual(cnt, trt_mod_cnt) + except Exception as e: + pytest.fail(f"unexpected exception raised: {e}") + def test_partition_fully_supported_one_op(self): class FullySupportedOneOp(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: