-
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3f97f86
commit 7640aa4
Showing
4 changed files
with
134 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
//! Create large batches of API requests for asynchronous processing. | ||
//! The Batch API returns completions within 24 hours for a 50% discount. Related guide: [Batch](https://platform.openai.com/docs/guides/batch) | ||
use crate::client::OpenAI; | ||
use crate::interfaces::batch; | ||
use crate::shared::response_wrapper::OpenAIResponse; | ||
|
||
pub struct Batch<'a> { | ||
openai: &'a OpenAI, | ||
} | ||
|
||
impl<'a> Batch<'a> { | ||
pub fn new(openai: &'a OpenAI) -> Self { | ||
Self { openai } | ||
} | ||
|
||
/// Creates and executes a batch from an uploaded file of requests | ||
pub async fn create( | ||
&self, | ||
req: &batch::CreateBatchRequest, | ||
) -> OpenAIResponse<batch::BatchResponse> { | ||
self.openai.post("/batches", req).await | ||
} | ||
|
||
/// Retrieve a batch | ||
pub async fn retrieve(&self, batch_id: String) -> OpenAIResponse<batch::BatchResponse> { | ||
self.openai.get(&format!("/batches/{batch_id}"), &()).await | ||
} | ||
|
||
/// Cancels an in-progress batch. The batch will be in status `cancelling` for up to 10 minutes, before changing to `cancelled`, | ||
/// where it will have partial results (if any) available in the output file. | ||
pub async fn cancel(&self, batch_id: String) -> OpenAIResponse<batch::BatchResponse> { | ||
self.openai | ||
.post(&format!("/batches/{batch_id}/cancel"), &()) | ||
.await | ||
} | ||
|
||
/// A list of paginated [Batch](https://platform.openai.com/docs/api-reference/batch/object) objects. | ||
pub async fn list( | ||
&self, | ||
req: &batch::ListBatchRequest, | ||
) -> OpenAIResponse<batch::BatchResponse> { | ||
self.openai.get("/batches", req).await | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
use crate::shared::response_wrapper::OpenAIError; | ||
use derive_builder::Builder; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
#[derive(Builder, Clone, Debug, Default, Serialize)] | ||
#[builder(name = "CreateBatchRequestBuilder")] | ||
#[builder(pattern = "mutable")] | ||
#[builder(setter(into, strip_option), default)] | ||
#[builder(derive(Debug))] | ||
#[builder(build_fn(error = "OpenAIError"))] | ||
pub struct CreateBatchRequest { | ||
/// The ID of an uploaded file that contains training data. | ||
/// | ||
/// See [upload file](https://platform.openai.com/docs/api-reference/files/upload) for how to upload a file. | ||
/// | ||
/// | ||
/// Your input file must be formatted as a [JSONL file](https://platform.openai.com/docs/api-reference/batch/request-input), | ||
/// and must be uploaded with the purpose `batch`. The file can contain up to 50,000 requests, and can be up to 100 MB in size. | ||
pub input_file_id: String, | ||
|
||
/// The endpoint to be used for all requests in the batch. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported. | ||
/// Note that `/v1/embeddings` batches are also restricted to a maximum of 50,000 embedding inputs across all requests in the batch. | ||
pub endpoint: String, | ||
|
||
/// The time frame within which the batch should be processed. Currently only `24h` is supported. | ||
pub completion_window: String, | ||
|
||
/// Optional custom metadata for the batch. | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
pub meta: Option<serde_json::Value>, | ||
} | ||
|
||
#[derive(Builder, Clone, Debug, Default, Serialize)] | ||
#[builder(name = "ListBatchRequestBuilder")] | ||
#[builder(pattern = "mutable")] | ||
#[builder(setter(into, strip_option), default)] | ||
#[builder(derive(Debug))] | ||
#[builder(build_fn(error = "OpenAIError"))] | ||
pub struct ListBatchRequest { | ||
/// A cursor for use in pagination. `after` is an object ID that defines your place in the list. | ||
/// For instance, if you make a list request and receive 100 objects, ending with obj_foo, | ||
/// your subsequent call can include after=obj_foo in order to fetch the next page of the list. | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
pub after: Option<String>, | ||
|
||
/// A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20. | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
pub limit: Option<String>, | ||
} | ||
|
||
#[derive(Debug, Deserialize, Clone, Serialize)] | ||
pub struct RequestCounts { | ||
total: u32, | ||
completed: u32, | ||
failed: u32, | ||
} | ||
|
||
#[derive(Debug, Deserialize, Clone, Serialize)] | ||
pub struct BatchResponse { | ||
id: String, | ||
object: String, | ||
endpoint: String, | ||
errors: Option<String>, | ||
input_file_id: String, | ||
completion_window: String, | ||
status: String, | ||
output_file_id: Option<String>, | ||
error_file_id: Option<String>, | ||
created_at: i64, | ||
in_progress_at: Option<i64>, | ||
expires_at: Option<i64>, | ||
finalizing_at: Option<i64>, | ||
completed_at: Option<i64>, | ||
failed_at: Option<i64>, | ||
expired_at: Option<i64>, | ||
cancelling_at: Option<i64>, | ||
cancelled_at: Option<i64>, | ||
request_counts: RequestCounts, | ||
metadata: serde_json::Value, | ||
} | ||
|
||
#[derive(Debug, Deserialize, Clone, Serialize)] | ||
pub struct ListBatchResponse { | ||
object: String, | ||
data: Vec<BatchResponse>, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
pub mod audio; | ||
pub mod batch; | ||
pub mod chat; | ||
pub mod completions; | ||
pub mod edits; | ||
|