Skip to content

Commit

Permalink
[Inference]Fix PaddleX model bugs when convert to pir-trt (Part2) (Pa…
Browse files Browse the repository at this point in the history
…ddlePaddle#69885)

* fix trt bugs

* fix bugs

* fix nearest_interp

* fix bugs

* fix bugs

* fix windows bugs
  • Loading branch information
YuanRisheng authored Dec 10, 2024
1 parent 914caad commit 0f66ede
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,8 @@ static phi::DataType TRT2PaddleDataType(nvinfer1::DataType type) {
"to paddle. Does the downstream paddle op here support int64?";
return phi::DataType::INT64;
#endif
#if IS_TRT_VERSION_GE(7000)
case nvinfer1::DataType::kBOOL:
return phi::DataType::BOOL;
#endif
default:
PADDLE_THROW(common::errors::InvalidArgument(
"unknown fluid datatype in Fluid op converter"));
Expand Down Expand Up @@ -489,11 +487,10 @@ void TensorRTEngineInstruction::BindInputTensor(
bind_index,
num_bindings));

#if IS_TRT_VERSION_GE(6000)
#if IS_TRT_VERSION_GE(8500)
if (trt_engine_->engine()->isShapeInferenceIO(input_name.c_str()) &&
trt_engine_->engine()->getTensorIOMode(input_name.c_str()) ==
nvinfer1::TensorIOMode::kINPUT) {
shape_v.resize(input_tensor.numel());
if (input_tensor.dtype() == phi::DataType::INT32) {
phi::memory_utils::Copy(phi::CPUPlace(),
shape_v.data(),
Expand Down Expand Up @@ -524,41 +521,6 @@ void TensorRTEngineInstruction::BindInputTensor(
input_name.c_str(),
paddle::platform::Vec2TRT_Dims(input_shape, input_name, true));
}
#else
trt_context->setBindingDimensions(
bind_index,
paddle::platform::Vec2TRT_Dims(input_shape, input_name, true));
// If this x is a shape tensor, we need call setInputShapeBinding
if (trt_engine_->engine()->isShapeBinding(bind_index) &&
trt_engine_->engine()->bindingIsInput(bind_index)) {
if (input_tensor.dtype() == phi::DataType::INT32) {
phi::memory_utils::Copy(phi::CPUPlace(),
shape_v.data(),
input_tensor.place(),
input_tensor.data<int32_t>(),
input_tensor.numel() * sizeof(int),
nullptr);
} else if (input_tensor.dtype() == phi::DataType::INT64) {
std::string x_t = input_name + "_cast_to_INT32";
if (scope.FindVar(x_t) == nullptr) {
const_cast<framework::Scope *>(&scope)->Var(x_t);
}
auto int32_tensor = scope.FindVar(x_t)->GetMutable<phi::DenseTensor>();
*int32_tensor = phi::Cast<int64_t>(
reinterpret_cast<const phi::GPUContext &>(*dev_ctx_),
input_tensor,
phi::DataType::INT32);
phi::memory_utils::Copy(phi::CPUPlace(),
shape_v.data(),
int32_tensor->place(),
int32_tensor->data<int32_t>(),
int32_tensor->numel() * sizeof(int),
nullptr);
}
trt_context->setInputShapeBinding(bind_index, shape_v.data());
}
#endif
#endif

*runtime_batch = input_shape[0];
VLOG(1) << "trt input [" << input_name << "] dtype is "
Expand Down Expand Up @@ -610,11 +572,10 @@ void TensorRTEngineInstruction::BindInputTensor(
} else if (input_tensor.dtype() == phi::DataType::FLOAT16) {
buffers[bind_index] = static_cast<void *>(
const_cast<float16 *>(input_tensor.data<float16>()));
#if IS_TRT_VERSION_GE(8400)
} else if (input_tensor.dtype() == phi::DataType::BOOL) {
buffers[bind_index] =
static_cast<void *>(const_cast<bool *>(input_tensor.data<bool>()));
#endif

} else {
PADDLE_THROW(common::errors::Fatal(
"The TRT Engine OP only support "
Expand Down Expand Up @@ -655,7 +616,6 @@ void TensorRTEngineInstruction::BindOutputTensor(
#endif
std::vector<int> ddim;

#if IS_TRT_VERSION_GE(8500)
auto x_name = trt_engine_->engine()->getIOTensorName(bind_index);
auto dims = trt_context->getTensorShape(x_name);
int nb_dims = dims.nbDims;
Expand All @@ -667,18 +627,6 @@ void TensorRTEngineInstruction::BindOutputTensor(
for (int i = 0; i < nb_dims; i++) {
ddim.push_back(dims.d[i]);
}
#else
auto dims = trt_context->getBindingDimensions(bind_index);
int nb_dims = dims.nbDims;
for (; nb_dims > 0; nb_dims--) {
// some 'x 1' of shape is normal, no need to remove it
if (dims.d[nb_dims - 1] != 1 || nb_dims == outputs_rank_[output_index])
break;
}
for (int i = 0; i < nb_dims; i++) {
ddim.push_back(dims.d[i]);
}
#endif

auto *fluid_t = output_tensor;
fluid_t->Resize(common::make_ddim(ddim));
Expand Down Expand Up @@ -721,14 +669,13 @@ void TensorRTEngineInstruction::RunTrt() {
"can not find var[%s] in scope", in_var_name));
auto in_var = scope.FindVar(in_var_name);
auto &in_variable_array = in_var->Get<VariableRefArray>();
// we will use shape_input when input is a shape tensor
std::vector<std::vector<int>> shape_inputs(in_variable_array.size());

for (const auto &index_name_pair : input_names_) {
size_t i = index_name_pair.first;
if (in_variable_array[i]->IsType<phi::DenseTensor>()) {
auto input_tensor = in_variable_array[i]->Get<phi::DenseTensor>();
// we will use shape_input when input is a shape tensor
shape_inputs[i].resize(input_tensor.numel());
// Bind input tensor to TRT.
BindInputTensor(index_name_pair.second,
input_tensor,
Expand Down Expand Up @@ -818,6 +765,13 @@ void TensorRTEngineInstruction::RunTrt() {
}

void TensorRTEngineInstruction::Run() {
#if IS_TRT_VERSION_LT(8500)
PADDLE_THROW(
common::errors::Unimplemented("PIR-TRT only support TensorRT "
"version that is >= 8.5,"
"Please check your TensorRT "
"in your env."));
#endif
PrepareDynamicShape();
RunTrt();
}
Expand Down
10 changes: 7 additions & 3 deletions python/paddle/tensorrt/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(self, paddle_program, scope, trt_config=None):

self.input_info = {}
self.trt_output_value_map = {}
self.engine_num = 0

def find_graph_inputs_outputs(self, group_op):
operations = next(iter(group_op.blocks())).ops
Expand Down Expand Up @@ -191,7 +192,7 @@ def convert_subgraph_to_trt(self, program, group_op):
for operand in op.operands():
source = operand.source()
if not source.initialized():
_logger.warning(f"Skipping uninitialized source: {source}")
operands.append(None)
continue
define_op_name = source.get_defining_op().name()
if define_op_name == "builtin.combine":
Expand Down Expand Up @@ -456,10 +457,12 @@ def convert_subgraph_to_trt(self, program, group_op):
% 10**8
)
CACHE_ROOT = get_cache_path()
CACHE_FILE = f"{CACHE_ROOT}/engine_{engine_name}.trt"
CACHE_FILE = f"{CACHE_ROOT}/engine_{engine_name}_{self.engine_num}.trt"
with open(CACHE_FILE, "wb") as f:
f.write(trt_engine)
PIR_DUMP_FILE = f"{CACHE_ROOT}/engine_{engine_name}.pir"
PIR_DUMP_FILE = (
f"{CACHE_ROOT}/engine_{engine_name}_{self.engine_num}.pir"
)
with open(PIR_DUMP_FILE, "w") as f:
f.write(group_str)
trt_params.engine_serialized_data = CACHE_FILE
Expand Down Expand Up @@ -520,6 +523,7 @@ def convert_program_to_trt(self):
for op in self.program.global_block().ops:
if op.name() == "cinn_op.group" or op.name() == "builtin.group":
_logger.info(f"start process {op.name()}")
self.engine_num += 1
new_out = self.convert_subgraph_to_trt(self.program, op)
orin_out_values = op.results()
for o_i in range(len(orin_out_values)):
Expand Down
36 changes: 20 additions & 16 deletions python/paddle/tensorrt/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,14 +271,30 @@ def trt_reshape(network, input, new_shape, name="", is_shape_tensor=False):
return reshape_layer.get_output(0)


# resize shape tensor's shape to 1dim
def resize_to_1d(network, shape_tensor):
if shape_tensor is None:
return shape_tensor
if len(shape_tensor.shape) > 1:
# shape_tensor need 1-dim in trt
shape_tensor_layer = network.add_shuffle(shape_tensor)
numel = 1
for ele in shape_tensor.shape:
numel *= ele
shape_tensor_layer.reshape_dims = [numel]
shape_tensor = shape_tensor_layer.get_output(0)
return shape_tensor


# Get element tensor of 1D shape tensor
def get_shape_tensor_element(network, x, index, is_scalar=False):
assert (
index >= 0
), f"The index should be greater or equal than 0, but got {index}"
index_tensor = add_1D_constant_layer(network, index, is_scalar=is_scalar)
gather_layer = network.add_gather(input=x, indices=index_tensor, axis=0)
return gather_layer.get_output(0)
shape_tensor = resize_to_1d(network, gather_layer.get_output(0))
return shape_tensor


def trt_less(network, a, b):
Expand Down Expand Up @@ -414,7 +430,7 @@ def map_trt_dtype(trt_dtype):


# Reduce the given tensor in the TensorRT network to a scalar
def trt_reduce_to_scalar(network, tensor):
def trt_reduce_to_scalar(network, tensor, dtype=trt.int32):
if len(tensor.shape) == 0:
return tensor
axes = 0
Expand All @@ -423,7 +439,8 @@ def trt_reduce_to_scalar(network, tensor):
reduce_layer = network.add_reduce(
tensor, trt.ReduceOperation.SUM, axes, keep_dims=False
)
return reduce_layer.get_output(0)
scalar = trt_cast(network, reduce_layer.get_output(0), dtype)
return scalar


def convert_conv2d(network, paddle_op, inputs):
Expand Down Expand Up @@ -657,16 +674,3 @@ def squeeze_trt(network, input_tensor, axes):
reshape_layer = network.add_shuffle(input_tensor)
reshape_layer.set_input(1, new_shape_tensor)
return reshape_layer.get_output(0)


# resize shape tensor's shape to 1dim
def resize_to_1d(network, shape_tensor):
if len(shape_tensor.shape) > 1:
# shape_tensor need 1-dim in trt
shape_tensor_layer = network.add_shuffle(shape_tensor)
numel = 1
for ele in shape_tensor.shape:
numel *= ele
shape_tensor_layer.reshape_dims = [numel]
shape_tensor = shape_tensor_layer.get_output(0)
return shape_tensor
41 changes: 11 additions & 30 deletions python/paddle/tensorrt/impls/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
import tensorrt as trt

from paddle.tensorrt.converter_utils import get_shape_tensor_element, trt_shape
from paddle.tensorrt.converter_utils import get_shape_tensor_element
from paddle.tensorrt.register import converter_registry
from paddle.tensorrt.util import get_trt_version_list

Expand Down Expand Up @@ -53,6 +53,10 @@ def dropout_converter(network, paddle_op, inputs):
)
def bilinear_interp_converter(network, paddle_op, inputs):
input_tensor = inputs[0]
input_shape_tensor = network.add_shape(input_tensor).get_output(0)
input_rank = (
input_shape_tensor.shape
) # The reason is unknown that adding this unused code make input_shape_tensor maintain the correct result.
data_format = paddle_op.attrs().get("data_format")
interp_method = paddle_op.attrs().get("interp_method")
align_corners = paddle_op.attrs().get("align_corners")
Expand Down Expand Up @@ -141,7 +145,6 @@ def bilinear_interp_converter(network, paddle_op, inputs):
else:
if outsize_tensor is not None:
outsize_itensors = []
input_shape_tensor = trt_shape(network, input_tensor)
batch_dim = get_shape_tensor_element(network, input_shape_tensor, 0)
outsize_itensors.append(batch_dim)
if data_format == "NCHW":
Expand Down Expand Up @@ -169,6 +172,10 @@ def bilinear_interp_converter(network, paddle_op, inputs):
)
def nearest_interp_converter(network, paddle_op, inputs):
input_tensor = inputs[0]
input_shape_tensor = network.add_shape(input_tensor).get_output(0)
input_rank = (
input_shape_tensor.shape
) # The reason is unknown that adding this unused code make input_shape_tensor maintain the correct result.
data_format = paddle_op.attrs().get("data_format")
interp_method = paddle_op.attrs().get("interp_method")
align_corners = paddle_op.attrs().get("align_corners")
Expand Down Expand Up @@ -215,33 +222,8 @@ def nearest_interp_converter(network, paddle_op, inputs):
scale_w = float(out_w) / float(in_dim[w_axis])

outsize_tensor = None
if trt_version_float >= 8.2:
if len(inputs) > 2 and inputs[2] is not None:
size_tensor_operand = paddle_op.operands()[2].source()
if size_tensor_operand.is_combine():
size_tensors = inputs[2]
if not isinstance(size_tensors, list):
size_tensors = [size_tensors]
if len(size_tensors) >= 2:
# Extract the first two elements representing height and width
outsize_h = size_tensors[0]
outsize_w = size_tensors[1]
outsize_tensor = network.add_concatenation(
[outsize_h, outsize_w]
).get_output(0)
else:
size_tensor_shape = size_tensor_operand.source().shape
if size_tensor_shape.size >= 2:
size_tensor = inputs[2]
outsize_h = network.add_slice(
size_tensor, start=[0], shape=[1], stride=[1]
).get_output(0)
outsize_w = network.add_slice(
size_tensor, start=[1], shape=[1], stride=[1]
).get_output(0)
outsize_tensor = network.add_concatenation(
[outsize_h, outsize_w]
).get_output(0)
if inputs[2] is not None:
outsize_tensor = network.add_concatenation(inputs[2]).get_output(0)

scales = [1.0] * len(input_tensor.shape)
if data_format == "NCHW":
Expand All @@ -258,7 +240,6 @@ def nearest_interp_converter(network, paddle_op, inputs):
)
if outsize_tensor is not None:
outsize_itensors = []
input_shape_tensor = trt_shape(network, input_tensor)
batch_dim = get_shape_tensor_element(network, input_shape_tensor, 0)
outsize_itensors.append(batch_dim)
if data_format == "NCHW":
Expand Down
20 changes: 10 additions & 10 deletions python/paddle/tensorrt/impls/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import tensorrt as trt

import paddle
from paddle.pir.core import _PADDLE_PIR_DTYPE_2_NUMPY_DTYPE
from paddle.tensorrt.converter_utils import (
add_1D_constant_layer,
cast_tensor,
resize_to_1d,
trt_cast,
trt_floor_div,
trt_max,
Expand Down Expand Up @@ -46,10 +48,11 @@ def full_converter(network, paddle_op, inputs):
shape = paddle_op.attrs()["shape"]
value = paddle_op.attrs().get("value", 1.0)
dtype = paddle_op.attrs().get("dtype")
if dtype == paddle.int32 or dtype == paddle.int64:
out_dtype = np.int32
else:
out_dtype = np.float32
out_dtype = np.dtype(_PADDLE_PIR_DTYPE_2_NUMPY_DTYPE[dtype])
if out_dtype == np.dtype("float64"):
out_dtype = np.dtype("float32")
if out_dtype == np.dtype("int64"):
out_dtype = np.dtype("int32")
full_layer = network.add_constant(
shape, np.full(shape, value, dtype=out_dtype)
)
Expand Down Expand Up @@ -113,9 +116,7 @@ def arange_converter(network, paddle_op, inputs):

number_tensor = trt_max(network, quotient_tensor, zero_tensor)

reshape_start_layer = trt_reshape(network, start, (1,))

start_tensor = trt_reduce_to_scalar(network, reshape_start_layer)
start_tensor = trt_reshape(network, start, ())

fill_layer = network.add_fill(shape=(), op=trt.FillOperation.LINSPACE)
fill_layer.set_input(0, number_tensor)
Expand Down Expand Up @@ -237,8 +238,6 @@ def full_with_tensor_converter(network, paddle_op, inputs):
shape_tensor = shape_tensor_list[0]
if not isinstance(shape_tensor, trt.ITensor):
raise TypeError("shape_tensor must be an ITensor")
if len(shape_tensor.shape) != 1:
raise ValueError("The rank of shape_tensor must be 1")
tensor_rank = shape_tensor.shape[0]
shapes_tensor = shape_tensor
else:
Expand All @@ -252,6 +251,7 @@ def full_with_tensor_converter(network, paddle_op, inputs):
shapes_tensor = concat_layer.get_output(0)
tensor_rank = len(shape_tensors)

shapes_tensor = resize_to_1d(network, shapes_tensor)
fill_layer = network.add_fill(shape=(), op=trt.FillOperation.LINSPACE)
fill_layer.set_input(0, shapes_tensor)

Expand All @@ -264,7 +264,7 @@ def full_with_tensor_converter(network, paddle_op, inputs):
)
elif dtype == paddle.float32:
beta_vec = [0.0] * tensor_rank
value_input = trt_reduce_to_scalar(network, value_input)
value_input = trt_reduce_to_scalar(network, value_input, trt.float32)
fill_layer.set_input(1, value_input)
fill_layer.set_input(
2, add_1D_constant_layer(network, beta_vec, np.float32)
Expand Down
Loading

0 comments on commit 0f66ede

Please sign in to comment.