forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request PaddlePaddle#9 from jiweibo/add_yolo_demo
[yolov3] Add yolov3 demo
- Loading branch information
Showing
8 changed files
with
406 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
cmake_minimum_required(VERSION 3.0) | ||
project(cpp_inference_demo CXX C) | ||
option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." ON) | ||
option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF) | ||
option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON) | ||
option(USE_TENSORRT "Compile demo with TensorRT." OFF) | ||
|
||
|
||
macro(safe_set_static_flag) | ||
foreach(flag_var | ||
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE | ||
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) | ||
if(${flag_var} MATCHES "/MD") | ||
string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") | ||
endif(${flag_var} MATCHES "/MD") | ||
endforeach(flag_var) | ||
endmacro() | ||
|
||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") | ||
set(CMAKE_STATIC_LIBRARY_PREFIX "") | ||
message("flags" ${CMAKE_CXX_FLAGS}) | ||
|
||
if(NOT DEFINED PADDLE_LIB) | ||
message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib") | ||
endif() | ||
if(NOT DEFINED DEMO_NAME) | ||
message(FATAL_ERROR "please set DEMO_NAME with -DDEMO_NAME=demo_name") | ||
endif() | ||
|
||
|
||
include_directories("${PADDLE_LIB}") | ||
include_directories("${PADDLE_LIB}/third_party/install/protobuf/include") | ||
include_directories("${PADDLE_LIB}/third_party/install/glog/include") | ||
include_directories("${PADDLE_LIB}/third_party/install/gflags/include") | ||
include_directories("${PADDLE_LIB}/third_party/install/xxhash/include") | ||
include_directories("${PADDLE_LIB}/third_party/install/zlib/include") | ||
include_directories("${PADDLE_LIB}/third_party/boost") | ||
include_directories("${PADDLE_LIB}/third_party/eigen3") | ||
|
||
if (USE_TENSORRT AND WITH_GPU) | ||
include_directories("${TENSORRT_ROOT}/include") | ||
link_directories("${TENSORRT_ROOT}/lib") | ||
endif() | ||
|
||
link_directories("${PADDLE_LIB}/third_party/install/zlib/lib") | ||
|
||
link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") | ||
link_directories("${PADDLE_LIB}/third_party/install/glog/lib") | ||
link_directories("${PADDLE_LIB}/third_party/install/gflags/lib") | ||
link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib") | ||
link_directories("${PADDLE_LIB}/paddle/lib") | ||
|
||
add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) | ||
|
||
if(WITH_MKL) | ||
include_directories("${PADDLE_LIB}/third_party/install/mklml/include") | ||
set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} | ||
${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) | ||
set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn") | ||
if(EXISTS ${MKLDNN_PATH}) | ||
include_directories("${MKLDNN_PATH}/include") | ||
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) | ||
endif() | ||
else() | ||
set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX}) | ||
endif() | ||
|
||
# Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a | ||
if(WITH_STATIC_LIB) | ||
set(DEPS | ||
${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX}) | ||
else() | ||
set(DEPS | ||
${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX}) | ||
endif() | ||
|
||
set(EXTERNAL_LIB "-lrt -ldl -lpthread") | ||
set(DEPS ${DEPS} | ||
${MATH_LIB} ${MKLDNN_LIB} | ||
glog gflags protobuf z xxhash | ||
${EXTERNAL_LIB}) | ||
|
||
if(WITH_GPU) | ||
if (USE_TENSORRT) | ||
set(DEPS ${DEPS} | ||
${TENSORRT_ROOT}/lib/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX}) | ||
set(DEPS ${DEPS} | ||
${TENSORRT_ROOT}/lib/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX}) | ||
endif() | ||
set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX}) | ||
set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX} ) | ||
set(DEPS ${DEPS} ${CUDA_LIB}/libcublas${CMAKE_SHARED_LIBRARY_SUFFIX} ) | ||
set(DEPS ${DEPS} ${CUDNN_LIB}/libcudnn${CMAKE_SHARED_LIBRARY_SUFFIX} ) | ||
endif() | ||
|
||
target_link_libraries(${DEMO_NAME} ${DEPS}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
## 运行C++ YOLOv3图像检测样例 | ||
|
||
### 一:获取YOLOv3模型 | ||
|
||
点击[链接](https://paddle-inference-dist.cdn.bcebos.com/PaddleLite/yolov3_infer.tar.gz)下载模型, 该模型在imagenet数据集训练得到的,如果你想获取更多的**模型训练信息**,请访问[这里](https://github.com/PaddlePaddle/PaddleDetection)。 | ||
|
||
### 二:**样例编译** | ||
|
||
文件`yolov3_test.cc` 为预测的样例程序(程序中的输入为固定值,如果您有opencv或其他方式进行数据读取的需求,需要对程序进行一定的修改)。 | ||
文件`CMakeLists.txt` 为编译构建文件。 | ||
脚本`run_impl.sh` 包含了第三方库、预编译库的信息配置。 | ||
|
||
编译yolov3样例,我们首先需要对脚本`run_impl.sh` 文件中的配置进行修改。 | ||
|
||
1)**修改`run_impl.sh`** | ||
|
||
打开`run_impl.sh`,我们对以下的几处信息进行修改: | ||
|
||
```shell | ||
# 根据预编译库中的version.txt信息判断是否将以下三个标记打开 | ||
WITH_MKL=ON | ||
WITH_GPU=ON | ||
USE_TENSORRT=OFF | ||
|
||
# 配置预测库的根目录 | ||
LIB_DIR=${YOUR_LIB_DIR}/fluid_inference_install_dir | ||
|
||
# 如果上述的WITH_GPU 或 USE_TENSORRT设为ON,请设置对应的CUDA, CUDNN, TENSORRT的路径。 | ||
CUDNN_LIB=/usr/local/cudnn/lib64 | ||
CUDA_LIB=/usr/local/cuda/lib64 | ||
# TENSORRT_ROOT=/usr/local/TensorRT-6.0.1.5 | ||
``` | ||
|
||
运行 `sh run_impl.sh`, 会在目录下产生build目录。 | ||
|
||
|
||
2) **运行样例** | ||
|
||
```shell | ||
# 进入build目录 | ||
cd build | ||
# 运行样例 | ||
./yolov3_test -model_file ${YOLO_MODEL_PATH}/__model__ --params_file ${YOLO_MODEL_PATH}/__params__ | ||
``` | ||
|
||
运行结束后,程序会将模型输出个数打印到屏幕,说明运行成功。 | ||
|
||
### 更多链接 | ||
- [Paddle Inference使用Quick Start!]() | ||
- [Paddle Inference Python Api使用]() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
work_path=$(dirname $(readlink -f $0)) | ||
|
||
mkdir -p build | ||
cd build | ||
rm -rf * | ||
|
||
# same with the yolov3_test.cc | ||
DEMO_NAME=yolov3_test | ||
|
||
WITH_MKL=ON | ||
WITH_GPU=ON | ||
USE_TENSORRT=OFF | ||
|
||
LIB_DIR=${work_path}/fluid_inference_install_dir | ||
CUDNN_LIB=/usr/local/cudnn/lib64 | ||
CUDA_LIB=/usr/local/cuda/lib64 | ||
# TENSORRT_ROOT=/usr/local/TensorRT-6.0.1.5 | ||
|
||
cmake .. -DPADDLE_LIB=${LIB_DIR} \ | ||
-DWITH_MKL=${WITH_MKL} \ | ||
-DDEMO_NAME=${DEMO_NAME} \ | ||
-DWITH_GPU=${WITH_GPU} \ | ||
-DWITH_STATIC_LIB=OFF \ | ||
-DUSE_TENSORRT=${USE_TENSORRT} \ | ||
-DCUDNN_LIB=${CUDNN_LIB} \ | ||
-DCUDA_LIB=${CUDA_LIB} \ | ||
-DTENSORRT_ROOT=${TENSORRT_ROOT} | ||
|
||
make -j |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
#include "paddle/include/paddle_inference_api.h" | ||
|
||
#include <numeric> | ||
#include <iostream> | ||
#include <memory> | ||
#include <chrono> | ||
|
||
#include <gflags/gflags.h> | ||
#include <glog/logging.h> | ||
|
||
using paddle::AnalysisConfig; | ||
|
||
DEFINE_string(model_file, "", "Directory of the inference model."); | ||
DEFINE_string(params_file, "", "Directory of the inference model."); | ||
DEFINE_string(model_dir, "", "Directory of the inference model."); | ||
DEFINE_int32(batch_size, 1, "Directory of the inference model."); | ||
DEFINE_bool(use_gpu, false, "enable gpu"); | ||
DEFINE_bool(use_mkldnn, false, "enable mkldnn"); | ||
DEFINE_bool(mem_optim, false, "enable memory optimize"); | ||
|
||
using Time = decltype(std::chrono::high_resolution_clock::now()); | ||
Time time() { return std::chrono::high_resolution_clock::now(); }; | ||
double time_diff(Time t1, Time t2) { | ||
typedef std::chrono::microseconds ms; | ||
auto diff = t2 - t1; | ||
ms counter = std::chrono::duration_cast<ms>(diff); | ||
return counter.count() / 1000.0; | ||
} | ||
|
||
std::unique_ptr<paddle::PaddlePredictor> CreatePredictor() { | ||
AnalysisConfig config; | ||
if (FLAGS_model_dir != "") { | ||
config.SetModel(FLAGS_model_dir); | ||
} else { | ||
config.SetModel(FLAGS_model_file, | ||
FLAGS_params_file); | ||
} | ||
if (FLAGS_use_gpu) { | ||
config.EnableUseGpu(100, 0); | ||
} | ||
if (FLAGS_use_mkldnn) { | ||
config.EnableMKLDNN(); | ||
} | ||
// Open the memory optim. | ||
if (FLAGS_mem_optim) { | ||
config.EnableMemoryOptim(); | ||
} | ||
// We use ZeroCopy, so we set config->SwitchUseFeedFetchOps(false) | ||
config.SwitchUseFeedFetchOps(false); | ||
return CreatePaddlePredictor(config); | ||
} | ||
|
||
void run(paddle::PaddlePredictor *predictor, | ||
const std::vector<float>& input, | ||
const std::vector<int>& input_shape, | ||
const std::vector<int>& input_im, | ||
const std::vector<int>& input_im_shape, | ||
std::vector<float> *out_data) { | ||
auto input_names = predictor->GetInputNames(); | ||
auto input_img = predictor->GetInputTensor(input_names[0]); | ||
input_img->Reshape(input_shape); | ||
input_img->copy_from_cpu(input.data()); | ||
|
||
auto input_size = predictor->GetInputTensor(input_names[1]); | ||
input_size->Reshape(input_im_shape); | ||
input_size->copy_from_cpu(input_im.data()); | ||
|
||
CHECK(predictor->ZeroCopyRun()); | ||
|
||
auto output_names = predictor->GetOutputNames(); | ||
// there is only one output of yolov3 | ||
auto output_t = predictor->GetOutputTensor(output_names[0]); | ||
std::vector<int> output_shape = output_t->shape(); | ||
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>()); | ||
|
||
out_data->resize(out_num); | ||
output_t->copy_to_cpu(out_data->data()); | ||
} | ||
|
||
int main(int argc, char* argv[]) { | ||
google::ParseCommandLineFlags(&argc, &argv, true); | ||
auto predictor = CreatePredictor(); | ||
|
||
const int height = 608; | ||
const int width = 608; | ||
const int channels = 3; | ||
std::vector<int> input_shape = {FLAGS_batch_size, channels, height, width}; | ||
std::vector<float> input_data(FLAGS_batch_size * channels * height * width, 0); | ||
for (size_t i = 0; i < input_data.size(); ++i) { | ||
input_data[i] = i % 255 * 0.13f; | ||
} | ||
std::vector<int> input_im_shape = {FLAGS_batch_size, 2}; | ||
std::vector<int> input_im_data(FLAGS_batch_size * 2, 608); | ||
|
||
std::vector<float> out_data; | ||
run(predictor.get(), input_data, input_shape, input_im_data, input_im_shape, &out_data); | ||
LOG(INFO) << "output num is " << out_data.size(); | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
## 运行YOLOv3图像检测样例 | ||
|
||
|
||
### 一:准备环境 | ||
|
||
请您在环境中安装1.7或以上版本的Paddle,具体的安装方式请参照[飞桨官方页面](https://www.paddlepaddle.org.cn/)的指示方式。 | ||
|
||
|
||
### 二:下载模型以及测试数据 | ||
|
||
|
||
1)**获取预测模型** | ||
|
||
点击[链接](https://paddle-inference-dist.cdn.bcebos.com/PaddleLite/yolov3_infer.tar.gz)下载模型, 该模型在imagenet数据集训练得到的,如果你想获取更多的**模型训练信息**,请访问[这里](https://github.com/PaddlePaddle/PaddleDetection)。 | ||
|
||
|
||
2)**获取预测样例图片** | ||
|
||
下载[样例图片](https://paddle-inference-dist.bj.bcebos.com/inference_demo/images/kite.jpg)。 | ||
|
||
图片如下: | ||
<p align="left"> | ||
<br> | ||
<img src='https://paddle-inference-dist.bj.bcebos.com/inference_demo/images/kite.jpg' width = "200" height = "200"> | ||
<br> | ||
<p> | ||
|
||
|
||
### 三:运行预测 | ||
|
||
文件`utils.py`包含了图像的预处理等帮助函数。 | ||
文件`infer_yolov3.py` 包含了创建predictor,读取示例图片,预测,获取输出的等功能。 | ||
|
||
运行: | ||
``` | ||
python infer_yolov3.py --model_file=./yolov3_infer/__model__ --params_file=./yolov3_infer/__params__ --use_gpu=1 | ||
``` | ||
|
||
输出结果如下所示: | ||
|
||
``` | ||
('category id is ', 0.0, ', bbox is ', array([ 98.47467, 471.34283, 120.73273, 578.5184 ], dtype=float32)) | ||
('category id is ', 0.0, ', bbox is ', array([ 51.752716, 415.51324 , 73.18762 , 515.24005 ], dtype=float32)) | ||
('category id is ', 0.0, ', bbox is ', array([ 37.176304, 343.378 , 46.64221 , 380.92963 ], dtype=float32)) | ||
('category id is ', 0.0, ', bbox is ', array([155.78638, 328.0806 , 159.5393 , 339.37192], dtype=float32)) | ||
('category id is ', 0.0, ', bbox is ', array([233.86328, 339.96912, 239.35403, 355.3322 ], dtype=float32)) | ||
('category id is ', 0.0, ', bbox is ', array([ 16.212902, 344.42365 , 25.193722, 377.97137 ], dtype=float32)) | ||
('category id is ', 0.0, ', bbox is ', array([ 10.583471, 356.67862 , 14.9261 , 372.8137 ], dtype=float32)) | ||
('category id is ', 0.0, ', bbox is ', array([ 79.76479, 364.19492, 86.07656, 385.64255], dtype=float32)) | ||
('category id is ', 0.0, ', bbox is ', array([312.8938 , 311.9908 , 314.58527, 316.60056], dtype=float32)) | ||
('category id is ', 33.0, ', bbox is ', array([266.97925 , 51.70044 , 299.45105 , 99.996414], dtype=float32)) | ||
('category id is ', 33.0, ', bbox is ', array([210.45593, 229.92128, 217.77551, 240.97136], dtype=float32)) | ||
('category id is ', 33.0, ', bbox is ', array([125.36278, 159.80171, 135.49306, 189.8976 ], dtype=float32)) | ||
('category id is ', 33.0, ', bbox is ', array([486.9354 , 266.164 , 494.4437 , 283.84637], dtype=float32)) | ||
('category id is ', 33.0, ', bbox is ', array([259.01584, 232.23044, 270.69266, 248.58704], dtype=float32)) | ||
('category id is ', 33.0, ', bbox is ', array([135.60567, 254.57668, 144.96178, 276.9275 ], dtype=float32)) | ||
('category id is ', 33.0, ', bbox is ', array([341.91315, 255.44394, 345.0335 , 262.3398 ], dtype=float32)) | ||
``` | ||
|
||
<p align="left"> | ||
<br> | ||
<img src='https://paddle-inference-dist.bj.bcebos.com/inference_demo/images/kite_res.jpg' width = "200" height = "200"> | ||
<br> | ||
<p> | ||
|
||
|
||
### 相关链接 | ||
- [Paddle Inference使用Quick Start!]() | ||
- [Paddle Inference Python Api使用]() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.