Skip to content

Commit

Permalink
code refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist committed Feb 19, 2024
1 parent 15bd436 commit 9388680
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 45 deletions.
75 changes: 45 additions & 30 deletions src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ extern crate accelerate_src;
extern crate intel_mkl_src;

use candle_core::{Device, IndexOp, Module, Tensor};
use candle_nn::VarBuilder;
use clap::{Parser, ValueEnum};
use env_logger::Env;
use hf_hub::api::sync::ApiBuilder;
use log::{debug, info};
use opencv::hub_prelude::MatTraitConst;
use std::fs::File;
Expand All @@ -16,9 +14,12 @@ use std::io::Write;
use std::path::PathBuf;
use std::time::Instant;
use surya::bbox::{draw_bboxes, generate_bbox};
use surya::detection::SemanticSegmentationModel;
use surya::hf::HfModel;
use surya::hf::HfModelInfo;
use surya::postprocess::save_image;
use surya::preprocess::{image_to_tensor, read_chunked_resized_image, read_image};
use surya::segformer::SemanticSegmentationModel;
use surya::recognition::RecognitionModel;

#[derive(Debug, ValueEnum, Clone, Copy)]
enum DeviceType {
Expand Down Expand Up @@ -65,14 +66,14 @@ struct Cli {
default_value = "model.safetensors",
help = "detection model's weights file name"
)]
weights_file_name: String,
detection_weights_file_name: String,

#[arg(
long,
default_value = "config.json",
help = "detection model's config file name"
)]
config_file_name: String,
detection_config_file_name: String,

#[arg(
long,
Expand Down Expand Up @@ -108,6 +109,20 @@ struct Cli {
)]
recognition_model_repo: String,

#[arg(
long,
default_value = "model.safetensors",
help = "recognition model's weights file name"
)]
recognition_weights_file_name: String,

#[arg(
long,
default_value = "config.json",
help = "recognition model's config file name"
)]
recognition_config_file_name: String,

#[arg(
long,
default_value = "./surya_output",
Expand Down Expand Up @@ -155,31 +170,30 @@ struct Cli {
}

impl Cli {
fn get_detection_model(
&self,
device: &Device,
num_labels: usize,
) -> surya::Result<SemanticSegmentationModel> {
let api = ApiBuilder::new().with_progress(true).build()?;
let repo = api.model(self.detection_model_repo.clone());
debug!(
"using model from HuggingFace repo {0}",
self.detection_model_repo
);
let model_file = repo.get(&self.weights_file_name)?;
debug!("using weights file '{0}'", self.weights_file_name);
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_file], candle_core::DType::F32, device)?
};
let config_file = repo.get(&self.config_file_name)?;
debug!("using config file '{0}'", self.config_file_name);
let config = serde_json::from_str(&std::fs::read_to_string(config_file)?)?;
debug!("loaded config: {:?}, num_labels {}", config, num_labels);
Ok(SemanticSegmentationModel::new(&config, num_labels, vb)?)
fn get_detection_model(&self, device: &Device) -> surya::Result<SemanticSegmentationModel> {
SemanticSegmentationModel::from_hf(
HfModelInfo {
model_type: "detection",
repo: self.detection_model_repo.clone(),
weights_file: self.detection_weights_file_name.clone(),
config_file: self.detection_config_file_name.clone(),
},
device,
)
}
}

const NUM_LABELS: usize = 2;
fn get_recognition_model(&self, device: &Device) -> surya::Result<RecognitionModel> {
RecognitionModel::from_hf(
HfModelInfo {
model_type: "recognition",
repo: self.recognition_model_repo.clone(),
weights_file: self.recognition_weights_file_name.clone(),
config_file: self.recognition_config_file_name.clone(),
},
device,
)
}
}

