From abaf4593877526d3bc64ea309e14b527eb454007 Mon Sep 17 00:00:00 2001 From: Gal Kleinman Date: Wed, 6 Nov 2024 17:19:55 +0200 Subject: [PATCH] formatting --- src/ai_models/instance.rs | 6 +++--- src/ai_models/registry.rs | 13 ++++++------- src/config/mod.rs | 2 +- src/handlers/chat.rs | 9 +++++++-- src/handlers/completion.rs | 5 ++++- src/handlers/embeddings.rs | 5 ++++- src/main.rs | 6 ++++-- src/models/common.rs | 2 +- src/models/embeddings.rs | 2 +- src/pipelines/mod.rs | 2 +- src/pipelines/pipeline.rs | 6 ++---- src/pipelines/plugin.rs | 6 ++---- src/pipelines/plugins/logging.rs | 6 ++---- src/pipelines/plugins/mod.rs | 2 +- src/pipelines/plugins/tracing.rs | 6 ++---- src/pipelines/services/mod.rs | 2 +- src/pipelines/services/model_router.rs | 7 +++---- src/providers/anthropic.rs | 27 ++++++++++++++------------ src/providers/mod.rs | 4 ++-- src/providers/openai.rs | 15 ++++++-------- src/providers/provider.rs | 9 +++++---- src/providers/registry.rs | 10 +++++----- src/routes.rs | 8 +++++--- src/state.rs | 4 ++-- 24 files changed, 85 insertions(+), 79 deletions(-) diff --git a/src/ai_models/instance.rs b/src/ai_models/instance.rs index 64a7442..5324dd2 100644 --- a/src/ai_models/instance.rs +++ b/src/ai_models/instance.rs @@ -1,10 +1,10 @@ -use std::sync::Arc; -use axum::http::StatusCode; use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; use crate::models::completion::{CompletionRequest, CompletionResponse}; use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse}; use crate::providers::provider::Provider; use crate::state::AppState; +use axum::http::StatusCode; +use std::sync::Arc; pub struct ModelInstance { pub name: String, @@ -28,7 +28,7 @@ impl ModelInstance { mut payload: CompletionRequest, ) -> Result { payload.model = self.model_type.clone(); - + self.provider.completions(state, payload).await } diff --git a/src/ai_models/registry.rs b/src/ai_models/registry.rs index 7e9d801..9bc6b73 100644 --- a/src/ai_models/registry.rs +++ b/src/ai_models/registry.rs @@ -1,11 +1,10 @@ +use anyhow::Result; use std::collections::HashMap; use std::sync::Arc; -use anyhow::Result; +use super::instance::ModelInstance; use crate::config::models::Model as ModelConfig; use crate::providers::registry::ProviderRegistry; -use super::instance::ModelInstance; - pub struct ModelRegistry { models: HashMap>, @@ -17,7 +16,7 @@ impl ModelRegistry { provider_registry: Arc, ) -> Result { let mut models = HashMap::new(); - + for config in model_configs { if let Some(provider) = provider_registry.get(&config.provider) { let model = Arc::new(ModelInstance { @@ -25,14 +24,14 @@ impl ModelRegistry { model_type: config.r#type.clone(), provider, }); - + models.insert(config.name.clone(), model); } } - + Ok(Self { models }) } - + pub fn get(&self, name: &str) -> Option> { self.models.get(name).cloned() } diff --git a/src/config/mod.rs b/src/config/mod.rs index 5253185..9ea4d54 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,2 +1,2 @@ -pub mod models; pub mod lib; +pub mod models; diff --git a/src/handlers/chat.rs b/src/handlers/chat.rs index 79b6d99..2e2ee14 100644 --- a/src/handlers/chat.rs +++ b/src/handlers/chat.rs @@ -1,4 +1,7 @@ -use crate::{models::chat::{ChatCompletionRequest, ChatCompletionResponse}, state::AppState}; +use crate::{ + models::chat::{ChatCompletionRequest, ChatCompletionResponse}, + state::AppState, +}; use axum::{extract::State, http::StatusCode, Json}; use std::sync::Arc; @@ -8,7 +11,9 @@ pub async fn completions( ) -> Result, StatusCode> { for model in state.config.models.iter() { if let Some(model) = state.model_registry.get(&model.name) { - let response = model.chat_completions(state.clone(), payload.clone()).await?; + let response = model + .chat_completions(state.clone(), payload.clone()) + .await?; return Ok(Json(response)); } } diff --git a/src/handlers/completion.rs b/src/handlers/completion.rs index 279a54d..6278409 100644 --- a/src/handlers/completion.rs +++ b/src/handlers/completion.rs @@ -1,4 +1,7 @@ -use crate::{models::completion::{CompletionRequest, CompletionResponse}, state::AppState}; +use crate::{ + models::completion::{CompletionRequest, CompletionResponse}, + state::AppState, +}; use axum::{extract::State, http::StatusCode, Json}; use std::sync::Arc; diff --git a/src/handlers/embeddings.rs b/src/handlers/embeddings.rs index 3c3af7f..9dae585 100644 --- a/src/handlers/embeddings.rs +++ b/src/handlers/embeddings.rs @@ -1,7 +1,10 @@ use axum::{extract::State, http::StatusCode, Json}; use std::sync::Arc; -use crate::{models::embeddings::{EmbeddingsRequest, EmbeddingsResponse}, state::AppState}; +use crate::{ + models::embeddings::{EmbeddingsRequest, EmbeddingsResponse}, + state::AppState, +}; pub async fn embeddings( State(state): State>, diff --git a/src/main.rs b/src/main.rs index 0c5cdf9..3182ddb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,8 +9,10 @@ async fn main() -> Result<(), anyhow::Error> { info!("Starting the application..."); - let config = load_config("config.yaml").map_err(|_| anyhow::anyhow!("Failed to load configuration"))?; - let state = Arc::new(AppState::new(config).map_err(|_| anyhow::anyhow!("Failed to create app state"))?); + let config = + load_config("config.yaml").map_err(|_| anyhow::anyhow!("Failed to load configuration"))?; + let state = + Arc::new(AppState::new(config).map_err(|_| anyhow::anyhow!("Failed to create app state"))?); let app = routes::create_router(state); let port: String = std::env::var("PORT").unwrap_or("3000".to_string()); let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", port)) diff --git a/src/models/common.rs b/src/models/common.rs index b400d4b..77f8880 100644 --- a/src/models/common.rs +++ b/src/models/common.rs @@ -5,4 +5,4 @@ pub struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, -} \ No newline at end of file +} diff --git a/src/models/embeddings.rs b/src/models/embeddings.rs index f823bcc..a280f3d 100644 --- a/src/models/embeddings.rs +++ b/src/models/embeddings.rs @@ -32,4 +32,4 @@ pub struct Embeddings { pub object: String, pub embedding: Vec, pub index: usize, -} \ No newline at end of file +} diff --git a/src/pipelines/mod.rs b/src/pipelines/mod.rs index 7c34bd9..819afbb 100644 --- a/src/pipelines/mod.rs +++ b/src/pipelines/mod.rs @@ -1,4 +1,4 @@ +pub mod pipeline; pub mod plugin; pub mod plugins; -pub mod pipeline; pub mod services; diff --git a/src/pipelines/pipeline.rs b/src/pipelines/pipeline.rs index be86e71..1846587 100644 --- a/src/pipelines/pipeline.rs +++ b/src/pipelines/pipeline.rs @@ -30,8 +30,6 @@ pub fn select_pipeline<'a>( .find(|p| p.name == pipeline_name) } else { // Default to pipeline named "default" - matching_pipelines - .into_iter() - .find(|p| p.name == "default") + matching_pipelines.into_iter().find(|p| p.name == "default") } -} \ No newline at end of file +} diff --git a/src/pipelines/plugin.rs b/src/pipelines/plugin.rs index 75d8459..efe4a07 100644 --- a/src/pipelines/plugin.rs +++ b/src/pipelines/plugin.rs @@ -38,13 +38,11 @@ where fn call(&mut self, request: Request) -> Self::Future { if !self.plugin.enabled() { let future = self.inner.call(request); - return Box::pin(async move { - future.await - }); + return Box::pin(async move { future.await }); } let future = self.inner.call(request); - + Box::pin(async move { let response = future.await?; // Here you can add post-processing logic diff --git a/src/pipelines/plugins/logging.rs b/src/pipelines/plugins/logging.rs index 58eae23..f90ab9f 100644 --- a/src/pipelines/plugins/logging.rs +++ b/src/pipelines/plugins/logging.rs @@ -1,5 +1,5 @@ -use crate::pipelines::plugin::Plugin; use crate::config::models::PluginConfig; +use crate::pipelines::plugin::Plugin; pub struct LoggingPlugin; @@ -12,9 +12,7 @@ impl Plugin for LoggingPlugin { true } - fn init(&mut self, _config: &PluginConfig) -> () { - - } + fn init(&mut self, _config: &PluginConfig) -> () {} fn clone_box(&self) -> Box { Box::new(LoggingPlugin) diff --git a/src/pipelines/plugins/mod.rs b/src/pipelines/plugins/mod.rs index f8e67b0..8f89f06 100644 --- a/src/pipelines/plugins/mod.rs +++ b/src/pipelines/plugins/mod.rs @@ -1,2 +1,2 @@ pub mod logging; -pub mod tracing; \ No newline at end of file +pub mod tracing; diff --git a/src/pipelines/plugins/tracing.rs b/src/pipelines/plugins/tracing.rs index 75e8cd1..b2b99e7 100644 --- a/src/pipelines/plugins/tracing.rs +++ b/src/pipelines/plugins/tracing.rs @@ -1,5 +1,5 @@ -use crate::pipelines::plugin::Plugin; use crate::config::models::PluginConfig; +use crate::pipelines::plugin::Plugin; pub struct TracingPlugin; @@ -12,9 +12,7 @@ impl Plugin for TracingPlugin { true } - fn init(&mut self, _config: &PluginConfig) -> () { - - } + fn init(&mut self, _config: &PluginConfig) -> () {} fn clone_box(&self) -> Box { Box::new(TracingPlugin) diff --git a/src/pipelines/services/mod.rs b/src/pipelines/services/mod.rs index 5f42390..c5b1955 100644 --- a/src/pipelines/services/mod.rs +++ b/src/pipelines/services/mod.rs @@ -1 +1 @@ -pub mod model_router; \ No newline at end of file +pub mod model_router; diff --git a/src/pipelines/services/model_router.rs b/src/pipelines/services/model_router.rs index 892a815..6d0188f 100644 --- a/src/pipelines/services/model_router.rs +++ b/src/pipelines/services/model_router.rs @@ -1,13 +1,12 @@ -use std::sync::Arc; -use tower::Service; use std::future::Future; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; +use tower::Service; use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; use crate::state::AppState; - pub struct ModelRouterService { state: Arc, models: Vec, @@ -47,4 +46,4 @@ impl Service for ModelRouterService { Err(axum::http::StatusCode::SERVICE_UNAVAILABLE) }) } -} \ No newline at end of file +} diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index a2e953c..6cf54e6 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -3,11 +3,13 @@ use axum::http::StatusCode; use std::sync::Arc; use super::provider::Provider; +use crate::config::models::Provider as ProviderConfig; use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; use crate::models::common::Usage; -use crate::models::completion::{CompletionRequest, CompletionResponse, CompletionChoice}; -use crate::models::embeddings::{Embeddings, EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse}; -use crate::config::models::Provider as ProviderConfig; +use crate::models::completion::{CompletionChoice, CompletionRequest, CompletionResponse}; +use crate::models::embeddings::{ + Embeddings, EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse, +}; use crate::state::AppState; pub struct AnthropicProvider { @@ -31,7 +33,7 @@ impl Provider for AnthropicProvider { fn r#type(&self) -> String { "anthropic".to_string() } - + async fn chat_completions( &self, state: Arc, @@ -54,12 +56,11 @@ impl Provider for AnthropicProvider { .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } else { - Err(StatusCode::from_u16(status.as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) + Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) } } - async fn completions( + async fn completions( &self, state: Arc, payload: CompletionRequest, @@ -85,8 +86,9 @@ impl Provider for AnthropicProvider { let status = response.status(); if !status.is_success() { - return Err(StatusCode::from_u16(status.as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)); + return Err( + StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) + ); } let anthropic_response: serde_json::Value = response @@ -119,7 +121,7 @@ impl Provider for AnthropicProvider { }) } - async fn embeddings( + async fn embeddings( &self, state: Arc, payload: EmbeddingsRequest, @@ -147,8 +149,9 @@ impl Provider for AnthropicProvider { let status = response.status(); if !status.is_success() { - return Err(StatusCode::from_u16(status.as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)); + return Err( + StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) + ); } let anthropic_response: serde_json::Value = response diff --git a/src/providers/mod.rs b/src/providers/mod.rs index f38e756..68133c8 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -1,4 +1,4 @@ pub mod anthropic; -pub mod registry; -pub mod provider; pub mod openai; +pub mod provider; +pub mod registry; diff --git a/src/providers/openai.rs b/src/providers/openai.rs index eac6c49..4ab1c43 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -1,11 +1,11 @@ use axum::async_trait; use axum::http::StatusCode; -use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; use crate::config::models::Provider as ProviderConfig; -use crate::state::AppState; +use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; use crate::models::completion::{CompletionRequest, CompletionResponse}; use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse}; +use crate::state::AppState; use std::sync::Arc; use super::provider::Provider; @@ -51,8 +51,7 @@ impl Provider for OpenAIProvider { .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } else { - Err(StatusCode::from_u16(status.as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) + Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) } } @@ -77,12 +76,11 @@ impl Provider for OpenAIProvider { .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } else { - Err(StatusCode::from_u16(status.as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) + Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) } } - async fn embeddings( + async fn embeddings( &self, state: Arc, payload: EmbeddingsRequest, @@ -103,8 +101,7 @@ impl Provider for OpenAIProvider { .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } else { - Err(StatusCode::from_u16(status.as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) + Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) } } } diff --git a/src/providers/provider.rs b/src/providers/provider.rs index 5b1471f..77d59f6 100644 --- a/src/providers/provider.rs +++ b/src/providers/provider.rs @@ -1,17 +1,18 @@ use axum::async_trait; use axum::http::StatusCode; +use crate::config::models::Provider as ProviderConfig; use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; use crate::models::completion::{CompletionRequest, CompletionResponse}; -use crate::config::models::Provider as ProviderConfig; use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse}; use crate::state::AppState; use std::sync::Arc; - #[async_trait] pub trait Provider: Send + Sync { - fn new(config: &ProviderConfig) -> Self where Self: Sized; + fn new(config: &ProviderConfig) -> Self + where + Self: Sized; fn name(&self) -> String; fn r#type(&self) -> String; @@ -20,7 +21,7 @@ pub trait Provider: Send + Sync { state: Arc, payload: ChatCompletionRequest, ) -> Result; - + async fn completions( &self, state: Arc, diff --git a/src/providers/registry.rs b/src/providers/registry.rs index f5821fc..79eea7f 100644 --- a/src/providers/registry.rs +++ b/src/providers/registry.rs @@ -1,9 +1,9 @@ +use anyhow::Result; use std::collections::HashMap; use std::sync::Arc; -use anyhow::Result; -use crate::providers::{anthropic::AnthropicProvider, openai::OpenAIProvider, provider::Provider}; use crate::config::models::Provider as ProviderConfig; +use crate::providers::{anthropic::AnthropicProvider, openai::OpenAIProvider, provider::Provider}; pub struct ProviderRegistry { providers: HashMap>, @@ -12,7 +12,7 @@ pub struct ProviderRegistry { impl ProviderRegistry { pub fn new(provider_configs: &[ProviderConfig]) -> Result { let mut providers = HashMap::new(); - + for config in provider_configs { let provider: Arc = match config.r#type.as_str() { "openai" => Arc::new(OpenAIProvider::new(config)), @@ -21,10 +21,10 @@ impl ProviderRegistry { }; providers.insert(config.name.clone(), provider); } - + Ok(Self { providers }) } - + pub fn get(&self, name: &str) -> Option> { self.providers.get(name).cloned() } diff --git a/src/routes.rs b/src/routes.rs index 9dc4076..9de3a02 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -5,13 +5,13 @@ use axum::{ routing::{get, post}, Router, }; +use std::iter::once; +use std::sync::Arc; use tower_http::compression::CompressionLayer; use tower_http::propagate_header::PropagateHeaderLayer; use tower_http::sensitive_headers::SetSensitiveRequestHeadersLayer; use tower_http::trace::TraceLayer; use tower_http::validate_request::ValidateRequestHeaderLayer; -use std::iter::once; -use std::sync::Arc; pub fn create_router(state: Arc) -> Router { let v1_routes = Router::new() @@ -21,7 +21,9 @@ pub fn create_router(state: Arc) -> Router { .layer(SetSensitiveRequestHeadersLayer::new(once(AUTHORIZATION))) .layer(TraceLayer::new_for_http()) .layer(CompressionLayer::new()) - .layer(PropagateHeaderLayer::new(HeaderName::from_static("x-request-id"))) + .layer(PropagateHeaderLayer::new(HeaderName::from_static( + "x-request-id", + ))) .layer(ValidateRequestHeaderLayer::accept("application/json")); Router::new() diff --git a/src/state.rs b/src/state.rs index f437ab6..e93a8c7 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,9 +1,9 @@ -use std::sync::Arc; use crate::ai_models::registry::ModelRegistry; use crate::config::models::Config; use crate::providers::registry::ProviderRegistry; -use reqwest::Client; use anyhow::Result; +use reqwest::Client; +use std::sync::Arc; #[derive(Clone)] pub struct AppState {