Skip to content

Commit

Permalink
feat: add Whisper
Browse files Browse the repository at this point in the history
  • Loading branch information
jkawamoto committed Jul 12, 2024
1 parent d819cc3 commit 6978b79
Show file tree
Hide file tree
Showing 5 changed files with 523 additions and 1 deletion.
5 changes: 5 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ fn main() {
println!("cargo:rerun-if-changed=src/generator.rs");
println!("cargo:rerun-if-changed=src/generator.cpp");
println!("cargo:rerun-if-changed=src/storage_view.rs");
println!("cargo:rerun-if-changed=src/whisper.rs");
println!("cargo:rerun-if-changed=src/whisper.cpp");
println!("cargo:rerun-if-changed=include/types.h");
println!("cargo:rerun-if-changed=include/config.h");
println!("cargo:rerun-if-changed=include/translator.h");
println!("cargo:rerun-if-changed=include/generator.h");
println!("cargo:rerun-if-changed=include/storage_view.h");
println!("cargo:rerun-if-changed=include/whisper.h");
println!("cargo:rerun-if-changed=CTranslate2");
println!("cargo:rerun-if-env-changed=LIBRARY_PATH");
if let Ok(library_path) = env::var("LIBRARY_PATH") {
Expand Down Expand Up @@ -87,9 +90,11 @@ fn main() {
"src/translator.rs",
"src/generator.rs",
"src/storage_view.rs",
"src/whisper.rs",
])
.file("src/translator.cpp")
.file("src/generator.cpp")
.file("src/whisper.cpp")
.include("CTranslate2/include")
.std("c++17")
.static_crt(cfg!(target_os = "windows"))
Expand Down
79 changes: 79 additions & 0 deletions include/whisper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// whisper.h
//
// Copyright (c) 2023-2024 Junpei Kawamoto
//
// This software is released under the MIT License.
//
// http://opensource.org/licenses/mit-license.php

#pragma once

#include <memory>

#include <ctranslate2/models/whisper.h>

#include "rust/cxx.h"

#include "config.h"

using ctranslate2::StorageView;

struct VecStr;
struct VecDetectionResult;
struct WhisperOptions;
struct WhisperGenerationResult;

class Whisper {
private:
std::unique_ptr<ctranslate2::models::Whisper> impl;

public:
Whisper(std::unique_ptr<ctranslate2::models::Whisper> impl)
: impl(std::move(impl)) { }

rust::Vec<WhisperGenerationResult>
generate(const StorageView& features, const rust::Slice<const VecStr> prompts, const WhisperOptions& options) const;

rust::Vec<VecDetectionResult>
detect_language(const StorageView& features) const;

inline bool is_multilingual() const {
return impl->is_multilingual();
}

inline size_t n_mels() const {
return impl->n_mels();
}

inline size_t num_languages() const {
return impl->num_languages();
}

inline size_t num_queued_batches() const {
return impl->num_queued_batches();
}

inline size_t num_active_batches() const {
return impl->num_active_batches();
}

inline size_t num_replicas() const {
return impl->num_replicas();
}
};

inline std::unique_ptr<Whisper> whisper(
rust::Str model_path,
std::unique_ptr<Config> config
) {
return std::make_unique<Whisper>(
std::make_unique<ctranslate2::models::Whisper>(
static_cast<std::string>(model_path),
config->device,
config->compute_type,
std::vector<int>(config->device_indices.begin(), config->device_indices.end()),
config->tensor_parallel,
*config->replica_pool_config
)
);
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ use crate::auto::Tokenizer as AutoTokenizer;
pub use crate::config::{set_log_level, set_random_seed};
use crate::config::Config;
pub use crate::generator::GenerationOptions;
pub use crate::storage_view::StorageView;
pub use crate::translator::TranslationOptions;

pub mod auto;
Expand All @@ -153,6 +152,7 @@ pub mod storage_view;
pub mod tokenizers;
pub mod translator;
mod types;
pub mod whisper;

/// Defines the necessary functions for a tokenizer.
///
Expand Down
77 changes: 77 additions & 0 deletions src/whisper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// whisper.cpp
//
// Copyright (c) 2023-2024 Junpei Kawamoto
//
// This software is released under the MIT License.
//
// http://opensource.org/licenses/mit-license.php

#include <utility>

#include "ct2rs/src/whisper.rs.h"

#include "ct2rs/include/types.h"

using rust::Slice;
using rust::Vec;

Vec<WhisperGenerationResult> Whisper::generate(
const StorageView& features,
const Slice<const VecStr> prompts,
const WhisperOptions& opts
) const {
auto futures = impl->generate(
features,
from_rust(prompts),
ctranslate2::models::WhisperOptions {
opts.beam_size,
opts.patience,
opts.length_penalty,
opts.repetition_penalty,
opts.no_repeat_ngram_size,
opts.max_length,
opts.sampling_topk,
opts.sampling_temperature,
opts.num_hypotheses,
opts.return_scores,
opts.return_no_speech_prob,
opts.max_initial_timestamp_index,
opts.suppress_blank,
from_rust(opts.suppress_tokens),
}
);

Vec<WhisperGenerationResult> res;
for (auto& future : futures) {
const auto& r = future.get();
res.push_back(WhisperGenerationResult {
to_rust<VecString>(r.sequences),
to_rust<VecUSize>(r.sequences_ids),
to_rust(r.scores),
r.no_speech_prob,
});
}

return res;
}

Vec<VecDetectionResult> Whisper::detect_language(const StorageView& features) const {
auto futures = impl->detect_language(features);

Vec<VecDetectionResult> res;
for (auto& future : futures) {
const auto& r = future.get();

Vec<DetectionResult> pairs;
for (auto& pair : r) {
pairs.push_back(DetectionResult {
std::get<0>(pair),
std::get<1>(pair),
});
}

res.push_back(VecDetectionResult { pairs });
}

return res;
}
Loading

0 comments on commit 6978b79

Please sign in to comment.