From 0b1e491057c599b1b8c49487f9a47eacd8c92742 Mon Sep 17 00:00:00 2001 From: tianshuo78520a Date: Wed, 11 Dec 2024 10:54:56 +0800 Subject: [PATCH] =?UTF-8?q?Revert=20"[Inference]Fix=20PaddleX=20model=20bu?= =?UTF-8?q?gs=20when=20convert=20to=20pir-trt=20(Part2)=20(#6=E2=80=A6"=20?= =?UTF-8?q?(#70122)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 0f66ede65610a84639aa2ef549ef63a908b92ce9. --- .../tensorrt_engine_instruction.cc | 66 ++++++++++++++++--- python/paddle/tensorrt/converter.py | 10 +-- python/paddle/tensorrt/converter_utils.py | 36 +++++----- python/paddle/tensorrt/impls/common.py | 41 ++++++++---- python/paddle/tensorrt/impls/creation.py | 20 +++--- test/cpp/inference/tensorrt/CMakeLists.txt | 34 +++++----- 6 files changed, 130 insertions(+), 77 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/tensorrt_engine_instruction.cc b/paddle/fluid/framework/new_executor/instruction/tensorrt_engine_instruction.cc index 1ca2688844c8a..269bc547b35d3 100644 --- a/paddle/fluid/framework/new_executor/instruction/tensorrt_engine_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/tensorrt_engine_instruction.cc @@ -239,8 +239,10 @@ static phi::DataType TRT2PaddleDataType(nvinfer1::DataType type) { "to paddle. Does the downstream paddle op here support int64?"; return phi::DataType::INT64; #endif +#if IS_TRT_VERSION_GE(7000) case nvinfer1::DataType::kBOOL: return phi::DataType::BOOL; +#endif default: PADDLE_THROW(common::errors::InvalidArgument( "unknown fluid datatype in Fluid op converter")); @@ -487,10 +489,11 @@ void TensorRTEngineInstruction::BindInputTensor( bind_index, num_bindings)); +#if IS_TRT_VERSION_GE(6000) +#if IS_TRT_VERSION_GE(8500) if (trt_engine_->engine()->isShapeInferenceIO(input_name.c_str()) && trt_engine_->engine()->getTensorIOMode(input_name.c_str()) == nvinfer1::TensorIOMode::kINPUT) { - shape_v.resize(input_tensor.numel()); if (input_tensor.dtype() == phi::DataType::INT32) { phi::memory_utils::Copy(phi::CPUPlace(), shape_v.data(), @@ -521,6 +524,41 @@ void TensorRTEngineInstruction::BindInputTensor( input_name.c_str(), paddle::platform::Vec2TRT_Dims(input_shape, input_name, true)); } +#else + trt_context->setBindingDimensions( + bind_index, + paddle::platform::Vec2TRT_Dims(input_shape, input_name, true)); + // If this x is a shape tensor, we need call setInputShapeBinding + if (trt_engine_->engine()->isShapeBinding(bind_index) && + trt_engine_->engine()->bindingIsInput(bind_index)) { + if (input_tensor.dtype() == phi::DataType::INT32) { + phi::memory_utils::Copy(phi::CPUPlace(), + shape_v.data(), + input_tensor.place(), + input_tensor.data(), + input_tensor.numel() * sizeof(int), + nullptr); + } else if (input_tensor.dtype() == phi::DataType::INT64) { + std::string x_t = input_name + "_cast_to_INT32"; + if (scope.FindVar(x_t) == nullptr) { + const_cast(&scope)->Var(x_t); + } + auto int32_tensor = scope.FindVar(x_t)->GetMutable(); + *int32_tensor = phi::Cast( + reinterpret_cast(*dev_ctx_), + input_tensor, + phi::DataType::INT32); + phi::memory_utils::Copy(phi::CPUPlace(), + shape_v.data(), + int32_tensor->place(), + int32_tensor->data(), + int32_tensor->numel() * sizeof(int), + nullptr); + } + trt_context->setInputShapeBinding(bind_index, shape_v.data()); + } +#endif +#endif *runtime_batch = input_shape[0]; VLOG(1) << "trt input [" << input_name << "] dtype is " @@ -572,10 +610,11 @@ void TensorRTEngineInstruction::BindInputTensor( } else if (input_tensor.dtype() == phi::DataType::FLOAT16) { buffers[bind_index] = static_cast( const_cast(input_tensor.data())); +#if IS_TRT_VERSION_GE(8400) } else if (input_tensor.dtype() == phi::DataType::BOOL) { buffers[bind_index] = static_cast(const_cast(input_tensor.data())); - +#endif } else { PADDLE_THROW(common::errors::Fatal( "The TRT Engine OP only support " @@ -616,6 +655,7 @@ void TensorRTEngineInstruction::BindOutputTensor( #endif std::vector ddim; +#if IS_TRT_VERSION_GE(8500) auto x_name = trt_engine_->engine()->getIOTensorName(bind_index); auto dims = trt_context->getTensorShape(x_name); int nb_dims = dims.nbDims; @@ -627,6 +667,18 @@ void TensorRTEngineInstruction::BindOutputTensor( for (int i = 0; i < nb_dims; i++) { ddim.push_back(dims.d[i]); } +#else + auto dims = trt_context->getBindingDimensions(bind_index); + int nb_dims = dims.nbDims; + for (; nb_dims > 0; nb_dims--) { + // some 'x 1' of shape is normal, no need to remove it + if (dims.d[nb_dims - 1] != 1 || nb_dims == outputs_rank_[output_index]) + break; + } + for (int i = 0; i < nb_dims; i++) { + ddim.push_back(dims.d[i]); + } +#endif auto *fluid_t = output_tensor; fluid_t->Resize(common::make_ddim(ddim)); @@ -669,13 +721,14 @@ void TensorRTEngineInstruction::RunTrt() { "can not find var[%s] in scope", in_var_name)); auto in_var = scope.FindVar(in_var_name); auto &in_variable_array = in_var->Get(); - // we will use shape_input when input is a shape tensor std::vector> shape_inputs(in_variable_array.size()); for (const auto &index_name_pair : input_names_) { size_t i = index_name_pair.first; if (in_variable_array[i]->IsType()) { auto input_tensor = in_variable_array[i]->Get(); + // we will use shape_input when input is a shape tensor + shape_inputs[i].resize(input_tensor.numel()); // Bind input tensor to TRT. BindInputTensor(index_name_pair.second, input_tensor, @@ -765,13 +818,6 @@ void TensorRTEngineInstruction::RunTrt() { } void TensorRTEngineInstruction::Run() { -#if IS_TRT_VERSION_LT(8500) - PADDLE_THROW( - common::errors::Unimplemented("PIR-TRT only support TensorRT " - "version that is >= 8.5," - "Please check your TensorRT " - "in your env.")); -#endif PrepareDynamicShape(); RunTrt(); } diff --git a/python/paddle/tensorrt/converter.py b/python/paddle/tensorrt/converter.py index bd646b6560ea7..6b290bbfc2473 100644 --- a/python/paddle/tensorrt/converter.py +++ b/python/paddle/tensorrt/converter.py @@ -87,7 +87,6 @@ def __init__(self, paddle_program, scope, trt_config=None): self.input_info = {} self.trt_output_value_map = {} - self.engine_num = 0 def find_graph_inputs_outputs(self, group_op): operations = next(iter(group_op.blocks())).ops @@ -192,7 +191,7 @@ def convert_subgraph_to_trt(self, program, group_op): for operand in op.operands(): source = operand.source() if not source.initialized(): - operands.append(None) + _logger.warning(f"Skipping uninitialized source: {source}") continue define_op_name = source.get_defining_op().name() if define_op_name == "builtin.combine": @@ -457,12 +456,10 @@ def convert_subgraph_to_trt(self, program, group_op): % 10**8 ) CACHE_ROOT = get_cache_path() - CACHE_FILE = f"{CACHE_ROOT}/engine_{engine_name}_{self.engine_num}.trt" + CACHE_FILE = f"{CACHE_ROOT}/engine_{engine_name}.trt" with open(CACHE_FILE, "wb") as f: f.write(trt_engine) - PIR_DUMP_FILE = ( - f"{CACHE_ROOT}/engine_{engine_name}_{self.engine_num}.pir" - ) + PIR_DUMP_FILE = f"{CACHE_ROOT}/engine_{engine_name}.pir" with open(PIR_DUMP_FILE, "w") as f: f.write(group_str) trt_params.engine_serialized_data = CACHE_FILE @@ -523,7 +520,6 @@ def convert_program_to_trt(self): for op in self.program.global_block().ops: if op.name() == "cinn_op.group" or op.name() == "builtin.group": _logger.info(f"start process {op.name()}") - self.engine_num += 1 new_out = self.convert_subgraph_to_trt(self.program, op) orin_out_values = op.results() for o_i in range(len(orin_out_values)): diff --git a/python/paddle/tensorrt/converter_utils.py b/python/paddle/tensorrt/converter_utils.py index 09e5f3a70d963..b83ffe787f0c3 100644 --- a/python/paddle/tensorrt/converter_utils.py +++ b/python/paddle/tensorrt/converter_utils.py @@ -271,21 +271,6 @@ def trt_reshape(network, input, new_shape, name="", is_shape_tensor=False): return reshape_layer.get_output(0) -# resize shape tensor's shape to 1dim -def resize_to_1d(network, shape_tensor): - if shape_tensor is None: - return shape_tensor - if len(shape_tensor.shape) > 1: - # shape_tensor need 1-dim in trt - shape_tensor_layer = network.add_shuffle(shape_tensor) - numel = 1 - for ele in shape_tensor.shape: - numel *= ele - shape_tensor_layer.reshape_dims = [numel] - shape_tensor = shape_tensor_layer.get_output(0) - return shape_tensor - - # Get element tensor of 1D shape tensor def get_shape_tensor_element(network, x, index, is_scalar=False): assert ( @@ -293,8 +278,7 @@ def get_shape_tensor_element(network, x, index, is_scalar=False): ), f"The index should be greater or equal than 0, but got {index}" index_tensor = add_1D_constant_layer(network, index, is_scalar=is_scalar) gather_layer = network.add_gather(input=x, indices=index_tensor, axis=0) - shape_tensor = resize_to_1d(network, gather_layer.get_output(0)) - return shape_tensor + return gather_layer.get_output(0) def trt_less(network, a, b): @@ -430,7 +414,7 @@ def map_trt_dtype(trt_dtype): # Reduce the given tensor in the TensorRT network to a scalar -def trt_reduce_to_scalar(network, tensor, dtype=trt.int32): +def trt_reduce_to_scalar(network, tensor): if len(tensor.shape) == 0: return tensor axes = 0 @@ -439,8 +423,7 @@ def trt_reduce_to_scalar(network, tensor, dtype=trt.int32): reduce_layer = network.add_reduce( tensor, trt.ReduceOperation.SUM, axes, keep_dims=False ) - scalar = trt_cast(network, reduce_layer.get_output(0), dtype) - return scalar + return reduce_layer.get_output(0) def convert_conv2d(network, paddle_op, inputs): @@ -674,3 +657,16 @@ def squeeze_trt(network, input_tensor, axes): reshape_layer = network.add_shuffle(input_tensor) reshape_layer.set_input(1, new_shape_tensor) return reshape_layer.get_output(0) + + +# resize shape tensor's shape to 1dim +def resize_to_1d(network, shape_tensor): + if len(shape_tensor.shape) > 1: + # shape_tensor need 1-dim in trt + shape_tensor_layer = network.add_shuffle(shape_tensor) + numel = 1 + for ele in shape_tensor.shape: + numel *= ele + shape_tensor_layer.reshape_dims = [numel] + shape_tensor = shape_tensor_layer.get_output(0) + return shape_tensor diff --git a/python/paddle/tensorrt/impls/common.py b/python/paddle/tensorrt/impls/common.py index b989fa5142ab8..a4567641fa2ab 100644 --- a/python/paddle/tensorrt/impls/common.py +++ b/python/paddle/tensorrt/impls/common.py @@ -16,7 +16,7 @@ import numpy as np import tensorrt as trt -from paddle.tensorrt.converter_utils import get_shape_tensor_element +from paddle.tensorrt.converter_utils import get_shape_tensor_element, trt_shape from paddle.tensorrt.register import converter_registry from paddle.tensorrt.util import get_trt_version_list @@ -53,10 +53,6 @@ def dropout_converter(network, paddle_op, inputs): ) def bilinear_interp_converter(network, paddle_op, inputs): input_tensor = inputs[0] - input_shape_tensor = network.add_shape(input_tensor).get_output(0) - input_rank = ( - input_shape_tensor.shape - ) # The reason is unknown that adding this unused code make input_shape_tensor maintain the correct result. data_format = paddle_op.attrs().get("data_format") interp_method = paddle_op.attrs().get("interp_method") align_corners = paddle_op.attrs().get("align_corners") @@ -145,6 +141,7 @@ def bilinear_interp_converter(network, paddle_op, inputs): else: if outsize_tensor is not None: outsize_itensors = [] + input_shape_tensor = trt_shape(network, input_tensor) batch_dim = get_shape_tensor_element(network, input_shape_tensor, 0) outsize_itensors.append(batch_dim) if data_format == "NCHW": @@ -172,10 +169,6 @@ def bilinear_interp_converter(network, paddle_op, inputs): ) def nearest_interp_converter(network, paddle_op, inputs): input_tensor = inputs[0] - input_shape_tensor = network.add_shape(input_tensor).get_output(0) - input_rank = ( - input_shape_tensor.shape - ) # The reason is unknown that adding this unused code make input_shape_tensor maintain the correct result. data_format = paddle_op.attrs().get("data_format") interp_method = paddle_op.attrs().get("interp_method") align_corners = paddle_op.attrs().get("align_corners") @@ -222,8 +215,33 @@ def nearest_interp_converter(network, paddle_op, inputs): scale_w = float(out_w) / float(in_dim[w_axis]) outsize_tensor = None - if inputs[2] is not None: - outsize_tensor = network.add_concatenation(inputs[2]).get_output(0) + if trt_version_float >= 8.2: + if len(inputs) > 2 and inputs[2] is not None: + size_tensor_operand = paddle_op.operands()[2].source() + if size_tensor_operand.is_combine(): + size_tensors = inputs[2] + if not isinstance(size_tensors, list): + size_tensors = [size_tensors] + if len(size_tensors) >= 2: + # Extract the first two elements representing height and width + outsize_h = size_tensors[0] + outsize_w = size_tensors[1] + outsize_tensor = network.add_concatenation( + [outsize_h, outsize_w] + ).get_output(0) + else: + size_tensor_shape = size_tensor_operand.source().shape + if size_tensor_shape.size >= 2: + size_tensor = inputs[2] + outsize_h = network.add_slice( + size_tensor, start=[0], shape=[1], stride=[1] + ).get_output(0) + outsize_w = network.add_slice( + size_tensor, start=[1], shape=[1], stride=[1] + ).get_output(0) + outsize_tensor = network.add_concatenation( + [outsize_h, outsize_w] + ).get_output(0) scales = [1.0] * len(input_tensor.shape) if data_format == "NCHW": @@ -240,6 +258,7 @@ def nearest_interp_converter(network, paddle_op, inputs): ) if outsize_tensor is not None: outsize_itensors = [] + input_shape_tensor = trt_shape(network, input_tensor) batch_dim = get_shape_tensor_element(network, input_shape_tensor, 0) outsize_itensors.append(batch_dim) if data_format == "NCHW": diff --git a/python/paddle/tensorrt/impls/creation.py b/python/paddle/tensorrt/impls/creation.py index b6b5e7711d8d8..169cf917ceae2 100644 --- a/python/paddle/tensorrt/impls/creation.py +++ b/python/paddle/tensorrt/impls/creation.py @@ -16,11 +16,9 @@ import tensorrt as trt import paddle -from paddle.pir.core import _PADDLE_PIR_DTYPE_2_NUMPY_DTYPE from paddle.tensorrt.converter_utils import ( add_1D_constant_layer, cast_tensor, - resize_to_1d, trt_cast, trt_floor_div, trt_max, @@ -48,11 +46,10 @@ def full_converter(network, paddle_op, inputs): shape = paddle_op.attrs()["shape"] value = paddle_op.attrs().get("value", 1.0) dtype = paddle_op.attrs().get("dtype") - out_dtype = np.dtype(_PADDLE_PIR_DTYPE_2_NUMPY_DTYPE[dtype]) - if out_dtype == np.dtype("float64"): - out_dtype = np.dtype("float32") - if out_dtype == np.dtype("int64"): - out_dtype = np.dtype("int32") + if dtype == paddle.int32 or dtype == paddle.int64: + out_dtype = np.int32 + else: + out_dtype = np.float32 full_layer = network.add_constant( shape, np.full(shape, value, dtype=out_dtype) ) @@ -116,7 +113,9 @@ def arange_converter(network, paddle_op, inputs): number_tensor = trt_max(network, quotient_tensor, zero_tensor) - start_tensor = trt_reshape(network, start, ()) + reshape_start_layer = trt_reshape(network, start, (1,)) + + start_tensor = trt_reduce_to_scalar(network, reshape_start_layer) fill_layer = network.add_fill(shape=(), op=trt.FillOperation.LINSPACE) fill_layer.set_input(0, number_tensor) @@ -238,6 +237,8 @@ def full_with_tensor_converter(network, paddle_op, inputs): shape_tensor = shape_tensor_list[0] if not isinstance(shape_tensor, trt.ITensor): raise TypeError("shape_tensor must be an ITensor") + if len(shape_tensor.shape) != 1: + raise ValueError("The rank of shape_tensor must be 1") tensor_rank = shape_tensor.shape[0] shapes_tensor = shape_tensor else: @@ -251,7 +252,6 @@ def full_with_tensor_converter(network, paddle_op, inputs): shapes_tensor = concat_layer.get_output(0) tensor_rank = len(shape_tensors) - shapes_tensor = resize_to_1d(network, shapes_tensor) fill_layer = network.add_fill(shape=(), op=trt.FillOperation.LINSPACE) fill_layer.set_input(0, shapes_tensor) @@ -264,7 +264,7 @@ def full_with_tensor_converter(network, paddle_op, inputs): ) elif dtype == paddle.float32: beta_vec = [0.0] * tensor_rank - value_input = trt_reduce_to_scalar(network, value_input, trt.float32) + value_input = trt_reduce_to_scalar(network, value_input) fill_layer.set_input(1, value_input) fill_layer.set_input( 2, add_1D_constant_layer(network, beta_vec, np.float32) diff --git a/test/cpp/inference/tensorrt/CMakeLists.txt b/test/cpp/inference/tensorrt/CMakeLists.txt index cb68443c986db..49ee3552e303b 100644 --- a/test/cpp/inference/tensorrt/CMakeLists.txt +++ b/test/cpp/inference/tensorrt/CMakeLists.txt @@ -1,20 +1,16 @@ -set(TENSORRT_VERSION_NUMBER - "${TENSORRT_MAJOR_VERSION}${TENSORRT_MINOR_VERSION}") -if(${TENSORRT_VERSION_NUMBER} GREATER_EQUAL 85) - nv_test( - test_tensorrt_engine_instruction - SRCS test_tensorrt_engine_instruction.cc - DEPS pir - trt_engine - naive_executor - phi - common - pir_save_load - pir_tensorrt_plugin) - set_tests_properties(test_tensorrt_engine_instruction PROPERTIES TIMEOUT 120) - if(WITH_ONNXRUNTIME AND WIN32) - # Copy onnxruntime for some c++ test in Windows, since the test will - # be build only in CI, so suppose the generator in Windows is Ninja. - copy_onnx(test_tensorrt_engine_instruction) - endif() +nv_test( + test_tensorrt_engine_instruction + SRCS test_tensorrt_engine_instruction.cc + DEPS pir + trt_engine + naive_executor + phi + common + pir_save_load + pir_tensorrt_plugin) +set_tests_properties(test_tensorrt_engine_instruction PROPERTIES TIMEOUT 120) +if(WITH_ONNXRUNTIME AND WIN32) + # Copy onnxruntime for some c++ test in Windows, since the test will + # be build only in CI, so suppose the generator in Windows is Ninja. + copy_onnx(test_tensorrt_engine_instruction) endif()