Skip to content

Commit

Permalink
Allow using a custom model when using zed.dev (#14933)
Browse files Browse the repository at this point in the history
Release Notes:

- N/A
  • Loading branch information
as-cii authored Jul 22, 2024
1 parent a334c69 commit 0155435
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 110 deletions.
33 changes: 29 additions & 4 deletions crates/anthropic/src/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ pub enum Model {
Claude3Sonnet,
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
Claude3Haiku,
#[serde(rename = "custom")]
Custom {
name: String,
#[serde(default)]
max_tokens: Option<usize>,
},
}

impl Model {
Expand All @@ -33,30 +39,41 @@ impl Model {
} else if id.starts_with("claude-3-haiku") {
Ok(Self::Claude3Haiku)
} else {
Err(anyhow!("Invalid model id: {}", id))
Ok(Self::Custom {
name: id.to_string(),
max_tokens: None,
})
}
}

pub fn id(&self) -> &'static str {
pub fn id(&self) -> &str {
match self {
Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
Model::Claude3Opus => "claude-3-opus-20240229",
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
Model::Claude3Haiku => "claude-3-opus-20240307",
Model::Custom { name, .. } => name,
}
}

pub fn display_name(&self) -> &'static str {
pub fn display_name(&self) -> &str {
match self {
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Custom { name, .. } => name,
}
}

pub fn max_token_count(&self) -> usize {
200_000
match self {
Self::Claude3_5Sonnet
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku => 200_000,
Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
}
}
}

Expand Down Expand Up @@ -90,13 +107,21 @@ impl From<Role> for String {

#[derive(Debug, Serialize)]
pub struct Request {
#[serde(serialize_with = "serialize_request_model")]
pub model: Model,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub system: String,
pub max_tokens: u32,
}

fn serialize_request_model<S>(model: &Model, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&model.id())
}

#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
Expand Down
11 changes: 9 additions & 2 deletions crates/assistant/src/assistant_settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,11 @@ mod tests {
"version": "1",
"provider": {
"name": "zed.dev",
"default_model": "custom"
"default_model": {
"custom": {
"name": "custom-provider"
}
}
}
}
}"#,
Expand All @@ -679,7 +683,10 @@ mod tests {
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::ZedDotDev {
model: CloudModel::Custom("custom".into())
model: CloudModel::Custom {
name: "custom-provider".into(),
max_tokens: None
}
}
);
}
Expand Down
51 changes: 38 additions & 13 deletions crates/collab/src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4514,7 +4514,7 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
}

