Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix for ort v1.15 #21

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ option(ONNXRUNTIME_DIR "Path to built ONNX Runtime directory." STRING)
message(STATUS "ONNXRUNTIME_DIR: ${ONNXRUNTIME_DIR}")

find_package(OpenCV REQUIRED)
#find_package(Qt5 COMPONENTS Widgets REQUIRED)

include_directories("include/")

Expand All @@ -16,16 +17,19 @@ add_executable(yolo_ort
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

target_include_directories(yolo_ort PRIVATE "${ONNXRUNTIME_DIR}/include")
target_include_directories(yolo_ort PRIVATE "${ONNXRUNTIME_DIR}/include/onnxruntime")
# link_directories("${ONNXRUNTIME_DIR}/lib")
target_compile_features(yolo_ort PRIVATE cxx_std_14)
target_link_libraries(yolo_ort ${OpenCV_LIBS})

if (WIN32)
target_link_libraries(yolo_ort "${ONNXRUNTIME_DIR}/lib/onnxruntime.lib")
endif(WIN32)

if (UNIX)
target_link_libraries(yolo_ort "${ONNXRUNTIME_DIR}/lib/libonnxruntime.so")
target_link_libraries(yolo_ort "${ONNXRUNTIME_DIR}/lib/libonnxruntime.so"
${OpenCV_LIBS}
# yolo_ort Qt5::Core
# yolo_ort Qt5::Widgets
)
endif(UNIX)

4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ C++ YOLO v5 ONNX Runtime inference code for object detection.

## Dependecies:
- OpenCV 4.x
- ONNXRuntime 1.7+
- OS: Tested on Windows 10 and Ubuntu 20.04
- ONNXRuntime 1.15+
- OS: Tested on centos8 archlinux
- CUDA 11+ [Optional]


Expand Down
10 changes: 7 additions & 3 deletions include/detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,13 @@ class YOLODetector
static void getBestClassInfo(std::vector<float>::iterator it, const int& numClasses,
float& bestConf, int& bestClassId);

std::vector<const char*> inputNames;
std::vector<const char*> outputNames;
bool isDynamicInputShape{};
cv::Size2f inputImageShape;

};
// Inputs
std::vector<Ort::AllocatedStringPtr> inputNodeNameAllocatedStrings;
std::vector<const char*> inputNames;
// Outputs
std::vector<Ort::AllocatedStringPtr> outputNodeNameAllocatedStrings;
std::vector<const char*> outputNames;
};
23 changes: 23 additions & 0 deletions src/.clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
---
# Defaults for all languages.
BasedOnStyle: LLVM

# Setting ColumnLimit to 0 so developer choices about where to break lines are maintained.
# Developers are responsible for adhering to the 120 character maximum.
ColumnLimit: 0
DerivePointerAlignment: false
# Avoid adding spaces between tokens in GSL_SUPPRESS arguments.
# E.g., don't change "GSL_SUPPRESS(r.11)" to "GSL_SUPPRESS(r .11)".
WhitespaceSensitiveMacros: ["GSL_SUPPRESS"]

# if you want to customize when working locally see https://clang.llvm.org/docs/ClangFormatStyleOptions.html for options.
# See ReformatSource.ps1 for a script to update all source according to the current options in this file.
# e.g. customizations to use Allman bracing and more indenting.
# AccessModifierOffset: -2
# BreakBeforeBraces: Allman
# CompactNamespaces: false
# IndentCaseLabels: true
IndentWidth: 4
# NamespaceIndentation: All

