Skip to content

Commit

Permalink
feat(openai): support streaming (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
nirga authored Nov 24, 2024
1 parent ccf5be5 commit 55823de
Show file tree
Hide file tree
Showing 20 changed files with 555 additions and 408 deletions.
418 changes: 141 additions & 277 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ edition = "2021"
axum = "0.7"
tokio = { version = "1.0", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
reqwest = { version = "0.11", features = ["json"] }
reqwest = { version = "0.12", features = ["json", "stream"] }
serde_json = "1.0"
axum-extra = "0.9.4"
tracing = "0.1"
Expand All @@ -36,3 +36,6 @@ opentelemetry-otlp = { version = "0.27.0", features = [
"reqwest-rustls",
] }
axum-prometheus = "0.7.0"
reqwest-streams = { version = "0.8.1", features = ["json"] }
futures = "0.3.31"
async-stream = "0.3.6"
8 changes: 8 additions & 0 deletions src/config/constants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use std::env;

pub fn stream_buffer_size_bytes() -> usize {
env::var("STREAM_BUFFER_SIZE_BYTES")
.unwrap_or_else(|_| "1000".to_string())
.parse::<usize>()
.unwrap_or(1000)
}
1 change: 1 addition & 0 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod constants;
pub mod lib;
pub mod models;
3 changes: 1 addition & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ async fn main() -> Result<(), anyhow::Error> {

info!("Starting Traceloop Hub...");

let config_path =
std::env::var("CONFIG_FILE_PATH").unwrap_or("/etc/config/default.yaml".to_string());
let config_path = std::env::var("CONFIG_FILE_PATH").unwrap_or("config.yaml".to_string());

info!("Loading configuration from {}", config_path);
let config = load_config(&config_path)
Expand Down
53 changes: 11 additions & 42 deletions src/models/chat.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use futures::stream::BoxStream;
use reqwest_streams::error::StreamBodyError;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

use super::common::Usage;
use super::content::ChatCompletionMessage;
use super::logprob::LogProbs;
use super::streaming::ChatCompletionChunk;
use super::usage::Usage;

#[derive(Deserialize, Serialize, Clone)]
pub struct ChatCompletionRequest {
Expand Down Expand Up @@ -29,30 +34,13 @@ pub struct ChatCompletionRequest {
pub user: Option<String>,
}

#[derive(Deserialize, Serialize, Clone)]
#[serde(untagged)]
pub enum ChatMessageContent {
String(String),
Array(Vec<ChatMessageContentPart>),
pub enum ChatCompletionResponse {
Stream(BoxStream<'static, Result<ChatCompletionChunk, StreamBodyError>>),
NonStream(ChatCompletion),
}

#[derive(Deserialize, Serialize, Clone)]
pub struct ChatMessageContentPart {
#[serde(rename = "type")]
pub r#type: String,
pub text: String,
}

#[derive(Deserialize, Serialize, Clone)]
pub struct ChatCompletionMessage {
pub role: String,
pub content: ChatMessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}

#[derive(Deserialize, Serialize, Clone)]
pub struct ChatCompletionResponse {
pub struct ChatCompletion {
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub object: Option<String>,
Expand All @@ -61,6 +49,7 @@ pub struct ChatCompletionResponse {
pub model: String,
pub choices: Vec<ChatCompletionChoice>,
pub usage: Usage,
pub system_fingerprint: Option<String>,
}

#[derive(Deserialize, Serialize, Clone)]
Expand All @@ -72,23 +61,3 @@ pub struct ChatCompletionChoice {
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogProbs>,
}

#[derive(Deserialize, Serialize, Clone)]
pub struct LogProbs {
pub content: Vec<LogProbContent>,
}

#[derive(Deserialize, Serialize, Clone)]
pub struct LogProbContent {
pub token: String,
pub logprob: f32,
pub bytes: Vec<u8>,
pub top_logprobs: Vec<TopLogProb>,
}

#[derive(Deserialize, Serialize, Clone)]
pub struct TopLogProb {
pub token: String,
pub logprob: f32,
pub bytes: Vec<u8>,
}
8 changes: 0 additions & 8 deletions src/models/common.rs

This file was deleted.

2 changes: 1 addition & 1 deletion src/models/completion.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

use super::common::Usage;
use super::usage::Usage;

#[derive(Deserialize, Serialize, Clone)]
pub struct CompletionRequest {
Expand Down
23 changes: 23 additions & 0 deletions src/models/content.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use serde::{Deserialize, Serialize};

#[derive(Deserialize, Serialize, Clone)]
#[serde(untagged)]
pub enum ChatMessageContent {
String(String),
Array(Vec<ChatMessageContentPart>),
}

#[derive(Deserialize, Serialize, Clone)]
pub struct ChatMessageContentPart {
#[serde(rename = "type")]
pub r#type: String,
pub text: String,
}

#[derive(Deserialize, Serialize, Clone)]
pub struct ChatCompletionMessage {
pub role: String,
pub content: ChatMessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
2 changes: 1 addition & 1 deletion src/models/embeddings.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize};

use super::common::Usage;
use super::usage::Usage;

#[derive(Deserialize, Serialize, Clone)]
pub struct EmbeddingsRequest {
Expand Down
39 changes: 39 additions & 0 deletions src/models/logprob.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use serde::{Deserialize, Serialize};

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct LogProbs {
pub content: Vec<LogProbContent>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct LogProbContent {
pub token: String,
pub logprob: f32,
pub bytes: Vec<u8>,
pub top_logprobs: Vec<TopLogprob>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct TopLogprob {
pub token: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub bytes: Option<Vec<i32>>,
pub logprob: f64,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct ChatCompletionTokenLogprob {
pub token: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub bytes: Option<Vec<i32>>,
pub logprob: f64,
pub top_logprobs: Vec<TopLogprob>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct ChoiceLogprobs {
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<Vec<ChatCompletionTokenLogprob>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub refusal: Option<Vec<ChatCompletionTokenLogprob>>,
}
6 changes: 5 additions & 1 deletion src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
pub mod chat;
pub mod common;
pub mod completion;
pub mod content;
pub mod embeddings;
pub mod logprob;
pub mod streaming;
pub mod tool_calls;
pub mod usage;
39 changes: 39 additions & 0 deletions src/models/streaming.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use serde::{Deserialize, Serialize};

use super::logprob::ChoiceLogprobs;
use super::tool_calls::ChoiceDeltaToolCall;
use super::usage::Usage;

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct ChoiceDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ChoiceDeltaToolCall>>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct Choice {
pub delta: ChoiceDelta,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChoiceLogprobs>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct ChatCompletionChunk {
pub id: String,
pub choices: Vec<Choice>,
pub created: i64,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}
28 changes: 28 additions & 0 deletions src/models/tool_calls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use serde::{Deserialize, Serialize};

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct ChoiceDeltaFunctionCall {
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct ChoiceDeltaToolCallFunction {
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct ChoiceDeltaToolCall {
pub index: i32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<ChoiceDeltaToolCallFunction>,
#[serde(skip_serializing_if = "Option::is_none")]
pub r#type: Option<String>,
}
26 changes: 26 additions & 0 deletions src/models/usage.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use serde::{Deserialize, Serialize};

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct CompletionTokensDetails {
pub accepted_prediction_tokens: Option<u32>,
pub audio_tokens: Option<u32>,
pub reasoning_tokens: Option<u32>,
pub rejected_prediction_tokens: Option<u32>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct PromptTokensDetails {
pub audio_tokens: Option<u32>,
pub cached_tokens: Option<u32>,
}

#[derive(Deserialize, Serialize, Clone, Debug, Default)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub completion_tokens_details: Option<CompletionTokensDetails>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<PromptTokensDetails>,
}
Loading

0 comments on commit 55823de

Please sign in to comment.