diff --git a/deployment/src/main.cpp b/deployment/src/main.cpp index 6788dfdd..3c1a0ea5 100644 --- a/deployment/src/main.cpp +++ b/deployment/src/main.cpp @@ -45,6 +45,40 @@ torch::Tensor ReadImage(const std::string& loc) { return img_tensor.clone(); }; +void OverlayBoxes(cv::Mat& img, + const std::vector& detections, + const std::vector& class_names, + const std::string& img_name, + bool label = true) { + for (const auto& detection : detections) { + const auto& box = detection.bbox; + float score = detection.score; + int class_idx = detection.class_idx; + + cv::rectangle(img, box, cv::Scalar(0, 0, 255), 2); + + if (label) { + std::stringstream ss; + ss << std::fixed << std::setprecision(2) << score; + std::string s = class_names[class_idx] + " " + ss.str(); + + auto font_face = cv::FONT_HERSHEY_DUPLEX; + auto font_scale = 1.0; + int thickness = 1; + int baseline=0; + auto s_size = cv::getTextSize(s, font_face, font_scale, thickness, &baseline); + cv::rectangle(img, + cv::Point(box.tl().x, box.tl().y - s_size.height - 5), + cv::Point(box.tl().x + s_size.width, box.tl().y), + cv::Scalar(0, 0, 255), -1); + cv::putText(img, s, cv::Point(box.tl().x, box.tl().y - 5), + font_face , font_scale, cv::Scalar(255, 255, 255), thickness); + } + } + + cv::imwrite(img_name, img); +} + int main(int argc, const char* argv[]) { cxxopts::Options parser(argv[0], "A LibTorch inference implementation of the yolov5"); @@ -94,6 +128,10 @@ int main(int argc, const char* argv[]) { std::string weights = opt["checkpoint"].as(); module = torch::jit::load(weights); module.to(device_type); + if (is_gpu) { + module.to(torch::kHalf); + } + module.eval(); std::cout << ">>> Model loaded" << std::endl; } catch (const torch::Error& e) { @@ -114,7 +152,11 @@ int main(int argc, const char* argv[]) { // Run once to warm up std::cout << ">>> Run once on empty image" << std::endl; - images.push_back(torch::rand({3, 416, 320}, options)); + auto img_dumy = torch::rand({3, 416, 320}, options); + if (is_gpu) { + img_dumy = img_dumy.to(torch::kHalf); + } + images.push_back(img_dumy); inputs.push_back(images); auto output = module.forward(inputs); @@ -128,6 +170,9 @@ int main(int argc, const char* argv[]) { // Read image auto img = ReadImage(image_path); img = img.to(device_type); + if (is_gpu) { + img = img.to(torch::kHalf); + } images.push_back(img); inputs.push_back(images);