Skip to content

Commit

Permalink
feat: update audio apis
Browse files Browse the repository at this point in the history
  • Loading branch information
YanceyOfficial committed Oct 23, 2024
1 parent 8975c57 commit 062668b
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 21 deletions.
1 change: 1 addition & 0 deletions rs_openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ thiserror = "1.0.40"
tokio = { version = "1.26.0", features = ["full"] }
tokio-stream = "0.1.12"
tracing = "0.1.37"
rand = "0.8.5"
23 changes: 17 additions & 6 deletions rs_openai/src/apis/audio.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
//! Learn how to turn audio into text.
//!
//! Related guide: [Speech to text](https://platform.openai.com/docs/guides/speech-to-text)
//! Learn how to turn audio into text or text into audio. Related guide: [Speech to text](https://platform.openai.com/docs/guides/speech-to-text)
use crate::client::OpenAI;
use crate::interfaces::audio;
use crate::shared::response_wrapper::{OpenAIError, OpenAIResponse};
use crate::shared::utils::generate_random_string;
use reqwest::multipart::Form;

pub struct Audio<'a> {
Expand All @@ -18,16 +17,18 @@ impl<'a> Audio<'a> {

/// Generates audio from the input text.
pub async fn create_speech(&self, req: &audio::CreateSpeechRequest) -> OpenAIResponse<()> {
let format = req.response_format.clone().unwrap_or_default();
let random_filename = generate_random_string(16);
self.openai
.post_with_file_response("/audio/speech", req, "")
.post_with_file_response("/audio/speech", req, &format!("{random_filename}.{format}"))
.await
}

/// Transcribes audio into the input language, response is `application/json`.
pub async fn create_transcription(
&self,
req: &audio::CreateTranscriptionRequest,
) -> OpenAIResponse<audio::VerboseJsonForAudioResponse> {
) -> OpenAIResponse<audio::SttResponse> {
if !self.is_json_type(req.response_format.clone()) {
return Err(OpenAIError::InvalidArgument(
"When `response_format` is set to `SttResponseFormat::Text` or `SttResponseFormat::Vtt or `SttResponseFormat::Srt`, use Audio::create_transcription_with_text_response".into(),
Expand All @@ -42,7 +43,7 @@ impl<'a> Audio<'a> {
pub async fn create_translation(
&self,
req: &audio::CreateTranslationRequest,
) -> OpenAIResponse<audio::VerboseJsonForAudioResponse> {
) -> OpenAIResponse<audio::SttResponse> {
if !self.is_json_type(req.response_format.clone()) {
return Err(OpenAIError::InvalidArgument(
"When `response_format` is set to `SttResponseFormat::Text` or `SttResponseFormat::Vtt or `SttResponseFormat::Srt`, use Audio::create_translation_with_text_response".into(),
Expand Down Expand Up @@ -112,6 +113,16 @@ impl<'a> Audio<'a> {
if let Some(language) = req.language.clone() {
form = form.text("laguage", language.to_string());
}

if let Some(timestamp_granularities) = req.timestamp_granularities.clone() {
for (index, value) in timestamp_granularities.iter().enumerate() {
form = form.text(
format!("timestamp_granularities[{}]", index),
value.to_string(),
);
}
}

form
}

Expand Down
73 changes: 58 additions & 15 deletions rs_openai/src/interfaces/audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ pub enum SttModel {
}

#[derive(Debug, Serialize, Default, Clone, strum::Display)]
pub enum AudioSpeechModel {
pub enum TtsModel {
#[default]
#[strum(serialize = "tts-1")]
Whisper1,
Expand All @@ -279,7 +279,7 @@ pub enum AudioSpeechModel {
#[builder(build_fn(error = "OpenAIError"))]
pub struct CreateSpeechRequest {
/// One of the available [TTS models](https://platform.openai.com/docs/models/tts): `tts-1` or `tts-1-hd`
pub model: AudioSpeechModel,
pub model: TtsModel,

/// The text to generate audio for. The maximum length is 4096 characters.
pub input: String,
Expand All @@ -290,7 +290,7 @@ pub struct CreateSpeechRequest {

/// The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`.
/// #[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<SttResponseFormat>, // default: mp3
pub response_format: Option<TtsResponseFormat>, // default: mp3

/// The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default.
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -304,12 +304,16 @@ pub struct CreateSpeechRequest {
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct CreateTranscriptionRequest {
/// The audio file to transcribe, in one of these formats: mp3, mp4, mpeg, mpga, m4a, wav, or webm.
/// The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
pub file: FileMeta,

/// ID of the model to use. Only `whisper-1` is currently available.
/// ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available.
pub model: SttModel,

/// The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency.
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<Language>,

/// An optional text to guide the model's style or continue a previous audio segment.
/// The [prompt](https://platform.openai.com/docs/guides/speech-to-text/prompting) should match the audio language.
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -325,9 +329,11 @@ pub struct CreateTranscriptionRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>, // min: 0, max: 1, default: 0

/// The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency.
/// The timestamp granularities to populate for this transcription.
/// `response_format` must be set `verbose_json` to use timestamp granularities.
/// Either or both of these options are supported: `word`, or `segment`. Note: There is no additional latency for segment timestamps, but generating word timestamps incurs additional latency.
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<Language>,
pub timestamp_granularities: Option<Vec<TimestampGranularity>>, // Defaults to segment
}

#[derive(Builder, Clone, Debug, Default, Serialize)]
Expand All @@ -337,10 +343,10 @@ pub struct CreateTranscriptionRequest {
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct CreateTranslationRequest {
/// The audio file to transcribe, in one of these formats: mp3, mp4, mpeg, mpga, m4a, wav, or webm.
/// The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
pub file: FileMeta,

/// ID of the model to use. Only `whisper-1` is currently available.
/// ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available.
pub model: SttModel,

/// An optional text to guide the model's style or continue a previous audio segment.
Expand All @@ -350,34 +356,71 @@ pub struct CreateTranslationRequest {

/// The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<SttResponseFormat>, // default: json
pub response_format: Option<SttResponseFormat>, // Defaults to json

/// The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random,
/// while lower values like 0.2 will make it more focused and deterministic.
/// The sampling temperature, between 0 and 1.
/// Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
/// If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>, // min: 0, max: 1, default: 0
pub temperature: Option<f32>, // Defaults to 0
}

#[derive(Debug, Serialize, Default, Clone, strum::Display)]
pub enum TimestampGranularity {
#[default]
#[strum(serialize = "segment")]
Segment,
#[strum(serialize = "word")]
Word,
}

/// Represents a verbose json transcription response returned by model, based on the provided input.
#[derive(Debug, Deserialize, Clone, Serialize)]
pub struct VerboseJsonForAudioResponse {
pub struct SttResponse {
/// The transcribed text.
pub text: String,
/// Always `transcribe`.
pub task: Option<String>,
/// The language of the input audio.
pub language: Option<String>,
/// The duration of the input audio.
pub duration: Option<f32>,
/// Segments of the transcribed text and their corresponding details.
pub segments: Option<Vec<Segment>>,
pub text: String,
/// Extracted words and their corresponding timestamps.
pub words: Option<Vec<Word>>,
}

#[derive(Debug, Deserialize, Clone, Serialize)]
pub struct Segment {
/// Unique identifier of the segment.
pub id: u32,
/// Seek offset of the segment.
pub seek: u32,
/// Start time of the segment in seconds.
pub start: f32,
/// End time of the segment in seconds.
pub end: f32,
/// Text content of the segment.
pub text: String,
/// Array of token IDs for the text content.
pub tokens: Vec<u32>,
/// Temperature parameter used for generating the segment.
pub temperature: f32,
/// Average logprob of the segment. If the value is lower than -1, consider the logprobs failed.
pub avg_logprob: f32,
/// Compression ratio of the segment. If the value is greater than 2.4, consider the compression failed.
pub compression_ratio: f32,
/// Probability of no speech in the segment. If the value is higher than 1.0 and the `avg_logprob` is below -1, consider this segment silent.
pub no_speech_prob: f32,
}

#[derive(Debug, Deserialize, Clone, Serialize)]
pub struct Word {
/// The text content of the word.
pub word: String,
/// Start time of the word in seconds.
pub start: f32,
/// End time of the word in seconds.
pub end: f32,
}
10 changes: 10 additions & 0 deletions rs_openai/src/shared/utils.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
use rand::{distributions::Alphanumeric, Rng};

pub fn is_stream(stream: Option<bool>) -> bool {
if stream.is_some() && stream.unwrap() {
return true;
}

false
}

pub fn generate_random_string(length: usize) -> String {
let rng = rand::thread_rng();
rng.sample_iter(&Alphanumeric)
.take(length)
.map(char::from)
.collect()
}

0 comments on commit 062668b

Please sign in to comment.