diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 97f4ad2ba0..f4d39bd056 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -264,6 +264,7 @@ jobs: pre-script: ${{ matrix.pre-script }} script: | export USE_HOST_DEPS=1 + export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH pushd . cd tests/py/core ${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver diff --git a/README.md b/README.md index aaf28f0b66..69c8f2716c 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass. - Bazel 5.2.0 -- Libtorch 2.4.0.dev (latest nightly) (built with CUDA 12.1) +- Libtorch 2.3.0 (built with CUDA 12.1) - CUDA 12.1 - cuDNN 8.9.5 - TensorRT 10.0.0.6 diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index 5c16cd03cd..062abb9a87 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -107,7 +107,7 @@ def _from( return dtype.f16 elif t == trt.float32: return dtype.f32 - elif trt.__version__ >= "7.0" and t == trt.bool: + elif t == trt.bool: return dtype.b else: raise TypeError( diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 05808dd37c..9a75add755 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -313,7 +313,7 @@ def run( ) timing_cache = self._create_timing_cache(builder_config, existing_cache) - engine = self.builder.build_engine(self.ctx.net, builder_config) + engine = self.builder.build_serialized_network(self.ctx.net, builder_config) assert engine serialized_cache = ( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py index 61c1fb99d7..8c5ee6a26a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -9,6 +9,7 @@ from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, get_positive_dim, get_trt_tensor, ) @@ -38,6 +39,12 @@ def shape( """ shape_layer = ctx.net.add_shape(input_val) input_shape = shape_layer.get_output(0) + input_shape = cast_trt_tensor( + ctx, + input_shape, + trt.int32, + name + "_shape_casted", + ) set_layer_name(shape_layer, target, name + "_shape", source_ir) n_dims = len(input_val.shape) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index b891ea44cb..0c152e15f1 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -15,6 +15,7 @@ _select_rt_device, multi_gpu_device_check, ) +from torch_tensorrt.logging import TRT_LOGGER logger = logging.getLogger(__name__) @@ -64,35 +65,19 @@ def _initialize(self) -> None: ) == (len(self.input_names) + len(self.output_names)) self.input_dtypes = [ - dtype._from(self.engine.get_binding_dtype(idx)) - for idx in self.input_binding_indices_in_order + dtype._from(self.engine.get_tensor_dtype(input_name)) + for input_name in self.input_names ] self.input_shapes = [ self.engine.get_tensor_shape(input_name) for input_name in self.input_names ] self.output_dtypes = [ - dtype._from(self.engine.get_binding_dtype(idx)) - for idx in self.output_binding_indices_in_order + dtype._from(self.engine.get_tensor_dtype(output_name)) + for output_name in self.output_names ] self.output_shapes = [ - ( - tuple(self.engine.get_binding_shape(idx)) - if self.engine.has_implicit_batch_dimension - else tuple() - ) - for idx in self.output_binding_indices_in_order - ] - self.hidden_output_dtypes = [ - dtype._from(self.engine.get_binding_dtype(idx)) - for idx in self.hidden_output_binding_indices_in_order - ] - self.hidden_output_shapes = [ - ( - tuple(self.engine.get_binding_shape(idx)) - if self.engine.has_implicit_batch_dimension - else tuple() - ) - for idx in self.hidden_output_binding_indices_in_order + self.engine.get_tensor_shape(output_name) + for output_name in self.output_names ] def _check_initialized(self) -> None: @@ -234,15 +219,11 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . bindings.append(output.data_ptr()) outputs.append(output) - for i, idx in enumerate(self.hidden_output_binding_indices_in_order): - shape = tuple(self.context.get_binding_shape(idx)) - - output = torch.empty( - size=shape, - dtype=self.hidden_output_dtypes[i].to(torch.dtype), - device=torch.cuda.current_device(), - ) - bindings[idx] = output.data_ptr() + # Assign tensor address appropriately + for idx in range(self.engine.num_io_tensors): + self.context.set_tensor_address( + self.engine.get_tensor_name(idx), bindings[idx] + ) with ( torch.autograd.profiler.record_function( diff --git a/pyproject.toml b/pyproject.toml index 2496491cf8..ec6c0fe19c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires = [ "cffi>=1.15.1", "typing-extensions>=4.7.0", "future>=0.18.3", - "tensorrt>=8.6,<8.7", + "tensorrt", "torch==2.3.0", "pybind11==2.6.2", "numpy", @@ -42,7 +42,7 @@ requires-python = ">=3.8" keywords = ["pytorch", "torch", "tensorrt", "trt", "ai", "artificial intelligence", "ml", "machine learning", "dl", "deep learning", "compiler", "dynamo", "torchscript", "inference"] dependencies = [ "torch==2.3.0", - "tensorrt>=8.6,<8.7", + "tensorrt", "packaging>=23", "numpy", "typing-extensions>=4.7.0",