From 78ceda5c07557943c558423dbe0ee7d3b24b49ef Mon Sep 17 00:00:00 2001 From: "Zewen (Evan) Li" Date: Fri, 17 May 2024 17:32:57 -0700 Subject: [PATCH] fix: bugs in TRT 10 upgrade (#2832) --- .../dynamo/conversion/_TRTInterpreter.py | 12 +++++++----- .../dynamo/runtime/_PythonTorchTensorRTModule.py | 8 ++++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 56e7e069c5..422842f644 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set import numpy as np +import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -25,7 +26,6 @@ from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -316,8 +316,10 @@ def run( ) timing_cache = self._create_timing_cache(builder_config, existing_cache) - engine = self.builder.build_serialized_network(self.ctx.net, builder_config) - assert engine + serialized_engine = self.builder.build_serialized_network( + self.ctx.net, builder_config + ) + assert serialized_engine serialized_cache = ( bytearray(timing_cache.serialize()) @@ -327,10 +329,10 @@ def run( _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info(f"TRT Engine uses: {engine.nbytes} bytes of Memory") + _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") return TRTInterpreterResult( - engine, self._input_names, self._output_names, serialized_cache + serialized_engine, self._input_names, self._output_names, serialized_cache ) def run_node(self, n: torch.fx.Node) -> torch.fx.Node: diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 0c152e15f1..1fcb765b47 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -29,7 +29,7 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc] def __init__( self, - engine: trt.ICudaEngine, + engine: bytes, input_names: Optional[List[str]] = None, output_names: Optional[List[str]] = None, target_device: Device = Device._current_device(), @@ -60,9 +60,9 @@ def _initialize(self) -> None: self.engine = runtime.deserialize_cuda_engine(self.engine) self.context = self.engine.create_execution_context() - assert ( - self.engine.num_io_tensors // self.engine.num_optimization_profiles - ) == (len(self.input_names) + len(self.output_names)) + assert self.engine.num_io_tensors == ( + len(self.input_names) + len(self.output_names) + ) self.input_dtypes = [ dtype._from(self.engine.get_tensor_dtype(input_name))