async fn complete_with_language_model(
request: proto::CompleteWithLanguageModel,
mut request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: Session,
open_ai_api_key: Option<Arc<str>>,
Expand All @@ -4530,18 +4530,43 @@ async fn complete_with_language_model(
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;

if request.model.starts_with("gpt") {
let api_key =
open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
complete_with_open_ai(request, response, session, api_key).await?;
} else if request.model.starts_with("gemini") {
let api_key = google_ai_api_key
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
complete_with_google_ai(request, response, session, api_key).await?;
} else if request.model.starts_with("claude") {
let api_key = anthropic_api_key
.ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
complete_with_anthropic(request, response, session, api_key).await?;
let mut provider_and_model = request.model.split('/');
let (provider, model) = match (
provider_and_model.next().unwrap(),
provider_and_model.next(),
) {
(provider, Some(model)) => (provider, model),
(model, None) => {
if model.starts_with("gpt") {
("openai", model)
} else if model.starts_with("gemini") {
("google", model)
} else if model.starts_with("claude") {
("anthropic", model)
} else {
("unknown", model)
}
}
};
let provider = provider.to_string();
request.model = model.to_string();

match provider.as_str() {
"openai" => {
let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
complete_with_open_ai(request, response, session, api_key).await?;
}
"anthropic" => {
let api_key =
anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
complete_with_anthropic(request, response, session, api_key).await?;
}
"google" => {
let api_key =
google_ai_api_key.context("no Google AI API key configured on the server")?;
complete_with_google_ai(request, response, session, api_key).await?;
}
provider => return Err(anyhow!("unknown provider {:?}", provider))?,
}

Ok(())
Expand Down
12 changes: 6 additions & 6 deletions crates/completion/src/cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ impl CloudCompletionProvider {

impl LanguageModelCompletionProvider for CloudCompletionProvider {
fn available_models(&self) -> Vec<LanguageModel> {
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
Some(custom_model)
let mut custom_model = if matches!(self.model, CloudModel::Custom { .. }) {
Some(self.model.clone())
} else {
None
};
CloudModel::iter()
.filter_map(move |model| {
if let CloudModel::Custom(_) = model {
Some(CloudModel::Custom(custom_model.take()?))
if let CloudModel::Custom { .. } = model {
custom_model.take()
} else {
Some(model)
}
Expand Down Expand Up @@ -117,9 +117,9 @@ impl LanguageModelCompletionProvider for CloudCompletionProvider {
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
count_open_ai_tokens(request, cx.background_executor())
}
LanguageModel::Cloud(CloudModel::Custom(model)) => {
LanguageModel::Cloud(CloudModel::Custom { name, .. }) => {
let request = self.client.request(proto::CountTokensWithLanguageModel {
model,
model: name,
messages: request
.messages
.iter()
Expand Down
1 change: 1 addition & 0 deletions crates/completion/src/open_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ pub fn count_open_ai_tokens(
| LanguageModel::Cloud(CloudModel::Claude3Opus)
| LanguageModel::Cloud(CloudModel::Claude3Sonnet)
| LanguageModel::Cloud(CloudModel::Claude3Haiku)
| LanguageModel::Cloud(CloudModel::Custom { .. })
| LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
Expand Down
116 changes: 31 additions & 85 deletions crates/language_model/src/model/cloud_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,100 +2,40 @@ use crate::LanguageModelRequest;
pub use anthropic::Model as AnthropicModel;
pub use ollama::Model as OllamaModel;
pub use open_ai::Model as OpenAiModel;
use schemars::{
schema::{InstanceType, Metadata, Schema, SchemaObject},
JsonSchema,
};
use serde::{
de::{self, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
};
use std::fmt;
use strum::{EnumIter, IntoEnumIterator};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use strum::EnumIter;

#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
pub enum CloudModel {
#[serde(rename = "gpt-3.5-turbo")]
Gpt3Point5Turbo,
#[serde(rename = "gpt-4")]
Gpt4,
#[serde(rename = "gpt-4-turbo-preview")]
Gpt4Turbo,
#[serde(rename = "gpt-4o")]
#[default]
Gpt4Omni,
#[serde(rename = "gpt-4o-mini")]
Gpt4OmniMini,
#[serde(rename = "claude-3-5-sonnet")]
Claude3_5Sonnet,
#[serde(rename = "claude-3-opus")]
Claude3Opus,
#[serde(rename = "claude-3-sonnet")]
Claude3Sonnet,
#[serde(rename = "claude-3-haiku")]
Claude3Haiku,
#[serde(rename = "gemini-1.5-pro")]
Gemini15Pro,
#[serde(rename = "gemini-1.5-flash")]
Gemini15Flash,
Custom(String),
}

impl Serialize for CloudModel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.id())
}
}

impl<'de> Deserialize<'de> for CloudModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ZedDotDevModelVisitor;

impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
type Value = CloudModel;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
}

fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
let model = CloudModel::iter()
.find(|model| model.id() == value)
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
Ok(model)
}
}

deserializer.deserialize_str(ZedDotDevModelVisitor)
}
}

impl JsonSchema for CloudModel {
fn schema_name() -> String {
"ZedDotDevModel".to_owned()
}

fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
let variants = CloudModel::iter()
.filter_map(|model| {
let id = model.id();
if id.is_empty() {
None
} else {
Some(id.to_string())
}
})
.collect::<Vec<_>>();
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
metadata: Some(Box::new(Metadata {
title: Some("ZedDotDevModel".to_owned()),
default: Some(CloudModel::default().id().into()),
examples: variants.into_iter().map(Into::into).collect(),
..Default::default()
})),
..Default::default()
})
}
#[serde(rename = "custom")]
Custom {
name: String,
max_tokens: Option<usize>,
},
}

impl CloudModel {
Expand All @@ -112,7 +52,7 @@ impl CloudModel {
Self::Claude3Haiku => "claude-3-haiku",
Self::Gemini15Pro => "gemini-1.5-pro",
Self::Gemini15Flash => "gemini-1.5-flash",
Self::Custom(id) => id,
Self::Custom { name, .. } => name,
}
}

Expand All @@ -129,7 +69,7 @@ impl CloudModel {
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Gemini15Pro => "Gemini 1.5 Pro",
Self::Gemini15Flash => "Gemini 1.5 Flash",
Self::Custom(id) => id.as_str(),
Self::Custom { name, .. } => name,
}
}

Expand All @@ -145,14 +85,20 @@ impl CloudModel {
| Self::Claude3Haiku => 200000,
Self::Gemini15Pro => 128000,
Self::Gemini15Flash => 32000,
Self::Custom(_) => 4096, // TODO: Make this configurable
Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
}
}

pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
request.preprocess_anthropic()
Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku
| Self::Claude3_5Sonnet => {
request.preprocess_anthropic();
}
Self::Custom { name, .. } if name.starts_with("anthropic/") => {
request.preprocess_anthropic();
}
_ => {}
}
Expand Down

0 comments on commit 0155435

Please sign in to comment.