From e99514b7daa29d052c37305372a7a5e9c2bd383a Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 18 Jun 2024 08:54:07 -0700 Subject: [PATCH] Add custom model for dynamic model checking --- .github/workflows/build-test-linux.yml | 1 - .github/workflows/build-test-windows.yml | 1 - tests/py/dynamo/models/test_dyn_models.py | 56 ++++++++ .../dynamo/models/test_hf_generate_dynamic.py | 132 ------------------ tests/py/requirements.txt | 2 +- 5 files changed, 57 insertions(+), 135 deletions(-) delete mode 100644 tests/py/dynamo/models/test_hf_generate_dynamic.py diff --git a/.github/workflows/build-test-linux.yml b/.github/workflows/build-test-linux.yml index 7d49440c86..22f03e6218 100644 --- a/.github/workflows/build-test-linux.yml +++ b/.github/workflows/build-test-linux.yml @@ -144,7 +144,6 @@ jobs: cd tests/py/dynamo python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py - python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_generation_compile.xml models/test_hf_generate_dynamic.py popd tests-py-dynamo-serde: diff --git a/.github/workflows/build-test-windows.yml b/.github/workflows/build-test-windows.yml index 9cb1087c7a..1bdb52ae8a 100644 --- a/.github/workflows/build-test-windows.yml +++ b/.github/workflows/build-test-windows.yml @@ -143,7 +143,6 @@ jobs: cd tests/py/dynamo python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py - python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_generation_compile.xml models/test_hf_generate_dynamic.py popd tests-py-dynamo-serde: diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index 3fd34de2ea..3332a84c5c 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -310,3 +310,59 @@ def forward(self, x): cos_sim > COSINE_THRESHOLD, msg=f"test_linear model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + + +@pytest.mark.unit +def test_dynamic_with_fallback_shape_tensor_pass_through(ir): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.conv(x) + x = x + 2 + x = x * 2 + out = torch.reshape(x, (-1, 2)) + out = self.relu(out) + return out + + model = MyModule().eval().cuda() + input_bs4 = torch.randn((4, 3, 224, 224)).to("cuda") + + compile_spec = { + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "min_block_size": 1, + "torch_executed_ops": {"torch.ops.aten.add.Tensor"}, + } + + # Compile the model + if ir == "torch_compile": + torch._dynamo.mark_dynamic(input_bs4, 0, min=1, max=8) + # Compile the model + trt_model = torch.compile(model, backend="tensorrt", options=compile_spec) + trt_model(input_bs4) + elif ir == "dynamo": + compile_spec["inputs"] = [ + torchtrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(4, 3, 224, 224), + max_shape=(8, 3, 224, 224), + dtype=torch.float32, + name="x", + ) + ] + trt_model = torchtrt.compile(model, **compile_spec) + + trt_model(input_bs4) + + input_bs6 = torch.randn((6, 3, 224, 224)).to("cuda") + cos_sim = cosine_similarity(model(input_bs6), trt_model(input_bs6)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_dynamic_with_fallback_shape_tensor_pass_through model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) diff --git a/tests/py/dynamo/models/test_hf_generate_dynamic.py b/tests/py/dynamo/models/test_hf_generate_dynamic.py deleted file mode 100644 index f118b103fe..0000000000 --- a/tests/py/dynamo/models/test_hf_generate_dynamic.py +++ /dev/null @@ -1,132 +0,0 @@ -import pytest -import torch -import torch_tensorrt -from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList -from transformers.generation.stopping_criteria import ( - EosTokenCriteria, - MaxLengthCriteria, -) - - -@pytest.mark.unit -def test_dynamic_generation_python_rt(): - """ - Tests HuggingFace Generate Code with dynamic shapes - Code Credit: @peri044 - """ - # Define tokenizer and model - tokenizer = AutoTokenizer.from_pretrained("gpt2") - model = ( - AutoModelForCausalLM.from_pretrained( - "gpt2", pad_token_id=tokenizer.eos_token_id, use_cache=False - ) - .eval() - .to("cuda") - ) - - # Input prompt - model_inputs = tokenizer(("Repeat " * 128)[:-1], return_tensors="pt").to("cuda") - input_ids = model_inputs["input_ids"] - max_tokens = 40 - - # Pyt model outputs - greedy_output = model.generate(**model_inputs, max_new_tokens=max_tokens) - print( - "Pytorch model generated text: ", - tokenizer.decode(greedy_output[0], skip_special_tokens=True), - ) - - # Compile Torch-TRT model - torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023) - model.forward = torch.compile( - model.forward, - backend="tensorrt", - dynamic=None, - options={ - "enabled_precisions": {torch.float}, - "torch_executed_ops": {"torch.ops.aten.slice.Tensor"}, - "use_python_runtime": True, - "optimization_level": 0, - "min_block_size": 29, - }, - ) - - # Auto-regressive generation loop for greedy search - stopping_criteria = StoppingCriteriaList( - [ - MaxLengthCriteria(max_length=max_tokens), - EosTokenCriteria(eos_token_id=tokenizer.eos_token_id), - ] - ) - while True: - trt_outputs = model(input_ids) - logits = trt_outputs.logits - next_token_logits = logits[:, -1, :] - next_tokens = torch.argmax(next_token_logits, dim=-1) - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - if stopping_criteria(input_ids, logits).item(): - break - - # TODO: Add test for correctness - - -@pytest.mark.unit -def test_dynamic_generation_cpp_rt(): - """ - Tests HuggingFace Generate Code with dynamic shapes - Code Credit: @peri044 - """ - # Define tokenizer and model - tokenizer = AutoTokenizer.from_pretrained("gpt2") - model = ( - AutoModelForCausalLM.from_pretrained( - "gpt2", pad_token_id=tokenizer.eos_token_id, use_cache=False - ) - .eval() - .to("cuda") - ) - - # Input prompt - model_inputs = tokenizer(("Repeat " * 128)[:-1], return_tensors="pt").to("cuda") - input_ids = model_inputs["input_ids"] - max_tokens = 40 - - # Pyt model outputs - greedy_output = model.generate(**model_inputs, max_new_tokens=max_tokens) - print( - "Pytorch model generated text: ", - tokenizer.decode(greedy_output[0], skip_special_tokens=True), - ) - - # Compile Torch-TRT model - torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023) - model.forward = torch.compile( - model.forward, - backend="tensorrt", - dynamic=None, - options={ - "enabled_precisions": {torch.float}, - "torch_executed_ops": {"torch.ops.aten.slice.Tensor"}, - "use_python_runtime": False, - "optimization_level": 0, - "min_block_size": 29, - }, - ) - - # Auto-regressive generation loop for greedy search - stopping_criteria = StoppingCriteriaList( - [ - MaxLengthCriteria(max_length=max_tokens), - EosTokenCriteria(eos_token_id=tokenizer.eos_token_id), - ] - ) - while True: - trt_outputs = model(input_ids) - logits = trt_outputs.logits - next_token_logits = logits[:, -1, :] - next_tokens = torch.argmax(next_token_logits, dim=-1) - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - if stopping_criteria(input_ids, logits).item(): - break - - # TODO: Add test for correctness diff --git a/tests/py/requirements.txt b/tests/py/requirements.txt index a77769a908..36dbda57bc 100644 --- a/tests/py/requirements.txt +++ b/tests/py/requirements.txt @@ -9,5 +9,5 @@ pytest-xdist>=3.6.1 pyyaml tensorrt==10.0.1 timm>=1.0.3 -transformers==4.41.0 +transformers==4.40.2 --extra-index-url https://pypi.nvidia.com