...
11 changes: 9 additions & 2 deletions src/detector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,15 @@ YOLODetector::YOLODetector(const std::string& modelPath,
for (auto shape : inputTensorShape)
std::cout << "Input shape: " << shape << std::endl;

inputNames.push_back(session.GetInputName(0, allocator));
outputNames.push_back(session.GetOutputName(0, allocator));
// inputNames.push_back(session.GetInputName(0, allocator));
// outputNames.push_back(session.GetOutputName(0, allocator));
auto input_name = session.GetInputNameAllocated(0, allocator);
inputNodeNameAllocatedStrings.push_back(std::move(input_name));
inputNames.push_back(inputNodeNameAllocatedStrings.back().get());

auto output_name = session.GetOutputNameAllocated(0, allocator);
outputNodeNameAllocatedStrings.push_back(std::move(output_name));
outputNames.push_back(outputNodeNameAllocatedStrings.back().get());

std::cout << "Input name: " << inputNames[0] << std::endl;
std::cout << "Output name: " << outputNames[0] << std::endl;
Expand Down
74 changes: 55 additions & 19 deletions src/main.cpp
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
#include <chrono>
#include <iostream>
#include <opencv2/opencv.hpp>
#include <ostream>

#include "cmdline.h"
#include "utils.h"
#include "detector.h"
#include "utils.h"


int main(int argc, char* argv[])
void Delay(int time) // time*1000为秒数
{
clock_t now = clock();

while (clock() - now < time)
;
}

int main(int argc, char *argv[]) {
const float confThreshold = 0.3f;
const float iouThreshold = 0.4f;

cmdline::parser cmd;
cmd.add<std::string>("model_path", 'm', "Path to onnx model.", true, "yolov5.onnx");
cmd.add<std::string>("image", 'i', "Image source to be detected.", true, "bus.jpg");
cmd.add<std::string>("image", 'i', "Image source to be detected.", false);
cmd.add<std::string>("v4l2", 'v', "video dev node to be detected.", false);
cmd.add<std::string>("class_names", 'c', "Path to class names file.", true, "coco.names");
cmd.add("gpu", '\0', "Inference on cuda device.");

Expand All @@ -23,37 +32,64 @@ int main(int argc, char* argv[])
const std::string classNamesPath = cmd.get<std::string>("class_names");
const std::vector<std::string> classNames = utils::loadNames(classNamesPath);
const std::string imagePath = cmd.get<std::string>("image");
const std::string videoPath = cmd.get<std::string>("v4l2");
const std::string modelPath = cmd.get<std::string>("model_path");

if (classNames.empty())
{
if (classNames.empty()) {
std::cerr << "Error: Empty class names file." << std::endl;
return -1;
}

YOLODetector detector {nullptr};
if (imagePath.empty() && videoPath.empty()) {
std::cerr << "At least give one source! jpg or /dev/videox"<<std::endl;
return -1;
}

YOLODetector detector{nullptr};
cv::Mat image;
std::vector<Detection> result;

try
{
try {
detector = YOLODetector(modelPath, isGPU, cv::Size(640, 640));
std::cout << "Model was initialized." << std::endl;

image = cv::imread(imagePath);
result = detector.detect(image, confThreshold, iouThreshold);
}
catch(const std::exception& e)
{
} catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
return -1;
}

utils::visualizeDetection(image, result, classNames);
if (!imagePath.empty()) {
image = cv::imread(imagePath);
result = detector.detect(image, confThreshold, iouThreshold);
utils::visualizeDetection(image, result, classNames);
cv::imshow("result", image);
// cv::imwrite("result.jpg", image);
cv::waitKey(0);
} else if (!videoPath.empty()) {
cv::VideoCapture cap(0);
if (!cap.isOpened()) {
std::cerr << "Error: Could not open camera." << std::endl;
return -1;
}

cv::imshow("result", image);
// cv::imwrite("result.jpg", image);
cv::waitKey(0);
while (true) {
cv::Mat frame;
cap >> frame;
if (frame.empty()) {
std::cerr << "Error: Could not read frame." << std::endl;
continue;
}

auto start_time = std::chrono::high_resolution_clock::now();
result = detector.detect(frame, confThreshold, iouThreshold);
utils::visualizeDetection(frame, result, classNames);
auto end_time = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
std::cout << "Execution time: " << duration.count() << " ms." << std::endl;

cv::imshow("result", frame);
cv::waitKey(10);
// break;
}
}
return 0;
}