Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update readme and refactor dilation #14

Merged
merged 3 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/builld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@ jobs:
uses: dtolnay/rust-toolchain@stable

- name: Run check on Ubuntu
run: cargo check --features=cli --bins
run: cargo check --features=cli --bin surya

- name: Run fmt check
run: cargo fmt --all --check

- name: Run clippy
run: cargo clippy --features=cli --bins
run: cargo clippy --features=cli --bin surya --lib

- name: Run unit tests
run: cargo test --features=cli --bins
run: cargo test --features=cli --bin surya --lib

- name: Test run surya
run: cargo run --features=cli --bin surya -- --help
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ image = { version = "0.24.8", default-features = false, features = [
log = { version = "0.4.20" }
opencv = { version = "0.88.8" }
serde = { version = "1.0.196" }
serde_json = { version = "1.0.112", optional = true }
serde_json = { version = "1.0.112" }

[features]
metal = ["candle-core/metal", "candle-nn/metal"]
cli = ["clap", "anyhow", "hf-hub", "serde_json"]
cli = ["clap", "anyhow", "hf-hub"]

[[bin]]
name = "surya"
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ Options:
model's weights file name [default: model.safetensors]
--config-file-name <CONFIG_FILE_NAME>
model's config file name [default: config.json]
--generate-bbox-image
whether to generate bbox image
--generate-heatmap
whether to generate heatmap
--generate-affinity-map
Expand Down
135 changes: 81 additions & 54 deletions src/bbox.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,45 @@
use log::debug;
use opencv::core::{
max_mat_f64, min_mat_f64, Mat, Point, Point2f, Rect, RotatedRect, Scalar, Size,
self, max_mat_f64, min_mat_f64, Mat, Point, Point2f, Rect, Scalar, Size, Vector, CV_32S, CV_8U,
};
use opencv::prelude::*;
use opencv::{core, imgcodecs, imgproc};
use opencv::types::VectorOfi32;
use opencv::{imgcodecs, imgproc};

#[derive(Debug, Clone)]
pub struct BBox {
pub rect: RotatedRect,
pub polygon: [Point2f; 4],
}

impl BBox {
fn scale_to_rect(&self, original_size: (u32, u32), heatmap: &Mat) -> opencv::Result<Rect> {
fn scale_to_original(&self, original_size: (u32, u32), heatmap: &Mat) -> opencv::Result<Self> {
let (width, height) = original_size;
let w_scaler = width as f32 / heatmap.cols() as f32;
let h_scaler = height as f32 / heatmap.rows() as f32;
let mut point_2fs = [Point2f::default(); 4];
self.rect.points(&mut point_2fs)?;
let mut points = [Point::default(); 4];
for (i, point_2f) in point_2fs.iter().enumerate() {
points[i] = Point::new(
(point_2f.x * w_scaler) as i32,
(point_2f.y * h_scaler) as i32,
);
let mut polygon = [Point2f::default(); 4];
for (i, point_2f) in self.polygon.iter().enumerate() {
polygon[i] = Point2f::new(point_2f.x * w_scaler, point_2f.y * h_scaler);
}
Ok(Rect::from_points(points[0], points[2]))
Ok(BBox { polygon })
}

fn draw_on_image(&self, image: &mut Mat) -> opencv::Result<()> {
// convert each point from Point2f to Point
let points: Vector<Point> = self
.polygon
.iter()
.map(|point| point.to::<i32>().unwrap())
.collect();
imgproc::polylines(
image,
&points,
true,
Scalar::new(0., 0., 255., 0.),
2,
opencv::imgproc::LINE_8,
0,
)?;
Ok(())
}
}

Expand All @@ -49,7 +64,7 @@ fn image_f32_to_u8(mat: Mat) -> opencv::Result<Mat> {
let mut r = Mat::default();
let alpha = 255.0;
let beta = 0.0;
mat.convert_to(&mut r, opencv::core::CV_8U, alpha, beta)?;
mat.convert_to(&mut r, CV_8U, alpha, beta)?;
Ok(r)
}

Expand All @@ -64,7 +79,7 @@ fn image_to_connected_components(mat: Mat) -> opencv::Result<(Mat, Mat, Mat)> {
&mut stats,
&mut centroids,
4,
core::CV_32S,
CV_32S,
)?;
Ok((labels, stats, centroids))
}
Expand All @@ -77,68 +92,79 @@ fn heatmap_label_max(heatmap: &Mat, labels: &Mat, label: i32) -> opencv::Result<
Ok(max_value)
}

fn connected_area_to_bbox(
labels: &Mat,
stats_row: &[i32],
label: i32,
) -> opencv::Result<RotatedRect> {
let (w, h, area) = (
// stats_row[imgproc::CC_STAT_LEFT as usize],
// stats_row[imgproc::CC_STAT_TOP as usize],
fn get_dilation_matrix(segmap: &mut Mat, stats_row: &[i32]) -> opencv::Result<Mat> {
let (x, y, w, h, area) = (
stats_row[imgproc::CC_STAT_LEFT as usize],
stats_row[imgproc::CC_STAT_TOP as usize],
stats_row[imgproc::CC_STAT_WIDTH as usize],
stats_row[imgproc::CC_STAT_HEIGHT as usize],
stats_row[imgproc::CC_STAT_AREA as usize],
);

let mut segmap = Mat::default();
core::compare(&labels, &(label as f64), &mut segmap, opencv::core::CMP_EQ)?;
let niter = (f64::sqrt((area * i32::min(w, h)) as f64 / (w * h) as f64) * 2.0) as i32;

let niter = {
let niter = (area * w.min(h)) as f64 / (w * h) as f64;
(niter.sqrt() * 2.0) as i32
};
let roi = {
let sx = (x - niter).max(0);
let sy = (y - niter).max(0);
let ex = (x + w + niter + 1).min(segmap.cols());
let ey = (y + h + niter + 1).min(segmap.rows());
Rect::new(sx, sy, ex - sx, ey - sy)
};
let mut roi = Mat::roi(segmap, roi)?;
let kernel = imgproc::get_structuring_element(
imgproc::MORPH_RECT,
Size::new(1 + niter, 1 + niter),
Point::new(-1, -1),
)?;

let mut dilated: Mat = Mat::default();
imgproc::dilate(
&segmap,
&mut dilated,
segmap,
&mut roi,
&kernel,
Point::new(-1, -1),
Point::new(-1, -1), // default anchor
1,
core::BORDER_CONSTANT,
Scalar::default(),
core::BORDER_CONSTANT, // border type
Scalar::default(), // border value
)?;
Ok(roi)
}

fn connected_area_to_bbox(
labels: &Mat,
stats_row: &[i32],
label: i32,
) -> opencv::Result<[Point2f; 4]> {
let mut segmap = Mat::default();
core::compare(&labels, &(label as f64), &mut segmap, opencv::core::CMP_EQ)?;

let dilated_roi = get_dilation_matrix(&mut segmap, stats_row)?;
dilated_roi.copy_to(&mut segmap)?;

let mut non_zero = Mat::default();
core::find_non_zero(&dilated, &mut non_zero)?;
imgproc::min_area_rect(&non_zero)
core::find_non_zero(&segmap, &mut non_zero)?;
let rotated_rect = imgproc::min_area_rect(&non_zero)?;
let mut points = [Point2f::default(); 4];
rotated_rect.points(&mut points)?;
Ok(points)
}

pub fn draw_bboxes(image: &mut Mat, bboxes: Vec<Rect>, output_file: &str) -> opencv::Result<()> {
pub fn draw_bboxes(image: &mut Mat, bboxes: Vec<BBox>, output_file: &str) -> opencv::Result<()> {
for bbox in bboxes {
imgproc::rectangle(
image,
bbox,
opencv::core::Scalar::new(255., 0., 0., 0.),
2,
opencv::imgproc::LINE_8,
0,
)?;
bbox.draw_on_image(image)?;
}
let params = opencv::types::VectorOfi32::new();
let params = VectorOfi32::new();
imgcodecs::imwrite(output_file, image, &params)?;
Ok(())
}

/// generate bbox from heatmap which are rescaled to original size
pub fn generate_bbox(
original_size: (u32, u32),
heatmap: Vec<Vec<f32>>,
non_max_suppression_threshold: f64,
text_threshold: f64,
bbox_area_threshold: i32,
) -> opencv::Result<Vec<Rect>> {
) -> opencv::Result<Vec<BBox>> {
let heatmap = Mat::from_slice_2d(&heatmap)?;
let labels = image_threshold(heatmap.clone(), non_max_suppression_threshold)?;
let labels = image_f32_to_u8(labels)?;
Expand All @@ -156,18 +182,19 @@ pub fn generate_bbox(
assert_eq!(2, centroids.cols(), "centroids must have 2 columns");

let mut bboxes = Vec::new();
for i in 1..stats.rows() {
let stats_row = stats.at_row::<i32>(i)?;
let area = stats_row[opencv::imgproc::CC_STAT_AREA as usize];
// 0 is background so skip it
for label in 1..stats.rows() {
let stats_row = stats.at_row::<i32>(label)?;
let area = stats_row[imgproc::CC_STAT_AREA as usize];
if area < bbox_area_threshold {
continue;
}
let max_value = heatmap_label_max(&heatmap, &labels, i)?;
let max_value = heatmap_label_max(&heatmap, &labels, label)?;
if max_value < text_threshold {
continue;
}
let rect = connected_area_to_bbox(&labels, stats_row, i)?;
bboxes.push(BBox { rect }.scale_to_rect(original_size, &heatmap)?);
let polygon = connected_area_to_bbox(&labels, stats_row, label)?;
bboxes.push(BBox { polygon }.scale_to_original(original_size, &heatmap)?);
}
Ok(bboxes)
}