From 6978b7944030c56aaacf56cfdeaaa40878c9f465 Mon Sep 17 00:00:00 2001 From: Junpei Kawamoto Date: Fri, 12 Jul 2024 17:04:50 -0600 Subject: [PATCH] feat: add Whisper --- build.rs | 5 + include/whisper.h | 79 ++++++++++ src/lib.rs | 2 +- src/whisper.cpp | 77 ++++++++++ src/whisper.rs | 361 ++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 523 insertions(+), 1 deletion(-) create mode 100644 include/whisper.h create mode 100644 src/whisper.cpp create mode 100644 src/whisper.rs diff --git a/build.rs b/build.rs index 314d1e9..07ae8de 100644 --- a/build.rs +++ b/build.rs @@ -21,11 +21,14 @@ fn main() { println!("cargo:rerun-if-changed=src/generator.rs"); println!("cargo:rerun-if-changed=src/generator.cpp"); println!("cargo:rerun-if-changed=src/storage_view.rs"); + println!("cargo:rerun-if-changed=src/whisper.rs"); + println!("cargo:rerun-if-changed=src/whisper.cpp"); println!("cargo:rerun-if-changed=include/types.h"); println!("cargo:rerun-if-changed=include/config.h"); println!("cargo:rerun-if-changed=include/translator.h"); println!("cargo:rerun-if-changed=include/generator.h"); println!("cargo:rerun-if-changed=include/storage_view.h"); + println!("cargo:rerun-if-changed=include/whisper.h"); println!("cargo:rerun-if-changed=CTranslate2"); println!("cargo:rerun-if-env-changed=LIBRARY_PATH"); if let Ok(library_path) = env::var("LIBRARY_PATH") { @@ -87,9 +90,11 @@ fn main() { "src/translator.rs", "src/generator.rs", "src/storage_view.rs", + "src/whisper.rs", ]) .file("src/translator.cpp") .file("src/generator.cpp") + .file("src/whisper.cpp") .include("CTranslate2/include") .std("c++17") .static_crt(cfg!(target_os = "windows")) diff --git a/include/whisper.h b/include/whisper.h new file mode 100644 index 0000000..6c705d1 --- /dev/null +++ b/include/whisper.h @@ -0,0 +1,79 @@ +// whisper.h +// +// Copyright (c) 2023-2024 Junpei Kawamoto +// +// This software is released under the MIT License. +// +// http://opensource.org/licenses/mit-license.php + +#pragma once + +#include + +#include + +#include "rust/cxx.h" + +#include "config.h" + +using ctranslate2::StorageView; + +struct VecStr; +struct VecDetectionResult; +struct WhisperOptions; +struct WhisperGenerationResult; + +class Whisper { +private: + std::unique_ptr impl; + +public: + Whisper(std::unique_ptr impl) + : impl(std::move(impl)) { } + + rust::Vec + generate(const StorageView& features, const rust::Slice prompts, const WhisperOptions& options) const; + + rust::Vec + detect_language(const StorageView& features) const; + + inline bool is_multilingual() const { + return impl->is_multilingual(); + } + + inline size_t n_mels() const { + return impl->n_mels(); + } + + inline size_t num_languages() const { + return impl->num_languages(); + } + + inline size_t num_queued_batches() const { + return impl->num_queued_batches(); + } + + inline size_t num_active_batches() const { + return impl->num_active_batches(); + } + + inline size_t num_replicas() const { + return impl->num_replicas(); + } +}; + +inline std::unique_ptr whisper( + rust::Str model_path, + std::unique_ptr config +) { + return std::make_unique( + std::make_unique( + static_cast(model_path), + config->device, + config->compute_type, + std::vector(config->device_indices.begin(), config->device_indices.end()), + config->tensor_parallel, + *config->replica_pool_config + ) + ); +} diff --git a/src/lib.rs b/src/lib.rs index 59c8bec..dcd9c99 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -141,7 +141,6 @@ use crate::auto::Tokenizer as AutoTokenizer; pub use crate::config::{set_log_level, set_random_seed}; use crate::config::Config; pub use crate::generator::GenerationOptions; -pub use crate::storage_view::StorageView; pub use crate::translator::TranslationOptions; pub mod auto; @@ -153,6 +152,7 @@ pub mod storage_view; pub mod tokenizers; pub mod translator; mod types; +pub mod whisper; /// Defines the necessary functions for a tokenizer. /// diff --git a/src/whisper.cpp b/src/whisper.cpp new file mode 100644 index 0000000..a93b229 --- /dev/null +++ b/src/whisper.cpp @@ -0,0 +1,77 @@ +// whisper.cpp +// +// Copyright (c) 2023-2024 Junpei Kawamoto +// +// This software is released under the MIT License. +// +// http://opensource.org/licenses/mit-license.php + +#include + +#include "ct2rs/src/whisper.rs.h" + +#include "ct2rs/include/types.h" + +using rust::Slice; +using rust::Vec; + +Vec Whisper::generate( + const StorageView& features, + const Slice prompts, + const WhisperOptions& opts +) const { + auto futures = impl->generate( + features, + from_rust(prompts), + ctranslate2::models::WhisperOptions { + opts.beam_size, + opts.patience, + opts.length_penalty, + opts.repetition_penalty, + opts.no_repeat_ngram_size, + opts.max_length, + opts.sampling_topk, + opts.sampling_temperature, + opts.num_hypotheses, + opts.return_scores, + opts.return_no_speech_prob, + opts.max_initial_timestamp_index, + opts.suppress_blank, + from_rust(opts.suppress_tokens), + } + ); + + Vec res; + for (auto& future : futures) { + const auto& r = future.get(); + res.push_back(WhisperGenerationResult { + to_rust(r.sequences), + to_rust(r.sequences_ids), + to_rust(r.scores), + r.no_speech_prob, + }); + } + + return res; +} + +Vec Whisper::detect_language(const StorageView& features) const { + auto futures = impl->detect_language(features); + + Vec res; + for (auto& future : futures) { + const auto& r = future.get(); + + Vec pairs; + for (auto& pair : r) { + pairs.push_back(DetectionResult { + std::get<0>(pair), + std::get<1>(pair), + }); + } + + res.push_back(VecDetectionResult { pairs }); + } + + return res; +} diff --git a/src/whisper.rs b/src/whisper.rs new file mode 100644 index 0000000..c1d5a41 --- /dev/null +++ b/src/whisper.rs @@ -0,0 +1,361 @@ +// whisper.rs +// +// Copyright (c) 2023-2024 Junpei Kawamoto +// +// This software is released under the MIT License. +// +// http://opensource.org/licenses/mit-license.php + +use std::ffi::OsString; +use std::fmt::{Debug, Formatter}; +use std::path::Path; + +use anyhow::{anyhow, Result}; +use cxx::UniquePtr; + +use crate::config::Config; +use crate::storage_view::StorageView; +use crate::types::vec_ffi_vecstr; +pub use crate::whisper::ffi::{DetectionResult, WhisperOptions}; +use crate::whisper::ffi::VecDetectionResult; + +#[cxx::bridge] +mod ffi { + #[derive(Clone, Debug)] + pub struct WhisperOptions { + /// Beam size to use for beam search (set 1 to run greedy search). + pub beam_size: usize, + /// Beam search patience factor, as described in https://arxiv.org/abs/2204.05424. + /// The decoding will continue until beam_size*patience hypotheses are finished. + pub patience: f32, + /// Exponential penalty applied to the length during beam search. + pub length_penalty: f32, + /// Penalty applied to the score of previously generated tokens, as described in + /// https://arxiv.org/abs/1909.05858 (set > 1 to penalize). + pub repetition_penalty: f32, + /// Prevent repetitions of ngrams with this size (set 0 to disable). + pub no_repeat_ngram_size: usize, + /// Maximum generation length. + pub max_length: usize, + /// Randomly sample from the top K candidates (set 0 to sample from the full distribution). + pub sampling_topk: usize, + /// High temperatures increase randomness. + pub sampling_temperature: f32, + /// Number of hypotheses to include in the result. + pub num_hypotheses: usize, + /// Include scores in the result. + pub return_scores: bool, + /// Include the probability of the no speech token in the result. + pub return_no_speech_prob: bool, + /// Maximum index of the first predicted timestamp. + pub max_initial_timestamp_index: usize, + /// Suppress blank outputs at the beginning of the sampling. + pub suppress_blank: bool, + /// List of token IDs to suppress. + /// -1 will suppress a default set of symbols as defined in the model config.json file. + pub suppress_tokens: Vec, + } + + struct WhisperGenerationResult { + sequences: Vec, + sequences_ids: Vec, + scores: Vec, + no_speech_prob: f32, + } + + #[derive(PartialEq, Clone, Debug)] + pub struct DetectionResult { + language: String, + probability: f32, + } + + #[derive(PartialEq, Clone)] + struct VecDetectionResult { + v: Vec, + } + + unsafe extern "C++" { + include!("ct2rs/src/types.rs.h"); + include!("ct2rs/include/whisper.h"); + + type VecStr<'a> = crate::types::ffi::VecStr<'a>; + type VecString = crate::types::ffi::VecString; + type VecUSize = crate::types::ffi::VecUSize; + + type Config = crate::config::ffi::Config; + + type StorageView = crate::storage_view::ffi::StorageView; + + type Whisper; + + fn whisper(model_path: &str, config: UniquePtr) -> Result>; + + fn generate( + self: &Whisper, + features: &StorageView, + prompts: &[VecStr], + options: &WhisperOptions, + ) -> Result>; + + fn detect_language( + self: &Whisper, + features: &StorageView, + ) -> Result>; + + fn is_multilingual(self: &Whisper) -> bool; + + fn n_mels(self: &Whisper) -> usize; + + fn num_languages(self: &Whisper) -> usize; + + fn num_queued_batches(self: &Whisper) -> usize; + + fn num_active_batches(self: &Whisper) -> usize; + + fn num_replicas(self: &Whisper) -> usize; + } +} + +impl Default for WhisperOptions { + fn default() -> Self { + Self { + beam_size: 5, + patience: 1., + length_penalty: 1., + repetition_penalty: 1., + no_repeat_ngram_size: 0, + max_length: 448, + sampling_topk: 1, + sampling_temperature: 1., + num_hypotheses: 1, + return_scores: false, + return_no_speech_prob: false, + max_initial_timestamp_index: 50, + suppress_blank: true, + suppress_tokens: vec![-1], + } + } +} + +/// A generation result from the Whisper model. +#[derive(Clone, Debug)] +pub struct WhisperGenerationResult { + /// Generated sequences of tokens. + pub sequences: Vec>, + /// Generated sequences of token IDs. + pub sequences_ids: Vec>, + /// Score of each sequence (empty if `return_scores` was disabled). + pub scores: Vec, + /// Probability of the no speech token (0 if `return_no_speech_prob` was disabled). + pub no_speech_prob: f32, +} + +impl WhisperGenerationResult { + /// Returns the number of sequences. + #[inline] + pub fn num_sequences(&self) -> usize { + self.sequences.len() + } + + /// Returns true if this result includes scores. + #[inline] + pub fn has_scores(&self) -> bool { + !self.scores.is_empty() + } +} + +impl From for WhisperGenerationResult { + fn from(r: ffi::WhisperGenerationResult) -> Self { + Self { + sequences: r.sequences.into_iter().map(Vec::::from).collect(), + sequences_ids: r + .sequences_ids + .into_iter() + .map(Vec::::from) + .collect(), + scores: r.scores, + no_speech_prob: r.no_speech_prob, + } + } +} + +impl Into> for VecDetectionResult { + fn into(self) -> Vec { + self.v + } +} + +/// A Rust binding to the +/// [`ctranslate2::models::Whisper`](https://opennmt.net/CTranslate2/python/ctranslate2.models.Whisper.html). +pub struct Whisper { + model: OsString, + ptr: UniquePtr, +} + +impl Whisper { + /// Initializes a Whisper model from a converted model. + pub fn new>(model_path: T, config: Config) -> Result { + let model_path = model_path.as_ref(); + Ok(Self { + model: model_path + .file_name() + .map(|s| s.to_os_string()) + .unwrap_or_default(), + ptr: ffi::whisper( + model_path + .to_str() + .ok_or_else(|| anyhow!("invalid path: {}", model_path.display()))?, + config.to_ffi(), + )?, + }) + } + + /// Encodes the input features and generates from the given prompt. + pub fn generate>( + &self, + features: &StorageView, + prompts: &[Vec], + opts: &WhisperOptions, + ) -> Result> { + self.ptr + .generate(features, &vec_ffi_vecstr(prompts), opts) + .map(|res| res.into_iter().map(WhisperGenerationResult::from).collect()) + .map_err(|e| anyhow!("failed to generate: {e}")) + } + + /// Returns the probability of each language. + pub fn detect_language(&self, features: &StorageView) -> Result>> { + self.ptr + .detect_language(features) + .map(|res| res.into_iter().map(VecDetectionResult::into).collect()) + .map_err(|e| anyhow!("failed to detect language: {e}")) + } + + /// Returns `true` if this model is multilingual. + #[inline] + pub fn is_multilingual(&self) -> bool { + self.ptr.is_multilingual() + } + + /// Returns dimension of mel input features. + #[inline] + pub fn n_mels(&self) -> usize { + self.ptr.n_mels() + } + + /// Returns the number of languages supported. + #[inline] + pub fn num_languages(&self) -> usize { + self.ptr.num_languages() + } + + /// Number of batches in the work queue. + #[inline] + pub fn num_queued_batches(&self) -> usize { + self.ptr.num_queued_batches() + } + + /// Number of batches in the work queue or currently processed by a worker. + #[inline] + pub fn num_active_batches(&self) -> usize { + self.ptr.num_active_batches() + } + + /// Number of parallel replicas. + #[inline] + pub fn num_replicas(&self) -> usize { + self.ptr.num_replicas() + } +} + +impl Debug for Whisper { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Whisper") + .field("model", &self.model) + .field("multilingual", &self.is_multilingual()) + .field("mels", &self.n_mels()) + .field("languages", &self.num_languages()) + .field("queued_batches", &self.num_queued_batches()) + .field("active_batches", &self.num_active_batches()) + .field("replicas", &self.num_replicas()) + .finish() + } +} + +unsafe impl Send for Whisper {} +unsafe impl Sync for Whisper {} + +#[cfg(test)] +mod tests { + use crate::whisper::{ffi, WhisperGenerationResult, WhisperOptions}; + + #[test] + fn test_default_options() { + let opts = WhisperOptions::default(); + + assert_eq!(opts.beam_size, 5); + assert_eq!(opts.patience, 1.); + assert_eq!(opts.length_penalty, 1.); + assert_eq!(opts.repetition_penalty, 1.); + assert_eq!(opts.no_repeat_ngram_size, 0); + assert_eq!(opts.max_length, 448); + assert_eq!(opts.sampling_topk, 1); + assert_eq!(opts.sampling_temperature, 1.); + assert_eq!(opts.num_hypotheses, 1); + assert!(!opts.return_scores); + assert!(!opts.return_no_speech_prob); + assert_eq!(opts.max_initial_timestamp_index, 50); + assert!(opts.suppress_blank); + assert_eq!(opts.suppress_tokens, vec![-1]); + } + + #[test] + fn test_generation_result() { + let sequences = vec![ + vec!["a".to_string(), "b".to_string()], + vec!["x".to_string(), "y".to_string(), "z".to_string()], + ]; + let sequences_ids = vec![vec![1, 2], vec![5, 6, 7]]; + let scores = vec![9., 8., 7.]; + let no_speech_prob = 10.; + + let res: WhisperGenerationResult = ffi::WhisperGenerationResult { + sequences: sequences + .iter() + .map(|v| ffi::VecString::from(v.clone())) + .collect(), + sequences_ids: sequences_ids + .iter() + .map(|v| ffi::VecUSize::from(v.clone())) + .collect(), + scores: scores.clone(), + no_speech_prob, + } + .into(); + + assert_eq!(res.sequences, sequences); + assert_eq!(res.sequences_ids, sequences_ids); + assert_eq!(res.scores, scores); + assert_eq!(res.no_speech_prob, no_speech_prob); + assert_eq!(res.num_sequences(), sequences.len()); + assert!(res.has_scores()); + } + + #[test] + fn test_empty_result() { + let res: WhisperGenerationResult = ffi::WhisperGenerationResult { + sequences: vec![], + sequences_ids: vec![], + scores: vec![], + no_speech_prob: 0., + } + .into(); + + assert!(res.sequences.is_empty()); + assert!(res.sequences_ids.is_empty()); + assert!(res.scores.is_empty()); + assert_eq!(res.no_speech_prob, 0.); + assert_eq!(res.num_sequences(), 0); + assert!(!res.has_scores()); + } +}