fn main() -> surya::Result<()> {
let args = Cli::parse();
Expand Down Expand Up @@ -218,7 +232,8 @@ fn main() -> surya::Result<()> {
.create(output_dir.clone())?;
info!("generating output to {:?}", output_dir);

let model = args.get_detection_model(&device, NUM_LABELS)?;
let detection_model = args.get_detection_model(&device)?;
// let recognition_model = args.get_recognition_model(&device)?;

let batch_size = args.detection_batch_size.unwrap_or(match device {
Device::Cpu => 2,
Expand All @@ -240,7 +255,7 @@ fn main() -> surya::Result<()> {
batch_size,
);
let now = Instant::now();
let segmentation = model.forward(&batch)?;
let segmentation = detection_model.forward(&batch)?;
info!("inference took {:.3}s", now.elapsed().as_secs_f32());
for i in 0..batch_size {
let heatmap: Tensor = segmentation.i(i)?.squeeze(0)?.i(0)?;
Expand Down
19 changes: 19 additions & 0 deletions src/detection/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
mod segformer;

use crate::error::Result;
use crate::hf::HfModel;
use candle_core::Device;
use candle_nn::VarBuilder;
pub use segformer::Config;
pub use segformer::SemanticSegmentationModel;
use std::path::PathBuf;

impl HfModel for SemanticSegmentationModel {
fn from_hf_files(config: PathBuf, weights: PathBuf, device: &Device) -> Result<Self> {
let config = serde_json::from_str(&std::fs::read_to_string(config)?)?;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights], candle_core::DType::F32, device)?
};
Self::new(&config, 2, vb).map_err(Into::into)
}
}
File renamed without changes.
43 changes: 43 additions & 0 deletions src/hf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//! HuggingFace API
use crate::error::Result;
use candle_core::Device;
use hf_hub::api::sync::ApiBuilder;
use log::debug;
use std::path::PathBuf;

pub struct HfModelInfo {
pub model_type: &'static str,
pub repo: String,
pub weights_file: String,
pub config_file: String,
}

pub trait HfModel {
fn from_hf(info: HfModelInfo, device: &Device) -> Result<Self>
where
Self: Sized,
{
let api = ApiBuilder::new().with_progress(true).build()?;
let repo = api.model(info.repo.clone());
debug!(
"using {} model from HuggingFace repo '{}'",
info.model_type, info.repo,
);
let model_file = repo.get(&info.weights_file)?;
debug!(
"using {} weights file '{}'",
info.model_type, info.weights_file
);
let config_file = repo.get(&info.config_file)?;
debug!(
"using {} config file '{}'",
info.model_type, info.config_file
);
Self::from_hf_files(config_file, model_file, device)
}

fn from_hf_files(config: PathBuf, weights: PathBuf, device: &Device) -> Result<Self>
where
Self: Sized;
}
6 changes: 4 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
pub mod bbox;
pub mod detection;
pub mod error;
pub mod hf;
pub mod postprocess;
pub mod preprocess;
pub mod segformer;
pub mod swin_transformer;
pub mod recognition;
pub mod tensor_roll;

pub use error::Result;
23 changes: 23 additions & 0 deletions src/recognition/mbart.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//! MBart with MOE
use candle_core::{Module, Result, Tensor};
use candle_nn::VarBuilder;

// TODO this is a placeholder

#[derive(Debug, Clone, serde::Deserialize)]
pub(crate) struct MBartConfig {}

#[derive(Debug, Clone)]
pub(crate) struct MBart {}

impl MBart {
pub(crate) fn new(config: &MBartConfig, vb: VarBuilder) -> Result<Self> {
Ok(Self {})
}
}

impl Module for MBart {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
Ok(input.clone())
}
}
55 changes: 55 additions & 0 deletions src/recognition/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//! The recognition module consists of donut encoder and an MBart decoder
mod mbart;
mod swin_transformer;

use crate::hf::HfModel;
use candle_core::{Device, Module, Result, Tensor};
use candle_nn::VarBuilder;
use mbart::MBart;
use mbart::MBartConfig;
use std::path::PathBuf;
use swin_transformer::SwinConfig;
use swin_transformer::SwinModel;

#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
encoder: SwinConfig,
decoder: MBartConfig,
pad_token_id: i32,
}

#[derive(Debug, Clone)]
pub struct RecognitionModel {
encoder: SwinModel,
decoder: MBart,
}

