Skip to content

Commit

Permalink
Merge pull request #73 from jkawamoto/whisper-wrapper
Browse files Browse the repository at this point in the history
Whisper wrapper
  • Loading branch information
jkawamoto authored Jul 19, 2024
2 parents 6da247d + 53eddee commit 42b4ae7
Show file tree
Hide file tree
Showing 9 changed files with 486 additions and 211 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ jobs:
submodules: recursive
- uses: Swatinem/rust-cache@v2
- name: Build
run: cargo build -vv --no-default-features -F "${{ matrix.feature }}"
run: cargo build -vv --no-default-features -F "whisper,${{ matrix.feature }}"
- name: Run tests
run: cargo test -vv --no-default-features -F "${{ matrix.feature }}"
run: cargo test -vv --no-default-features -F "whisper,${{ matrix.feature }}"

build-linux:
strategy:
Expand Down Expand Up @@ -75,9 +75,9 @@ jobs:
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- name: Build
run: cargo build -vv --no-default-features -F "${{ matrix.feature }}"
run: cargo build -vv --no-default-features -F "whisper,${{ matrix.feature }}"
- name: Run tests
run: cargo test -vv --no-default-features -F "${{ matrix.feature }}"
run: cargo test -vv --no-default-features -F "whisper,${{ matrix.feature }}"

build-windows:
strategy:
Expand All @@ -93,6 +93,6 @@ jobs:
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- name: Build
run: cargo build -vv --no-default-features -F "${{ matrix.feature }}"
run: cargo build -vv --no-default-features -F "whisper,${{ matrix.feature }}"
- name: Run tests
run: cargo test -vv --no-default-features -F "${{ matrix.feature }}"
run: cargo test -vv --no-default-features -F "whisper,${{ matrix.feature }}"
20 changes: 15 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ cxx = { version = "1.0.123", features = ["c++17"] }
sentencepiece = "0.11.2"
tokenizers = "0.19.1"

# Dependencies for Whisper model
ndarray = { version = "0.15.6", optional = true }
rustfft = { version = "6.2.0", optional = true }
serde = { version = "1.0.204", features = ["derive"], optional = true }
serde_json = { version = "1.0.120", optional = true }

[target.'cfg(windows)'.dependencies]
intel-mkl-src = { version = "0.8.1", optional = true, features = ["mkl-static-ilp64-seq"] }

Expand All @@ -43,12 +49,10 @@ intel-mkl-src = { version = "0.8.1", optional = true }

[dev-dependencies]
clap = { version = "4.5.7", features = ["derive"] }
hound = "3.5.1"
ndarray = "0.15.6"
rand = "0.8.5"
rustfft = "6.2.0"
serde = { version = "1.0.202", features = ["derive"] }
serde_json = "1.0.117"

# Dependencies for Whisper example
hound = { version = "3.5.1" }

[build-dependencies]
cmake = "0.1.50"
Expand All @@ -57,6 +61,9 @@ walkdir = "2.5.0"

[features]
default = ["ruy", "accelerate"]
whisper = ["dep:ndarray", "dep:rustfft", "dep:serde", "dep:serde_json"]

# Features to select backends.
mkl = ["dep:intel-mkl-src"]
openblas = []
ruy = []
Expand Down Expand Up @@ -102,3 +109,6 @@ name = "stream"

[[example]]
name = "whisper"

[package.metadata.docs.rs]
features = ["whisper"]
188 changes: 9 additions & 179 deletions examples/whisper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,19 @@
//! Then, execute the sample code below with the following command:
//!
//! ```bash
//! cargo run --example whisper -- ./whisper-tiny-ct2 audio.wav
//! cargo run -F whisper --example whisper -- ./whisper-tiny-ct2 audio.wav
//! ```
//!
use std::fs::File;
use std::io::BufReader;
use std::path::{Path, PathBuf};
use std::time;

use anyhow::Result;
use clap::Parser;
use hound::WavReader;
use ndarray::{Array2, Ix3};
use rustfft::num_complex::Complex;
use rustfft::FftPlanner;
use serde::Deserialize;

use ct2rs::sys::{StorageView, Whisper};
use ct2rs::tokenizers::auto;
use ct2rs::Tokenizer;
use ct2rs::Whisper;

const PREPROCESSOR_CONFIG_FILE: &str = "preprocessor_config.json";
#[cfg(not(feature = "whisper"))]
compile_error!("This example requires 'whisper' feature.");

