From 0155435142ccdf9994406f23c227833e8b5fb41b Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 22 Jul 2024 12:25:53 +0200 Subject: [PATCH] Allow using a custom model when using zed.dev (#14933) Release Notes: - N/A --- crates/anthropic/src/anthropic.rs | 33 ++++- crates/assistant/src/assistant_settings.rs | 11 +- crates/collab/src/rpc.rs | 51 ++++++-- crates/completion/src/cloud.rs | 12 +- crates/completion/src/open_ai.rs | 1 + .../language_model/src/model/cloud_model.rs | 116 +++++------------- 6 files changed, 114 insertions(+), 110 deletions(-) diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index c36a7f37fd25c..21cb4d75aa9a7 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -20,6 +20,12 @@ pub enum Model { Claude3Sonnet, #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")] Claude3Haiku, + #[serde(rename = "custom")] + Custom { + name: String, + #[serde(default)] + max_tokens: Option, + }, } impl Model { @@ -33,30 +39,41 @@ impl Model { } else if id.starts_with("claude-3-haiku") { Ok(Self::Claude3Haiku) } else { - Err(anyhow!("Invalid model id: {}", id)) + Ok(Self::Custom { + name: id.to_string(), + max_tokens: None, + }) } } - pub fn id(&self) -> &'static str { + pub fn id(&self) -> &str { match self { Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620", Model::Claude3Opus => "claude-3-opus-20240229", Model::Claude3Sonnet => "claude-3-sonnet-20240229", Model::Claude3Haiku => "claude-3-opus-20240307", + Model::Custom { name, .. } => name, } } - pub fn display_name(&self) -> &'static str { + pub fn display_name(&self) -> &str { match self { Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", Self::Claude3Opus => "Claude 3 Opus", Self::Claude3Sonnet => "Claude 3 Sonnet", Self::Claude3Haiku => "Claude 3 Haiku", + Self::Custom { name, .. } => name, } } pub fn max_token_count(&self) -> usize { - 200_000 + match self { + Self::Claude3_5Sonnet + | Self::Claude3Opus + | Self::Claude3Sonnet + | Self::Claude3Haiku => 200_000, + Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000), + } } } @@ -90,6 +107,7 @@ impl From for String { #[derive(Debug, Serialize)] pub struct Request { + #[serde(serialize_with = "serialize_request_model")] pub model: Model, pub messages: Vec, pub stream: bool, @@ -97,6 +115,13 @@ pub struct Request { pub max_tokens: u32, } +fn serialize_request_model(model: &Model, serializer: S) -> Result +where + S: serde::Serializer, +{ + serializer.serialize_str(&model.id()) +} + #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct RequestMessage { pub role: Role, diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 7fca691e7a244..e19dc65a44542 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -668,7 +668,11 @@ mod tests { "version": "1", "provider": { "name": "zed.dev", - "default_model": "custom" + "default_model": { + "custom": { + "name": "custom-provider" + } + } } } }"#, @@ -679,7 +683,10 @@ mod tests { assert_eq!( AssistantSettings::get_global(cx).provider, AssistantProvider::ZedDotDev { - model: CloudModel::Custom("custom".into()) + model: CloudModel::Custom { + name: "custom-provider".into(), + max_tokens: None + } } ); } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index d9113898084a2..4960eaa213488 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -4514,7 +4514,7 @@ impl RateLimit for CompleteWithLanguageModelRateLimit { } async fn complete_with_language_model( - request: proto::CompleteWithLanguageModel, + mut request: proto::CompleteWithLanguageModel, response: StreamingResponse, session: Session, open_ai_api_key: Option>, @@ -4530,18 +4530,43 @@ async fn complete_with_language_model( .check::(session.user_id()) .await?; - if request.model.starts_with("gpt") { - let api_key = - open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?; - complete_with_open_ai(request, response, session, api_key).await?; - } else if request.model.starts_with("gemini") { - let api_key = google_ai_api_key - .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?; - complete_with_google_ai(request, response, session, api_key).await?; - } else if request.model.starts_with("claude") { - let api_key = anthropic_api_key - .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?; - complete_with_anthropic(request, response, session, api_key).await?; + let mut provider_and_model = request.model.split('/'); + let (provider, model) = match ( + provider_and_model.next().unwrap(), + provider_and_model.next(), + ) { + (provider, Some(model)) => (provider, model), + (model, None) => { + if model.starts_with("gpt") { + ("openai", model) + } else if model.starts_with("gemini") { + ("google", model) + } else if model.starts_with("claude") { + ("anthropic", model) + } else { + ("unknown", model) + } + } + }; + let provider = provider.to_string(); + request.model = model.to_string(); + + match provider.as_str() { + "openai" => { + let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?; + complete_with_open_ai(request, response, session, api_key).await?; + } + "anthropic" => { + let api_key = + anthropic_api_key.context("no Anthropic AI API key configured on the server")?; + complete_with_anthropic(request, response, session, api_key).await?; + } + "google" => { + let api_key = + google_ai_api_key.context("no Google AI API key configured on the server")?; + complete_with_google_ai(request, response, session, api_key).await?; + } + provider => return Err(anyhow!("unknown provider {:?}", provider))?, } Ok(()) diff --git a/crates/completion/src/cloud.rs b/crates/completion/src/cloud.rs index f84576aeca101..ba1a7dd233455 100644 --- a/crates/completion/src/cloud.rs +++ b/crates/completion/src/cloud.rs @@ -54,15 +54,15 @@ impl CloudCompletionProvider { impl LanguageModelCompletionProvider for CloudCompletionProvider { fn available_models(&self) -> Vec { - let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() { - Some(custom_model) + let mut custom_model = if matches!(self.model, CloudModel::Custom { .. }) { + Some(self.model.clone()) } else { None }; CloudModel::iter() .filter_map(move |model| { - if let CloudModel::Custom(_) = model { - Some(CloudModel::Custom(custom_model.take()?)) + if let CloudModel::Custom { .. } = model { + custom_model.take() } else { Some(model) } @@ -117,9 +117,9 @@ impl LanguageModelCompletionProvider for CloudCompletionProvider { // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation. count_open_ai_tokens(request, cx.background_executor()) } - LanguageModel::Cloud(CloudModel::Custom(model)) => { + LanguageModel::Cloud(CloudModel::Custom { name, .. }) => { let request = self.client.request(proto::CountTokensWithLanguageModel { - model, + model: name, messages: request .messages .iter() diff --git a/crates/completion/src/open_ai.rs b/crates/completion/src/open_ai.rs index d187842bcbef5..21a0bbd73eee3 100644 --- a/crates/completion/src/open_ai.rs +++ b/crates/completion/src/open_ai.rs @@ -241,6 +241,7 @@ pub fn count_open_ai_tokens( | LanguageModel::Cloud(CloudModel::Claude3Opus) | LanguageModel::Cloud(CloudModel::Claude3Sonnet) | LanguageModel::Cloud(CloudModel::Claude3Haiku) + | LanguageModel::Cloud(CloudModel::Custom { .. }) | LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => { // Tiktoken doesn't yet support these models, so we manually use the // same tokenizer as GPT-4. diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 20b2bf7d4f90e..43cb393a04836 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -2,100 +2,40 @@ use crate::LanguageModelRequest; pub use anthropic::Model as AnthropicModel; pub use ollama::Model as OllamaModel; pub use open_ai::Model as OpenAiModel; -use schemars::{ - schema::{InstanceType, Metadata, Schema, SchemaObject}, - JsonSchema, -}; -use serde::{ - de::{self, Visitor}, - Deserialize, Deserializer, Serialize, Serializer, -}; -use std::fmt; -use strum::{EnumIter, IntoEnumIterator}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use strum::EnumIter; -#[derive(Clone, Debug, Default, PartialEq, EnumIter)] +#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)] pub enum CloudModel { + #[serde(rename = "gpt-3.5-turbo")] Gpt3Point5Turbo, + #[serde(rename = "gpt-4")] Gpt4, + #[serde(rename = "gpt-4-turbo-preview")] Gpt4Turbo, + #[serde(rename = "gpt-4o")] #[default] Gpt4Omni, + #[serde(rename = "gpt-4o-mini")] Gpt4OmniMini, + #[serde(rename = "claude-3-5-sonnet")] Claude3_5Sonnet, + #[serde(rename = "claude-3-opus")] Claude3Opus, + #[serde(rename = "claude-3-sonnet")] Claude3Sonnet, + #[serde(rename = "claude-3-haiku")] Claude3Haiku, + #[serde(rename = "gemini-1.5-pro")] Gemini15Pro, + #[serde(rename = "gemini-1.5-flash")] Gemini15Flash, - Custom(String), -} - -impl Serialize for CloudModel { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_str(self.id()) - } -} - -impl<'de> Deserialize<'de> for CloudModel { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct ZedDotDevModelVisitor; - - impl<'de> Visitor<'de> for ZedDotDevModelVisitor { - type Value = CloudModel; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a string for a ZedDotDevModel variant or a custom model") - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - let model = CloudModel::iter() - .find(|model| model.id() == value) - .unwrap_or_else(|| CloudModel::Custom(value.to_string())); - Ok(model) - } - } - - deserializer.deserialize_str(ZedDotDevModelVisitor) - } -} - -impl JsonSchema for CloudModel { - fn schema_name() -> String { - "ZedDotDevModel".to_owned() - } - - fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema { - let variants = CloudModel::iter() - .filter_map(|model| { - let id = model.id(); - if id.is_empty() { - None - } else { - Some(id.to_string()) - } - }) - .collect::>(); - Schema::Object(SchemaObject { - instance_type: Some(InstanceType::String.into()), - enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()), - metadata: Some(Box::new(Metadata { - title: Some("ZedDotDevModel".to_owned()), - default: Some(CloudModel::default().id().into()), - examples: variants.into_iter().map(Into::into).collect(), - ..Default::default() - })), - ..Default::default() - }) - } + #[serde(rename = "custom")] + Custom { + name: String, + max_tokens: Option, + }, } impl CloudModel { @@ -112,7 +52,7 @@ impl CloudModel { Self::Claude3Haiku => "claude-3-haiku", Self::Gemini15Pro => "gemini-1.5-pro", Self::Gemini15Flash => "gemini-1.5-flash", - Self::Custom(id) => id, + Self::Custom { name, .. } => name, } } @@ -129,7 +69,7 @@ impl CloudModel { Self::Claude3Haiku => "Claude 3 Haiku", Self::Gemini15Pro => "Gemini 1.5 Pro", Self::Gemini15Flash => "Gemini 1.5 Flash", - Self::Custom(id) => id.as_str(), + Self::Custom { name, .. } => name, } } @@ -145,14 +85,20 @@ impl CloudModel { | Self::Claude3Haiku => 200000, Self::Gemini15Pro => 128000, Self::Gemini15Flash => 32000, - Self::Custom(_) => 4096, // TODO: Make this configurable + Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000), } } pub fn preprocess_request(&self, request: &mut LanguageModelRequest) { match self { - Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => { - request.preprocess_anthropic() + Self::Claude3Opus + | Self::Claude3Sonnet + | Self::Claude3Haiku + | Self::Claude3_5Sonnet => { + request.preprocess_anthropic(); + } + Self::Custom { name, .. } if name.starts_with("anthropic/") => { + request.preprocess_anthropic(); } _ => {} }