From d2a0a7cab871ce98e9d575bb0e6bcf34a2c55018 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Fri, 13 Jan 2023 07:35:56 +0000 Subject: [PATCH 1/3] [Model] Support PaddleYOLOv8 model --- .../paddledetection/cpp/CMakeLists.txt | 3 + .../paddledetection/cpp/infer_yolov8.cc | 159 ++++++++++++++++++ .../paddledetection/python/infer_yolov8.py | 62 +++++++ fastdeploy/vision/detection/ppdet/model.h | 17 ++ .../vision/detection/ppdet/__init__.py | 25 +++ 5 files changed, 266 insertions(+) create mode 100755 examples/vision/detection/paddledetection/cpp/infer_yolov8.cc create mode 100755 examples/vision/detection/paddledetection/python/infer_yolov8.py diff --git a/examples/vision/detection/paddledetection/cpp/CMakeLists.txt b/examples/vision/detection/paddledetection/cpp/CMakeLists.txt index 6dcbb7cc88..3eb3af1e99 100644 --- a/examples/vision/detection/paddledetection/cpp/CMakeLists.txt +++ b/examples/vision/detection/paddledetection/cpp/CMakeLists.txt @@ -42,6 +42,9 @@ target_link_libraries(infer_yolov6_demo ${FASTDEPLOY_LIBS}) add_executable(infer_yolov7_demo ${PROJECT_SOURCE_DIR}/infer_yolov7.cc) target_link_libraries(infer_yolov7_demo ${FASTDEPLOY_LIBS}) +add_executable(infer_yolov8_demo ${PROJECT_SOURCE_DIR}/infer_yolov8.cc) +target_link_libraries(infer_yolov8_demo ${FASTDEPLOY_LIBS}) + add_executable(infer_rtmdet_demo ${PROJECT_SOURCE_DIR}/infer_rtmdet.cc) target_link_libraries(infer_rtmdet_demo ${FASTDEPLOY_LIBS}) diff --git a/examples/vision/detection/paddledetection/cpp/infer_yolov8.cc b/examples/vision/detection/paddledetection/cpp/infer_yolov8.cc new file mode 100755 index 0000000000..2b11ecec9a --- /dev/null +++ b/examples/vision/detection/paddledetection/cpp/infer_yolov8.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +#ifdef WIN32 +const char sep = '\\'; +#else +const char sep = '/'; +#endif + +void CpuInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + auto option = fastdeploy::RuntimeOption(); + option.UseCpu(); + auto model = fastdeploy::vision::detection::PaddleYOLOv8(model_file, params_file, + config_file, option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + std::cout << res.Str() << std::endl; + auto vis_im = fastdeploy::vision::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void KunlunXinInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + auto option = fastdeploy::RuntimeOption(); + option.UseKunlunXin(); + auto model = fastdeploy::vision::detection::PaddleYOLOv8(model_file, params_file, + config_file, option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + std::cout << res.Str() << std::endl; + auto vis_im = fastdeploy::vision::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void GpuInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + auto model = fastdeploy::vision::detection::PaddleYOLOv8(model_file, params_file, + config_file, option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + std::cout << res.Str() << std::endl; + auto vis_im = fastdeploy::vision::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void TrtInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + option.UseTrtBackend(); + auto model = fastdeploy::vision::detection::PaddleYOLOv8(model_file, params_file, + config_file, option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + std::cout << res.Str() << std::endl; + auto vis_im = fastdeploy::vision::VisDetection(im, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::cout + << "Usage: infer_demo path/to/model_dir path/to/image run_option, " + "e.g ./infer_model ./ppyolo_dirname ./test.jpeg 0" + << std::endl; + std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " + "with gpu; 2: run with gpu and use tensorrt backend; 3: run with kunlunxin." + << std::endl; + return -1; + } + + if (std::atoi(argv[3]) == 0) { + CpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 1) { + GpuInfer(argv[1], argv[2]); + } else if(std::atoi(argv[3]) == 2){ + TrtInfer(argv[1], argv[2]); + } else if(std::atoi(argv[3]) == 3){ + KunlunXinInfer(argv[1], argv[2]); + } + return 0; +} diff --git a/examples/vision/detection/paddledetection/python/infer_yolov8.py b/examples/vision/detection/paddledetection/python/infer_yolov8.py new file mode 100755 index 0000000000..d1479f5c46 --- /dev/null +++ b/examples/vision/detection/paddledetection/python/infer_yolov8.py @@ -0,0 +1,62 @@ +import fastdeploy as fd +import cv2 +import os + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_dir", + required=True, + help="Path of PaddleDetection model directory") + parser.add_argument( + "--image", required=True, help="Path of test image file.") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Type of inference device, support 'kunlunxin', 'cpu' or 'gpu'.") + parser.add_argument( + "--use_trt", + type=ast.literal_eval, + default=False, + help="Wether to use tensorrt.") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + + if args.device.lower() == "kunlunxin": + option.use_kunlunxin() + + if args.device.lower() == "gpu": + option.use_gpu() + + if args.use_trt: + option.use_trt_backend() + return option + + +args = parse_arguments() + +model_file = os.path.join(args.model_dir, "model.pdmodel") +params_file = os.path.join(args.model_dir, "model.pdiparams") +config_file = os.path.join(args.model_dir, "infer_cfg.yml") + +# 配置runtime,加载模型 +runtime_option = build_option(args) +model = fd.vision.detection.PaddleYOLOv8( + model_file, params_file, config_file, runtime_option=runtime_option) + +# 预测图片检测结果 +im = cv2.imread(args.image) +result = model.predict(im.copy()) +print(result) + +# 预测结果可视化 +vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5) +cv2.imwrite("visualized_result.jpg", vis_im) +print("Visualized result save in ./visualized_result.jpg") diff --git a/fastdeploy/vision/detection/ppdet/model.h b/fastdeploy/vision/detection/ppdet/model.h index 812286d755..a3797bdb8b 100755 --- a/fastdeploy/vision/detection/ppdet/model.h +++ b/fastdeploy/vision/detection/ppdet/model.h @@ -245,6 +245,23 @@ class FASTDEPLOY_DECL PaddleYOLOv7 : public PPDetBase { virtual std::string ModelName() const { return "PaddleDetection/YOLOv7"; } }; +class FASTDEPLOY_DECL PaddleYOLOv8 : public PPDetBase { + public: + PaddleYOLOv8(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE) + : PPDetBase(model_file, params_file, config_file, custom_option, + model_format) { + valid_cpu_backends = {Backend::OPENVINO, Backend::ORT, Backend::PDINFER}; + valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; + valid_kunlunxin_backends = {Backend::LITE}; + initialized = Initialize(); + } + + virtual std::string ModelName() const { return "PaddleDetection/YOLOv8"; } +}; + class FASTDEPLOY_DECL RTMDet : public PPDetBase { public: RTMDet(const std::string& model_file, const std::string& params_file, diff --git a/python/fastdeploy/vision/detection/ppdet/__init__.py b/python/fastdeploy/vision/detection/ppdet/__init__.py index 9f4ad75bc4..ad0f8e51d4 100644 --- a/python/fastdeploy/vision/detection/ppdet/__init__.py +++ b/python/fastdeploy/vision/detection/ppdet/__init__.py @@ -490,6 +490,31 @@ def __init__(self, assert self.initialized, "PaddleYOLOv7 model initialize failed." +class PaddleYOLOv8(PPYOLOE): + def __init__(self, + model_file, + params_file, + config_file, + runtime_option=None, + model_format=ModelFormat.PADDLE): + """Load a YOLOv8 model exported by PaddleDetection. + + :param model_file: (str)Path of model file, e.g yolov8/model.pdmodel + :param params_file: (str)Path of parameters file, e.g yolov8/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param config_file: (str)Path of configuration file for deployment, e.g yolov8/infer_cfg.yml + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model + """ + + super(PPYOLOE, self).__init__(runtime_option) + + assert model_format == ModelFormat.PADDLE, "PaddleYOLOv8 model only support model format of ModelFormat.Paddle now." + self._model = C.vision.detection.PaddleYOLOv8( + model_file, params_file, config_file, self._runtime_option, + model_format) + assert self.initialized, "PaddleYOLOv8 model initialize failed." + + class RTMDet(PPYOLOE): def __init__(self, model_file, From 2bfbeb967ab263524fde6fd25f0248feb5cd1214 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Sat, 14 Jan 2023 05:47:08 +0000 Subject: [PATCH 2/3] [YOLOv8] Add PaddleYOLOv8 pybind --- fastdeploy/vision/detection/ppdet/ppdet_pybind.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc b/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc index 800d656b37..67f63eb23b 100644 --- a/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc +++ b/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc @@ -134,30 +134,42 @@ void BindPPDet(pybind11::module& m) { .def(pybind11::init()); + pybind11::class_(m, "PaddleYOLOv8") + .def(pybind11::init()); + pybind11::class_(m, "RTMDet") .def(pybind11::init()); + pybind11::class_(m, "CascadeRCNN") .def(pybind11::init()); + pybind11::class_(m, "PSSDet") .def(pybind11::init()); + pybind11::class_(m, "RetinaNet") .def(pybind11::init()); + pybind11::class_(m, "PPYOLOESOD") .def(pybind11::init()); + pybind11::class_(m, "FCOS") .def(pybind11::init()); + pybind11::class_(m, "TTFNet") .def(pybind11::init()); + pybind11::class_(m, "TOOD") .def(pybind11::init()); + pybind11::class_(m, "GFL") .def(pybind11::init()); From c47a2fa2e07e28174b5481c448803742594ae7f6 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Sat, 14 Jan 2023 13:52:31 +0800 Subject: [PATCH 3/3] [Other] update from latest develop (#30) * [Backend] Remove all lite options in RuntimeOption (#1109) * Remove all lite options in RuntimeOption * Fix code error * move pybind * Fix build error * [Backend] Add TensorRT FP16 support for AdaptivePool2d (#1116) * add fp16 cuda kernel * fix code bug * update code * [Doc] Fix KunlunXin doc (#1139) fix kunlunxin doc * [Model] Support PaddleYOLOv8 model (#1136) Co-authored-by: Jason Co-authored-by: yeliang2258 <30516196+yeliang2258@users.noreply.github.com> --- docs/cn/build_and_install/kunlunxin.md | 5 +- fastdeploy/pybind/runtime.cc | 37 +++++------- .../common/cuda/adaptive_pool2d_kernel.cu | 54 +++++++++++++----- .../common/cuda/adaptive_pool2d_kernel.h | 9 ++- .../backends/tensorrt/ops/adaptive_pool2d.cc | 26 ++++++--- fastdeploy/runtime/runtime.cc | 33 +---------- fastdeploy/runtime/runtime_option.cc | 57 +++++++++++-------- fastdeploy/runtime/runtime_option.h | 39 ++----------- 8 files changed, 121 insertions(+), 139 deletions(-) mode change 100755 => 100644 fastdeploy/pybind/runtime.cc mode change 100755 => 100644 fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.cu diff --git a/docs/cn/build_and_install/kunlunxin.md b/docs/cn/build_and_install/kunlunxin.md index a5e9e9d066..375626dab0 100755 --- a/docs/cn/build_and_install/kunlunxin.md +++ b/docs/cn/build_and_install/kunlunxin.md @@ -23,13 +23,16 @@ FastDeploy 基于 Paddle Lite 后端支持在昆仑芯 XPU 上进行部署推理 | ORT_DIRECTORY | 当开启ONNX Runtime后端时,用于指定用户本地的ONNX Runtime库路径;如果不指定,编译过程会自动下载ONNX Runtime库 | | OPENCV_DIRECTORY | 当ENABLE_VISION=ON时,用于指定用户本地的OpenCV库路径;如果不指定,编译过程会自动下载OpenCV库 | | OPENVINO_DIRECTORY | 当开启OpenVINO后端时, 用于指定用户本地的OpenVINO库路径;如果不指定,编译过程会自动下载OpenVINO库 | + 更多编译选项请参考[FastDeploy编译选项说明](./README.md) ## 基于 Paddle Lite 的 C++ FastDeploy 库编译 - OS: Linux - gcc/g++: version >= 8.2 - cmake: version >= 3.15 + 此外更推荐开发者自行安装,编译时通过`-DOPENCV_DIRECTORY`来指定环境中的OpenCV(如若不指定-DOPENCV_DIRECTORY,会自动下载FastDeploy提供的预编译的OpenCV,但在**Linux平台**无法支持Video的读取,以及imshow等可视化界面功能) + ``` sudo apt-get install libopencv-dev ``` @@ -47,7 +50,7 @@ cmake -DWITH_KUNLUNXIN=ON \ -DENABLE_PADDLE_BACKEND=ON \ # 可选择开启 Paddle 后端 -DCMAKE_INSTALL_PREFIX=fastdeploy-kunlunxin \ -DENABLE_VISION=ON \ # 是否编译集成视觉模型的部署模块,可选择开启 - -DOPENCV_DIRECTORY=/usr/lib/x86_64-linux-gnu/cmake/opencv4 \ + -DOPENCV_DIRECTORY=/usr/lib/x86_64-linux-gnu/cmake/opencv4 \ # 指定系统自带的opencv路径 .. # Build FastDeploy KunlunXin XPU C++ SDK diff --git a/fastdeploy/pybind/runtime.cc b/fastdeploy/pybind/runtime.cc old mode 100755 new mode 100644 index 3402dd896f..3c7b4f7a97 --- a/fastdeploy/pybind/runtime.cc +++ b/fastdeploy/pybind/runtime.cc @@ -37,12 +37,17 @@ void BindRuntime(pybind11::module& m) { .def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend) .def("use_lite_backend", &RuntimeOption::UseLiteBackend) .def("set_lite_device_names", &RuntimeOption::SetLiteDeviceNames) - .def("set_lite_context_properties", &RuntimeOption::SetLiteContextProperties) + .def("set_lite_context_properties", + &RuntimeOption::SetLiteContextProperties) .def("set_lite_model_cache_dir", &RuntimeOption::SetLiteModelCacheDir) - .def("set_lite_dynamic_shape_info", &RuntimeOption::SetLiteDynamicShapeInfo) - .def("set_lite_subgraph_partition_path", &RuntimeOption::SetLiteSubgraphPartitionPath) - .def("set_lite_mixed_precision_quantization_config_path", &RuntimeOption::SetLiteMixedPrecisionQuantizationConfigPath) - .def("set_lite_subgraph_partition_config_buffer", &RuntimeOption::SetLiteSubgraphPartitionConfigBuffer) + .def("set_lite_dynamic_shape_info", + &RuntimeOption::SetLiteDynamicShapeInfo) + .def("set_lite_subgraph_partition_path", + &RuntimeOption::SetLiteSubgraphPartitionPath) + .def("set_lite_mixed_precision_quantization_config_path", + &RuntimeOption::SetLiteMixedPrecisionQuantizationConfigPath) + .def("set_lite_subgraph_partition_config_buffer", + &RuntimeOption::SetLiteSubgraphPartitionConfigBuffer) .def("set_paddle_mkldnn", &RuntimeOption::SetPaddleMKLDNN) .def("set_openvino_device", &RuntimeOption::SetOpenVINODevice) .def("set_openvino_shape_info", &RuntimeOption::SetOpenVINOShapeInfo) @@ -114,21 +119,7 @@ void BindRuntime(pybind11::module& m) { .def_readwrite("ipu_available_memory_proportion", &RuntimeOption::ipu_available_memory_proportion) .def_readwrite("ipu_enable_half_partial", - &RuntimeOption::ipu_enable_half_partial) - .def_readwrite("kunlunxin_l3_workspace_size", - &RuntimeOption::kunlunxin_l3_workspace_size) - .def_readwrite("kunlunxin_locked", - &RuntimeOption::kunlunxin_locked) - .def_readwrite("kunlunxin_autotune", - &RuntimeOption::kunlunxin_autotune) - .def_readwrite("kunlunxin_autotune_file", - &RuntimeOption::kunlunxin_autotune_file) - .def_readwrite("kunlunxin_precision", - &RuntimeOption::kunlunxin_precision) - .def_readwrite("kunlunxin_adaptive_seqlen", - &RuntimeOption::kunlunxin_adaptive_seqlen) - .def_readwrite("kunlunxin_enable_multi_stream", - &RuntimeOption::kunlunxin_enable_multi_stream); + &RuntimeOption::ipu_enable_half_partial); pybind11::class_(m, "TensorInfo") .def_readwrite("name", &TensorInfo::name) @@ -151,9 +142,9 @@ void BindRuntime(pybind11::module& m) { auto dtype = NumpyDataTypeToFDDataType(warm_datas[i][j].dtype()); std::vector data_shape; - data_shape.insert(data_shape.begin(), warm_datas[i][j].shape(), - warm_datas[i][j].shape() + - warm_datas[i][j].ndim()); + data_shape.insert( + data_shape.begin(), warm_datas[i][j].shape(), + warm_datas[i][j].shape() + warm_datas[i][j].ndim()); warm_tensors[i][j].Resize(data_shape, dtype); memcpy(warm_tensors[i][j].MutableData(), warm_datas[i][j].mutable_data(), diff --git a/fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.cu b/fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.cu old mode 100755 new mode 100644 index 560dc561fd..78f9d227b4 --- a/fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.cu +++ b/fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.cu @@ -17,8 +17,8 @@ #include "adaptive_pool2d_kernel.h" namespace fastdeploy { - -__global__ void CudaCastKernel(const float* in, float* out, int edge, +template +__global__ void CudaCastKernel(const T1* in, T2* out, int edge, int out_bc_offset, int in_bc_offset, int ih, int iw, int oh, int ow, bool is_avg) { int position = blockDim.x * blockIdx.x + threadIdx.x; @@ -32,29 +32,37 @@ __global__ void CudaCastKernel(const float* in, float* out, int edge, int hend = ceilf(static_cast((h + 1) * ih) / oh); int wstart = floorf(static_cast(w * iw) / ow); int wend = ceilf(static_cast((w + 1) * iw) / ow); + float ele_val = 0.0; if (is_avg) { - out[position] = 0.0; + ele_val = 0.0; } else { - out[position] = in[offset * in_bc_offset + hstart * iw + wstart]; + ele_val = + static_cast(in[offset * in_bc_offset + hstart * iw + wstart]); } for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { int input_idx = h * iw + w; if (is_avg) { - out[position] = out[position] + in[offset * in_bc_offset + input_idx]; + ele_val = + ele_val + static_cast(in[offset * in_bc_offset + input_idx]); } else { - out[position] = - max(out[position], in[offset * in_bc_offset + input_idx]); + ele_val = + (ele_val > + static_cast(in[offset * in_bc_offset + input_idx])) + ? ele_val + : static_cast(in[offset * in_bc_offset + input_idx]); } } } - out[position] = out[position] / ((hend - hstart) * (wend - wstart)); + out[position] = static_cast( + ele_val / static_cast(((hend - hstart) * (wend - wstart)))); } void CudaAdaptivePool(const std::vector& input_dims, - const std::vector& output_dims, float* output, - const float* input, void* compute_stream, - const std::string& pooling_type) { + const std::vector& output_dims, void* output, + const void* input, void* compute_stream, + const std::string& pooling_type, const std::string& dtype, + const std::string& out_dtype) { auto casted_compute_stream = reinterpret_cast(compute_stream); int out_bc_offset = output_dims[2] * output_dims[3]; int in_bc_offset = input_dims[2] * input_dims[3]; @@ -65,9 +73,27 @@ void CudaAdaptivePool(const std::vector& input_dims, bool is_avg = pooling_type == "avg"; int threads = 256; int blocks = ceil(jobs / static_cast(threads)); - CudaCastKernel<<>>( - input, output, jobs, out_bc_offset, in_bc_offset, int(input_dims[2]), - int(input_dims[3]), int(output_dims[2]), int(output_dims[3]), is_avg); + if (dtype == "float") { + CudaCastKernel<<>>( + static_cast(input), static_cast(output), jobs, + out_bc_offset, in_bc_offset, int(input_dims[2]), int(input_dims[3]), + int(output_dims[2]), int(output_dims[3]), is_avg); + } else if (dtype == "half") { + if (out_dtype == "half") { + CudaCastKernel<<>>( + static_cast(input), static_cast(output), jobs, + out_bc_offset, in_bc_offset, int(input_dims[2]), int(input_dims[3]), + int(output_dims[2]), int(output_dims[3]), is_avg); + } + if (out_dtype == "float") { + CudaCastKernel + <<>>( + static_cast(input), static_cast(output), + jobs, out_bc_offset, in_bc_offset, int(input_dims[2]), + int(input_dims[3]), int(output_dims[2]), int(output_dims[3]), + is_avg); + } + } } } // namespace fastdeploy #endif diff --git a/fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.h b/fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.h index dc29c07dc0..ddb7cb8155 100755 --- a/fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.h +++ b/fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.h @@ -15,6 +15,7 @@ #pragma once +#include #include #include #include @@ -25,8 +26,10 @@ namespace fastdeploy { void CudaAdaptivePool(const std::vector& input_dims, - const std::vector& output_dims, float* output, - const float* input, void* compute_stream, - const std::string& pooling_type); + const std::vector& output_dims, void* output, + const void* input, void* compute_stream, + const std::string& pooling_type, + const std::string& dtype = "float", + const std::string& out_dtype = "float"); } // namespace fastdeploy diff --git a/fastdeploy/runtime/backends/tensorrt/ops/adaptive_pool2d.cc b/fastdeploy/runtime/backends/tensorrt/ops/adaptive_pool2d.cc index ae7cef7f41..a977944750 100644 --- a/fastdeploy/runtime/backends/tensorrt/ops/adaptive_pool2d.cc +++ b/fastdeploy/runtime/backends/tensorrt/ops/adaptive_pool2d.cc @@ -63,11 +63,6 @@ int AdaptivePool2d::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { - if (inputDesc[0].type != nvinfer1::DataType::kFLOAT) { - return -1; - } - auto const* data = static_cast(inputs[0]); - auto* result = static_cast(outputs[0]); int nums = outputDesc[0].dims.d[0] * outputDesc[0].dims.d[1] * outputDesc[0].dims.d[2] * outputDesc[0].dims.d[3]; std::vector input_size, output_size; @@ -75,8 +70,18 @@ int AdaptivePool2d::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, input_size.push_back(inputDesc[0].dims.d[i]); output_size.push_back(outputDesc[0].dims.d[i]); } - CudaAdaptivePool(input_size, output_size, result, data, stream, - pooling_type_); + if (inputDesc[0].type == nvinfer1::DataType::kHALF) { + if (outputDesc[0].type == nvinfer1::DataType::kHALF) { + CudaAdaptivePool(input_size, output_size, outputs[0], inputs[0], stream, + pooling_type_, "half", "half"); + } else if (outputDesc[0].type == nvinfer1::DataType::kFLOAT) { + CudaAdaptivePool(input_size, output_size, outputs[0], inputs[0], stream, + pooling_type_, "half", "float"); + } + } else if (inputDesc[0].type == nvinfer1::DataType::kFLOAT) { + CudaAdaptivePool(input_size, output_size, outputs[0], inputs[0], stream, + pooling_type_, "float", "float"); + } return cudaPeekAtLastError(); } @@ -106,7 +111,12 @@ nvinfer1::DataType AdaptivePool2d::getOutputDataType( bool AdaptivePool2d::supportsFormatCombination( int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept { - return (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR); + if ((inOut[pos].format == nvinfer1::PluginFormat::kLINEAR) && + (inOut[pos].type == nvinfer1::DataType::kFLOAT || + inOut[pos].type == nvinfer1::DataType::kHALF)) { + return true; + } + return false; } int AdaptivePool2d::initialize() noexcept { return 0; } diff --git a/fastdeploy/runtime/runtime.cc b/fastdeploy/runtime/runtime.cc index cc4b23da94..ffa135a3a7 100644 --- a/fastdeploy/runtime/runtime.cc +++ b/fastdeploy/runtime/runtime.cc @@ -390,43 +390,12 @@ void Runtime::CreateTrtBackend() { void Runtime::CreateLiteBackend() { #ifdef ENABLE_LITE_BACKEND - auto lite_option = LiteBackendOption(); - lite_option.threads = option.cpu_thread_num; - lite_option.enable_int8 = option.lite_enable_int8; - lite_option.enable_fp16 = option.lite_enable_fp16; - lite_option.power_mode = static_cast(option.lite_power_mode); - lite_option.optimized_model_dir = option.lite_optimized_model_dir; - lite_option.nnadapter_subgraph_partition_config_path = - option.lite_nnadapter_subgraph_partition_config_path; - lite_option.nnadapter_subgraph_partition_config_buffer = - option.lite_nnadapter_subgraph_partition_config_buffer; - lite_option.nnadapter_device_names = option.lite_nnadapter_device_names; - lite_option.nnadapter_context_properties = - option.lite_nnadapter_context_properties; - lite_option.nnadapter_model_cache_dir = option.lite_nnadapter_model_cache_dir; - lite_option.nnadapter_dynamic_shape_info = - option.lite_nnadapter_dynamic_shape_info; - lite_option.nnadapter_mixed_precision_quantization_config_path = - option.lite_nnadapter_mixed_precision_quantization_config_path; - lite_option.enable_timvx = option.enable_timvx; - lite_option.enable_ascend = option.enable_ascend; - lite_option.enable_kunlunxin = option.enable_kunlunxin; - lite_option.device_id = option.device_id; - lite_option.kunlunxin_l3_workspace_size = option.kunlunxin_l3_workspace_size; - lite_option.kunlunxin_locked = option.kunlunxin_locked; - lite_option.kunlunxin_autotune = option.kunlunxin_autotune; - lite_option.kunlunxin_autotune_file = option.kunlunxin_autotune_file; - lite_option.kunlunxin_precision = option.kunlunxin_precision; - lite_option.kunlunxin_adaptive_seqlen = option.kunlunxin_adaptive_seqlen; - lite_option.kunlunxin_enable_multi_stream = - option.kunlunxin_enable_multi_stream; - FDASSERT(option.model_format == ModelFormat::PADDLE, "LiteBackend only support model format of ModelFormat::PADDLE"); backend_ = utils::make_unique(); auto casted_backend = dynamic_cast(backend_.get()); FDASSERT(casted_backend->InitFromPaddle(option.model_file, option.params_file, - lite_option), + option.paddle_lite_option), "Load model from nb file failed while initializing LiteBackend."); #else FDASSERT(false, diff --git a/fastdeploy/runtime/runtime_option.cc b/fastdeploy/runtime/runtime_option.cc index ad467c0458..1f1bfa1ad0 100644 --- a/fastdeploy/runtime/runtime_option.cc +++ b/fastdeploy/runtime/runtime_option.cc @@ -85,8 +85,8 @@ void RuntimeOption::UseRKNPU2(fastdeploy::rknpu2::CpuName rknpu2_name, } void RuntimeOption::UseTimVX() { - enable_timvx = true; device = Device::TIMVX; + paddle_lite_option.enable_timvx = true; } void RuntimeOption::UseKunlunXin(int kunlunxin_id, int l3_workspace_size, @@ -95,21 +95,21 @@ void RuntimeOption::UseKunlunXin(int kunlunxin_id, int l3_workspace_size, const std::string& precision, bool adaptive_seqlen, bool enable_multi_stream) { - enable_kunlunxin = true; - device_id = kunlunxin_id; - kunlunxin_l3_workspace_size = l3_workspace_size; - kunlunxin_locked = locked; - kunlunxin_autotune = autotune; - kunlunxin_autotune_file = autotune_file; - kunlunxin_precision = precision; - kunlunxin_adaptive_seqlen = adaptive_seqlen; - kunlunxin_enable_multi_stream = enable_multi_stream; device = Device::KUNLUNXIN; + paddle_lite_option.enable_kunlunxin = true; + paddle_lite_option.device_id = kunlunxin_id; + paddle_lite_option.kunlunxin_l3_workspace_size = l3_workspace_size; + paddle_lite_option.kunlunxin_locked = locked; + paddle_lite_option.kunlunxin_autotune = autotune; + paddle_lite_option.kunlunxin_autotune_file = autotune_file; + paddle_lite_option.kunlunxin_precision = precision; + paddle_lite_option.kunlunxin_adaptive_seqlen = adaptive_seqlen; + paddle_lite_option.kunlunxin_enable_multi_stream = enable_multi_stream; } void RuntimeOption::UseAscend() { - enable_ascend = true; device = Device::ASCEND; + paddle_lite_option.enable_ascend = true; } void RuntimeOption::UseSophgo() { @@ -124,6 +124,7 @@ void RuntimeOption::SetExternalStream(void* external_stream) { void RuntimeOption::SetCpuThreadNum(int thread_num) { FDASSERT(thread_num > 0, "The thread_num must be greater than 0."); cpu_thread_num = thread_num; + paddle_lite_option.threads = thread_num; } void RuntimeOption::SetOrtGraphOptLevel(int level) { @@ -231,57 +232,65 @@ void RuntimeOption::SetOpenVINODevice(const std::string& name) { openvino_device = name; } -void RuntimeOption::EnableLiteFP16() { lite_enable_fp16 = true; } +void RuntimeOption::EnableLiteFP16() { paddle_lite_option.enable_fp16 = true; } -void RuntimeOption::DisableLiteFP16() { lite_enable_fp16 = false; } -void RuntimeOption::EnableLiteInt8() { lite_enable_int8 = true; } +void RuntimeOption::DisableLiteFP16() { + paddle_lite_option.enable_fp16 = false; +} + +void RuntimeOption::EnableLiteInt8() { paddle_lite_option.enable_int8 = true; } + +void RuntimeOption::DisableLiteInt8() { + paddle_lite_option.enable_int8 = false; +} -void RuntimeOption::DisableLiteInt8() { lite_enable_int8 = false; } void RuntimeOption::SetLitePowerMode(LitePowerMode mode) { - lite_power_mode = mode; + paddle_lite_option.power_mode = mode; } void RuntimeOption::SetLiteOptimizedModelDir( const std::string& optimized_model_dir) { - lite_optimized_model_dir = optimized_model_dir; + paddle_lite_option.optimized_model_dir = optimized_model_dir; } void RuntimeOption::SetLiteSubgraphPartitionPath( const std::string& nnadapter_subgraph_partition_config_path) { - lite_nnadapter_subgraph_partition_config_path = + paddle_lite_option.nnadapter_subgraph_partition_config_path = nnadapter_subgraph_partition_config_path; } void RuntimeOption::SetLiteSubgraphPartitionConfigBuffer( const std::string& nnadapter_subgraph_partition_config_buffer) { - lite_nnadapter_subgraph_partition_config_buffer = + paddle_lite_option.nnadapter_subgraph_partition_config_buffer = nnadapter_subgraph_partition_config_buffer; } void RuntimeOption::SetLiteDeviceNames( const std::vector& nnadapter_device_names) { - lite_nnadapter_device_names = nnadapter_device_names; + paddle_lite_option.nnadapter_device_names = nnadapter_device_names; } void RuntimeOption::SetLiteContextProperties( const std::string& nnadapter_context_properties) { - lite_nnadapter_context_properties = nnadapter_context_properties; + paddle_lite_option.nnadapter_context_properties = + nnadapter_context_properties; } void RuntimeOption::SetLiteModelCacheDir( const std::string& nnadapter_model_cache_dir) { - lite_nnadapter_model_cache_dir = nnadapter_model_cache_dir; + paddle_lite_option.nnadapter_model_cache_dir = nnadapter_model_cache_dir; } void RuntimeOption::SetLiteDynamicShapeInfo( const std::map>>& nnadapter_dynamic_shape_info) { - lite_nnadapter_dynamic_shape_info = nnadapter_dynamic_shape_info; + paddle_lite_option.nnadapter_dynamic_shape_info = + nnadapter_dynamic_shape_info; } void RuntimeOption::SetLiteMixedPrecisionQuantizationConfigPath( const std::string& nnadapter_mixed_precision_quantization_config_path) { - lite_nnadapter_mixed_precision_quantization_config_path = + paddle_lite_option.nnadapter_mixed_precision_quantization_config_path = nnadapter_mixed_precision_quantization_config_path; } diff --git a/fastdeploy/runtime/runtime_option.h b/fastdeploy/runtime/runtime_option.h index 785cbcb785..fd4f2a6b11 100644 --- a/fastdeploy/runtime/runtime_option.h +++ b/fastdeploy/runtime/runtime_option.h @@ -350,7 +350,8 @@ struct FASTDEPLOY_DECL RuntimeOption { bool enable_half_partial = false); Backend backend = Backend::UNKNOWN; - // for cpu inference and preprocess + + // for cpu inference // default will let the backend choose their own default value int cpu_thread_num = -1; int device_id = 0; @@ -388,31 +389,6 @@ struct FASTDEPLOY_DECL RuntimeOption { float ipu_available_memory_proportion = 1.0; bool ipu_enable_half_partial = false; - // ======Only for Paddle Lite Backend===== - // 0: LITE_POWER_HIGH 1: LITE_POWER_LOW 2: LITE_POWER_FULL - // 3: LITE_POWER_NO_BIND 4: LITE_POWER_RAND_HIGH - // 5: LITE_POWER_RAND_LOW - LitePowerMode lite_power_mode = LitePowerMode::LITE_POWER_NO_BIND; - // enable int8 or not - bool lite_enable_int8 = false; - // enable fp16 or not - bool lite_enable_fp16 = false; - // optimized model dir for CxxConfig - std::string lite_optimized_model_dir = ""; - std::string lite_nnadapter_subgraph_partition_config_path = ""; - // and other nnadapter settings for CxxConfig - std::string lite_nnadapter_subgraph_partition_config_buffer = ""; - std::string lite_nnadapter_context_properties = ""; - std::string lite_nnadapter_model_cache_dir = ""; - std::string lite_nnadapter_mixed_precision_quantization_config_path = ""; - std::map>> - lite_nnadapter_dynamic_shape_info = {{"", {{0}}}}; - std::vector lite_nnadapter_device_names = {}; - - bool enable_timvx = false; - bool enable_ascend = false; - bool enable_kunlunxin = false; - // ======Only for Trt Backend======= std::map> trt_max_shape; std::map> trt_min_shape; @@ -444,14 +420,9 @@ struct FASTDEPLOY_DECL RuntimeOption { fastdeploy::rknpu2::CoreMask rknpu2_core_mask_ = fastdeploy::rknpu2::CoreMask::RKNN_NPU_CORE_AUTO; - // ======Only for KunlunXin XPU Backend======= - int kunlunxin_l3_workspace_size = 0xfffc00; - bool kunlunxin_locked = false; - bool kunlunxin_autotune = true; - std::string kunlunxin_autotune_file = ""; - std::string kunlunxin_precision = "int16"; - bool kunlunxin_adaptive_seqlen = false; - bool kunlunxin_enable_multi_stream = false; + + /// Option to configure Paddle Lite backend + LiteBackendOption paddle_lite_option; std::string model_file = ""; // Path of model file std::string params_file = ""; // Path of parameters file, can be empty