impl RecognitionModel {
pub fn new(config: &Config, vb: VarBuilder) -> Result<Self> {
let encoder = SwinModel::new(&config.encoder, vb.pp("encoder"))?;
let decoder = MBart::new(&config.decoder, vb.pp("decoder"))?;
Ok(Self { encoder, decoder })
}
}

impl Module for RecognitionModel {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let encoded = self.encoder.forward(input)?;
self.decoder.forward(&encoded)
}
}

impl HfModel for RecognitionModel {
fn from_hf_files(
config: PathBuf,
weights: PathBuf,
device: &Device,
) -> crate::error::Result<Self> {
let config = serde_json::from_str(&std::fs::read_to_string(config)?)?;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights], candle_core::DType::F16, device)?
};
Self::new(&config, vb).map_err(Into::into)
}
}
26 changes: 13 additions & 13 deletions src/swin_transformer.rs → src/recognition/swin_transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use candle_nn::{
};

#[derive(Debug, Clone, serde::Deserialize)]
pub struct SwinConfig {
pub image_size: usize,
pub(crate) struct SwinConfig {
pub image_size: (usize, usize),
pub patch_size: usize,
pub num_channels: usize,
pub embed_dim: usize,
Expand All @@ -37,7 +37,7 @@ impl Default for SwinConfig {
/// this defaults to the Swin-Tiny model with window size 7 and patch size 4 and 224x224 images
fn default() -> Self {
Self {
image_size: 224,
image_size: (224, 224),
patch_size: 4,
num_channels: 3,
embed_dim: 96,
Expand Down Expand Up @@ -320,7 +320,7 @@ impl SwinSelfAttention {

fn generate_relative_position_index(window_size: usize, device: &Device) -> Result<Tensor> {
debug_assert!(window_size > 1, "window_size must be greater than 1");
let window_size: i64 = window_size as i64;
let window_size = window_size as i64;
let h = Tensor::arange(0, window_size, device)?;
let w = Tensor::arange(0, window_size, device)?;
let xy_indexing = false; // use ij indexing
Expand Down Expand Up @@ -705,21 +705,18 @@ impl Module for SwinEncoder {
}

#[derive(Debug, Clone)]
struct SwinModel {
pub(crate) struct SwinModel {
embeddings: SwinEmbeddings,
encoder: SwinEncoder,
layernorm: LayerNorm,
}

impl SwinModel {
pub fn new(config: &SwinConfig, vb: VarBuilder) -> Result<Self> {
pub(crate) fn new(config: &SwinConfig, vb: VarBuilder) -> Result<Self> {
let embeddings = SwinEmbeddings::new(config, vb.pp("embeddings"))?;
let encoder = SwinEncoder::new(config, vb.pp("encoder"))?;
let layernorm = layer_norm(config.embed_dim, config.layer_norm_eps, vb.pp("layernorm"))?;
Ok(Self {
embeddings,
encoder,
layernorm,
})
}
}
Expand All @@ -728,8 +725,8 @@ impl Module for SwinModel {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = self.embeddings.forward(x)?;
let x = self.encoder.forward(&x)?;
let x = self.layernorm.forward(&x)?;
Ok(x)
// this is the same as adaptive avg pool with output size 1
x.mean(1)
}
}

Expand All @@ -755,7 +752,10 @@ mod test {
"embed_dim": 96,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"image_size": 224,
"image_size": [
224,
224
],
"initializer_range": 0.02,
"layer_norm_eps": 1e-05,
"mlp_ratio": 4.0,
Expand All @@ -776,7 +776,7 @@ mod test {
"window_size": 7
}"#;
let config: SwinConfig = serde_json::from_str(config_raw).unwrap();
assert_eq!(config.image_size, 224);
assert_eq!(config.image_size, (224, 224));
assert_eq!(config.patch_size, 4);
assert_eq!(config.num_channels, 3);
assert_eq!(config.embed_dim, 96);
Expand Down

0 comments on commit 9388680

Please sign in to comment.