-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
202 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
mod segformer; | ||
|
||
use crate::error::Result; | ||
use crate::hf::HfModel; | ||
use candle_core::Device; | ||
use candle_nn::VarBuilder; | ||
pub use segformer::Config; | ||
pub use segformer::SemanticSegmentationModel; | ||
use std::path::PathBuf; | ||
|
||
impl HfModel for SemanticSegmentationModel { | ||
fn from_hf_files(config: PathBuf, weights: PathBuf, device: &Device) -> Result<Self> { | ||
let config = serde_json::from_str(&std::fs::read_to_string(config)?)?; | ||
let vb = unsafe { | ||
VarBuilder::from_mmaped_safetensors(&[weights], candle_core::DType::F32, device)? | ||
}; | ||
Self::new(&config, 2, vb).map_err(Into::into) | ||
} | ||
} |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
//! HuggingFace API | ||
use crate::error::Result; | ||
use candle_core::Device; | ||
use hf_hub::api::sync::ApiBuilder; | ||
use log::debug; | ||
use std::path::PathBuf; | ||
|
||
pub struct HfModelInfo { | ||
pub model_type: &'static str, | ||
pub repo: String, | ||
pub weights_file: String, | ||
pub config_file: String, | ||
} | ||
|
||
pub trait HfModel { | ||
fn from_hf(info: HfModelInfo, device: &Device) -> Result<Self> | ||
where | ||
Self: Sized, | ||
{ | ||
let api = ApiBuilder::new().with_progress(true).build()?; | ||
let repo = api.model(info.repo.clone()); | ||
debug!( | ||
"using {} model from HuggingFace repo '{}'", | ||
info.model_type, info.repo, | ||
); | ||
let model_file = repo.get(&info.weights_file)?; | ||
debug!( | ||
"using {} weights file '{}'", | ||
info.model_type, info.weights_file | ||
); | ||
let config_file = repo.get(&info.config_file)?; | ||
debug!( | ||
"using {} config file '{}'", | ||
info.model_type, info.config_file | ||
); | ||
Self::from_hf_files(config_file, model_file, device) | ||
} | ||
|
||
fn from_hf_files(config: PathBuf, weights: PathBuf, device: &Device) -> Result<Self> | ||
where | ||
Self: Sized; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
pub mod bbox; | ||
pub mod detection; | ||
pub mod error; | ||
pub mod hf; | ||
pub mod postprocess; | ||
pub mod preprocess; | ||
pub mod segformer; | ||
pub mod swin_transformer; | ||
pub mod recognition; | ||
pub mod tensor_roll; | ||
|
||
pub use error::Result; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
//! MBart with MOE | ||
use candle_core::{Module, Result, Tensor}; | ||
use candle_nn::VarBuilder; | ||
|
||
// TODO this is a placeholder | ||
|
||
#[derive(Debug, Clone, serde::Deserialize)] | ||
pub(crate) struct MBartConfig {} | ||
|
||
#[derive(Debug, Clone)] | ||
pub(crate) struct MBart {} | ||
|
||
impl MBart { | ||
pub(crate) fn new(config: &MBartConfig, vb: VarBuilder) -> Result<Self> { | ||
Ok(Self {}) | ||
} | ||
} | ||
|
||
impl Module for MBart { | ||
fn forward(&self, input: &Tensor) -> Result<Tensor> { | ||
Ok(input.clone()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
//! The recognition module consists of donut encoder and an MBart decoder | ||
mod mbart; | ||
mod swin_transformer; | ||
|
||
use crate::hf::HfModel; | ||
use candle_core::{Device, Module, Result, Tensor}; | ||
use candle_nn::VarBuilder; | ||
use mbart::MBart; | ||
use mbart::MBartConfig; | ||
use std::path::PathBuf; | ||
use swin_transformer::SwinConfig; | ||
use swin_transformer::SwinModel; | ||
|
||
#[derive(Debug, Clone, serde::Deserialize)] | ||
pub struct Config { | ||
encoder: SwinConfig, | ||
decoder: MBartConfig, | ||
pad_token_id: i32, | ||
} | ||
|
||
#[derive(Debug, Clone)] | ||
pub struct RecognitionModel { | ||
encoder: SwinModel, | ||
decoder: MBart, | ||
} | ||
|
||
impl RecognitionModel { | ||
pub fn new(config: &Config, vb: VarBuilder) -> Result<Self> { | ||
let encoder = SwinModel::new(&config.encoder, vb.pp("encoder"))?; | ||
let decoder = MBart::new(&config.decoder, vb.pp("decoder"))?; | ||
Ok(Self { encoder, decoder }) | ||
} | ||
} | ||
|
||
impl Module for RecognitionModel { | ||
fn forward(&self, input: &Tensor) -> Result<Tensor> { | ||
let encoded = self.encoder.forward(input)?; | ||
self.decoder.forward(&encoded) | ||
} | ||
} | ||
|
||
impl HfModel for RecognitionModel { | ||
fn from_hf_files( | ||
config: PathBuf, | ||
weights: PathBuf, | ||
device: &Device, | ||
) -> crate::error::Result<Self> { | ||
let config = serde_json::from_str(&std::fs::read_to_string(config)?)?; | ||
let vb = unsafe { | ||
VarBuilder::from_mmaped_safetensors(&[weights], candle_core::DType::F16, device)? | ||
}; | ||
Self::new(&config, vb).map_err(Into::into) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters