From 8faf7ecbd5fe67ee996fb2066b92465b631d3893 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 20 Mar 2024 09:57:54 -0700 Subject: [PATCH] chore: trt 10 fixes --- py/torch_tensorrt/dynamo/_compiler.py | 2 +- .../dynamo/conversion/_TRTInterpreter.py | 2 +- .../dynamo/conversion/impl/elementwise/base.py | 11 ----------- 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index b321eabcb2..ac430bf883 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -634,7 +634,7 @@ def convert_module_to_trt_engine( import io with io.BytesIO() as engine_bytes: - engine_bytes.write(interpreter_result.engine.serialize()) + engine_bytes.write(interpreter_result.engine) engine_bytearray = engine_bytes.getvalue() return engine_bytearray diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 8b9d730a72..ffcc9c195e 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -172,7 +172,7 @@ def run( if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( - trt.ProfilingVerbosity.VERBOSE + trt.ProfilingVerbosity.DETAILED if self.compilation_settings.debug else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index 8282ee8698..6664a67cfe 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -147,17 +147,6 @@ def convert_binary_elementwise( ctx, rhs_val, trt_promoted_type, name, target, source_ir ) - # Check the limitation in the doc string. - if ctx.net.has_implicit_batch_dimension: - if is_lhs_trt_tensor and not is_rhs_trt_tensor: - assert len(lhs_val.shape) >= len( - rhs_val.shape - ), f"{lhs_val.shape} >= {rhs_val.shape}" - elif not is_lhs_trt_tensor and is_rhs_trt_tensor: - assert len(rhs_val.shape) >= len( - lhs_val.shape - ), f"{rhs_val.shape} >= {lhs_val.shape}" - lhs_val, rhs_val = broadcast( ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" )