From f61c99f5a2e11ca76c95cb89c82f7b58cf010026 Mon Sep 17 00:00:00 2001 From: YanceyOfficial Date: Mon, 21 Oct 2024 15:42:45 +0800 Subject: [PATCH] feat: separate apis and interfaces --- README.md | 2 +- rs_openai/Cargo.toml | 10 +- rs_openai/src/apis/audio.rs | 351 ++-------------------- rs_openai/src/apis/chat.rs | 176 +---------- rs_openai/src/apis/completions.rs | 177 +---------- rs_openai/src/apis/edits.rs | 67 +---- rs_openai/src/apis/embeddings.rs | 60 +--- rs_openai/src/apis/engines.rs | 20 +- rs_openai/src/apis/files.rs | 59 +--- rs_openai/src/apis/fine_tunes.rs | 193 ++---------- rs_openai/src/apis/images.rs | 154 +--------- rs_openai/src/apis/models.rs | 39 +-- rs_openai/src/apis/moderations.rs | 89 +----- rs_openai/src/client.rs | 33 ++ rs_openai/src/interfaces/audio.rs | 383 ++++++++++++++++++++++++ rs_openai/src/interfaces/chat.rs | 162 ++++++++++ rs_openai/src/interfaces/completions.rs | 169 +++++++++++ rs_openai/src/interfaces/edits.rs | 60 ++++ rs_openai/src/interfaces/embeddings.rs | 57 ++++ rs_openai/src/interfaces/engines.rs | 15 + rs_openai/src/interfaces/files.rs | 46 +++ rs_openai/src/interfaces/fine_tunes.rs | 163 ++++++++++ rs_openai/src/interfaces/images.rs | 138 +++++++++ rs_openai/src/interfaces/mod.rs | 11 + rs_openai/src/interfaces/models.rs | 34 +++ rs_openai/src/interfaces/moderations.rs | 83 +++++ rs_openai/src/lib.rs | 3 +- rs_openai/src/shared/macro.rs | 6 +- 28 files changed, 1470 insertions(+), 1290 deletions(-) create mode 100644 rs_openai/src/interfaces/audio.rs create mode 100644 rs_openai/src/interfaces/chat.rs create mode 100644 rs_openai/src/interfaces/completions.rs create mode 100644 rs_openai/src/interfaces/edits.rs create mode 100644 rs_openai/src/interfaces/embeddings.rs create mode 100644 rs_openai/src/interfaces/engines.rs create mode 100644 rs_openai/src/interfaces/files.rs create mode 100644 rs_openai/src/interfaces/fine_tunes.rs create mode 100644 rs_openai/src/interfaces/images.rs create mode 100644 rs_openai/src/interfaces/mod.rs create mode 100644 rs_openai/src/interfaces/models.rs create mode 100644 rs_openai/src/interfaces/moderations.rs diff --git a/README.md b/README.md index 62b5545..78ad1db 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ The OpenAI Rust library provides convenient access to the OpenAI API from Rust a ```toml [dependencies] -rs_openai = { version = "0.4.1" } +rs_openai = { version = "0.5.0" } ``` ## Features diff --git a/rs_openai/Cargo.toml b/rs_openai/Cargo.toml index 174020b..4772469 100644 --- a/rs_openai/Cargo.toml +++ b/rs_openai/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rs_openai" -version = "0.4.1" +version = "0.5.0" edition = "2021" authors = ["Yancey Leo "] description = "The OpenAI Rust library provides convenient access to the OpenAI API from Rust applications." @@ -14,14 +14,14 @@ readme = "../README.md" [dependencies] backoff = "0.4.0" -derive_builder = "0.12.0" +derive_builder = "0.20.2" dotenvy = "0.15.6" futures = "0.3.27" -reqwest = { version = "0.11.14", features = ["json", "stream", "multipart"] } -reqwest-eventsource = "0.4.0" +reqwest = { version = "0.12.8", features = ["json", "stream", "multipart"] } +reqwest-eventsource = "0.6.0" serde = { version = "1.0.156", features = ["derive"] } serde_json = "1.0.94" -strum = { version = "0.24.1", features = ["derive"] } +strum = { version = "0.26.3", features = ["derive"] } thiserror = "1.0.40" tokio = { version = "1.26.0", features = ["full"] } tokio-stream = "0.1.12" diff --git a/rs_openai/src/apis/audio.rs b/rs_openai/src/apis/audio.rs index 08eaf0a..039e8ec 100644 --- a/rs_openai/src/apis/audio.rs +++ b/rs_openai/src/apis/audio.rs @@ -3,321 +3,9 @@ //! Related guide: [Speech to text](https://platform.openai.com/docs/guides/speech-to-text) use crate::client::OpenAI; +use crate::interfaces::audio; use crate::shared::response_wrapper::{OpenAIError, OpenAIResponse}; -use crate::shared::types::FileMeta; -use derive_builder::Builder; use reqwest::multipart::Form; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Default, Clone, strum::Display)] -pub enum ResponseFormat { - #[default] - #[strum(serialize = "json")] - Json, - #[strum(serialize = "text")] - Text, - #[strum(serialize = "srt")] - Srt, - #[strum(serialize = "verbose_json")] - VerboseJson, - #[strum(serialize = "vtt")] - Vtt, -} - -#[derive(Debug, Serialize, Default, Clone, strum::Display)] -pub enum Language { - #[default] - #[strum(serialize = "en")] - English, - #[strum(serialize = "zh")] - Chinese, - #[strum(serialize = "de")] - German, - #[strum(serialize = "es")] - Spanish, - #[strum(serialize = "ru")] - Russian, - #[strum(serialize = "ko")] - Korean, - #[strum(serialize = "fr")] - French, - #[strum(serialize = "ja")] - Japanese, - #[strum(serialize = "pt")] - Portuguese, - #[strum(serialize = "tr")] - Turkish, - #[strum(serialize = "pl")] - Polish, - #[strum(serialize = "ca")] - Catalan, - #[strum(serialize = "nl")] - Dutch, - #[strum(serialize = "ar")] - Arabic, - #[strum(serialize = "sv")] - Swedish, - #[strum(serialize = "it")] - Italian, - #[strum(serialize = "id")] - Indonesian, - #[strum(serialize = "hi")] - Hindi, - #[strum(serialize = "fi")] - Finnish, - #[strum(serialize = "vi")] - Vietnamese, - #[strum(serialize = "he")] - Hebrew, - #[strum(serialize = "uk")] - Ukrainian, - #[strum(serialize = "el")] - Greek, - #[strum(serialize = "ms")] - Malay, - #[strum(serialize = "cs")] - Czech, - #[strum(serialize = "ro")] - Romanian, - #[strum(serialize = "da")] - Danish, - #[strum(serialize = "hu")] - Hungarian, - #[strum(serialize = "ta")] - Tamil, - #[strum(serialize = "no")] - Norwegian, - #[strum(serialize = "th")] - Thai, - #[strum(serialize = "ur")] - Urdu, - #[strum(serialize = "hr")] - Croatian, - #[strum(serialize = "bg")] - Bulgarian, - #[strum(serialize = "lt")] - Lithuanian, - #[strum(serialize = "la")] - Latin, - #[strum(serialize = "mi")] - Maori, - #[strum(serialize = "ml")] - Malayalam, - #[strum(serialize = "cy")] - Welsh, - #[strum(serialize = "sk")] - Slovak, - #[strum(serialize = "te")] - Telugu, - #[strum(serialize = "fa")] - Persian, - #[strum(serialize = "lv")] - Latvian, - #[strum(serialize = "bn")] - Bengali, - #[strum(serialize = "sr")] - Serbian, - #[strum(serialize = "az")] - Azerbaijani, - #[strum(serialize = "sl")] - Slovenian, - #[strum(serialize = "kn")] - Kannada, - #[strum(serialize = "et")] - Estonian, - #[strum(serialize = "mk")] - Macedonian, - #[strum(serialize = "br")] - Breton, - #[strum(serialize = "eu")] - Basque, - #[strum(serialize = "is")] - Icelandic, - #[strum(serialize = "hy")] - Armenian, - #[strum(serialize = "ne")] - Nepali, - #[strum(serialize = "mn")] - Mongolian, - #[strum(serialize = "bs")] - Bosnian, - #[strum(serialize = "kk")] - Kazakh, - #[strum(serialize = "sq")] - Albanian, - #[strum(serialize = "sw")] - Swahili, - #[strum(serialize = "gl")] - Galician, - #[strum(serialize = "mr")] - Marathi, - #[strum(serialize = "pa")] - Punjabi, - #[strum(serialize = "si")] - Sinhala, - #[strum(serialize = "km")] - Khmer, - #[strum(serialize = "sn")] - Shona, - #[strum(serialize = "yo")] - Yoruba, - #[strum(serialize = "so")] - Somali, - #[strum(serialize = "af")] - Afrikaans, - #[strum(serialize = "oc")] - Occitan, - #[strum(serialize = "ka")] - Georgian, - #[strum(serialize = "be")] - Belarusian, - #[strum(serialize = "tg")] - Tajik, - #[strum(serialize = "sd")] - Sindhi, - #[strum(serialize = "gu")] - Gujarati, - #[strum(serialize = "am")] - Amharic, - #[strum(serialize = "yi")] - Yiddish, - #[strum(serialize = "lo")] - Lao, - #[strum(serialize = "uz")] - Uzbek, - #[strum(serialize = "fo")] - Faroese, - #[strum(serialize = "ht")] - HaitianCreole, - #[strum(serialize = "ps")] - Pashto, - #[strum(serialize = "tk")] - Turkmen, - #[strum(serialize = "nn")] - Nynorsk, - #[strum(serialize = "mt")] - Maltese, - #[strum(serialize = "sa")] - Sanskrit, - #[strum(serialize = "lb")] - Luxembourgish, - #[strum(serialize = "my")] - Myanmar, - #[strum(serialize = "bo")] - Tibetan, - #[strum(serialize = "tl")] - Tagalog, - #[strum(serialize = "mg")] - Malagasy, - #[strum(serialize = "as")] - Assamese, - #[strum(serialize = "tt")] - Tatar, - #[strum(serialize = "haw")] - Hawaiian, - #[strum(serialize = "ln")] - Lingala, - #[strum(serialize = "ha")] - Hausa, - #[strum(serialize = "ba")] - Bashkir, - #[strum(serialize = "jw")] - Javanese, - #[strum(serialize = "su")] - Sundanese, -} - -#[derive(Debug, Serialize, Default, Clone, strum::Display)] -pub enum AudioModel { - #[default] - #[strum(serialize = "whisper-1")] - Whisper1, -} - -#[derive(Builder, Clone, Debug, Default, Serialize)] -#[builder(name = "CreateTranscriptionRequestBuilder")] -#[builder(pattern = "mutable")] -#[builder(setter(into, strip_option), default)] -#[builder(derive(Debug))] -#[builder(build_fn(error = "OpenAIError"))] -pub struct CreateTranscriptionRequest { - /// The audio file to transcribe, in one of these formats: mp3, mp4, mpeg, mpga, m4a, wav, or webm. - pub file: FileMeta, - - /// ID of the model to use. Only `whisper-1` is currently available. - pub model: AudioModel, - - /// An optional text to guide the model's style or continue a previous audio segment. - /// The [prompt](https://platform.openai.com/docs/guides/speech-to-text/prompting) should match the audio language. - #[serde(skip_serializing_if = "Option::is_none")] - pub prompt: Option, - - /// The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt. - #[serde(skip_serializing_if = "Option::is_none")] - pub response_format: Option, // default: "json" - - /// The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, - /// while lower values like 0.2 will make it more focused and deterministic. - /// If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit. - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, // min: 0, max: 1, default: 0 - - /// The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency. - #[serde(skip_serializing_if = "Option::is_none")] - pub language: Option, -} - -#[derive(Builder, Clone, Debug, Default, Serialize)] -#[builder(name = "CreateTranslationRequestBuilder")] -#[builder(pattern = "mutable")] -#[builder(setter(into, strip_option), default)] -#[builder(derive(Debug))] -#[builder(build_fn(error = "OpenAIError"))] -pub struct CreateTranslationRequest { - /// The audio file to transcribe, in one of these formats: mp3, mp4, mpeg, mpga, m4a, wav, or webm. - pub file: FileMeta, - - /// ID of the model to use. Only `whisper-1` is currently available. - pub model: AudioModel, - - /// An optional text to guide the model's style or continue a previous audio segment. - /// The [prompt](https://platform.openai.com/docs/guides/speech-to-text/prompting) should be in English. - #[serde(skip_serializing_if = "Option::is_none")] - pub prompt: Option, - - /// The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt. - #[serde(skip_serializing_if = "Option::is_none")] - pub response_format: Option, // default: json - - /// The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, - /// while lower values like 0.2 will make it more focused and deterministic. - /// If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit. - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, // min: 0, max: 1, default: 0 -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct VerboseJsonForAudioResponse { - pub task: Option, - pub language: Option, - pub duration: Option, - pub segments: Option>, - pub text: String, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct Segment { - pub id: u32, - pub seek: u32, - pub start: f32, - pub end: f32, - pub text: String, - pub tokens: Vec, - pub temperature: f32, - pub avg_logprob: f32, - pub compression_ratio: f32, - pub no_speech_prob: f32, -} pub struct Audio<'a> { openai: &'a OpenAI, @@ -328,14 +16,21 @@ impl<'a> Audio<'a> { Self { openai } } + /// Generates audio from the input text. + pub async fn create_speech(&self, req: &audio::CreateSpeechRequest) -> OpenAIResponse<()> { + self.openai + .post_with_file_response("/audio/speech", req, "") + .await + } + /// Transcribes audio into the input language, response is `application/json`. pub async fn create_transcription( &self, - req: &CreateTranscriptionRequest, - ) -> OpenAIResponse { + req: &audio::CreateTranscriptionRequest, + ) -> OpenAIResponse { if !self.is_json_type(req.response_format.clone()) { return Err(OpenAIError::InvalidArgument( - "When `response_format` is set to `ResponseFormat::Text` or `ResponseFormat::Vtt or `ResponseFormat::Srt`, use Audio::create_transcription_with_text_response".into(), + "When `response_format` is set to `SttResponseFormat::Text` or `SttResponseFormat::Vtt or `SttResponseFormat::Srt`, use Audio::create_transcription_with_text_response".into(), )); } @@ -346,11 +41,11 @@ impl<'a> Audio<'a> { /// Translates audio into English, response is `application/json`. pub async fn create_translation( &self, - req: &CreateTranslationRequest, - ) -> OpenAIResponse { + req: &audio::CreateTranslationRequest, + ) -> OpenAIResponse { if !self.is_json_type(req.response_format.clone()) { return Err(OpenAIError::InvalidArgument( - "When `response_format` is set to `ResponseFormat::Text` or `ResponseFormat::Vtt or `ResponseFormat::Srt`, use Audio::create_translation_with_text_response".into(), + "When `response_format` is set to `SttResponseFormat::Text` or `SttResponseFormat::Vtt or `SttResponseFormat::Srt`, use Audio::create_translation_with_text_response".into(), )); } @@ -361,11 +56,11 @@ impl<'a> Audio<'a> { /// Transcribes audio into the input language, response is `text/plain`. pub async fn create_transcription_with_text_response( &self, - req: &CreateTranscriptionRequest, + req: &audio::CreateTranscriptionRequest, ) -> OpenAIResponse { if self.is_json_type(req.response_format.clone()) { return Err(OpenAIError::InvalidArgument( - "When `response_format` is `None` or `ResponseFormat::Json` or `ResponseFormat::VerboseJson`, use Audio::create_transcription".into(), + "When `response_format` is `None` or `SttResponseFormat::Json` or `SttResponseFormat::VerboseJson`, use Audio::create_transcription".into(), )); } @@ -378,11 +73,11 @@ impl<'a> Audio<'a> { /// Translates audio into English, response is `text/plain`. pub async fn create_translation_with_text_response( &self, - req: &CreateTranslationRequest, + req: &audio::CreateTranslationRequest, ) -> OpenAIResponse { if self.is_json_type(req.response_format.clone()) { return Err(OpenAIError::InvalidArgument( - "When response_format is `None` or `ResponseFormat::Json` or `ResponseFormat::VerboseJson`, use Audio::create_translation".into(), + "When response_format is `None` or `SttResponseFormat::Json` or `SttResponseFormat::VerboseJson`, use Audio::create_translation".into(), )); } @@ -392,7 +87,7 @@ impl<'a> Audio<'a> { .await } - fn create_transcription_form(&self, req: &CreateTranscriptionRequest) -> Form { + fn create_transcription_form(&self, req: &audio::CreateTranscriptionRequest) -> Form { let file_part = reqwest::multipart::Part::stream(req.file.buffer.clone()) .file_name(req.file.filename.clone()) .mime_str("application/octet-stream") @@ -420,7 +115,7 @@ impl<'a> Audio<'a> { form } - fn create_translation_form(&self, req: &CreateTranslationRequest) -> Form { + fn create_translation_form(&self, req: &audio::CreateTranslationRequest) -> Form { let file_part = reqwest::multipart::Part::stream(req.file.buffer.clone()) .file_name(req.file.filename.clone()) .mime_str("application/octet-stream") @@ -445,14 +140,14 @@ impl<'a> Audio<'a> { form } - fn is_json_type(&self, format_type: Option) -> bool { + fn is_json_type(&self, format_type: Option) -> bool { if format_type.is_none() { return true; } let format_type_display = format_type.unwrap().to_string(); - if format_type_display == ResponseFormat::Json.to_string() - || format_type_display == ResponseFormat::VerboseJson.to_string() + if format_type_display == audio::SttResponseFormat::Json.to_string() + || format_type_display == audio::SttResponseFormat::VerboseJson.to_string() { return true; } diff --git a/rs_openai/src/apis/chat.rs b/rs_openai/src/apis/chat.rs index ba7bf3b..a7b11f3 100644 --- a/rs_openai/src/apis/chat.rs +++ b/rs_openai/src/apis/chat.rs @@ -1,170 +1,11 @@ //! Given a chat conversation, the model will return a chat completion response. use crate::client::OpenAI; +use crate::interfaces::chat; use crate::shared::response_wrapper::{OpenAIError, OpenAIResponse}; -use crate::shared::types::Stop; use crate::shared::utils::is_stream; -use derive_builder::Builder; use futures::Stream; -use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, pin::Pin}; - -#[derive(Debug, Serialize, Deserialize, Clone, Default, strum::Display)] -#[serde(rename_all = "lowercase")] -pub enum Role { - #[strum(serialize = "system")] - System, - #[default] - #[strum(serialize = "user")] - User, - #[strum(serialize = "assistant")] - Assistant, -} - -#[derive(Builder, Default, Debug, Clone, Deserialize, Serialize)] -#[builder(name = "ChatCompletionMessageRequestBuilder")] -#[builder(pattern = "mutable")] -#[builder(setter(into, strip_option), default)] -#[builder(derive(Debug))] -#[builder(build_fn(error = "OpenAIError"))] -pub struct ChatCompletionMessage { - /// The role of the author of this message. One of `system`, `user`, or `assistant`. - pub role: Role, - - /// The contents of the message. - pub content: String, - - /// The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters. - pub name: Option, -} - -#[derive(Builder, Clone, Debug, Default, Serialize)] -#[builder(name = "CreateChatRequestBuilder")] -#[builder(pattern = "mutable")] -#[builder(setter(into, strip_option), default)] -#[builder(derive(Debug))] -#[builder(build_fn(error = "OpenAIError"))] -pub struct CreateChatRequest { - /// ID of the model to use. - /// See the [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API. - pub model: String, - - /// A list of messages describing the conversation so far. - pub messages: Vec, - - /// What sampling temperature to use, between 0 and 2. - /// Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. - /// - /// We generally recommend altering this or `top_p` but not both. - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, // min: 0, max: 2, default: 1 - - /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. - /// So 0.1 means only the tokens comprising the top 10% probability mass are considered. - /// - /// We generally recommend altering this or `temperature` but not both. - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, // default: 1 - - /// How many chat completion choices to generate for each input message. - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, // default: 1 - - /// If set, partial message deltas will be sent, like in ChatGPT. - /// Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. - /// See the OpenAI Cookbook for [example code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb). - /// - /// For streamed progress, use [`create_with_stream`](Chat::create_with_stream). - #[serde(skip_serializing_if = "Option::is_none")] - pub stream: Option, // default: false - - /// Up to 4 sequences where the API will stop generating further tokens. - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option, // default: null - - /// The maximum number of tokens to generate in the chat completion. - /// - /// The total length of input tokens and generated tokens is limited by the model's context length. - #[serde(skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, - - /// Number between -2.0 and 2.0. - /// Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. - /// - /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) - #[serde(skip_serializing_if = "Option::is_none")] - pub presence_penalty: Option, // min: -2.0, max: 2.0, default: 0 - - /// Number between -2.0 and 2.0. - /// Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. - /// - /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) - #[serde(skip_serializing_if = "Option::is_none")] - pub frequency_penalty: Option, // min: -2.0, max: 2.0, default: 0 - - /// Modify the likelihood of specified tokens appearing in the completion. - /// - /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. - /// Mathematically, the bias is added to the logits generated by the model prior to sampling. - /// The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; - /// values like -100 or 100 should result in a ban or exclusive selection of the relevant token. - #[serde(skip_serializing_if = "Option::is_none")] - pub logit_bias: Option>, // default: null - - /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct Message { - pub role: String, - pub content: String, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct ChatUsage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct ChatChoice { - pub message: ChatCompletionMessage, - pub finish_reason: String, - pub index: u32, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct ChatResponse { - pub id: String, - pub object: String, - pub created: u32, - pub choices: Vec, - pub usage: ChatUsage, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct Delta { - pub content: Option, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct ChatChoiceStream { - pub delta: Delta, - pub finish_reason: Option, - pub index: u32, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct ChatStreamResponse { - pub id: String, - pub object: String, - pub model: String, - pub created: u32, - pub choices: Vec, -} +use std::pin::Pin; pub struct Chat<'a> { openai: &'a OpenAI, @@ -176,7 +17,10 @@ impl<'a> Chat<'a> { } /// Creates a completion for the chat message. - pub async fn create(&self, req: &CreateChatRequest) -> OpenAIResponse { + pub async fn create( + &self, + req: &chat::CreateChatRequest, + ) -> OpenAIResponse { if is_stream(req.stream) { return Err(OpenAIError::InvalidArgument( "When stream is true, use Chat::create_with_stream".into(), @@ -189,9 +33,11 @@ impl<'a> Chat<'a> { /// Creates a completion for the chat message. pub async fn create_with_stream( &self, - req: &CreateChatRequest, - ) -> Result> + Send>>, OpenAIError> - { + req: &chat::CreateChatRequest, + ) -> Result< + Pin> + Send>>, + OpenAIError, + > { if !is_stream(req.stream) { return Err(OpenAIError::InvalidArgument( "When stream is false, use Chat::create".into(), diff --git a/rs_openai/src/apis/completions.rs b/rs_openai/src/apis/completions.rs index 1dcbd67..acd52c1 100644 --- a/rs_openai/src/apis/completions.rs +++ b/rs_openai/src/apis/completions.rs @@ -1,179 +1,12 @@ //! Given a prompt, the model will return one or more predicted completions, and can also return the probabilities of alternative tokens at each position. use crate::client::OpenAI; +use crate::interfaces::completions; use crate::shared::response_wrapper::{OpenAIError, OpenAIResponse}; -use crate::shared::types::Stop; use crate::shared::utils::is_stream; -use derive_builder::Builder; use futures::Stream; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use std::pin::Pin; -#[derive(Debug, Clone, Serialize)] -#[serde(untagged)] -pub enum Prompt { - String(String), - ArrayOfString(Vec), - ArrayOfTokens(Vec), - ArrayOfTokenArrays(Vec>), -} - -#[derive(Builder, Clone, Debug, Default, Serialize)] -#[builder(name = "CreateCompletionRequestBuilder")] -#[builder(pattern = "mutable")] -#[builder(setter(into, strip_option), default)] -#[builder(derive(Debug))] -#[builder(build_fn(error = "OpenAIError"))] -pub struct CreateCompletionRequest { - /// ID of the model to use. - /// You can use the [List models](https://platform.openai.com/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](https://platform.openai.com/docs/models/overview) for descriptions of them. - pub model: String, - - /// The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays. - /// - /// Note that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document. - #[serde(skip_serializing_if = "Option::is_none")] - pub prompt: Option, // default: <|endoftext|> - - /// The suffix that comes after a completion of inserted text. - #[serde(skip_serializing_if = "Option::is_none")] - pub suffix: Option, // default: null - - /// The maximum number of [token](https://platform.openai.com/tokenizer) to generate in the completion - /// - /// The token count of your prompt plus `max_tokens` cannot exceed the model's context length. - /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096). - #[serde(skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, // default: 16 - - /// What sampling temperature to use, between 0 and 2. - /// Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. - /// - /// We generally recommend altering this or `top_p` but not both. - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, // min: 0, max: 2, default: 1 - - /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. - /// So 0.1 means only the tokens comprising the top 10% probability mass are considered. - /// - /// We generally recommend altering this or `temperature` but not both. - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, // default: 1 - - /// How many completions to generate for each prompt - /// - /// **Note**: Because this parameter generates many completions, it can quickly consume your token quota. - /// Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`. - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, // default: 1 - - /// Whether to stream back partial progress. - /// If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. - /// - /// For streamed progress, use [`create_with_stream`](Completions::create_with_stream). - #[serde(skip_serializing_if = "Option::is_none")] - pub stream: Option, // default: false - - /// Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. - /// For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. - /// The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response. - /// - /// The maximum value for `logprobs` is 5. If you need more than this, please contact us through our [Help center](https://help.openai.com/) and describe your use case. - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, // default: null - - /// Echo back the prompt in addition to the completion. - #[serde(skip_serializing_if = "Option::is_none")] - pub echo: Option, // default: false - - /// Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence. - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option, // default: null - - /// Number between -2.0 and 2.0. - /// Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. - /// - /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) - #[serde(skip_serializing_if = "Option::is_none")] - pub presence_penalty: Option, // min: -2.0, max: 2.0, default: 0 - - /// Number between -2.0 and 2.0. - /// Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. - /// - /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) - #[serde(skip_serializing_if = "Option::is_none")] - pub frequency_penalty: Option, // min: -2.0, max: 2.0, default: 0 - - /// Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token). - /// Results cannot be streamed. - /// - /// When used with `n`, `best_of` controls the number of candidate completions and n specifies how many to return – `best_of` must be greater than `n`. - /// - /// **Note**: Because this parameter generates many completions, it can quickly consume your token quota. - /// Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`. - #[serde(skip_serializing_if = "Option::is_none")] - pub best_of: Option, // default: 1 - - /// Modify the likelihood of specified tokens appearing in the completion. - /// - /// Accepts a json object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. - /// You can use this [tokenizer](https://platform.openai.com/tokenizer?view=bpe) tool (which works for both GPT-2 and GPT-3) to convert text to token IDs. - /// Mathematically, the bias is added to the logits generated by the model prior to sampling. - /// The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; - /// values like -100 or 100 should result in a ban or exclusive selection of the relevant token. - /// - /// As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated. - #[serde(skip_serializing_if = "Option::is_none")] - pub logit_bias: Option>, // default: null - - /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct CompletionResponse { - pub id: String, - pub object: String, - pub created: u32, - pub model: String, - pub choices: Vec, - pub usage: Usage, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct CompletionChoice { - pub text: String, - pub index: u32, - pub logprobs: Option, - pub finish_reason: String, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct CompletionChoiceStream { - pub text: String, - pub index: usize, - pub logprobs: Option, - pub finish_reason: Option, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct CompletionStreamResponse { - pub id: String, - pub object: String, - pub created: u32, - pub model: String, - pub choices: Vec, -} - pub struct Completions<'a> { openai: &'a OpenAI, } @@ -186,8 +19,8 @@ impl<'a> Completions<'a> { /// Creates a completion for the provided prompt and parameters. pub async fn create( &self, - req: &CreateCompletionRequest, - ) -> OpenAIResponse { + req: &completions::CreateCompletionRequest, + ) -> OpenAIResponse { if is_stream(req.stream) { return Err(OpenAIError::InvalidArgument( "When stream is true, use Completions::create_with_stream".into(), @@ -200,9 +33,9 @@ impl<'a> Completions<'a> { /// Creates a completion for the provided prompt and parameters. pub async fn create_with_stream( &self, - req: &CreateCompletionRequest, + req: &completions::CreateCompletionRequest, ) -> Result< - Pin> + Send>>, + Pin> + Send>>, OpenAIError, > { if !is_stream(req.stream) { diff --git a/rs_openai/src/apis/edits.rs b/rs_openai/src/apis/edits.rs index 1321e2f..6eea141 100644 --- a/rs_openai/src/apis/edits.rs +++ b/rs_openai/src/apis/edits.rs @@ -1,66 +1,8 @@ //! Given a prompt and an instruction, the model will return an edited version of the prompt. use crate::client::OpenAI; -use crate::shared::response_wrapper::{OpenAIError, OpenAIResponse}; -use derive_builder::Builder; -use serde::{Deserialize, Serialize}; - -#[derive(Builder, Clone, Debug, Default, Serialize)] -#[builder(name = "CreateEditRequestBuilder")] -#[builder(pattern = "mutable")] -#[builder(setter(into, strip_option), default)] -#[builder(derive(Debug))] -#[builder(build_fn(error = "OpenAIError"))] -pub struct CreateEditRequest { - /// ID of the model to use. You can use the `text-davinci-edit-001` or `code-davinci-edit-001` model with this endpoint. - pub model: String, - - /// The input text to use as a starting point for the edit. - #[serde(skip_serializing_if = "Option::is_none")] - pub input: Option, - - /// The instruction that tells the model how to edit the prompt. - pub instruction: String, - - /// How many edits to generate for the input and instruction. - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, // default: 1 - - /// What sampling temperature to use, between 0 and 2. - /// Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. - /// - /// We generally recommend altering this or `top_p` but not both. - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, // min: 0, max: 2, default: 1 - - /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. - /// So 0.1 means only the tokens comprising the top 10% probability mass are considered. - /// - /// We generally recommend altering this or `temperature` but not both. - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, // default: 1 -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct EditResponse { - pub object: String, - pub created: u32, - pub choices: Vec, - pub usage: Usage, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct Choice { - pub text: String, - pub index: u32, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, -} +use crate::interfaces::edits; +use crate::shared::response_wrapper::OpenAIResponse; pub struct Edits<'a> { openai: &'a OpenAI, @@ -72,7 +14,10 @@ impl<'a> Edits<'a> { } /// Creates a new edit for the provided input, instruction, and parameters. - pub async fn create(&self, req: &CreateEditRequest) -> OpenAIResponse { + pub async fn create( + &self, + req: &edits::CreateEditRequest, + ) -> OpenAIResponse { self.openai.post("/edits", req).await } } diff --git a/rs_openai/src/apis/embeddings.rs b/rs_openai/src/apis/embeddings.rs index 7b66cb8..83f137f 100644 --- a/rs_openai/src/apis/embeddings.rs +++ b/rs_openai/src/apis/embeddings.rs @@ -3,59 +3,8 @@ //! Related guide: [Embeddings](https://platform.openai.com/docs/guides/embeddings) use crate::client::OpenAI; -use crate::shared::response_wrapper::{OpenAIError, OpenAIResponse}; -use derive_builder::Builder; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Clone)] -#[serde(untagged)] -pub enum EmbeddingInput { - String(String), - ArrayOfString(Vec), - ArrayOfTokens(Vec), - ArrayOfTokenArrays(Vec>), -} - -#[derive(Builder, Clone, Debug, Default, Serialize)] -#[builder(name = "CreateEmbeddingRequestBuilder")] -#[builder(pattern = "mutable")] -#[builder(setter(into, strip_option), default)] -#[builder(derive(Debug))] -#[builder(build_fn(error = "OpenAIError"))] -pub struct CreateEmbeddingRequest { - /// ID of the model to use. - /// Use the [List models](https://platform.openai.com/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](https://platform.openai.com/docs/models/overview) for descriptions of them. - pub model: String, - - /// Input text to get embeddings for, encoded as a string or array of tokens. - /// To get embeddings for multiple inputs in a single request, pass an array of strings or array of token arrays. Each input must not exceed 8192 tokens in length. - pub input: EmbeddingInput, - - /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct EmbeddingResponse { - pub object: String, - pub data: Vec, - pub model: String, - pub usage: Usage, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct EmbeddingData { - pub object: String, - pub embedding: Vec, - pub index: u32, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Usage { - pub prompt_tokens: u32, - pub total_tokens: u32, -} +use crate::interfaces::embeddings; +use crate::shared::response_wrapper::OpenAIResponse; pub struct Embeddings<'a> { openai: &'a OpenAI, @@ -67,7 +16,10 @@ impl<'a> Embeddings<'a> { } /// Creates an embedding vector representing the input text. - pub async fn create(&self, req: &CreateEmbeddingRequest) -> OpenAIResponse { + pub async fn create( + &self, + req: &embeddings::CreateEmbeddingRequest, + ) -> OpenAIResponse { self.openai.post("/embeddings", req).await } } diff --git a/rs_openai/src/apis/engines.rs b/rs_openai/src/apis/engines.rs index 70db401..1a69738 100644 --- a/rs_openai/src/apis/engines.rs +++ b/rs_openai/src/apis/engines.rs @@ -7,22 +7,8 @@ //! Please use their replacement, [Models](https://platform.openai.com/docs/api-reference/models), instead. [Learn more](https://help.openai.com/TODO). use crate::client::OpenAI; +use crate::interfaces::engines; use crate::shared::response_wrapper::OpenAIResponse; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct EngineResponse { - pub id: String, - pub object: String, - pub owner: String, - pub ready: bool, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct EngineListResponse { - pub data: Vec, - pub object: String, -} pub struct Engines<'a> { openai: &'a OpenAI, @@ -38,7 +24,7 @@ impl<'a> Engines<'a> { #[deprecated( note = "The Engines endpoints are deprecated. Please use their replacement, Models, instead." )] - pub async fn list(&self) -> OpenAIResponse { + pub async fn list(&self) -> OpenAIResponse { self.openai.get("/engines", &()).await } @@ -50,7 +36,7 @@ impl<'a> Engines<'a> { #[deprecated( note = "The Engines endpoints are deprecated. Please use their replacement, Models, instead." )] - pub async fn retrieve(&self, engine_id: &str) -> OpenAIResponse { + pub async fn retrieve(&self, engine_id: &str) -> OpenAIResponse { self.openai.get(&format!("/engines/{engine_id}"), &()).await } } diff --git a/rs_openai/src/apis/files.rs b/rs_openai/src/apis/files.rs index 0fb3af9..2a745f3 100644 --- a/rs_openai/src/apis/files.rs +++ b/rs_openai/src/apis/files.rs @@ -1,53 +1,9 @@ //! Files are used to upload documents that can be used with features like [Fine-tuning](https://platform.openai.com/docs/api-reference/fine-tunes). use crate::client::OpenAI; -use crate::shared::response_wrapper::{OpenAIError, OpenAIResponse}; -use crate::shared::types::FileMeta; -use derive_builder::Builder; +use crate::interfaces::files; +use crate::shared::response_wrapper::OpenAIResponse; use reqwest::multipart::Form; -use serde::{Deserialize, Serialize}; - -#[derive(Builder, Clone, Debug, Default, Serialize)] -#[builder(name = "UploadFileRequestBuilder")] -#[builder(pattern = "mutable")] -#[builder(setter(into, strip_option), default)] -#[builder(derive(Debug))] -#[builder(build_fn(error = "OpenAIError"))] -pub struct UploadFileRequest { - /// Name of the [JSON Lines](https://jsonlines.readthedocs.io/en/latest/) file to be uploaded. - /// - /// If the `purpose` is set to "fine-tune", each line is a JSON record with "prompt" and "completion" fields representing your [training examples](https://platform.openai.com/docs/guides/fine-tuning/prepare-training-data). - pub file: FileMeta, - - /// The intended purpose of the uploaded documents. - /// - /// Use "fine-tune" for [Fine-tuning](https://platform.openai.com/docs/api-reference/fine-tunes). - /// This allows us to validate the format of the uploaded file. - pub purpose: String, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct FileResponse { - pub id: String, - pub object: String, - pub bytes: u64, - pub created_at: u32, - pub filename: String, - pub purpose: String, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct FileListResponse { - pub data: Vec, - pub object: String, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct DeleteFileResponse { - pub id: String, - pub object: String, - pub deleted: bool, -} pub struct Files<'a> { openai: &'a OpenAI, @@ -58,14 +14,17 @@ impl<'a> Files<'a> { Self { openai } } /// Returns a list of files that belong to the user's organization. - pub async fn list(&self) -> OpenAIResponse { + pub async fn list(&self) -> OpenAIResponse { self.openai.get("/files", &()).await } /// Upload a file that contains document(s) to be used across various endpoints/features. /// Currently, the size of all the files uploaded by one organization can be up to 1 GB. /// Please contact us if you need to increase the storage limit. - pub async fn upload(&self, req: &UploadFileRequest) -> OpenAIResponse { + pub async fn upload( + &self, + req: &files::UploadFileRequest, + ) -> OpenAIResponse { let file_part = reqwest::multipart::Part::stream(req.file.buffer.clone()) .file_name(req.file.filename.clone()) .mime_str("application/octet-stream") @@ -83,7 +42,7 @@ impl<'a> Files<'a> { /// # Path parameters /// /// - `file_id` - The ID of the file to use for this request - pub async fn delete(&self, file_id: &str) -> OpenAIResponse { + pub async fn delete(&self, file_id: &str) -> OpenAIResponse { self.openai.delete(&format!("/files/{file_id}"), &()).await } @@ -92,7 +51,7 @@ impl<'a> Files<'a> { /// # Path parameters /// /// - `file_id` - The ID of the file to use for this request - pub async fn retrieve(&self, file_id: &str) -> OpenAIResponse { + pub async fn retrieve(&self, file_id: &str) -> OpenAIResponse { self.openai.get(&format!("/files/{file_id}"), &()).await } diff --git a/rs_openai/src/apis/fine_tunes.rs b/rs_openai/src/apis/fine_tunes.rs index 60a3c05..58477b5 100644 --- a/rs_openai/src/apis/fine_tunes.rs +++ b/rs_openai/src/apis/fine_tunes.rs @@ -3,172 +3,11 @@ //! Related guide: [Fine-tune models](https://platform.openai.com/docs/guides/fine-tuning) use crate::client::OpenAI; +use crate::interfaces::fine_tunes; use crate::shared::response_wrapper::{OpenAIError, OpenAIResponse}; -use derive_builder::Builder; use futures::Stream; -use serde::{Deserialize, Serialize}; use std::pin::Pin; -#[derive(Builder, Clone, Debug, Default, Serialize)] -#[builder(name = "CreateFineTuneRequestBuilder")] -#[builder(pattern = "mutable")] -#[builder(setter(into, strip_option), default)] -#[builder(derive(Debug))] -#[builder(build_fn(error = "OpenAIError"))] -pub struct CreateFineTuneRequest { - /// The ID of an uploaded file that contains training data. - /// - /// See [upload file](https://platform.openai.com/docs/api-reference/files/upload) for how to upload a file. - /// - /// - /// Your dataset must be formatted as a JSONL file, where each training example is a JSON object with the keys "prompt" and "completion". - /// Additionally, you must upload your file with the purpose `fine-tune`. - /// - /// - /// See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning/creating-training-data) for more details. - pub training_file: String, - - /// The ID of an uploaded file that contains validation data. - /// - /// If you provide this file, the data is used to generate validation metrics periodically during fine-tuning. - /// These metrics can be viewed in the [fine-tuning results file](https://platform.openai.com/docs/guides/fine-tuning/analyzing-your-fine-tuned-model). - /// Your train and validation data should be mutually exclusive. - /// - /// Your dataset must be formatted as a JSONL file, where each validation example is a JSON object with the keys "prompt" and "completion". - /// Additionally, you must upload your file with the purpose `fine-tune`. - /// - /// See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning/creating-training-data) for more details. - #[serde(skip_serializing_if = "Option::is_none")] - pub validation_file: Option, - - /// The name of the base model to fine-tune. - /// You can select one of "ada", "babbage", "curie", "davinci", or a fine-tuned model created after 2022-04-21. - /// To learn more about these models, see the [Models](https://platform.openai.com/docs/models) documentation. - pub model: Option, - - /// The number of epochs to train the model for. - /// An epoch refers to one full cycle through the training dataset. - #[serde(skip_serializing_if = "Option::is_none")] - pub n_epochs: Option, - - /// The batch size to use for training. - /// The batch size is the number of training examples used to train a single forward and backward pass. - /// - /// By default, the batch size will be dynamically configured to be ~0.2% of the number of examples in the training set, capped at 256. - /// In general, we've found that larger batch sizes tend to work better for larger datasets. - #[serde(skip_serializing_if = "Option::is_none")] - pub batch_size: Option, - - /// The learning rate multiplier to use for training. - /// The fine-tuning learning rate is the original learning rate used for pretraining multiplied by this value. - /// - /// By default, the learning rate multiplier is 0.05, 0.1, or 0.2 depending on final `batch_size` (larger learning rates tend to perform better with larger batch sizes). - /// We recommend experimenting with values in the range 0.02 to 0.2 to see what produces the best results. - #[serde(skip_serializing_if = "Option::is_none")] - pub learning_rate_multiplier: Option, - - /// The weight to use for loss on the prompt tokens. - /// This controls how much the model tries to learn to generate the prompt (as compared to the completion which always has a weight of 1.0), and can add a stabilizing effect to training when completions are short. - /// - /// If prompts are extremely long (relative to completions), it may make sense to reduce this weight so as to avoid over-prioritizing learning the prompt. - #[serde(skip_serializing_if = "Option::is_none")] - pub prompt_loss_weight: Option, - - /// If set, we calculate classification-specific metrics such as accuracy and F-1 score using the validation set at the end of every epoch. - /// These metrics can be viewed in the [results file](https://platform.openai.com/docs/guides/fine-tuning/analyzing-your-fine-tuned-model). - /// - /// In order to compute classification metrics, you must provide a `validation_file`. - /// Additionally, you must specify `classification_n_classes` for multiclass classification or `classification_positive_class` for binary classification. - #[serde(skip_serializing_if = "Option::is_none")] - pub compute_classification_metrics: Option, - - /// The number of classes in a classification task. - /// - /// This parameter is required for multiclass classification. - #[serde(skip_serializing_if = "Option::is_none")] - pub classification_n_classes: Option, - - /// The positive class in binary classification. - /// - /// This parameter is needed to generate precision, recall, and F1 metrics when doing binary classification. - #[serde(skip_serializing_if = "Option::is_none")] - pub classification_positive_class: Option, - - /// If provided, we calculate F-beta scores at the specified beta values. The F-beta score is a generalization of F-1 score. This is only used for binary classification. - /// - /// With a beta of 1 (i.e. the F-1 score), precision and recall are given the same weight. A larger beta score puts more weight on recall and less on precision. A smaller beta score puts more weight on precision and less on recall. - #[serde(skip_serializing_if = "Option::is_none")] - classification_betas: Option>, - - /// A string of up to 40 characters that will be added to your fine-tuned model name. - /// - /// For example, a `suffix` of "custom-model-name" would produce a model name like `ada:ft-your-org:custom-model-name-2022-02-15-04-21-04`. - #[serde(skip_serializing_if = "Option::is_none")] - suffix: Option, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct FineTuneResponse { - pub id: String, - pub object: String, - pub model: String, - pub created_at: u32, - pub events: Option>, - pub fine_tuned_model: Option, - pub hyperparams: HyperParams, - pub organization_id: String, - pub result_files: Vec, - pub status: String, - pub validation_files: Vec, - pub training_files: Vec, - pub updated_at: u32, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct FineTuneEvent { - pub object: String, - pub created_at: u32, - pub level: String, - pub message: String, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct HyperParams { - pub batch_size: u32, - pub learning_rate_multiplier: f32, - pub n_epochs: u32, - pub prompt_loss_weight: f32, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct TrainingFile { - pub id: String, - pub object: String, - pub bytes: u32, - pub created_at: u32, - pub filename: String, - pub purpose: String, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct FineTuneListResponse { - pub object: String, - pub data: Vec, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct EventListResponse { - pub object: String, - pub data: Vec, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct DeleteFileResponse { - pub id: String, - pub object: String, - pub deleted: bool, -} - pub struct FineTunes<'a> { openai: &'a OpenAI, } @@ -183,7 +22,10 @@ impl<'a> FineTunes<'a> { /// OpenAIResponse includes details of the enqueued job including job status and the name of the fine-tuned models once complete. /// /// [Learn more about Fine-tuning](https://platform.openai.com/docs/guides/fine-tuning) - pub async fn create(&self, req: &CreateFineTuneRequest) -> OpenAIResponse { + pub async fn create( + &self, + req: &fine_tunes::CreateFineTuneRequest, + ) -> OpenAIResponse { self.openai.post("/fine-tunes", req).await } @@ -194,7 +36,10 @@ impl<'a> FineTunes<'a> { /// - `fine_tune_id` - The ID of the fine-tune job /// /// [Learn more about Fine-tuning](https://platform.openai.com/docs/guides/fine-tuning) - pub async fn retrieve(&self, fine_tune_id: &str) -> OpenAIResponse { + pub async fn retrieve( + &self, + fine_tune_id: &str, + ) -> OpenAIResponse { self.openai .get(&format!("/fine-tunes/{fine_tune_id}"), &()) .await @@ -205,14 +50,14 @@ impl<'a> FineTunes<'a> { /// # Path parameters /// /// - `fine_tune_id` - The ID of the fine-tune job to cancel - pub async fn cancel(&self, fine_tune_id: &str) -> OpenAIResponse { + pub async fn cancel(&self, fine_tune_id: &str) -> OpenAIResponse { self.openai .post(&format!("/fine-tunes/{fine_tune_id}/cancel"), &()) .await } /// List your organization's fine-tuning jobs - pub async fn list(&self) -> OpenAIResponse { + pub async fn list(&self) -> OpenAIResponse { self.openai.get("/fine-tunes", &()).await } @@ -225,7 +70,10 @@ impl<'a> FineTunes<'a> { /// - `fine_tune_id` - The ID of the fine-tune job to get events for. /// /// TODO: Since free accounts cannot read fine-tune event content, I have to verify this api until purchase a Plus. - pub async fn retrieve_content(&self, fine_tune_id: &str) -> OpenAIResponse { + pub async fn retrieve_content( + &self, + fine_tune_id: &str, + ) -> OpenAIResponse { self.openai .get(&format!("/fine-tunes/{fine_tune_id}/events"), &()) .await @@ -244,8 +92,10 @@ impl<'a> FineTunes<'a> { pub async fn retrieve_content_stream( &self, fine_tune_id: &str, - ) -> Result> + Send>>, OpenAIError> - { + ) -> Result< + Pin> + Send>>, + OpenAIError, + > { Ok(self .openai .get_stream( @@ -260,7 +110,10 @@ impl<'a> FineTunes<'a> { /// # Path parameters /// /// - `model` - The model to delete - pub async fn delete_model(&self, model: &str) -> OpenAIResponse { + pub async fn delete_model( + &self, + model: &str, + ) -> OpenAIResponse { self.openai.delete(&format!("/models/{model}"), &()).await } } diff --git a/rs_openai/src/apis/images.rs b/rs_openai/src/apis/images.rs index 770e781..36f3ef7 100644 --- a/rs_openai/src/apis/images.rs +++ b/rs_openai/src/apis/images.rs @@ -3,145 +3,9 @@ //! Related guide: [Image generation](https://platform.openai.com/docs/guides/images) use crate::client::OpenAI; -use crate::shared::response_wrapper::{OpenAIError, OpenAIResponse}; -use crate::shared::types::FileMeta; -use derive_builder::Builder; +use crate::interfaces::images; +use crate::shared::response_wrapper::OpenAIResponse; use reqwest::multipart::Form; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Default, Clone, strum::Display)] -#[serde(rename_all = "snake_case")] -pub enum ResponseFormat { - #[default] - #[strum(serialize = "url")] - Url, - - #[strum(serialize = "b64_json")] - #[serde(rename = "b64_json")] - B64Json, -} - -#[derive(Default, Debug, Serialize, Clone, strum::Display)] -pub enum ImageSize { - #[strum(serialize = "256x256")] - #[serde(rename = "256x256")] - S256x256, - - #[strum(serialize = "512x512")] - #[serde(rename = "256x256")] - S512x512, - - #[default] - #[strum(serialize = "1024x1024")] - #[serde(rename = "256x256")] - S1024x1024, -} - -#[derive(Builder, Clone, Debug, Default, Serialize)] -#[builder(name = "CreateImageRequestBuilder")] -#[builder(pattern = "mutable")] -#[builder(setter(into, strip_option), default)] -#[builder(derive(Debug))] -#[builder(build_fn(error = "OpenAIError"))] -pub struct CreateImageRequest { - /// A text description of the desired image(s). The maximum length is 1000 characters. - pub prompt: String, - - /// The number of images to generate. Must be between 1 and 10. - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, // default: 1, min: 1, max: 10 - - /// The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`. - #[serde(skip_serializing_if = "Option::is_none")] - pub size: Option, // default: "1024x1024" - - /// The format in which the generated images are returned. Must be one of `url` or `b64_json`. - #[serde(skip_serializing_if = "Option::is_none")] - pub response_format: Option, // default: "url" - - /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://beta.openai.com/docs/api-reference/authentication) - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, -} - -#[derive(Builder, Clone, Debug, Default, Serialize)] -#[builder(name = "CreateImageEditRequestBuilder")] -#[builder(pattern = "mutable")] -#[builder(setter(into, strip_option), default)] -#[builder(derive(Debug))] -#[builder(build_fn(error = "OpenAIError"))] -pub struct CreateImageEditRequest { - /// The image to edit. Must be a valid PNG file, less than 4MB, and square. - /// If mask is not provided, image must have transparency, which will be used as the mask. - pub image: FileMeta, - - /// A text description of the desired image(s). The maximum length is 1000 characters. - pub prompt: String, - - /// An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where `image` should be edited. - /// Must be a valid PNG file, less than 4MB, and have the same dimensions as `image`. - #[serde(skip_serializing_if = "Option::is_none")] - pub mask: Option, - - /// The number of images to generate. Must be between 1 and 10. - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, // default: 1, min: 1, max: 10 - - /// The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`. - #[serde(skip_serializing_if = "Option::is_none")] - pub size: Option, // default: "1024x1024" - - /// The format in which the generated images are returned. Must be one of `url` or `b64_json`. - #[serde(skip_serializing_if = "Option::is_none")] - pub response_format: Option, // default: "url" - - /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. - /// [Learn more](https://beta.openai.com/docs/api-reference/authentication) - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, -} - -#[derive(Builder, Clone, Debug, Default, Serialize)] -#[builder(name = "CreateImageVariationRequestBuilder")] -#[builder(pattern = "mutable")] -#[builder(setter(into, strip_option), default)] -#[builder(derive(Debug))] -#[builder(build_fn(error = "OpenAIError"))] -pub struct CreateImageVariationRequest { - /// The image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square. - pub image: FileMeta, - - /// The number of images to generate. Must be between 1 and 10. - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, // default: 1, min: 1, max: 10 - - /// The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`. - #[serde(skip_serializing_if = "Option::is_none")] - pub size: Option, // default: "1024x1024" - - /// The format in which the generated images are returned. Must be one of `url` or `b64_json`. - #[serde(skip_serializing_if = "Option::is_none")] - pub response_format: Option, // default: "url" - - /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. - /// [Learn more](https://beta.openai.com/docs/api-reference/authentication) - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -#[serde(rename_all = "lowercase")] -pub enum ImageData { - Url(String), - - #[serde(rename = "b64_json")] - B64Json(String), -} -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct ImageResponse { - pub created: i64, - pub data: Vec, -} pub struct Images<'a> { openai: &'a OpenAI, @@ -153,12 +17,18 @@ impl<'a> Images<'a> { } /// Creates an image given a prompt. - pub async fn create(&self, req: &CreateImageRequest) -> OpenAIResponse { + pub async fn create( + &self, + req: &images::CreateImageRequest, + ) -> OpenAIResponse { self.openai.post("/images/generations", req).await } /// Creates an edited or extended image given an original image and a prompt. - pub async fn create_edit(&self, req: &CreateImageEditRequest) -> OpenAIResponse { + pub async fn create_edit( + &self, + req: &images::CreateImageEditRequest, + ) -> OpenAIResponse { let file_part = reqwest::multipart::Part::stream(req.image.buffer.clone()) .file_name(req.image.filename.clone()) .mime_str("application/octet-stream") @@ -199,8 +69,8 @@ impl<'a> Images<'a> { /// Creates a variation of a given image. pub async fn create_variations( &self, - req: &CreateImageVariationRequest, - ) -> OpenAIResponse { + req: &images::CreateImageVariationRequest, + ) -> OpenAIResponse { let file_part = reqwest::multipart::Part::stream(req.image.buffer.clone()) .file_name(req.image.filename.clone()) .mime_str("application/octet-stream") diff --git a/rs_openai/src/apis/models.rs b/rs_openai/src/apis/models.rs index 957be14..85f36c9 100644 --- a/rs_openai/src/apis/models.rs +++ b/rs_openai/src/apis/models.rs @@ -2,41 +2,8 @@ //! You can refer to the [Models](https://platform.openai.com/docs/models/overview) documentation to understand what models are available and the differences between them. use crate::client::OpenAI; +use crate::interfaces::models; use crate::shared::response_wrapper::OpenAIResponse; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct ModelPermission { - pub id: String, - pub object: String, - pub created: u32, - pub allow_create_engine: bool, - pub allow_sampling: bool, - pub allow_logprobs: bool, - pub allow_search_indices: bool, - pub allow_view: bool, - pub allow_fine_tuning: bool, - pub organization: String, - pub group: Option, - pub is_blocking: bool, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct ModelResponse { - pub id: String, - pub object: String, - pub created: u32, - pub owned_by: String, - pub permission: Vec, - pub root: String, - pub parent: Option, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct ListModelResponse { - pub object: String, - pub data: Vec, -} pub struct Models<'a> { openai: &'a OpenAI, @@ -52,12 +19,12 @@ impl<'a> Models<'a> { /// # Path parameters /// /// - `model` - The ID of the model to use for this request. - pub async fn retrieve(&self, model: &str) -> OpenAIResponse { + pub async fn retrieve(&self, model: &str) -> OpenAIResponse { self.openai.get(&format!("/models/{model}"), &()).await } /// Lists the currently available models, and provides basic information about each one such as the owner and availability. - pub async fn list(&self) -> OpenAIResponse { + pub async fn list(&self) -> OpenAIResponse { self.openai.get("/models", &()).await } } diff --git a/rs_openai/src/apis/moderations.rs b/rs_openai/src/apis/moderations.rs index 2c9b775..9645bff 100644 --- a/rs_openai/src/apis/moderations.rs +++ b/rs_openai/src/apis/moderations.rs @@ -3,89 +3,8 @@ //! Related guide: [Moderations](https://platform.openai.com/docs/guides/moderation) use crate::client::OpenAI; -use crate::shared::response_wrapper::{OpenAIError, OpenAIResponse}; -use derive_builder::Builder; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Clone)] -#[serde(untagged)] -pub enum ModerationInput { - String(String), - ArrayOfString(Vec), -} - -#[derive(Debug, Serialize, Default, Clone)] -pub enum ModerationModel { - #[default] - #[serde(rename = "text-moderation-latest")] - Latest, - #[serde(rename = "text-moderation-stable")] - Stable, -} - -#[derive(Builder, Clone, Debug, Default, Serialize)] -#[builder(name = "CreateModerationRequestBuilder")] -#[builder(pattern = "mutable")] -#[builder(setter(into, strip_option), default)] -#[builder(derive(Debug))] -#[builder(build_fn(error = "OpenAIError"))] -pub struct CreateModerationRequest { - /// The input text to classify. - pub input: ModerationInput, - - /// Two content moderations models are available: `text-moderation-stable` and `text-moderation-latest`. - /// - /// The default is `text-moderation-latest` which will be automatically upgraded over time. - /// This ensures you are always using our most accurate model. - /// If you use `text-moderation-stable`, we will provide advanced notice before updating the model. - /// Accuracy of `text-moderation-stable` may be slightly lower than for `text-moderation-latest`. - #[serde(skip_serializing_if = "Option::is_none")] - pub model: Option, // default: "text-moderation-latest" -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct ModerationResponse { - pub id: String, - pub model: String, - pub results: Vec, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct ModerationCategory { - pub categories: ModerationCategories, - pub category_scores: ModerationCategoryScores, - pub flagged: bool, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct ModerationCategories { - pub sexual: bool, - pub hate: bool, - pub violence: bool, - #[serde(rename = "self-harm")] - pub self_harm: bool, - #[serde(rename = "sexual/minors")] - pub sexual_minors: bool, - #[serde(rename = "hate/threatening")] - pub hate_threatening: bool, - #[serde(rename = "violence/graphic")] - pub violence_graphic: bool, -} - -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct ModerationCategoryScores { - pub sexual: f32, - pub hate: f32, - pub violence: f32, - #[serde(rename = "self-harm")] - pub self_harm: f32, - #[serde(rename = "sexual/minors")] - pub sexual_minors: f32, - #[serde(rename = "hate/threatening")] - pub hate_threatening: f32, - #[serde(rename = "violence/graphic")] - pub violence_graphic: f32, -} +use crate::interfaces::moderations; +use crate::shared::response_wrapper::OpenAIResponse; pub struct Moderations<'a> { openai: &'a OpenAI, @@ -99,8 +18,8 @@ impl<'a> Moderations<'a> { /// Classifies if text violates OpenAI's Content Policy. pub async fn create( &self, - req: &CreateModerationRequest, - ) -> OpenAIResponse { + req: &moderations::CreateModerationRequest, + ) -> OpenAIResponse { self.openai.post("/moderations", req).await } } diff --git a/rs_openai/src/client.rs b/rs_openai/src/client.rs index ef80710..5003701 100644 --- a/rs_openai/src/client.rs +++ b/rs_openai/src/client.rs @@ -7,6 +7,8 @@ use futures::{stream::StreamExt, Stream}; use reqwest::{header::HeaderMap, multipart::Form, Client, Method, RequestBuilder}; use reqwest_eventsource::{Event, EventSource, RequestBuilderExt}; use serde::{de::DeserializeOwned, Serialize}; +use std::fs::File; +use std::io::{self}; use std::{fmt::Debug, pin::Pin}; // Default v1 API base url @@ -89,6 +91,24 @@ impl OpenAI { Ok(text) } + async fn resolve_file_response(request: RequestBuilder, filename: &str) -> OpenAIResponse<()> { + let response = request.send().await?; + let status = response.status(); + let text = response.text().await?; + + if !status.is_success() { + let api_error: ApiErrorResponse = + serde_json::from_slice(text.as_ref()).map_err(OpenAIError::JSONDeserialize)?; + + return Err(OpenAIError::ApiError(api_error)); + } + + let mut file = File::create(filename).expect("failed to create file"); + io::copy(&mut text.as_bytes(), &mut file).expect("failed to copy content"); + + Ok(()) + } + pub(crate) async fn get(&self, route: &str, query: &F) -> OpenAIResponse where T: DeserializeOwned + Debug, @@ -142,6 +162,19 @@ impl OpenAI { Self::resolve_text_response(request).await } + pub(crate) async fn post_with_file_response( + &self, + route: &str, + json: &T, + filename: &str, + ) -> OpenAIResponse<()> + where + T: Serialize, + { + let request = self.openai_request(Method::POST, route, |request| request.json(json)); + Self::resolve_file_response(request, filename).await + } + pub(crate) async fn post_stream( &self, route: &str, diff --git a/rs_openai/src/interfaces/audio.rs b/rs_openai/src/interfaces/audio.rs new file mode 100644 index 0000000..2f02913 --- /dev/null +++ b/rs_openai/src/interfaces/audio.rs @@ -0,0 +1,383 @@ +use crate::shared::response_wrapper::OpenAIError; +use crate::shared::types::FileMeta; +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Default, Clone, strum::Display)] +pub enum SttResponseFormat { + #[default] + #[strum(serialize = "json")] + Json, + #[strum(serialize = "text")] + Text, + #[strum(serialize = "srt")] + Srt, + #[strum(serialize = "verbose_json")] + VerboseJson, + #[strum(serialize = "vtt")] + Vtt, +} + +#[derive(Debug, Serialize, Default, Clone, strum::Display)] +pub enum TtsResponseFormat { + #[default] + #[strum(serialize = "mp3")] + Mp3, + #[strum(serialize = "opus")] + Opus, + #[strum(serialize = "aac")] + Aac, + #[strum(serialize = "flac")] + Flac, + #[strum(serialize = "wav")] + Wav, + #[strum(serialize = "pcm")] + Pcm, +} + +#[derive(Debug, Serialize, Default, Clone, strum::Display)] +pub enum Language { + #[default] + #[strum(serialize = "en")] + English, + #[strum(serialize = "zh")] + Chinese, + #[strum(serialize = "de")] + German, + #[strum(serialize = "es")] + Spanish, + #[strum(serialize = "ru")] + Russian, + #[strum(serialize = "ko")] + Korean, + #[strum(serialize = "fr")] + French, + #[strum(serialize = "ja")] + Japanese, + #[strum(serialize = "pt")] + Portuguese, + #[strum(serialize = "tr")] + Turkish, + #[strum(serialize = "pl")] + Polish, + #[strum(serialize = "ca")] + Catalan, + #[strum(serialize = "nl")] + Dutch, + #[strum(serialize = "ar")] + Arabic, + #[strum(serialize = "sv")] + Swedish, + #[strum(serialize = "it")] + Italian, + #[strum(serialize = "id")] + Indonesian, + #[strum(serialize = "hi")] + Hindi, + #[strum(serialize = "fi")] + Finnish, + #[strum(serialize = "vi")] + Vietnamese, + #[strum(serialize = "he")] + Hebrew, + #[strum(serialize = "uk")] + Ukrainian, + #[strum(serialize = "el")] + Greek, + #[strum(serialize = "ms")] + Malay, + #[strum(serialize = "cs")] + Czech, + #[strum(serialize = "ro")] + Romanian, + #[strum(serialize = "da")] + Danish, + #[strum(serialize = "hu")] + Hungarian, + #[strum(serialize = "ta")] + Tamil, + #[strum(serialize = "no")] + Norwegian, + #[strum(serialize = "th")] + Thai, + #[strum(serialize = "ur")] + Urdu, + #[strum(serialize = "hr")] + Croatian, + #[strum(serialize = "bg")] + Bulgarian, + #[strum(serialize = "lt")] + Lithuanian, + #[strum(serialize = "la")] + Latin, + #[strum(serialize = "mi")] + Maori, + #[strum(serialize = "ml")] + Malayalam, + #[strum(serialize = "cy")] + Welsh, + #[strum(serialize = "sk")] + Slovak, + #[strum(serialize = "te")] + Telugu, + #[strum(serialize = "fa")] + Persian, + #[strum(serialize = "lv")] + Latvian, + #[strum(serialize = "bn")] + Bengali, + #[strum(serialize = "sr")] + Serbian, + #[strum(serialize = "az")] + Azerbaijani, + #[strum(serialize = "sl")] + Slovenian, + #[strum(serialize = "kn")] + Kannada, + #[strum(serialize = "et")] + Estonian, + #[strum(serialize = "mk")] + Macedonian, + #[strum(serialize = "br")] + Breton, + #[strum(serialize = "eu")] + Basque, + #[strum(serialize = "is")] + Icelandic, + #[strum(serialize = "hy")] + Armenian, + #[strum(serialize = "ne")] + Nepali, + #[strum(serialize = "mn")] + Mongolian, + #[strum(serialize = "bs")] + Bosnian, + #[strum(serialize = "kk")] + Kazakh, + #[strum(serialize = "sq")] + Albanian, + #[strum(serialize = "sw")] + Swahili, + #[strum(serialize = "gl")] + Galician, + #[strum(serialize = "mr")] + Marathi, + #[strum(serialize = "pa")] + Punjabi, + #[strum(serialize = "si")] + Sinhala, + #[strum(serialize = "km")] + Khmer, + #[strum(serialize = "sn")] + Shona, + #[strum(serialize = "yo")] + Yoruba, + #[strum(serialize = "so")] + Somali, + #[strum(serialize = "af")] + Afrikaans, + #[strum(serialize = "oc")] + Occitan, + #[strum(serialize = "ka")] + Georgian, + #[strum(serialize = "be")] + Belarusian, + #[strum(serialize = "tg")] + Tajik, + #[strum(serialize = "sd")] + Sindhi, + #[strum(serialize = "gu")] + Gujarati, + #[strum(serialize = "am")] + Amharic, + #[strum(serialize = "yi")] + Yiddish, + #[strum(serialize = "lo")] + Lao, + #[strum(serialize = "uz")] + Uzbek, + #[strum(serialize = "fo")] + Faroese, + #[strum(serialize = "ht")] + HaitianCreole, + #[strum(serialize = "ps")] + Pashto, + #[strum(serialize = "tk")] + Turkmen, + #[strum(serialize = "nn")] + Nynorsk, + #[strum(serialize = "mt")] + Maltese, + #[strum(serialize = "sa")] + Sanskrit, + #[strum(serialize = "lb")] + Luxembourgish, + #[strum(serialize = "my")] + Myanmar, + #[strum(serialize = "bo")] + Tibetan, + #[strum(serialize = "tl")] + Tagalog, + #[strum(serialize = "mg")] + Malagasy, + #[strum(serialize = "as")] + Assamese, + #[strum(serialize = "tt")] + Tatar, + #[strum(serialize = "haw")] + Hawaiian, + #[strum(serialize = "ln")] + Lingala, + #[strum(serialize = "ha")] + Hausa, + #[strum(serialize = "ba")] + Bashkir, + #[strum(serialize = "jw")] + Javanese, + #[strum(serialize = "su")] + Sundanese, +} + +#[derive(Debug, Serialize, Default, Clone, strum::Display)] +pub enum Voice { + #[default] + #[strum(serialize = "alloy")] + Alloy, + #[strum(serialize = "echo")] + Echo, + #[strum(serialize = "fable")] + Fable, + #[strum(serialize = "onyx")] + Onyx, + #[strum(serialize = "nova")] + Nova, + #[strum(serialize = "shimmer")] + Shimmer, +} + +#[derive(Debug, Serialize, Default, Clone, strum::Display)] +pub enum SttModel { + #[default] + #[strum(serialize = "whisper-1")] + Whisper1, +} + +#[derive(Debug, Serialize, Default, Clone, strum::Display)] +pub enum AudioSpeechModel { + #[default] + #[strum(serialize = "tts-1")] + Whisper1, + #[strum(serialize = "tts-1-hd")] + Whisper1Hd, +} + +#[derive(Builder, Clone, Debug, Default, Serialize)] +#[builder(name = "CreateSpeechRequestBuilder")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateSpeechRequest { + /// One of the available [TTS models](https://platform.openai.com/docs/models/tts): `tts-1` or `tts-1-hd` + pub model: AudioSpeechModel, + + /// The text to generate audio for. The maximum length is 4096 characters. + pub input: String, + + /// The voice to use when generating the audio. Supported voices are `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`. + /// Previews of the voices are available in the [Text to speech guide](https://platform.openai.com/docs/guides/text-to-speech/voice-options). + pub voice: Voice, + + /// The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`. + /// #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, // default: mp3 + + /// The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default. + #[serde(skip_serializing_if = "Option::is_none")] + pub speed: Option, // min: 0.25, max: 4.0, default: 1.0 +} + +#[derive(Builder, Clone, Debug, Default, Serialize)] +#[builder(name = "CreateTranscriptionRequestBuilder")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateTranscriptionRequest { + /// The audio file to transcribe, in one of these formats: mp3, mp4, mpeg, mpga, m4a, wav, or webm. + pub file: FileMeta, + + /// ID of the model to use. Only `whisper-1` is currently available. + pub model: SttModel, + + /// An optional text to guide the model's style or continue a previous audio segment. + /// The [prompt](https://platform.openai.com/docs/guides/speech-to-text/prompting) should match the audio language. + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, + + /// The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt. + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, // default: "json" + + /// The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, + /// while lower values like 0.2 will make it more focused and deterministic. + /// If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, // min: 0, max: 1, default: 0 + + /// The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency. + #[serde(skip_serializing_if = "Option::is_none")] + pub language: Option, +} + +#[derive(Builder, Clone, Debug, Default, Serialize)] +#[builder(name = "CreateTranslationRequestBuilder")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateTranslationRequest { + /// The audio file to transcribe, in one of these formats: mp3, mp4, mpeg, mpga, m4a, wav, or webm. + pub file: FileMeta, + + /// ID of the model to use. Only `whisper-1` is currently available. + pub model: SttModel, + + /// An optional text to guide the model's style or continue a previous audio segment. + /// The [prompt](https://platform.openai.com/docs/guides/speech-to-text/prompting) should be in English. + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, + + /// The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt. + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, // default: json + + /// The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, + /// while lower values like 0.2 will make it more focused and deterministic. + /// If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, // min: 0, max: 1, default: 0 +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct VerboseJsonForAudioResponse { + pub task: Option, + pub language: Option, + pub duration: Option, + pub segments: Option>, + pub text: String, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct Segment { + pub id: u32, + pub seek: u32, + pub start: f32, + pub end: f32, + pub text: String, + pub tokens: Vec, + pub temperature: f32, + pub avg_logprob: f32, + pub compression_ratio: f32, + pub no_speech_prob: f32, +} diff --git a/rs_openai/src/interfaces/chat.rs b/rs_openai/src/interfaces/chat.rs new file mode 100644 index 0000000..6f0e6be --- /dev/null +++ b/rs_openai/src/interfaces/chat.rs @@ -0,0 +1,162 @@ +use crate::shared::response_wrapper::OpenAIError; +use crate::shared::types::Stop; +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Serialize, Deserialize, Clone, Default, strum::Display)] +#[serde(rename_all = "lowercase")] +pub enum Role { + #[strum(serialize = "system")] + System, + #[default] + #[strum(serialize = "user")] + User, + #[strum(serialize = "assistant")] + Assistant, +} + +#[derive(Builder, Default, Debug, Clone, Deserialize, Serialize)] +#[builder(name = "ChatCompletionMessageRequestBuilder")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ChatCompletionMessage { + /// The role of the author of this message. One of `system`, `user`, or `assistant`. + pub role: Role, + + /// The contents of the message. + pub content: String, + + /// The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters. + pub name: Option, +} + +#[derive(Builder, Clone, Debug, Default, Serialize)] +#[builder(name = "CreateChatRequestBuilder")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateChatRequest { + /// ID of the model to use. + /// See the [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API. + pub model: String, + + /// A list of messages describing the conversation so far. + pub messages: Vec, + + /// What sampling temperature to use, between 0 and 2. + /// Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + /// + /// We generally recommend altering this or `top_p` but not both. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, // min: 0, max: 2, default: 1 + + /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. + /// So 0.1 means only the tokens comprising the top 10% probability mass are considered. + /// + /// We generally recommend altering this or `temperature` but not both. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, // default: 1 + + /// How many chat completion choices to generate for each input message. + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, // default: 1 + + /// If set, partial message deltas will be sent, like in ChatGPT. + /// Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. + /// See the OpenAI Cookbook for [example code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb). + /// + /// For streamed progress, use [`create_with_stream`](Chat::create_with_stream). + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, // default: false + + /// Up to 4 sequences where the API will stop generating further tokens. + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, // default: null + + /// The maximum number of tokens to generate in the chat completion. + /// + /// The total length of input tokens and generated tokens is limited by the model's context length. + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + + /// Number between -2.0 and 2.0. + /// Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + /// + /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, // min: -2.0, max: 2.0, default: 0 + + /// Number between -2.0 and 2.0. + /// Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + /// + /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, // min: -2.0, max: 2.0, default: 0 + + /// Modify the likelihood of specified tokens appearing in the completion. + /// + /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. + /// Mathematically, the bias is added to the logits generated by the model prior to sampling. + /// The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; + /// values like -100 or 100 should result in a ban or exclusive selection of the relevant token. + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, // default: null + + /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct Message { + pub role: String, + pub content: String, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct ChatUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct ChatChoice { + pub message: ChatCompletionMessage, + pub finish_reason: String, + pub index: u32, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct ChatResponse { + pub id: String, + pub object: String, + pub created: u32, + pub choices: Vec, + pub usage: ChatUsage, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct Delta { + pub content: Option, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct ChatChoiceStream { + pub delta: Delta, + pub finish_reason: Option, + pub index: u32, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct ChatStreamResponse { + pub id: String, + pub object: String, + pub model: String, + pub created: u32, + pub choices: Vec, +} diff --git a/rs_openai/src/interfaces/completions.rs b/rs_openai/src/interfaces/completions.rs new file mode 100644 index 0000000..a3a5966 --- /dev/null +++ b/rs_openai/src/interfaces/completions.rs @@ -0,0 +1,169 @@ +use crate::shared::response_wrapper::OpenAIError; +use crate::shared::types::Stop; +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Clone, Serialize)] +#[serde(untagged)] +pub enum Prompt { + String(String), + ArrayOfString(Vec), + ArrayOfTokens(Vec), + ArrayOfTokenArrays(Vec>), +} + +#[derive(Builder, Clone, Debug, Default, Serialize)] +#[builder(name = "CreateCompletionRequestBuilder")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateCompletionRequest { + /// ID of the model to use. + /// You can use the [List models](https://platform.openai.com/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](https://platform.openai.com/docs/models/overview) for descriptions of them. + pub model: String, + + /// The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays. + /// + /// Note that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document. + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, // default: <|endoftext|> + + /// The suffix that comes after a completion of inserted text. + #[serde(skip_serializing_if = "Option::is_none")] + pub suffix: Option, // default: null + + /// The maximum number of [token](https://platform.openai.com/tokenizer) to generate in the completion + /// + /// The token count of your prompt plus `max_tokens` cannot exceed the model's context length. + /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096). + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, // default: 16 + + /// What sampling temperature to use, between 0 and 2. + /// Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + /// + /// We generally recommend altering this or `top_p` but not both. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, // min: 0, max: 2, default: 1 + + /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. + /// So 0.1 means only the tokens comprising the top 10% probability mass are considered. + /// + /// We generally recommend altering this or `temperature` but not both. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, // default: 1 + + /// How many completions to generate for each prompt + /// + /// **Note**: Because this parameter generates many completions, it can quickly consume your token quota. + /// Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`. + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, // default: 1 + + /// Whether to stream back partial progress. + /// If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. + /// + /// For streamed progress, use [`create_with_stream`](Completions::create_with_stream). + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, // default: false + + /// Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. + /// For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. + /// The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response. + /// + /// The maximum value for `logprobs` is 5. If you need more than this, please contact us through our [Help center](https://help.openai.com/) and describe your use case. + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, // default: null + + /// Echo back the prompt in addition to the completion. + #[serde(skip_serializing_if = "Option::is_none")] + pub echo: Option, // default: false + + /// Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence. + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, // default: null + + /// Number between -2.0 and 2.0. + /// Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + /// + /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, // min: -2.0, max: 2.0, default: 0 + + /// Number between -2.0 and 2.0. + /// Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + /// + /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, // min: -2.0, max: 2.0, default: 0 + + /// Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token). + /// Results cannot be streamed. + /// + /// When used with `n`, `best_of` controls the number of candidate completions and n specifies how many to return – `best_of` must be greater than `n`. + /// + /// **Note**: Because this parameter generates many completions, it can quickly consume your token quota. + /// Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`. + #[serde(skip_serializing_if = "Option::is_none")] + pub best_of: Option, // default: 1 + + /// Modify the likelihood of specified tokens appearing in the completion. + /// + /// Accepts a json object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. + /// You can use this [tokenizer](https://platform.openai.com/tokenizer?view=bpe) tool (which works for both GPT-2 and GPT-3) to convert text to token IDs. + /// Mathematically, the bias is added to the logits generated by the model prior to sampling. + /// The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; + /// values like -100 or 100 should result in a ban or exclusive selection of the relevant token. + /// + /// As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated. + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, // default: null + + /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct CompletionResponse { + pub id: String, + pub object: String, + pub created: u32, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct CompletionChoice { + pub text: String, + pub index: u32, + pub logprobs: Option, + pub finish_reason: String, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct CompletionChoiceStream { + pub text: String, + pub index: usize, + pub logprobs: Option, + pub finish_reason: Option, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct CompletionStreamResponse { + pub id: String, + pub object: String, + pub created: u32, + pub model: String, + pub choices: Vec, +} diff --git a/rs_openai/src/interfaces/edits.rs b/rs_openai/src/interfaces/edits.rs new file mode 100644 index 0000000..969e759 --- /dev/null +++ b/rs_openai/src/interfaces/edits.rs @@ -0,0 +1,60 @@ +use crate::shared::response_wrapper::OpenAIError; +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +#[derive(Builder, Clone, Debug, Default, Serialize)] +#[builder(name = "CreateEditRequestBuilder")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateEditRequest { + /// ID of the model to use. You can use the `text-davinci-edit-001` or `code-davinci-edit-001` model with this endpoint. + pub model: String, + + /// The input text to use as a starting point for the edit. + #[serde(skip_serializing_if = "Option::is_none")] + pub input: Option, + + /// The instruction that tells the model how to edit the prompt. + pub instruction: String, + + /// How many edits to generate for the input and instruction. + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, // default: 1 + + /// What sampling temperature to use, between 0 and 2. + /// Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + /// + /// We generally recommend altering this or `top_p` but not both. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, // min: 0, max: 2, default: 1 + + /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. + /// So 0.1 means only the tokens comprising the top 10% probability mass are considered. + /// + /// We generally recommend altering this or `temperature` but not both. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, // default: 1 +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct EditResponse { + pub object: String, + pub created: u32, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct Choice { + pub text: String, + pub index: u32, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} diff --git a/rs_openai/src/interfaces/embeddings.rs b/rs_openai/src/interfaces/embeddings.rs new file mode 100644 index 0000000..fa75c0f --- /dev/null +++ b/rs_openai/src/interfaces/embeddings.rs @@ -0,0 +1,57 @@ +//! Get a vector representation of a given input that can be easily consumed by machine learning models and algorithms. +//! +//! Related guide: [Embeddings](https://platform.openai.com/docs/guides/embeddings) + +use crate::shared::response_wrapper::OpenAIError; +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Clone)] +#[serde(untagged)] +pub enum EmbeddingInput { + String(String), + ArrayOfString(Vec), + ArrayOfTokens(Vec), + ArrayOfTokenArrays(Vec>), +} + +#[derive(Builder, Clone, Debug, Default, Serialize)] +#[builder(name = "CreateEmbeddingRequestBuilder")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateEmbeddingRequest { + /// ID of the model to use. + /// Use the [List models](https://platform.openai.com/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](https://platform.openai.com/docs/models/overview) for descriptions of them. + pub model: String, + + /// Input text to get embeddings for, encoded as a string or array of tokens. + /// To get embeddings for multiple inputs in a single request, pass an array of strings or array of token arrays. Each input must not exceed 8192 tokens in length. + pub input: EmbeddingInput, + + /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct EmbeddingResponse { + pub object: String, + pub data: Vec, + pub model: String, + pub usage: Usage, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct EmbeddingData { + pub object: String, + pub embedding: Vec, + pub index: u32, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub total_tokens: u32, +} diff --git a/rs_openai/src/interfaces/engines.rs b/rs_openai/src/interfaces/engines.rs new file mode 100644 index 0000000..646b73e --- /dev/null +++ b/rs_openai/src/interfaces/engines.rs @@ -0,0 +1,15 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct EngineResponse { + pub id: String, + pub object: String, + pub owner: String, + pub ready: bool, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct EngineListResponse { + pub data: Vec, + pub object: String, +} diff --git a/rs_openai/src/interfaces/files.rs b/rs_openai/src/interfaces/files.rs new file mode 100644 index 0000000..5e0c54e --- /dev/null +++ b/rs_openai/src/interfaces/files.rs @@ -0,0 +1,46 @@ +use crate::shared::response_wrapper::OpenAIError; +use crate::shared::types::FileMeta; +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +#[derive(Builder, Clone, Debug, Default, Serialize)] +#[builder(name = "UploadFileRequestBuilder")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct UploadFileRequest { + /// Name of the [JSON Lines](https://jsonlines.readthedocs.io/en/latest/) file to be uploaded. + /// + /// If the `purpose` is set to "fine-tune", each line is a JSON record with "prompt" and "completion" fields representing your [training examples](https://platform.openai.com/docs/guides/fine-tuning/prepare-training-data). + pub file: FileMeta, + + /// The intended purpose of the uploaded documents. + /// + /// Use "fine-tune" for [Fine-tuning](https://platform.openai.com/docs/api-reference/fine-tunes). + /// This allows us to validate the format of the uploaded file. + pub purpose: String, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct FileResponse { + pub id: String, + pub object: String, + pub bytes: u64, + pub created_at: u32, + pub filename: String, + pub purpose: String, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct FileListResponse { + pub data: Vec, + pub object: String, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct DeleteFileResponse { + pub id: String, + pub object: String, + pub deleted: bool, +} diff --git a/rs_openai/src/interfaces/fine_tunes.rs b/rs_openai/src/interfaces/fine_tunes.rs new file mode 100644 index 0000000..488bab1 --- /dev/null +++ b/rs_openai/src/interfaces/fine_tunes.rs @@ -0,0 +1,163 @@ +use crate::shared::response_wrapper::OpenAIError; +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +#[derive(Builder, Clone, Debug, Default, Serialize)] +#[builder(name = "CreateFineTuneRequestBuilder")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateFineTuneRequest { + /// The ID of an uploaded file that contains training data. + /// + /// See [upload file](https://platform.openai.com/docs/api-reference/files/upload) for how to upload a file. + /// + /// + /// Your dataset must be formatted as a JSONL file, where each training example is a JSON object with the keys "prompt" and "completion". + /// Additionally, you must upload your file with the purpose `fine-tune`. + /// + /// + /// See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning/creating-training-data) for more details. + pub training_file: String, + + /// The ID of an uploaded file that contains validation data. + /// + /// If you provide this file, the data is used to generate validation metrics periodically during fine-tuning. + /// These metrics can be viewed in the [fine-tuning results file](https://platform.openai.com/docs/guides/fine-tuning/analyzing-your-fine-tuned-model). + /// Your train and validation data should be mutually exclusive. + /// + /// Your dataset must be formatted as a JSONL file, where each validation example is a JSON object with the keys "prompt" and "completion". + /// Additionally, you must upload your file with the purpose `fine-tune`. + /// + /// See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning/creating-training-data) for more details. + #[serde(skip_serializing_if = "Option::is_none")] + pub validation_file: Option, + + /// The name of the base model to fine-tune. + /// You can select one of "ada", "babbage", "curie", "davinci", or a fine-tuned model created after 2022-04-21. + /// To learn more about these models, see the [Models](https://platform.openai.com/docs/models) documentation. + pub model: Option, + + /// The number of epochs to train the model for. + /// An epoch refers to one full cycle through the training dataset. + #[serde(skip_serializing_if = "Option::is_none")] + pub n_epochs: Option, + + /// The batch size to use for training. + /// The batch size is the number of training examples used to train a single forward and backward pass. + /// + /// By default, the batch size will be dynamically configured to be ~0.2% of the number of examples in the training set, capped at 256. + /// In general, we've found that larger batch sizes tend to work better for larger datasets. + #[serde(skip_serializing_if = "Option::is_none")] + pub batch_size: Option, + + /// The learning rate multiplier to use for training. + /// The fine-tuning learning rate is the original learning rate used for pretraining multiplied by this value. + /// + /// By default, the learning rate multiplier is 0.05, 0.1, or 0.2 depending on final `batch_size` (larger learning rates tend to perform better with larger batch sizes). + /// We recommend experimenting with values in the range 0.02 to 0.2 to see what produces the best results. + #[serde(skip_serializing_if = "Option::is_none")] + pub learning_rate_multiplier: Option, + + /// The weight to use for loss on the prompt tokens. + /// This controls how much the model tries to learn to generate the prompt (as compared to the completion which always has a weight of 1.0), and can add a stabilizing effect to training when completions are short. + /// + /// If prompts are extremely long (relative to completions), it may make sense to reduce this weight so as to avoid over-prioritizing learning the prompt. + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_loss_weight: Option, + + /// If set, we calculate classification-specific metrics such as accuracy and F-1 score using the validation set at the end of every epoch. + /// These metrics can be viewed in the [results file](https://platform.openai.com/docs/guides/fine-tuning/analyzing-your-fine-tuned-model). + /// + /// In order to compute classification metrics, you must provide a `validation_file`. + /// Additionally, you must specify `classification_n_classes` for multiclass classification or `classification_positive_class` for binary classification. + #[serde(skip_serializing_if = "Option::is_none")] + pub compute_classification_metrics: Option, + + /// The number of classes in a classification task. + /// + /// This parameter is required for multiclass classification. + #[serde(skip_serializing_if = "Option::is_none")] + pub classification_n_classes: Option, + + /// The positive class in binary classification. + /// + /// This parameter is needed to generate precision, recall, and F1 metrics when doing binary classification. + #[serde(skip_serializing_if = "Option::is_none")] + pub classification_positive_class: Option, + + /// If provided, we calculate F-beta scores at the specified beta values. The F-beta score is a generalization of F-1 score. This is only used for binary classification. + /// + /// With a beta of 1 (i.e. the F-1 score), precision and recall are given the same weight. A larger beta score puts more weight on recall and less on precision. A smaller beta score puts more weight on precision and less on recall. + #[serde(skip_serializing_if = "Option::is_none")] + classification_betas: Option>, + + /// A string of up to 40 characters that will be added to your fine-tuned model name. + /// + /// For example, a `suffix` of "custom-model-name" would produce a model name like `ada:ft-your-org:custom-model-name-2022-02-15-04-21-04`. + #[serde(skip_serializing_if = "Option::is_none")] + suffix: Option, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct FineTuneResponse { + pub id: String, + pub object: String, + pub model: String, + pub created_at: u32, + pub events: Option>, + pub fine_tuned_model: Option, + pub hyperparams: HyperParams, + pub organization_id: String, + pub result_files: Vec, + pub status: String, + pub validation_files: Vec, + pub training_files: Vec, + pub updated_at: u32, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct FineTuneEvent { + pub object: String, + pub created_at: u32, + pub level: String, + pub message: String, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct HyperParams { + pub batch_size: u32, + pub learning_rate_multiplier: f32, + pub n_epochs: u32, + pub prompt_loss_weight: f32, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct TrainingFile { + pub id: String, + pub object: String, + pub bytes: u32, + pub created_at: u32, + pub filename: String, + pub purpose: String, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct FineTuneListResponse { + pub object: String, + pub data: Vec, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct EventListResponse { + pub object: String, + pub data: Vec, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct DeleteFileResponse { + pub id: String, + pub object: String, + pub deleted: bool, +} diff --git a/rs_openai/src/interfaces/images.rs b/rs_openai/src/interfaces/images.rs new file mode 100644 index 0000000..cd94b6a --- /dev/null +++ b/rs_openai/src/interfaces/images.rs @@ -0,0 +1,138 @@ +use crate::shared::response_wrapper::OpenAIError; +use crate::shared::types::FileMeta; +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Default, Clone, strum::Display)] +#[serde(rename_all = "snake_case")] +pub enum ResponseFormat { + #[default] + #[strum(serialize = "url")] + Url, + + #[strum(serialize = "b64_json")] + #[serde(rename = "b64_json")] + B64Json, +} + +#[derive(Default, Debug, Serialize, Clone, strum::Display)] +pub enum ImageSize { + #[strum(serialize = "256x256")] + #[serde(rename = "256x256")] + S256x256, + + #[strum(serialize = "512x512")] + #[serde(rename = "256x256")] + S512x512, + + #[default] + #[strum(serialize = "1024x1024")] + #[serde(rename = "256x256")] + S1024x1024, +} + +#[derive(Builder, Clone, Debug, Default, Serialize)] +#[builder(name = "CreateImageRequestBuilder")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateImageRequest { + /// A text description of the desired image(s). The maximum length is 1000 characters. + pub prompt: String, + + /// The number of images to generate. Must be between 1 and 10. + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, // default: 1, min: 1, max: 10 + + /// The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`. + #[serde(skip_serializing_if = "Option::is_none")] + pub size: Option, // default: "1024x1024" + + /// The format in which the generated images are returned. Must be one of `url` or `b64_json`. + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, // default: "url" + + /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://beta.openai.com/docs/api-reference/authentication) + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +#[derive(Builder, Clone, Debug, Default, Serialize)] +#[builder(name = "CreateImageEditRequestBuilder")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateImageEditRequest { + /// The image to edit. Must be a valid PNG file, less than 4MB, and square. + /// If mask is not provided, image must have transparency, which will be used as the mask. + pub image: FileMeta, + + /// A text description of the desired image(s). The maximum length is 1000 characters. + pub prompt: String, + + /// An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where `image` should be edited. + /// Must be a valid PNG file, less than 4MB, and have the same dimensions as `image`. + #[serde(skip_serializing_if = "Option::is_none")] + pub mask: Option, + + /// The number of images to generate. Must be between 1 and 10. + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, // default: 1, min: 1, max: 10 + + /// The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`. + #[serde(skip_serializing_if = "Option::is_none")] + pub size: Option, // default: "1024x1024" + + /// The format in which the generated images are returned. Must be one of `url` or `b64_json`. + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, // default: "url" + + /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + /// [Learn more](https://beta.openai.com/docs/api-reference/authentication) + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +#[derive(Builder, Clone, Debug, Default, Serialize)] +#[builder(name = "CreateImageVariationRequestBuilder")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateImageVariationRequest { + /// The image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square. + pub image: FileMeta, + + /// The number of images to generate. Must be between 1 and 10. + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, // default: 1, min: 1, max: 10 + + /// The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`. + #[serde(skip_serializing_if = "Option::is_none")] + pub size: Option, // default: "1024x1024" + + /// The format in which the generated images are returned. Must be one of `url` or `b64_json`. + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, // default: "url" + + /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + /// [Learn more](https://beta.openai.com/docs/api-reference/authentication) + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum ImageData { + Url(String), + + #[serde(rename = "b64_json")] + B64Json(String), +} +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct ImageResponse { + pub created: i64, + pub data: Vec, +} diff --git a/rs_openai/src/interfaces/mod.rs b/rs_openai/src/interfaces/mod.rs new file mode 100644 index 0000000..545cce8 --- /dev/null +++ b/rs_openai/src/interfaces/mod.rs @@ -0,0 +1,11 @@ +pub mod audio; +pub mod chat; +pub mod completions; +pub mod edits; +pub mod embeddings; +pub mod engines; +pub mod files; +pub mod fine_tunes; +pub mod images; +pub mod models; +pub mod moderations; diff --git a/rs_openai/src/interfaces/models.rs b/rs_openai/src/interfaces/models.rs new file mode 100644 index 0000000..4ec2668 --- /dev/null +++ b/rs_openai/src/interfaces/models.rs @@ -0,0 +1,34 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct ModelPermission { + pub id: String, + pub object: String, + pub created: u32, + pub allow_create_engine: bool, + pub allow_sampling: bool, + pub allow_logprobs: bool, + pub allow_search_indices: bool, + pub allow_view: bool, + pub allow_fine_tuning: bool, + pub organization: String, + pub group: Option, + pub is_blocking: bool, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct ModelResponse { + pub id: String, + pub object: String, + pub created: u32, + pub owned_by: String, + pub permission: Vec, + pub root: String, + pub parent: Option, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct ListModelResponse { + pub object: String, + pub data: Vec, +} diff --git a/rs_openai/src/interfaces/moderations.rs b/rs_openai/src/interfaces/moderations.rs new file mode 100644 index 0000000..3b31fa8 --- /dev/null +++ b/rs_openai/src/interfaces/moderations.rs @@ -0,0 +1,83 @@ +use crate::shared::response_wrapper::OpenAIError; +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Clone)] +#[serde(untagged)] +pub enum ModerationInput { + String(String), + ArrayOfString(Vec), +} + +#[derive(Debug, Serialize, Default, Clone)] +pub enum ModerationModel { + #[default] + #[serde(rename = "text-moderation-latest")] + Latest, + #[serde(rename = "text-moderation-stable")] + Stable, +} + +#[derive(Builder, Clone, Debug, Default, Serialize)] +#[builder(name = "CreateModerationRequestBuilder")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateModerationRequest { + /// The input text to classify. + pub input: ModerationInput, + + /// Two content moderations models are available: `text-moderation-stable` and `text-moderation-latest`. + /// + /// The default is `text-moderation-latest` which will be automatically upgraded over time. + /// This ensures you are always using our most accurate model. + /// If you use `text-moderation-stable`, we will provide advanced notice before updating the model. + /// Accuracy of `text-moderation-stable` may be slightly lower than for `text-moderation-latest`. + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, // default: "text-moderation-latest" +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct ModerationResponse { + pub id: String, + pub model: String, + pub results: Vec, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct ModerationCategory { + pub categories: ModerationCategories, + pub category_scores: ModerationCategoryScores, + pub flagged: bool, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct ModerationCategories { + pub sexual: bool, + pub hate: bool, + pub violence: bool, + #[serde(rename = "self-harm")] + pub self_harm: bool, + #[serde(rename = "sexual/minors")] + pub sexual_minors: bool, + #[serde(rename = "hate/threatening")] + pub hate_threatening: bool, + #[serde(rename = "violence/graphic")] + pub violence_graphic: bool, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct ModerationCategoryScores { + pub sexual: f32, + pub hate: f32, + pub violence: f32, + #[serde(rename = "self-harm")] + pub self_harm: f32, + #[serde(rename = "sexual/minors")] + pub sexual_minors: f32, + #[serde(rename = "hate/threatening")] + pub hate_threatening: f32, + #[serde(rename = "violence/graphic")] + pub violence_graphic: f32, +} diff --git a/rs_openai/src/lib.rs b/rs_openai/src/lib.rs index c2248b1..c4f8254 100644 --- a/rs_openai/src/lib.rs +++ b/rs_openai/src/lib.rs @@ -57,6 +57,7 @@ pub mod apis; pub mod client; pub mod shared; +pub mod interfaces; pub use apis::*; -pub use client::*; \ No newline at end of file +pub use client::*; diff --git a/rs_openai/src/shared/macro.rs b/rs_openai/src/shared/macro.rs index 2c8b247..5f5a7e3 100644 --- a/rs_openai/src/shared/macro.rs +++ b/rs_openai/src/shared/macro.rs @@ -1,6 +1,6 @@ -use crate::apis::completions::Prompt; -use crate::apis::embeddings::EmbeddingInput; -use crate::apis::moderations::ModerationInput; +use crate::interfaces::completions::Prompt; +use crate::interfaces::embeddings::EmbeddingInput; +use crate::interfaces::moderations::ModerationInput; use crate::shared::types::Stop; macro_rules! impl_default {