Skip to content

Commit

Permalink
fix the global partitioner bug (#3195)
Browse files Browse the repository at this point in the history
  • Loading branch information
lanluo-nvidia authored Oct 7, 2024
1 parent ededc0b commit ce38387
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 3 deletions.
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:

# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False

# If specified, try using the fast partitioner and fall back to the global one on failure
if settings.use_fast_partitioner:
try:
Expand Down Expand Up @@ -408,6 +407,9 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
# Generate the corresponding TRT Module for those
for name, _ in partitioned_module.named_children():
submodule = getattr(partitioned_module, name)
# filter on the GraphModule
if not isinstance(submodule, torch.fx.graph_module.GraphModule):
continue
# Criteria for a module to be convertible to TRT
if settings.use_fast_partitioner and "_run_on_acc" not in name:
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule))
Expand Down
3 changes: 1 addition & 2 deletions py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,7 @@ def partition(
# Determine partitions based on user specifications and operator support
# Then, fuse partitions and display overview of supported/unsupported operators
partitions = partitioner.propose_partitions()
fused_graph = partitioner.fuse_partitions(partitions)

fused_graph = partitioner.fuse_partitions(partitions, prefix="_run_on_acc_")
if verbose:
supported_ops.print_support_overview(len(partitions))

Expand Down
52 changes: 52 additions & 0 deletions tests/py/dynamo/partitioning/test_global_partitioning.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,66 @@
from copy import deepcopy

import numpy as np
import pytest
import torch
import torch.nn.functional as F
import torch_tensorrt
from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo import partitioning

from ..testing_utilities import lower_graph_testing


class TestGlobalPartitioning(TestCase):
@parameterized.expand(
[
({}, 1),
({"torch.ops.aten.relu.default"}, 3),
]
)
def test_end2end_global_partition(self, torch_executed_ops, trt_mod_cnt):
class SimpleCNN(torch.nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 12, 3, padding=1)
self.bn = torch.nn.BatchNorm2d(12)
self.conv2 = torch.nn.Conv2d(12, 12, 3, padding=1)
self.fc1 = torch.nn.Linear(12 * 56 * 56, 10)

def forward(self, x, b=5):
x = self.conv1(x)
x = F.relu(x)
x = self.bn(x)
x = F.max_pool2d(x, (2, 2))
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, (2, 2))
x = torch.flatten(x, 1)
x = x + b
return self.fc1(x)

mod = SimpleCNN().to("cuda")
mod.eval()
with torch.no_grad():
inputs = torch.rand((1, 3, 224, 224)).to("cuda")
try:
trt_mod = torch_tensorrt.compile(
mod,
ir="dynamo",
inputs=[inputs],
min_block_size=1,
torch_executed_ops=torch_executed_ops,
use_fast_partitioner=False,
)
cnt = 0
for name, _ in trt_mod.named_children():
if "_run_on_acc" in name:
cnt += 1
self.assertEqual(cnt, trt_mod_cnt)
except Exception as e:
pytest.fail(f"unexpected exception raised: {e}")

def test_partition_fully_supported_one_op(self):
class FullySupportedOneOp(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
Expand Down

0 comments on commit ce38387

Please sign in to comment.