From b755b883ac90925f177064b5db7b084419929550 Mon Sep 17 00:00:00 2001 From: zhoushunjie Date: Sat, 3 Dec 2022 16:38:36 +0000 Subject: [PATCH 1/2] Add DisablePaddleTrtOPs --- fastdeploy/backends/paddle/paddle_backend.cc | 183 +++++++++++-------- fastdeploy/backends/paddle/paddle_backend.h | 26 +-- fastdeploy/runtime.cc | 137 +++++++------- fastdeploy/runtime.h | 52 +++--- 4 files changed, 213 insertions(+), 185 deletions(-) mode change 100755 => 100644 fastdeploy/backends/paddle/paddle_backend.h mode change 100755 => 100644 fastdeploy/runtime.cc diff --git a/fastdeploy/backends/paddle/paddle_backend.cc b/fastdeploy/backends/paddle/paddle_backend.cc index 866bf578e2..49abf02b11 100644 --- a/fastdeploy/backends/paddle/paddle_backend.cc +++ b/fastdeploy/backends/paddle/paddle_backend.cc @@ -22,24 +22,34 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) { option_ = option; if (option.use_gpu) { config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id); - if(option_.external_stream_) { + if (option_.external_stream_) { config_.SetExecStream(option_.external_stream_); } if (option.enable_trt) { #ifdef ENABLE_TRT_BACKEND + config_.Exp_DisableTensorRtOPs(option.trt_disabled_ops_); auto precision = paddle_infer::PrecisionType::kFloat32; if (option.trt_option.enable_fp16) { precision = paddle_infer::PrecisionType::kHalf; } bool use_static = false; if (option.trt_option.serialize_file != "") { - FDWARNING << "Detect that tensorrt cache file has been set to " << option.trt_option.serialize_file << ", but while enable paddle2trt, please notice that the cache file will save to the directory where paddle model saved." << std::endl; + FDWARNING + << "Detect that tensorrt cache file has been set to " + << option.trt_option.serialize_file + << ", but while enable paddle2trt, please notice that the cache " + "file will save to the directory where paddle model saved." + << std::endl; use_static = true; } - config_.EnableTensorRtEngine(option.trt_option.max_workspace_size, option.trt_option.max_batch_size, 3, precision, use_static); + config_.EnableTensorRtEngine(option.trt_option.max_workspace_size, + option.trt_option.max_batch_size, 3, + precision, use_static); SetTRTDynamicShapeToConfig(option); #else - FDWARNING << "The FastDeploy is not compiled with TensorRT backend, so will fallback to GPU with Paddle Inference Backend." << std::endl; + FDWARNING << "The FastDeploy is not compiled with TensorRT backend, so " + "will fallback to GPU with Paddle Inference Backend." + << std::endl; #endif } } else if (option.use_ipu) { @@ -98,39 +108,48 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file, if (!ReadBinaryFromFile(model_file, &contents)) { return false; } - auto reader = - paddle2onnx::PaddleReader(contents.c_str(), contents.size()); + auto reader = paddle2onnx::PaddleReader(contents.c_str(), contents.size()); // If it's a quantized model, and use cpu with mkldnn, automaticaly switch to int8 mode if (reader.is_quantize_model) { if (option.use_gpu) { - FDWARNING << "The loaded model is a quantized model, while inference on GPU, please use TensorRT backend to get better performance." << std::endl; + FDWARNING << "The loaded model is a quantized model, while inference on " + "GPU, please use TensorRT backend to get better performance." + << std::endl; if (option.enable_trt) { #ifdef ENABLE_TRT_BACKEND bool use_static = false; if (option.trt_option.serialize_file != "") { - FDWARNING << "Detect that tensorrt cache file has been set to " << option.trt_option.serialize_file << ", but while enable paddle2trt, please notice that the cache file will save to the directory where paddle model saved." << std::endl; + FDWARNING + << "Detect that tensorrt cache file has been set to " + << option.trt_option.serialize_file + << ", but while enable paddle2trt, please notice that the cache " + "file will save to the directory where paddle model saved." + << std::endl; use_static = true; } - config_.EnableTensorRtEngine(option.trt_option.max_workspace_size, option.trt_option.max_batch_size, 3, paddle_infer::PrecisionType::kInt8, use_static, false); + config_.EnableTensorRtEngine(option.trt_option.max_workspace_size, + option.trt_option.max_batch_size, 3, + paddle_infer::PrecisionType::kInt8, + use_static, false); SetTRTDynamicShapeToConfig(option); - #endif } } if (option.enable_mkldnn) { config_.EnableMkldnnInt8(); } else { - FDWARNING << "The loaded model is a quantized model, while inference on CPU, please enable MKLDNN to get better performance." << std::endl; + FDWARNING << "The loaded model is a quantized model, while inference on " + "CPU, please enable MKLDNN to get better performance." + << std::endl; } } inputs_desc_.resize(reader.num_inputs); for (int i = 0; i < reader.num_inputs; ++i) { std::string name(reader.inputs[i].name); - std::vector shape( - reader.inputs[i].shape, - reader.inputs[i].shape + reader.inputs[i].rank); + std::vector shape(reader.inputs[i].shape, + reader.inputs[i].shape + reader.inputs[i].rank); inputs_desc_[i].name = name; inputs_desc_[i].shape.assign(shape.begin(), shape.end()); inputs_desc_[i].dtype = ReaderDataTypeToFD(reader.inputs[i].dtype); @@ -138,7 +157,9 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file, outputs_desc_.resize(reader.num_outputs); for (int i = 0; i < reader.num_outputs; ++i) { std::string name(reader.outputs[i].name); - std::vector shape(reader.outputs[i].shape, reader.outputs[i].shape + reader.outputs[i].rank); + std::vector shape(reader.outputs[i].shape, + reader.outputs[i].shape + + reader.outputs[i].rank); outputs_desc_[i].name = name; outputs_desc_[i].shape.assign(shape.begin(), shape.end()); outputs_desc_[i].dtype = ReaderDataTypeToFD(reader.outputs[i].dtype); @@ -147,7 +168,8 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file, if (option.collect_shape) { // Set the shape info file. auto curr_model_dir = GetDirFromPath(model_file); - std::string shape_range_info = PathJoin(curr_model_dir, "shape_range_info.pbtxt"); + std::string shape_range_info = + PathJoin(curr_model_dir, "shape_range_info.pbtxt"); if (!CheckFileExists(shape_range_info)) { FDINFO << "Start generating shape range info file." << std::endl; paddle_infer::Config analysis_config; @@ -164,7 +186,8 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file, CollectShapeRun(predictor_tmp.get(), opt_shape); FDINFO << "Finish generating shape range info file." << std::endl; } - FDINFO << "Start loading shape range info file "<< shape_range_info << " to set TensorRT dynamic shape." << std::endl; + FDINFO << "Start loading shape range info file " << shape_range_info + << " to set TensorRT dynamic shape." << std::endl; config_.EnableTunedTensorRtDynamicShape(shape_range_info, false); } #endif @@ -194,8 +217,7 @@ std::vector PaddleBackend::GetOutputInfos() { } bool PaddleBackend::Infer(std::vector& inputs, - std::vector* outputs, - bool copy_to_fd) { + std::vector* outputs, bool copy_to_fd) { if (inputs.size() != inputs_desc_.size()) { FDERROR << "[PaddleBackend] Size of inputs(" << inputs.size() << ") should keep same with the inputs of this model(" @@ -211,13 +233,13 @@ bool PaddleBackend::Infer(std::vector& inputs, predictor_->Run(); // output share backend memory only support CPU or GPU - if(option_.use_ipu) { + if (option_.use_ipu) { copy_to_fd = true; } outputs->resize(outputs_desc_.size()); for (size_t i = 0; i < outputs_desc_.size(); ++i) { auto handle = predictor_->GetOutputHandle(outputs_desc_[i].name); - if(copy_to_fd) { + if (copy_to_fd) { (*outputs)[i].is_pinned_memory = option_.enable_pinned_memory; } PaddleTensorToFDTensor(handle, &((*outputs)[i]), copy_to_fd); @@ -225,47 +247,47 @@ bool PaddleBackend::Infer(std::vector& inputs, return true; } -std::unique_ptr PaddleBackend::Clone(void *stream, int device_id) { - std::unique_ptr new_backend = utils::make_unique(); +std::unique_ptr PaddleBackend::Clone(void* stream, int device_id) { + std::unique_ptr new_backend = + utils::make_unique(); auto casted_backend = dynamic_cast(new_backend.get()); - if(device_id > 0 && option_.use_gpu == true && device_id != option_.gpu_id) { + if (device_id > 0 && option_.use_gpu == true && device_id != option_.gpu_id) { auto clone_option = option_; clone_option.gpu_id = device_id; clone_option.external_stream_ = stream; casted_backend->InitFromPaddle(clone_option.model_file, - clone_option.params_file, - clone_option); - FDWARNING << "The target device id:" - << device_id - << " is different from current device id:" - << option_.gpu_id - << ", cannot share memory with current engine." - << std::endl; + clone_option.params_file, clone_option); + FDWARNING << "The target device id:" << device_id + << " is different from current device id:" << option_.gpu_id + << ", cannot share memory with current engine." << std::endl; return new_backend; } casted_backend->inputs_desc_.assign(inputs_desc_.begin(), inputs_desc_.end()); - casted_backend->outputs_desc_.assign(outputs_desc_.begin(), outputs_desc_.end()); + casted_backend->outputs_desc_.assign(outputs_desc_.begin(), + outputs_desc_.end()); casted_backend->predictor_ = std::move(predictor_->Clone(stream)); return new_backend; } #ifdef ENABLE_TRT_BACKEND -void PaddleBackend::SetTRTDynamicShapeToConfig(const PaddleBackendOption& option) { - std::map> max_shape; - std::map> min_shape; - std::map> opt_shape; - GetDynamicShapeFromOption(option, &max_shape, &min_shape, &opt_shape); +void PaddleBackend::SetTRTDynamicShapeToConfig( + const PaddleBackendOption& option) { + std::map> max_shape; + std::map> min_shape; + std::map> opt_shape; + GetDynamicShapeFromOption(option, &max_shape, &min_shape, &opt_shape); + if (min_shape.size() > 0) { FDINFO << "Start setting trt dynamic shape." << std::endl; - if (min_shape.size() > 0) { - config_.SetTRTDynamicShapeInfo(min_shape, max_shape, opt_shape); - } + config_.SetTRTDynamicShapeInfo(min_shape, max_shape, opt_shape); FDINFO << "Finish setting trt dynamic shape." << std::endl; + } } -void PaddleBackend::GetDynamicShapeFromOption(const PaddleBackendOption& option, - std::map>* max_shape, - std::map>* min_shape, - std::map>* opt_shape) const { +void PaddleBackend::GetDynamicShapeFromOption( + const PaddleBackendOption& option, + std::map>* max_shape, + std::map>* min_shape, + std::map>* opt_shape) const { auto print_shape = [](const std::vector& shape) -> std::string { std::ostringstream oss; oss << "["; @@ -281,24 +303,35 @@ void PaddleBackend::GetDynamicShapeFromOption(const PaddleBackendOption& option, for (const auto& item : option.trt_option.min_shape) { auto max_iter = option.trt_option.max_shape.find(item.first); auto opt_iter = option.trt_option.opt_shape.find(item.first); - FDASSERT(max_iter != option.trt_option.max_shape.end(), "Cannot find %s in TrtBackendOption::min_shape.", item.first.c_str()); - FDASSERT(opt_iter != option.trt_option.opt_shape.end(), "Cannot find %s in TrtBackendOption::opt_shape.", item.first.c_str()); - (*max_shape)[item.first].assign(max_iter->second.begin(), max_iter->second.end()); - (*opt_shape)[item.first].assign(opt_iter->second.begin(), opt_iter->second.end()); + FDASSERT(max_iter != option.trt_option.max_shape.end(), + "Cannot find %s in TrtBackendOption::min_shape.", + item.first.c_str()); + FDASSERT(opt_iter != option.trt_option.opt_shape.end(), + "Cannot find %s in TrtBackendOption::opt_shape.", + item.first.c_str()); + (*max_shape)[item.first].assign(max_iter->second.begin(), + max_iter->second.end()); + (*opt_shape)[item.first].assign(opt_iter->second.begin(), + opt_iter->second.end()); (*min_shape)[item.first].assign(item.second.begin(), item.second.end()); - FDINFO << item.first << ": the max shape = " << print_shape(max_iter->second) + FDINFO << item.first + << ": the max shape = " << print_shape(max_iter->second) << ", the min shape = " << print_shape(item.second) - << ", the opt shape = " << print_shape(opt_iter->second) << std::endl; + << ", the opt shape = " << print_shape(opt_iter->second) + << std::endl; } } -void PaddleBackend::CollectShapeRun(paddle_infer::Predictor* predictor, +void PaddleBackend::CollectShapeRun( + paddle_infer::Predictor* predictor, const std::map>& shape) const { auto input_names = predictor->GetInputNames(); auto input_type = predictor->GetInputTypes(); - for(auto name : input_names) { - FDASSERT(shape.find(name) != shape.end() && input_type.find(name) != input_type.end(), - "Paddle Input name [%s] is not one of the trt dynamic shape.", name.c_str()); + for (auto name : input_names) { + FDASSERT(shape.find(name) != shape.end() && + input_type.find(name) != input_type.end(), + "Paddle Input name [%s] is not one of the trt dynamic shape.", + name.c_str()); auto tensor = predictor->GetInputHandle(name); auto shape_value = shape.at(name); int shape_num = std::accumulate(shape_value.begin(), shape_value.end(), 1, @@ -306,30 +339,30 @@ void PaddleBackend::CollectShapeRun(paddle_infer::Predictor* predictor, tensor->Reshape(shape_value); auto dtype = input_type[name]; switch (dtype) { - case paddle_infer::DataType::FLOAT32: { - std::vector input_data(shape_num, 1.0); - tensor->CopyFromCpu(input_data.data()); - break; - } - case paddle_infer::DataType::INT32: { - std::vector input_data(shape_num, 1); - tensor->CopyFromCpu(input_data.data()); - break; - } - case paddle_infer::DataType::INT64: { - std::vector input_data(shape_num, 1); - tensor->CopyFromCpu(input_data.data()); - break; - } - default: { - FDASSERT(false, "Input data Paddle backend only supports FP32/INT32/INT64 currently."); - break; - } + case paddle_infer::DataType::FLOAT32: { + std::vector input_data(shape_num, 1.0); + tensor->CopyFromCpu(input_data.data()); + break; + } + case paddle_infer::DataType::INT32: { + std::vector input_data(shape_num, 1); + tensor->CopyFromCpu(input_data.data()); + break; + } + case paddle_infer::DataType::INT64: { + std::vector input_data(shape_num, 1); + tensor->CopyFromCpu(input_data.data()); + break; + } + default: { + FDASSERT(false, "Input data Paddle backend only supports " + "FP32/INT32/INT64 currently."); + break; + } } } predictor->Run(); } #endif - } // namespace fastdeploy diff --git a/fastdeploy/backends/paddle/paddle_backend.h b/fastdeploy/backends/paddle/paddle_backend.h old mode 100755 new mode 100644 index ba083ae431..2df0c67399 --- a/fastdeploy/backends/paddle/paddle_backend.h +++ b/fastdeploy/backends/paddle/paddle_backend.h @@ -23,8 +23,8 @@ #ifdef ENABLE_PADDLE_FRONTEND #include "paddle2onnx/converter.h" #endif -#include "paddle_inference_api.h" // NOLINT #include "fastdeploy/utils/unique_ptr.h" +#include "paddle_inference_api.h" // NOLINT #ifdef ENABLE_TRT_BACKEND #include "fastdeploy/backends/tensorrt/trt_backend.h" @@ -60,6 +60,7 @@ struct PaddleBackendOption { #ifdef ENABLE_TRT_BACKEND TrtBackendOption trt_option; bool collect_shape = false; + std::vector trt_disabled_ops_{}; #endif #ifdef WITH_IPU @@ -91,8 +92,7 @@ void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor, FDTensor& fd_tensor); // if copy_to_fd is true, copy memory data to FDTensor /// else share memory to FDTensor void PaddleTensorToFDTensor(std::unique_ptr& tensor, - FDTensor* fd_tensor, - bool copy_to_fd); + FDTensor* fd_tensor, bool copy_to_fd); // Convert data type from paddle inference to fastdeploy FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype); @@ -106,20 +106,18 @@ class PaddleBackend : public BaseBackend { virtual ~PaddleBackend() = default; void BuildOption(const PaddleBackendOption& option); - bool InitFromPaddle( - const std::string& model_file, const std::string& params_file, - const PaddleBackendOption& option = PaddleBackendOption()); + bool + InitFromPaddle(const std::string& model_file, const std::string& params_file, + const PaddleBackendOption& option = PaddleBackendOption()); - bool Infer(std::vector& inputs, - std::vector* outputs, + bool Infer(std::vector& inputs, std::vector* outputs, bool copy_to_fd = true) override; - int NumInputs() const override { return inputs_desc_.size(); } int NumOutputs() const override { return outputs_desc_.size(); } - std::unique_ptr Clone(void *stream = nullptr, + std::unique_ptr Clone(void* stream = nullptr, int device_id = -1) override; TensorInfo GetInputInfo(int index) override; @@ -129,9 +127,11 @@ class PaddleBackend : public BaseBackend { private: #ifdef ENABLE_TRT_BACKEND - void CollectShapeRun(paddle_infer::Predictor* predictor, - const std::map>& shape) const; - void GetDynamicShapeFromOption(const PaddleBackendOption& option, + void + CollectShapeRun(paddle_infer::Predictor* predictor, + const std::map>& shape) const; + void GetDynamicShapeFromOption( + const PaddleBackendOption& option, std::map>* max_shape, std::map>* min_shape, std::map>* opt_shape) const; diff --git a/fastdeploy/runtime.cc b/fastdeploy/runtime.cc old mode 100755 new mode 100644 index 088cf273ba..1a51cebea8 --- a/fastdeploy/runtime.cc +++ b/fastdeploy/runtime.cc @@ -94,7 +94,7 @@ std::string Str(const Backend& b) { return "Backend::POROS"; } else if (b == Backend::RKNPU2) { return "Backend::RKNPU2"; - }else if (b == Backend::OPENVINO) { + } else if (b == Backend::OPENVINO) { return "Backend::OPENVINO"; } else if (b == Backend::LITE) { return "Backend::PDLITE"; @@ -113,7 +113,7 @@ std::ostream& operator<<(std::ostream& out, const Backend& backend) { out << "Backend::OPENVINO"; } else if (backend == Backend::RKNPU2) { out << "Backend::RKNPU2"; - }else if (backend == Backend::POROS) { + } else if (backend == Backend::POROS) { out << "Backend::POROS"; } else if (backend == Backend::LITE) { out << "Backend::PDLITE"; @@ -152,15 +152,17 @@ bool CheckModelFormat(const std::string& model_file, } else if (model_format == ModelFormat::TORCHSCRIPT) { if (model_file.size() < 3 || model_file.substr(model_file.size() - 3, 3) != ".pt") { - FDERROR << "With model format of ModelFormat::TORCHSCRIPT, the model file " - "should ends with `.pt`, but now it's " - << model_file << std::endl; + FDERROR + << "With model format of ModelFormat::TORCHSCRIPT, the model file " + "should ends with `.pt`, but now it's " + << model_file << std::endl; return false; } } else { - FDERROR << "Only support model format with frontend ModelFormat::PADDLE / " - "ModelFormat::ONNX / ModelFormat::RKNN / ModelFormat::TORCHSCRIPT." - << std::endl; + FDERROR + << "Only support model format with frontend ModelFormat::PADDLE / " + "ModelFormat::ONNX / ModelFormat::RKNN / ModelFormat::TORCHSCRIPT." + << std::endl; return false; } return true; @@ -205,9 +207,9 @@ void RuntimeOption::SetModelPath(const std::string& model_path, model_file = model_path; model_format = ModelFormat::TORCHSCRIPT; } else { - FDASSERT( - false, - "The model format only can be ModelFormat::PADDLE/ModelFormat::ONNX/ModelFormat::TORCHSCRIPT."); + FDASSERT(false, + "The model format only can be " + "ModelFormat::PADDLE/ModelFormat::ONNX/ModelFormat::TORCHSCRIPT."); } } @@ -317,13 +319,18 @@ void RuntimeOption::EnablePaddleLogInfo() { pd_enable_log_info = true; } void RuntimeOption::DisablePaddleLogInfo() { pd_enable_log_info = false; } void RuntimeOption::EnablePaddleToTrt() { - FDASSERT(backend == Backend::TRT, "Should call UseTrtBackend() before call EnablePaddleToTrt()."); + FDASSERT(backend == Backend::TRT, + "Should call UseTrtBackend() before call EnablePaddleToTrt()."); #ifdef ENABLE_PADDLE_BACKEND - FDINFO << "While using TrtBackend with EnablePaddleToTrt, FastDeploy will change to use Paddle Inference Backend." << std::endl; + FDINFO << "While using TrtBackend with EnablePaddleToTrt, FastDeploy will " + "change to use Paddle Inference Backend." + << std::endl; backend = Backend::PDINFER; pd_enable_trt = true; #else - FDASSERT(false, "While using TrtBackend with EnablePaddleToTrt, require the FastDeploy is compiled with Paddle Inference Backend, please rebuild your FastDeploy."); + FDASSERT(false, "While using TrtBackend with EnablePaddleToTrt, require the " + "FastDeploy is compiled with Paddle Inference Backend, " + "please rebuild your FastDeploy."); #endif } @@ -336,20 +343,12 @@ void RuntimeOption::SetOpenVINODevice(const std::string& name) { openvino_device = name; } -void RuntimeOption::EnableLiteFP16() { - lite_enable_fp16 = true; -} +void RuntimeOption::EnableLiteFP16() { lite_enable_fp16 = true; } -void RuntimeOption::DisableLiteFP16() { - lite_enable_fp16 = false; -} -void RuntimeOption::EnableLiteInt8() { - lite_enable_int8 = true; -} +void RuntimeOption::DisableLiteFP16() { lite_enable_fp16 = false; } +void RuntimeOption::EnableLiteInt8() { lite_enable_int8 = true; } -void RuntimeOption::DisableLiteInt8() { - lite_enable_int8 = false; -} +void RuntimeOption::DisableLiteInt8() { lite_enable_int8 = false; } void RuntimeOption::SetLitePowerMode(LitePowerMode mode) { lite_power_mode = mode; } @@ -361,7 +360,8 @@ void RuntimeOption::SetLiteOptimizedModelDir( void RuntimeOption::SetLiteSubgraphPartitionPath( const std::string& nnadapter_subgraph_partition_config_path) { - lite_nnadapter_subgraph_partition_config_path = nnadapter_subgraph_partition_config_path; + lite_nnadapter_subgraph_partition_config_path = + nnadapter_subgraph_partition_config_path; } void RuntimeOption::SetTrtInputShape(const std::string& input_name, @@ -387,8 +387,8 @@ void RuntimeOption::SetTrtInputShape(const std::string& input_name, void RuntimeOption::SetTrtMaxWorkspaceSize(size_t max_workspace_size) { trt_max_workspace_size = max_workspace_size; } -void RuntimeOption::SetTrtMaxBatchSize(size_t max_batch_size){ - trt_max_batch_size = max_batch_size; +void RuntimeOption::SetTrtMaxBatchSize(size_t max_batch_size) { + trt_max_batch_size = max_batch_size; } void RuntimeOption::EnableTrtFP16() { trt_enable_fp16 = true; } @@ -422,27 +422,27 @@ bool Runtime::Compile(std::vector>& prewarm_tensors, poros_option.enable_fp16 = option.trt_enable_fp16; poros_option.max_batch_size = option.trt_max_batch_size; poros_option.max_workspace_size = option.trt_max_workspace_size; - FDASSERT(option.model_format == ModelFormat::TORCHSCRIPT, - "PorosBackend only support model format of ModelFormat::TORCHSCRIPT."); + FDASSERT( + option.model_format == ModelFormat::TORCHSCRIPT, + "PorosBackend only support model format of ModelFormat::TORCHSCRIPT."); backend_ = utils::make_unique(); auto casted_backend = dynamic_cast(backend_.get()); FDASSERT( casted_backend->Compile(option.model_file, prewarm_tensors, poros_option), "Load model from Torchscript failed while initliazing PorosBackend."); #else - FDASSERT(false, - "PorosBackend is not available, please compiled with " - "ENABLE_POROS_BACKEND=ON."); + FDASSERT(false, "PorosBackend is not available, please compiled with " + "ENABLE_POROS_BACKEND=ON."); #endif return true; } -void RuntimeOption::EnablePaddleTrtCollectShape() { - pd_collect_shape = true; -} +void RuntimeOption::EnablePaddleTrtCollectShape() { pd_collect_shape = true; } -void RuntimeOption::DisablePaddleTrtCollectShape() { - pd_collect_shape = false; +void RuntimeOption::DisablePaddleTrtCollectShape() { pd_collect_shape = false; } + +void RuntimeOption::DisablePaddleTrtOPs(const std::vector& ops) { + trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end()); } void RuntimeOption::UseIpu(int device_num, int micro_batch_size, @@ -519,9 +519,9 @@ bool Runtime::Init(const RuntimeOption& _option) { } else if (option.backend == Backend::POROS) { FDASSERT(option.device == Device::CPU || option.device == Device::GPU, "Backend::POROS only supports Device::CPU/Device::GPU."); - FDASSERT( - option.model_format == ModelFormat::TORCHSCRIPT, - "Backend::POROS only supports model format of ModelFormat::TORCHSCRIPT."); + FDASSERT(option.model_format == ModelFormat::TORCHSCRIPT, + "Backend::POROS only supports model format of " + "ModelFormat::TORCHSCRIPT."); FDINFO << "Runtime initialized with Backend::POROS in " << Str(option.device) << "." << std::endl; return true; @@ -572,7 +572,7 @@ std::vector Runtime::GetOutputInfos() { bool Runtime::Infer(std::vector& input_tensors, std::vector* output_tensors) { - for (auto& tensor: input_tensors) { + for (auto& tensor : input_tensors) { FDASSERT(tensor.device_id < 0 || tensor.device_id == option.device_id, "Device id of input tensor(%d) and runtime(%d) are not same.", tensor.device_id, option.device_id); @@ -589,17 +589,15 @@ void Runtime::BindInputTensor(const std::string& name, FDTensor& input) { for (auto& t : input_tensors_) { if (t.name == name) { is_exist = true; - t.SetExternalData(input.shape, input.dtype, - input.MutableData(), input.device, - input.device_id); + t.SetExternalData(input.shape, input.dtype, input.MutableData(), + input.device, input.device_id); break; } } - if(!is_exist) { + if (!is_exist) { FDTensor new_tensor(name); - new_tensor.SetExternalData(input.shape, input.dtype, - input.MutableData(), input.device, - input.device_id); + new_tensor.SetExternalData(input.shape, input.dtype, input.MutableData(), + input.device, input.device_id); input_tensors_.emplace_back(std::move(new_tensor)); } } @@ -644,6 +642,7 @@ void Runtime::CreatePaddleBackend() { trt_option.serialize_file = option.trt_serialize_file; trt_option.enable_pinned_memory = option.enable_pinned_memory; pd_option.trt_option = trt_option; + pd_option.trt_disabled_ops_ = option.trt_disabled_ops_; } #endif #ifdef WITH_IPU @@ -669,9 +668,8 @@ void Runtime::CreatePaddleBackend() { pd_option), "Load model from Paddle failed while initliazing PaddleBackend."); #else - FDASSERT(false, - "PaddleBackend is not available, please compiled with " - "ENABLE_PADDLE_BACKEND=ON."); + FDASSERT(false, "PaddleBackend is not available, please compiled with " + "ENABLE_PADDLE_BACKEND=ON."); #endif } @@ -701,9 +699,8 @@ void Runtime::CreateOpenVINOBackend() { "Load model from Paddle failed while initliazing OrtBackend."); } #else - FDASSERT(false, - "OpenVINOBackend is not available, please compiled with " - "ENABLE_OPENVINO_BACKEND=ON."); + FDASSERT(false, "OpenVINOBackend is not available, please compiled with " + "ENABLE_OPENVINO_BACKEND=ON."); #endif } @@ -733,9 +730,8 @@ void Runtime::CreateOrtBackend() { "Load model from Paddle failed while initliazing OrtBackend."); } #else - FDASSERT(false, - "OrtBackend is not available, please compiled with " - "ENABLE_ORT_BACKEND=ON."); + FDASSERT(false, "OrtBackend is not available, please compiled with " + "ENABLE_ORT_BACKEND=ON."); #endif } @@ -772,9 +768,8 @@ void Runtime::CreateTrtBackend() { "Load model from Paddle failed while initliazing TrtBackend."); } #else - FDASSERT(false, - "TrtBackend is not available, please compiled with " - "ENABLE_TRT_BACKEND=ON."); + FDASSERT(false, "TrtBackend is not available, please compiled with " + "ENABLE_TRT_BACKEND=ON."); #endif } @@ -786,7 +781,8 @@ void Runtime::CreateLiteBackend() { lite_option.enable_fp16 = option.lite_enable_fp16; lite_option.power_mode = static_cast(option.lite_power_mode); lite_option.optimized_model_dir = option.lite_optimized_model_dir; - lite_option.nnadapter_subgraph_partition_config_path = option.lite_nnadapter_subgraph_partition_config_path; + lite_option.nnadapter_subgraph_partition_config_path = + option.lite_nnadapter_subgraph_partition_config_path; lite_option.enable_timvx = option.enable_timvx; FDASSERT(option.model_format == ModelFormat::PADDLE, "LiteBackend only support model format of ModelFormat::PADDLE"); @@ -796,9 +792,8 @@ void Runtime::CreateLiteBackend() { lite_option), "Load model from nb file failed while initializing LiteBackend."); #else - FDASSERT(false, - "LiteBackend is not available, please compiled with " - "ENABLE_LITE_BACKEND=ON."); + FDASSERT(false, "LiteBackend is not available, please compiled with " + "ENABLE_LITE_BACKEND=ON."); #endif } @@ -821,10 +816,8 @@ void Runtime::CreateRKNPU2Backend() { Runtime* Runtime::Clone(void* stream, int device_id) { Runtime* runtime = new Runtime(); - if (option.backend != Backend::OPENVINO - && option.backend != Backend::PDINFER - && option.backend != Backend::TRT - ) { + if (option.backend != Backend::OPENVINO && + option.backend != Backend::PDINFER && option.backend != Backend::TRT) { runtime->Init(option); FDWARNING << "Only OpenVINO/Paddle Inference/TensorRT support \ clone engine to reduce CPU/GPU memory usage now. For " @@ -834,8 +827,8 @@ Runtime* Runtime::Clone(void* stream, int device_id) { << std::endl; return runtime; } - FDINFO << "Runtime Clone with Backend:: " << Str(option.backend) << " in " << Str(option.device) - << "." << std::endl; + FDINFO << "Runtime Clone with Backend:: " << Str(option.backend) << " in " + << Str(option.device) << "." << std::endl; runtime->option = option; runtime->backend_ = backend_->Clone(stream, device_id); return runtime; diff --git a/fastdeploy/runtime.h b/fastdeploy/runtime.h index e96643345b..e53c7ca1ed 100644 --- a/fastdeploy/runtime.h +++ b/fastdeploy/runtime.h @@ -24,9 +24,9 @@ #include #include +#include "backends/rknpu/rknpu2/rknpu2_config.h" #include "fastdeploy/backends/backend.h" #include "fastdeploy/utils/perf.h" -#include "backends/rknpu/rknpu2/rknpu2_config.h" /** \brief All C++ FastDeploy APIs are defined inside this namespace * @@ -35,14 +35,14 @@ namespace fastdeploy { /*! Inference backend supported in FastDeploy */ enum Backend { - UNKNOWN, ///< Unknown inference backend - ORT, ///< ONNX Runtime, support Paddle/ONNX format model, CPU / Nvidia GPU - TRT, ///< TensorRT, support Paddle/ONNX format model, Nvidia GPU only + UNKNOWN, ///< Unknown inference backend + ORT, ///< ONNX Runtime, support Paddle/ONNX format model, CPU / Nvidia GPU + TRT, ///< TensorRT, support Paddle/ONNX format model, Nvidia GPU only PDINFER, ///< Paddle Inference, support Paddle format model, CPU / Nvidia GPU POROS, ///< Poros, support TorchScript format model, CPU / Nvidia GPU OPENVINO, ///< Intel OpenVINO, support Paddle/ONNX format, CPU only - LITE, ///< Paddle Lite, support Paddle format model, ARM CPU only - RKNPU2, ///< RKNPU2, support RKNN format model, Rockchip NPU only + LITE, ///< Paddle Lite, support Paddle format model, ARM CPU only + RKNPU2, ///< RKNPU2, support RKNN format model, Rockchip NPU only }; FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out, @@ -94,10 +94,10 @@ struct FASTDEPLOY_DECL RuntimeOption { /// Use Nvidia GPU to inference void UseGpu(int gpu_id = 0); - void UseRKNPU2(fastdeploy::rknpu2::CpuName rknpu2_name - = fastdeploy::rknpu2::CpuName::RK3588, - fastdeploy::rknpu2::CoreMask rknpu2_core - = fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_0); + void UseRKNPU2(fastdeploy::rknpu2::CpuName rknpu2_name = + fastdeploy::rknpu2::CpuName::RK3588, + fastdeploy::rknpu2::CoreMask rknpu2_core = + fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_0); /// Use TimVX to inference void UseTimVX(); @@ -116,9 +116,7 @@ struct FASTDEPLOY_DECL RuntimeOption { void UsePaddleBackend(); /// Wrapper function of UsePaddleBackend() - void UsePaddleInferBackend() { - return UsePaddleBackend(); - } + void UsePaddleInferBackend() { return UsePaddleBackend(); } /// Set ONNX Runtime as inference backend, support CPU/GPU void UseOrtBackend(); @@ -136,9 +134,7 @@ struct FASTDEPLOY_DECL RuntimeOption { void UseLiteBackend(); /// Wrapper function of UseLiteBackend() - void UsePaddleLiteBackend() { - return UseLiteBackend(); - } + void UsePaddleLiteBackend() { return UseLiteBackend(); } /// Set mkldnn switch while using Paddle Inference as inference backend void SetPaddleMKLDNN(bool pd_mkldnn = true); @@ -177,7 +173,7 @@ struct FASTDEPLOY_DECL RuntimeOption { * @brief Set shape info for OpenVINO */ void SetOpenVINOShapeInfo( - const std::map>& shape_info) { + const std::map>& shape_info) { ov_shape_infos = shape_info; } @@ -197,7 +193,7 @@ struct FASTDEPLOY_DECL RuntimeOption { * @brief Set nnadapter subgraph partition path for Paddle Lite backend. */ void SetLiteSubgraphPartitionPath( - const std::string& nnadapter_subgraph_partition_config_path); + const std::string& nnadapter_subgraph_partition_config_path); /** * @brief enable half precision while use paddle lite backend @@ -275,6 +271,11 @@ struct FASTDEPLOY_DECL RuntimeOption { */ void DisablePaddleTrtCollectShape(); + /** + * @brief Prevent ops running in paddle trt backend + */ + void DisablePaddleTrtOPs(const std::vector& ops); + /* * @brief Set number of streams by the OpenVINO backends */ @@ -363,6 +364,8 @@ struct FASTDEPLOY_DECL RuntimeOption { bool trt_enable_int8 = false; size_t trt_max_batch_size = 32; size_t trt_max_workspace_size = 1 << 30; + // ======Only for PaddleTrt Backend======= + std::vector trt_disabled_ops_{}; // ======Only for Poros Backend======= bool is_dynamic = false; @@ -378,12 +381,12 @@ struct FASTDEPLOY_DECL RuntimeOption { std::vector ov_cpu_operators; // ======Only for RKNPU2 Backend======= - fastdeploy::rknpu2::CpuName rknpu2_cpu_name_ - = fastdeploy::rknpu2::CpuName::RK3588; - fastdeploy::rknpu2::CoreMask rknpu2_core_mask_ - = fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_AUTO; + fastdeploy::rknpu2::CpuName rknpu2_cpu_name_ = + fastdeploy::rknpu2::CpuName::RK3588; + fastdeploy::rknpu2::CoreMask rknpu2_core_mask_ = + fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_AUTO; - std::string model_file = ""; // Path of model file + std::string model_file = ""; // Path of model file std::string params_file = ""; // Path of parameters file, can be empty // format of input model ModelFormat model_format = ModelFormat::AUTOREC; @@ -450,8 +453,7 @@ struct FASTDEPLOY_DECL Runtime { * \param[in] stream CUDA Stream, defualt param is nullptr * \return new Runtime* by this clone */ - Runtime* Clone(void* stream = nullptr, - int device_id = -1); + Runtime* Clone(void* stream = nullptr, int device_id = -1); RuntimeOption option; From 414cbc21987ddf21dc9948ccb6220db574760086 Mon Sep 17 00:00:00 2001 From: zhoushunjie Date: Sun, 4 Dec 2022 10:25:36 +0000 Subject: [PATCH 2/2] Add delete_paddle_backend_pass disable_paddle_trt_ops pybind --- fastdeploy/pybind/runtime.cc | 79 ++++++++++++++++++++---------------- python/fastdeploy/runtime.py | 10 +++++ 2 files changed, 54 insertions(+), 35 deletions(-) diff --git a/fastdeploy/pybind/runtime.cc b/fastdeploy/pybind/runtime.cc index 502c030fcf..75767c6657 100644 --- a/fastdeploy/pybind/runtime.cc +++ b/fastdeploy/pybind/runtime.cc @@ -35,7 +35,8 @@ void BindRuntime(pybind11::module& m) { .def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN) .def("set_openvino_device", &RuntimeOption::SetOpenVINODevice) .def("set_openvino_shape_info", &RuntimeOption::SetOpenVINOShapeInfo) - .def("set_openvino_cpu_operators", &RuntimeOption::SetOpenVINOCpuOperators) + .def("set_openvino_cpu_operators", + &RuntimeOption::SetOpenVINOCpuOperators) .def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo) .def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo) .def("set_paddle_mkldnn_cache_size", @@ -52,10 +53,15 @@ void BindRuntime(pybind11::module& m) { .def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile) .def("enable_pinned_memory", &RuntimeOption::EnablePinnedMemory) .def("disable_pinned_memory", &RuntimeOption::DisablePinnedMemory) - .def("enable_paddle_trt_collect_shape", &RuntimeOption::EnablePaddleTrtCollectShape) - .def("disable_paddle_trt_collect_shape", &RuntimeOption::DisablePaddleTrtCollectShape) + .def("enable_paddle_trt_collect_shape", + &RuntimeOption::EnablePaddleTrtCollectShape) + .def("disable_paddle_trt_collect_shape", + &RuntimeOption::DisablePaddleTrtCollectShape) .def("use_ipu", &RuntimeOption::UseIpu) .def("set_ipu_config", &RuntimeOption::SetIpuConfig) + .def("delete_paddle_backend_pass", + &RuntimeOption::DeletePaddleBackendPass) + .def("disable_paddle_trt_ops", &RuntimeOption::DisablePaddleTrtOPs) .def_readwrite("model_file", &RuntimeOption::model_file) .def_readwrite("params_file", &RuntimeOption::params_file) .def_readwrite("model_format", &RuntimeOption::model_format) @@ -117,9 +123,9 @@ void BindRuntime(pybind11::module& m) { auto dtype = NumpyDataTypeToFDDataType(warm_datas[i][j].dtype()); std::vector data_shape; - data_shape.insert( - data_shape.begin(), warm_datas[i][j].shape(), - warm_datas[i][j].shape() + warm_datas[i][j].ndim()); + data_shape.insert(data_shape.begin(), warm_datas[i][j].shape(), + warm_datas[i][j].shape() + + warm_datas[i][j].ndim()); warm_tensors[i][j].Resize(data_shape, dtype); memcpy(warm_tensors[i][j].MutableData(), warm_datas[i][j].mutable_data(), @@ -160,36 +166,39 @@ void BindRuntime(pybind11::module& m) { } return results; }) - .def("infer", [](Runtime& self, std::map& data) { - std::vector inputs; - inputs.reserve(data.size()); - for (auto iter = data.begin(); iter != data.end(); ++iter) { - FDTensor tensor; - tensor.SetExternalData(iter->second.Shape(), iter->second.Dtype(), iter->second.Data(), iter->second.device); - tensor.name = iter->first; - inputs.push_back(tensor); - } - std::vector outputs; - if (!self.Infer(inputs, &outputs)) { - throw std::runtime_error("Failed to inference with Runtime."); - } - return outputs; - }) - .def("infer", [](Runtime& self, std::vector& inputs) { - std::vector outputs; - return self.Infer(inputs, &outputs); - }) + .def("infer", + [](Runtime& self, std::map& data) { + std::vector inputs; + inputs.reserve(data.size()); + for (auto iter = data.begin(); iter != data.end(); ++iter) { + FDTensor tensor; + tensor.SetExternalData(iter->second.Shape(), + iter->second.Dtype(), iter->second.Data(), + iter->second.device); + tensor.name = iter->first; + inputs.push_back(tensor); + } + std::vector outputs; + if (!self.Infer(inputs, &outputs)) { + throw std::runtime_error("Failed to inference with Runtime."); + } + return outputs; + }) + .def("infer", + [](Runtime& self, std::vector& inputs) { + std::vector outputs; + return self.Infer(inputs, &outputs); + }) .def("bind_input_tensor", &Runtime::BindInputTensor) - .def("infer", [](Runtime& self) { - self.Infer(); - }) - .def("get_output_tensor", [](Runtime& self, const std::string& name) { - FDTensor* output = self.GetOutputTensor(name); - if(output == nullptr) { - return pybind11::cast(nullptr); - } - return pybind11::cast(*output); - }) + .def("infer", [](Runtime& self) { self.Infer(); }) + .def("get_output_tensor", + [](Runtime& self, const std::string& name) { + FDTensor* output = self.GetOutputTensor(name); + if (output == nullptr) { + return pybind11::cast(nullptr); + } + return pybind11::cast(*output); + }) .def("num_inputs", &Runtime::NumInputs) .def("num_outputs", &Runtime::NumOutputs) .def("get_input_info", &Runtime::GetInputInfo) diff --git a/python/fastdeploy/runtime.py b/python/fastdeploy/runtime.py index 71bfe06771..6461da66de 100755 --- a/python/fastdeploy/runtime.py +++ b/python/fastdeploy/runtime.py @@ -435,6 +435,16 @@ def disable_paddle_trt_collect_shape(self): """ return self._option.disable_paddle_trt_collect_shape() + def delete_paddle_backend_pass(self, pass_name): + """Delete pass by name in paddle backend + """ + return self._option.delete_paddle_backend_pass(pass_name) + + def disable_paddle_trt_ops(self, ops): + """Disable some ops in paddle trt backend + """ + return self._option.disable_paddle_trt_ops(ops) + def use_ipu(self, device_num=1, micro_batch_size=1,