/// Transcribe a file using Whisper models.
#[derive(Parser, Debug)]
Expand All @@ -61,70 +52,16 @@ struct Args {

fn main() -> Result<()> {
let args = Args::parse();
let cfg = PreprocessorConfig::read(args.model_dir.join(PREPROCESSOR_CONFIG_FILE))?;

let mut samples = read_audio(args.audio_file, cfg.sampling_rate)?;
if samples.len() < cfg.n_samples {
samples.append(&mut vec![0f32; cfg.n_samples - samples.len()]);
} else {
samples.truncate(cfg.n_samples);
}

// Compute STFT
let stft = stft(&samples, cfg.n_fft, cfg.hop_length);
let whisper = Whisper::new(args.model_dir, Default::default())?;

// Compute Mel Spectrogram
let mel_spectrogram = mel_spectrogram(&stft, &cfg.mel_filters);
let samples = read_audio(args.audio_file, whisper.sampling_rate())?;

let shape = mel_spectrogram.shape();
let new_shape = Ix3(1, shape[0], shape[1]);

let mut mel_spectrogram = mel_spectrogram.into_shape(new_shape)?;
if !mel_spectrogram.is_standard_layout() {
mel_spectrogram = mel_spectrogram.as_standard_layout().into_owned()
let res = whisper.generate(&samples, None, false, &Default::default())?;
for r in res {
println!("{}", r);
}

let shape = mel_spectrogram.shape().to_vec();
let storage_view = StorageView::new(
&shape,
mel_spectrogram.as_slice_mut().unwrap(),
Default::default(),
)?;

// Load the model.
let model = Whisper::new(&args.model_dir, Default::default()).unwrap();
let tokenizer = auto::Tokenizer::new(&args.model_dir)?;

let now = time::Instant::now();

// Detect language.
let lang = model.detect_language(&storage_view)?;
println!("Detected language: {:?}", lang[0][0]);

// Transcribe.
let res = model.generate(
&storage_view,
&[vec![
"<|startoftranscript|>",
&lang[0][0].language,
"<|transcribe|>",
"<|notimestamps|>",
]],
&Default::default(),
)?;

let elapsed = now.elapsed();

match res.into_iter().next() {
None => println!("Empty result"),
Some(r) => {
for v in r.sequences.into_iter() {
println!("{:?}", tokenizer.decode(v));
}
}
}
println!("Time taken: {:?}", elapsed);

Ok(())
}

Expand Down Expand Up @@ -159,110 +96,3 @@ fn read_audio<T: AsRef<Path>>(path: T, sample_rate: usize) -> Result<Vec<f32>> {

Ok(resample(mono, spec.sample_rate as usize, sample_rate))
}

fn stft(samples: &[f32], n_fft: usize, hop_length: usize) -> Array2<Complex<f32>> {
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(n_fft);

let n_frames = (samples.len() - 1) / hop_length + 1;
let mut stft = Array2::zeros((n_fft / 2 + 1, n_frames));

let mut padded_samples = samples.to_vec();
padded_samples.extend(vec![0.0; n_fft]);

for (i, frame) in padded_samples
.windows(n_fft)
.step_by(hop_length)
.take(n_frames)
.enumerate()
{
let mut fft_input: Vec<Complex<f32>> =
frame.iter().map(|&x| Complex::new(x, 0.0)).collect();
fft.process(&mut fft_input);
for (j, value) in fft_input.iter().take(n_fft / 2 + 1).enumerate() {
stft[[j, i]] = *value;
}
}

stft
}

fn mel_spectrogram(stft: &Array2<Complex<f32>>, mel_filter_bank: &Array2<f32>) -> Array2<f32> {
let spectrum = stft.mapv(|x| x.norm_sqr());

let mut res = mel_filter_bank.dot(&spectrum).mapv(|x| x.log10());
let global_max = res.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
res.mapv_inplace(|x| x.max(global_max - 8.0));
res.mapv_inplace(|x| (x + 4.0) / 4.0);

res
}

#[allow(dead_code)]
#[derive(Debug)]
struct PreprocessorConfig {
chunk_length: usize,
feature_extractor_type: String,
feature_size: usize,
hop_length: usize,
n_fft: usize,
n_samples: usize,
nb_max_frames: usize,
padding_side: String,
padding_value: f32,
processor_class: String,
return_attention_mask: bool,
sampling_rate: usize,
mel_filters: Array2<f32>,
}

impl PreprocessorConfig {
fn read<T: AsRef<Path>>(path: T) -> Result<Self> {
let file = File::open(path)?;
let reader = BufReader::new(file);

#[derive(Deserialize)]
struct PreprocessorConfigAux {
chunk_length: usize,
feature_extractor_type: String,
feature_size: usize,
hop_length: usize,
n_fft: usize,
n_samples: usize,
nb_max_frames: usize,
padding_side: String,
padding_value: f32,
processor_class: String,
return_attention_mask: bool,
sampling_rate: usize,
mel_filters: Vec<Vec<f32>>,
}
let aux: PreprocessorConfigAux = serde_json::from_reader(reader)?;

let rows = aux.mel_filters.len();
let cols = aux
.mel_filters
.first()
.map(|row| row.len())
.unwrap_or_default();

Ok(Self {
chunk_length: aux.chunk_length,
feature_extractor_type: aux.feature_extractor_type,
feature_size: aux.feature_size,
hop_length: aux.hop_length,
n_fft: aux.n_fft,
n_samples: aux.n_samples,
nb_max_frames: aux.nb_max_frames,
padding_side: aux.padding_side,
padding_value: aux.padding_value,
processor_class: aux.processor_class,
return_attention_mask: aux.return_attention_mask,
sampling_rate: aux.sampling_rate,
mel_filters: Array2::from_shape_vec(
(rows, cols),
aux.mel_filters.into_iter().flatten().collect(),
)?,
})
}
}
4 changes: 3 additions & 1 deletion src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ use std::path::Path;

use anyhow::{anyhow, Result};

pub use sys::GenerationOptions;

use crate::tokenizer::encode_all;

use super::{sys, Config, GenerationOptions, GenerationStepResult, Tokenizer};
use super::{sys, Config, GenerationStepResult, Tokenizer};

/// A text generator with a tokenizer.
///
Expand Down
24 changes: 16 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
//! [ctranslate2::Whisper](https://opennmt.net/CTranslate2/python/ctranslate2.models.Whisper.html)
//! provided by CTranslate2, specifically [`sys::Translator`], [`sys::Generator`], and
//! [`sys::Whisper`].
//! * More user-friendly versions of these, [`Translator`] and [`Generator`],
//! * More user-friendly versions of these, [`Translator`], [`Generator`],
//! and [`Whisper`] (`whisper` feature is required),
//! which incorporate tokenizers for easier handling.
//!
//! # Basic Usage
Expand All @@ -27,7 +28,7 @@
//! ```no_run
//! # use anyhow::Result;
//! #
//! use ct2rs::{Config, Translator, TranslationOptions, GenerationStepResult};
//! use ct2rs::{Config, Translator, TranslationOptions};
//!
//! # fn main() -> Result<()> {
//! let sources = vec![
Expand Down Expand Up @@ -67,22 +68,29 @@
//! [the example code](https://github.com/jkawamoto/ctranslate2-rs/blob/main/examples/stream.rs)
//! for more information.
//!
//!
#![cfg_attr(docsrs, feature(doc_cfg))]

#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

pub use generator::Generator;
pub use generator::{GenerationOptions, Generator};
pub use result::GenerationStepResult;
pub use sys::{
set_log_level, set_random_seed, BatchType, ComputeType, Config, Device, GenerationOptions,
LogLevel, TranslationOptions,
};
pub use sys::{set_log_level, set_random_seed, BatchType, ComputeType, Config, Device, LogLevel};
pub use tokenizer::Tokenizer;
pub use translator::Translator;
pub use translator::{TranslationOptions, Translator};
#[cfg(feature = "whisper")]
#[cfg_attr(docsrs, doc(cfg(feature = "whisper")))]
pub use whisper::{Whisper, WhisperOptions};

mod generator;
mod result;
pub mod sys;
mod tokenizer;
pub mod tokenizers;
mod translator;

#[cfg(feature = "whisper")]
#[cfg_attr(docsrs, doc(cfg(feature = "whisper")))]
mod whisper;
2 changes: 0 additions & 2 deletions src/sys/storage_view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,5 @@ mod tests {
assert_eq!(v.rank(), rank);
assert!(!v.empty());
assert_eq!(v.device(), Device::CPU);

println!("{:?}", v);
}
}
Loading

0 comments on commit 42b4ae7

Please sign in to comment.