Skip to content

Commit

Permalink
Add half precision inference in libtorch
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Nov 27, 2020
1 parent ef3981a commit 288ddd0
Showing 1 changed file with 46 additions and 1 deletion.
47 changes: 46 additions & 1 deletion deployment/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,40 @@ torch::Tensor ReadImage(const std::string& loc) {
return img_tensor.clone();
};

void OverlayBoxes(cv::Mat& img,
const std::vector<Detection>& detections,
const std::vector<std::string>& class_names,
const std::string& img_name,
bool label = true) {
for (const auto& detection : detections) {
const auto& box = detection.bbox;
float score = detection.score;
int class_idx = detection.class_idx;

cv::rectangle(img, box, cv::Scalar(0, 0, 255), 2);

if (label) {
std::stringstream ss;
ss << std::fixed << std::setprecision(2) << score;
std::string s = class_names[class_idx] + " " + ss.str();

auto font_face = cv::FONT_HERSHEY_DUPLEX;
auto font_scale = 1.0;
int thickness = 1;
int baseline=0;
auto s_size = cv::getTextSize(s, font_face, font_scale, thickness, &baseline);
cv::rectangle(img,
cv::Point(box.tl().x, box.tl().y - s_size.height - 5),
cv::Point(box.tl().x + s_size.width, box.tl().y),
cv::Scalar(0, 0, 255), -1);
cv::putText(img, s, cv::Point(box.tl().x, box.tl().y - 5),
font_face , font_scale, cv::Scalar(255, 255, 255), thickness);
}
}

cv::imwrite(img_name, img);
}

int main(int argc, const char* argv[]) {
cxxopts::Options parser(argv[0], "A LibTorch inference implementation of the yolov5");

Expand Down Expand Up @@ -94,6 +128,10 @@ int main(int argc, const char* argv[]) {
std::string weights = opt["checkpoint"].as<std::string>();
module = torch::jit::load(weights);
module.to(device_type);
if (is_gpu) {
module.to(torch::kHalf);
}

module.eval();
std::cout << ">>> Model loaded" << std::endl;
} catch (const torch::Error& e) {
Expand All @@ -114,7 +152,11 @@ int main(int argc, const char* argv[]) {

// Run once to warm up
std::cout << ">>> Run once on empty image" << std::endl;
images.push_back(torch::rand({3, 416, 320}, options));
auto img_dumy = torch::rand({3, 416, 320}, options);
if (is_gpu) {
img_dumy = img_dumy.to(torch::kHalf);
}
images.push_back(img_dumy);
inputs.push_back(images);

auto output = module.forward(inputs);
Expand All @@ -128,6 +170,9 @@ int main(int argc, const char* argv[]) {
// Read image
auto img = ReadImage(image_path);
img = img.to(device_type);
if (is_gpu) {
img = img.to(torch::kHalf);
}

images.push_back(img);
inputs.push_back(images);
Expand Down

0 comments on commit 288ddd0

Please sign in to comment.