Skip to content

Commit

Permalink
Renamed tests and added test skip
Browse files Browse the repository at this point in the history
  • Loading branch information
cehongwang committed Aug 13, 2024
1 parent 7235de6 commit b3aa04f
Showing 1 changed file with 57 additions and 14 deletions.
71 changes: 57 additions & 14 deletions tests/py/dynamo/models/test_model_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import torch
import torch.nn.functional as F
import torch_tensorrt as torchtrt
import torch_tensorrt as torch_trt
import torchvision.models as models
from torch import nn

# from torch import nn
from torch_tensorrt.dynamo import refit_module_weights
from torch_tensorrt.dynamo._refit import (
construct_refit_mapping,
Expand All @@ -29,6 +28,10 @@
assertions = unittest.TestCase()


@unittest.skipIf(
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
"TorchScript Frontend is not available",
)
@pytest.mark.unit
def test_mapping():

Expand Down Expand Up @@ -81,8 +84,12 @@ def test_mapping():
torch._dynamo.reset()


@unittest.skipIf(
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
"TorchScript Frontend is not available",
)
@pytest.mark.unit
def test_fast_refit_one_engine():
def test_refit_one_engine_with_weightmap():

model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down Expand Up @@ -126,8 +133,12 @@ def test_fast_refit_one_engine():
torch._dynamo.reset()


@unittest.skipIf(
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
"TorchScript Frontend is not available",
)
@pytest.mark.unit
def test_fast_refit_one_engine_no_map():
def test_refit_one_engine_no_map_with_weightmap():

model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down Expand Up @@ -173,8 +184,12 @@ def test_fast_refit_one_engine_no_map():
torch._dynamo.reset()


@unittest.skipIf(
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
"TorchScript Frontend is not available",
)
@pytest.mark.unit
def test_fast_refit_one_engine_wrong_map():
def test_refit_one_engine_with_wrong_weightmap():

model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down Expand Up @@ -224,8 +239,12 @@ def test_fast_refit_one_engine_wrong_map():
torch._dynamo.reset()


@unittest.skipIf(
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
"TorchScript Frontend is not available",
)
@pytest.mark.unit
def test_fast_refit_one_engine_bert():
def test_refit_one_engine_bert_with_weightmap():
inputs = [
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
]
Expand Down Expand Up @@ -275,8 +294,12 @@ def test_fast_refit_one_engine_bert():
torch._dynamo.reset()


@unittest.skipIf(
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
"TorchScript Frontend is not available",
)
@pytest.mark.unit
def test_fast_refit_one_engine_inline_runtime():
def test_refit_one_engine_inline_runtime__with_weightmap():
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down Expand Up @@ -322,7 +345,7 @@ def test_fast_refit_one_engine_inline_runtime():


@pytest.mark.unit
def test_fast_refit_one_engine_python_runtime():
def test_refit_one_engine_python_runtime_with_weightmap():

model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down Expand Up @@ -366,8 +389,12 @@ def test_fast_refit_one_engine_python_runtime():
torch._dynamo.reset()


@unittest.skipIf(
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
"TorchScript Frontend is not available",
)
@pytest.mark.unit
def test_fast_refit_multiple_engine():
def test_refit_multiple_engine_with_weightmap():

class net(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -433,8 +460,12 @@ def forward(self, x):
torch._dynamo.reset()


@unittest.skipIf(
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
"TorchScript Frontend is not available",
)
@pytest.mark.unit
def test_refit_one_engine():
def test_refit_one_engine_without_weightmap():

model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down Expand Up @@ -478,8 +509,12 @@ def test_refit_one_engine():
torch._dynamo.reset()


@unittest.skipIf(
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
"TorchScript Frontend is not available",
)
@pytest.mark.unit
def test_refit_one_engine_bert():
def test_refit_one_engine_bert_without_weightmap():
inputs = [
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
]
Expand Down Expand Up @@ -529,8 +564,12 @@ def test_refit_one_engine_bert():
torch._dynamo.reset()


@unittest.skipIf(
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
"TorchScript Frontend is not available",
)
@pytest.mark.unit
def test_refit_one_engine_inline_runtime():
def test_refit_one_engine_inline_runtime_without_weightmap():
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down Expand Up @@ -576,7 +615,7 @@ def test_refit_one_engine_inline_runtime():


@pytest.mark.unit
def test_refit_one_engine_python_runtime():
def test_refit_one_engine_python_runtime_without_weightmap():

model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down Expand Up @@ -620,8 +659,12 @@ def test_refit_one_engine_python_runtime():
torch._dynamo.reset()


@unittest.skipIf(
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
"TorchScript Frontend is not available",
)
@pytest.mark.unit
def test_refit_multiple_engine():
def test_refit_multiple_engine_without_weightmap():

class net(nn.Module):
def __init__(self):
Expand Down

0 comments on commit b3aa04f

Please sign in to comment.