Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
galkleinman committed Nov 6, 2024
1 parent abb16d9 commit abaf459
Show file tree
Hide file tree
Showing 24 changed files with 85 additions and 79 deletions.
6 changes: 3 additions & 3 deletions src/ai_models/instance.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::sync::Arc;
use axum::http::StatusCode;
use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse};
use crate::models::completion::{CompletionRequest, CompletionResponse};
use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse};
use crate::providers::provider::Provider;
use crate::state::AppState;
use axum::http::StatusCode;
use std::sync::Arc;

pub struct ModelInstance {
pub name: String,
Expand All @@ -28,7 +28,7 @@ impl ModelInstance {
mut payload: CompletionRequest,
) -> Result<CompletionResponse, StatusCode> {
payload.model = self.model_type.clone();

self.provider.completions(state, payload).await
}

Expand Down
13 changes: 6 additions & 7 deletions src/ai_models/registry.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use anyhow::Result;
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Result;

use super::instance::ModelInstance;
use crate::config::models::Model as ModelConfig;
use crate::providers::registry::ProviderRegistry;
use super::instance::ModelInstance;


pub struct ModelRegistry {
models: HashMap<String, Arc<ModelInstance>>,
Expand All @@ -17,22 +16,22 @@ impl ModelRegistry {
provider_registry: Arc<ProviderRegistry>,
) -> Result<Self> {
let mut models = HashMap::new();

for config in model_configs {
if let Some(provider) = provider_registry.get(&config.provider) {
let model = Arc::new(ModelInstance {
name: config.name.clone(),
model_type: config.r#type.clone(),
provider,
});

models.insert(config.name.clone(), model);
}
}

Ok(Self { models })
}

pub fn get(&self, name: &str) -> Option<Arc<ModelInstance>> {
self.models.get(name).cloned()
}
Expand Down
2 changes: 1 addition & 1 deletion src/config/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pub mod models;
pub mod lib;
pub mod models;
9 changes: 7 additions & 2 deletions src/handlers/chat.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::{models::chat::{ChatCompletionRequest, ChatCompletionResponse}, state::AppState};
use crate::{
models::chat::{ChatCompletionRequest, ChatCompletionResponse},
state::AppState,
};
use axum::{extract::State, http::StatusCode, Json};
use std::sync::Arc;

Expand All @@ -8,7 +11,9 @@ pub async fn completions(
) -> Result<Json<ChatCompletionResponse>, StatusCode> {
for model in state.config.models.iter() {
if let Some(model) = state.model_registry.get(&model.name) {
let response = model.chat_completions(state.clone(), payload.clone()).await?;
let response = model
.chat_completions(state.clone(), payload.clone())
.await?;
return Ok(Json(response));
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/handlers/completion.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::{models::completion::{CompletionRequest, CompletionResponse}, state::AppState};
use crate::{
models::completion::{CompletionRequest, CompletionResponse},
state::AppState,
};
use axum::{extract::State, http::StatusCode, Json};
use std::sync::Arc;

Expand Down
5 changes: 4 additions & 1 deletion src/handlers/embeddings.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use axum::{extract::State, http::StatusCode, Json};
use std::sync::Arc;

use crate::{models::embeddings::{EmbeddingsRequest, EmbeddingsResponse}, state::AppState};
use crate::{
models::embeddings::{EmbeddingsRequest, EmbeddingsResponse},
state::AppState,
};

pub async fn embeddings(
State(state): State<Arc<AppState>>,
Expand Down
6 changes: 4 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ async fn main() -> Result<(), anyhow::Error> {

info!("Starting the application...");

let config = load_config("config.yaml").map_err(|_| anyhow::anyhow!("Failed to load configuration"))?;
let state = Arc::new(AppState::new(config).map_err(|_| anyhow::anyhow!("Failed to create app state"))?);
let config =
load_config("config.yaml").map_err(|_| anyhow::anyhow!("Failed to load configuration"))?;
let state =
Arc::new(AppState::new(config).map_err(|_| anyhow::anyhow!("Failed to create app state"))?);
let app = routes::create_router(state);
let port: String = std::env::var("PORT").unwrap_or("3000".to_string());
let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", port))
Expand Down
2 changes: 1 addition & 1 deletion src/models/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
}
2 changes: 1 addition & 1 deletion src/models/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ pub struct Embeddings {
pub object: String,
pub embedding: Vec<f32>,
pub index: usize,
}
}
2 changes: 1 addition & 1 deletion src/pipelines/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod pipeline;
pub mod plugin;
pub mod plugins;
pub mod pipeline;
pub mod services;
6 changes: 2 additions & 4 deletions src/pipelines/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ pub fn select_pipeline<'a>(
.find(|p| p.name == pipeline_name)
} else {
// Default to pipeline named "default"
matching_pipelines
.into_iter()
.find(|p| p.name == "default")
matching_pipelines.into_iter().find(|p| p.name == "default")
}
}
}
6 changes: 2 additions & 4 deletions src/pipelines/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,11 @@ where
fn call(&mut self, request: Request) -> Self::Future {
if !self.plugin.enabled() {
let future = self.inner.call(request);
return Box::pin(async move {
future.await
});
return Box::pin(async move { future.await });
}

let future = self.inner.call(request);

Box::pin(async move {
let response = future.await?;
// Here you can add post-processing logic
Expand Down
6 changes: 2 additions & 4 deletions src/pipelines/plugins/logging.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::pipelines::plugin::Plugin;
use crate::config::models::PluginConfig;
use crate::pipelines::plugin::Plugin;

pub struct LoggingPlugin;

Expand All @@ -12,9 +12,7 @@ impl Plugin for LoggingPlugin {
true
}

fn init(&mut self, _config: &PluginConfig) -> () {

}
fn init(&mut self, _config: &PluginConfig) -> () {}

fn clone_box(&self) -> Box<dyn Plugin> {
Box::new(LoggingPlugin)
Expand Down
2 changes: 1 addition & 1 deletion src/pipelines/plugins/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pub mod logging;
pub mod tracing;
pub mod tracing;
6 changes: 2 additions & 4 deletions src/pipelines/plugins/tracing.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::pipelines::plugin::Plugin;
use crate::config::models::PluginConfig;
use crate::pipelines::plugin::Plugin;

pub struct TracingPlugin;

Expand All @@ -12,9 +12,7 @@ impl Plugin for TracingPlugin {
true
}

fn init(&mut self, _config: &PluginConfig) -> () {

}
fn init(&mut self, _config: &PluginConfig) -> () {}

fn clone_box(&self) -> Box<dyn Plugin> {
Box::new(TracingPlugin)
Expand Down
2 changes: 1 addition & 1 deletion src/pipelines/services/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pub mod model_router;
pub mod model_router;
7 changes: 3 additions & 4 deletions src/pipelines/services/model_router.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use std::sync::Arc;
use tower::Service;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tower::Service;

use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse};
use crate::state::AppState;


pub struct ModelRouterService {
state: Arc<AppState>,
models: Vec<String>,
Expand Down Expand Up @@ -47,4 +46,4 @@ impl Service<ChatCompletionRequest> for ModelRouterService {
Err(axum::http::StatusCode::SERVICE_UNAVAILABLE)
})
}
}
}
27 changes: 15 additions & 12 deletions src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ use axum::http::StatusCode;
use std::sync::Arc;

use super::provider::Provider;
use crate::config::models::Provider as ProviderConfig;
use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse};
use crate::models::common::Usage;
use crate::models::completion::{CompletionRequest, CompletionResponse, CompletionChoice};
use crate::models::embeddings::{Embeddings, EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse};
use crate::config::models::Provider as ProviderConfig;
use crate::models::completion::{CompletionChoice, CompletionRequest, CompletionResponse};
use crate::models::embeddings::{
Embeddings, EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse,
};
use crate::state::AppState;

pub struct AnthropicProvider {
Expand All @@ -31,7 +33,7 @@ impl Provider for AnthropicProvider {
fn r#type(&self) -> String {
"anthropic".to_string()
}

async fn chat_completions(
&self,
state: Arc<AppState>,
Expand All @@ -54,12 +56,11 @@ impl Provider for AnthropicProvider {
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} else {
Err(StatusCode::from_u16(status.as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
}
}

async fn completions(
async fn completions(
&self,
state: Arc<AppState>,
payload: CompletionRequest,
Expand All @@ -85,8 +86,9 @@ impl Provider for AnthropicProvider {

let status = response.status();
if !status.is_success() {
return Err(StatusCode::from_u16(status.as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR));
return Err(
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
);
}

let anthropic_response: serde_json::Value = response
Expand Down Expand Up @@ -119,7 +121,7 @@ impl Provider for AnthropicProvider {
})
}

async fn embeddings(
async fn embeddings(
&self,
state: Arc<AppState>,
payload: EmbeddingsRequest,
Expand Down Expand Up @@ -147,8 +149,9 @@ impl Provider for AnthropicProvider {

let status = response.status();
if !status.is_success() {
return Err(StatusCode::from_u16(status.as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR));
return Err(
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
);
}

let anthropic_response: serde_json::Value = response
Expand Down
4 changes: 2 additions & 2 deletions src/providers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod anthropic;
pub mod registry;
pub mod provider;
pub mod openai;
pub mod provider;
pub mod registry;
15 changes: 6 additions & 9 deletions src/providers/openai.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use axum::async_trait;
use axum::http::StatusCode;

use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse};
use crate::config::models::Provider as ProviderConfig;
use crate::state::AppState;
use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse};
use crate::models::completion::{CompletionRequest, CompletionResponse};
use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse};
use crate::state::AppState;
use std::sync::Arc;

use super::provider::Provider;
Expand Down Expand Up @@ -51,8 +51,7 @@ impl Provider for OpenAIProvider {
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} else {
Err(StatusCode::from_u16(status.as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
}
}

Expand All @@ -77,12 +76,11 @@ impl Provider for OpenAIProvider {
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} else {
Err(StatusCode::from_u16(status.as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
}
}

async fn embeddings(
async fn embeddings(
&self,
state: Arc<AppState>,
payload: EmbeddingsRequest,
Expand All @@ -103,8 +101,7 @@ impl Provider for OpenAIProvider {
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} else {
Err(StatusCode::from_u16(status.as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
}
}
}
9 changes: 5 additions & 4 deletions src/providers/provider.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use axum::async_trait;
use axum::http::StatusCode;

use crate::config::models::Provider as ProviderConfig;
use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse};
use crate::models::completion::{CompletionRequest, CompletionResponse};
use crate::config::models::Provider as ProviderConfig;
use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse};
use crate::state::AppState;
use std::sync::Arc;


#[async_trait]
pub trait Provider: Send + Sync {
fn new(config: &ProviderConfig) -> Self where Self: Sized;
fn new(config: &ProviderConfig) -> Self
where
Self: Sized;
fn name(&self) -> String;
fn r#type(&self) -> String;

Expand All @@ -20,7 +21,7 @@ pub trait Provider: Send + Sync {
state: Arc<AppState>,
payload: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, StatusCode>;

async fn completions(
&self,
state: Arc<AppState>,
Expand Down
Loading

0 comments on commit abaf459

Please sign in to comment.