From 6ec62916a9fa820bd2e9aa1b2a176d5a739b7e35 Mon Sep 17 00:00:00 2001 From: Sahil Yeole Date: Sat, 17 Aug 2024 17:00:08 +0530 Subject: [PATCH] make model enum and rename to secret Signed-off-by: Sahil Yeole --- src/cli/generator/config.rs | 17 ++-- src/cli/llm/infer_type_name.rs | 15 ++-- src/cli/llm/mod.rs | 1 + src/cli/llm/model.rs | 145 +++++++++++++++++++++++++++++++++ src/cli/llm/wizard.rs | 7 +- 5 files changed, 167 insertions(+), 18 deletions(-) create mode 100644 src/cli/llm/model.rs diff --git a/src/cli/generator/config.rs b/src/cli/generator/config.rs index aa79e16eb08..45f5c772425 100644 --- a/src/cli/generator/config.rs +++ b/src/cli/generator/config.rs @@ -9,6 +9,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use url::Url; +use crate::cli::llm::model::Model; use crate::core::config::transformer::Preset; use crate::core::config::{self, ConfigReaderContext}; use crate::core::http::Method; @@ -32,9 +33,9 @@ pub struct Config { #[serde(rename_all = "camelCase")] #[serde(deny_unknown_fields)] pub struct LLMConfig { - pub model: String, + pub model: Model, #[serde(skip_serializing_if = "TemplateString::is_empty")] - pub api_key: TemplateString, + pub secret: TemplateString, } #[derive(Clone, Deserialize, Serialize, Debug, Default)] @@ -281,8 +282,8 @@ impl Config { .collect::>>>()?; let output = self.output.resolve(parent_dir)?; - let llm_api_key = self.llm.api_key.resolve(&reader_context); - let llm = LLMConfig { model: self.llm.model, api_key: llm_api_key }; + let llm_api_key = self.llm.secret.resolve(&reader_context); + let llm = LLMConfig { model: self.llm.model, secret: llm_api_key }; Ok(Config { inputs, @@ -520,15 +521,15 @@ mod tests { }; let config = Config::default().llm(LLMConfig { - model: "gpt-3.5-turbo".to_string(), - api_key: TemplateString::parse("{{.env.TAILCALL_LLM_API_KEY}}").unwrap(), + model: Model::Gpt3_5Turbo, + secret: TemplateString::parse("{{.env.TAILCALL_LLM_API_KEY}}").unwrap(), }); let resolved_config = config.into_resolved("", reader_ctx).unwrap(); let actual = resolved_config.llm; let expected = LLMConfig { - model: "gpt-3.5-turbo".to_string(), - api_key: TemplateString::try_from(token).unwrap(), + model: Model::Gpt3_5Turbo, + secret: TemplateString::try_from(token).unwrap(), }; assert_eq!(actual, expected); diff --git a/src/cli/llm/infer_type_name.rs b/src/cli/llm/infer_type_name.rs index 77154019c17..ee41a0c4b66 100644 --- a/src/cli/llm/infer_type_name.rs +++ b/src/cli/llm/infer_type_name.rs @@ -3,14 +3,15 @@ use std::collections::HashMap; use genai::chat::{ChatMessage, ChatRequest, ChatResponse}; use serde::{Deserialize, Serialize}; +use super::model::Model; use super::{Error, Result, Wizard}; use crate::cli::generator::config::LLMConfig; use crate::core::config::Config; #[derive(Default)] pub struct InferTypeName { - model: String, - api_key: Option, + model: Model, + secret: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -76,17 +77,17 @@ impl TryInto for Question { impl InferTypeName { pub fn new(llm_config: LLMConfig) -> InferTypeName { - let api_key = if !llm_config.api_key.is_empty() { - Some(llm_config.api_key.to_string()) + let secret = if !llm_config.secret.is_empty() { + Some(llm_config.secret.to_string()) } else { None }; - Self { model: llm_config.model, api_key } + Self { model: llm_config.model, secret } } pub async fn generate(&mut self, config: &Config) -> Result> { - let api_key = self.api_key.as_ref().map(|s| s.to_owned()); + let secret = self.secret.as_ref().map(|s| s.to_owned()); - let wizard: Wizard = Wizard::new(self.model.clone(), api_key); + let wizard: Wizard = Wizard::new(self.model, secret); let mut new_name_mappings: HashMap = HashMap::new(); diff --git a/src/cli/llm/mod.rs b/src/cli/llm/mod.rs index 40c0dce6102..e756a0b8f84 100644 --- a/src/cli/llm/mod.rs +++ b/src/cli/llm/mod.rs @@ -3,6 +3,7 @@ pub mod infer_type_name; pub use error::Error; use error::Result; pub use infer_type_name::InferTypeName; +pub mod model; mod wizard; pub use wizard::Wizard; diff --git a/src/cli/llm/model.rs b/src/cli/llm/model.rs new file mode 100644 index 00000000000..c258eb92e9b --- /dev/null +++ b/src/cli/llm/model.rs @@ -0,0 +1,145 @@ +use std::collections::HashMap; + +use serde::de::{self, Deserializer}; +use serde::ser::Serializer; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Default)] +pub enum Model { + // OpenAI Models + #[default] + Gpt3_5Turbo, + GPT4, + Gpt4Turbo, + Gpt4oMini, + GPT4O, + + // Ollama Models + Gemma2B, + Gemma2, + Llama3_1, + + // Anthropic Models + Claude3Haiku20240307, + Claude3Sonnet20240229, + Claude3Opus20240229, + Claude35Sonnet20240620, + + // Cohere Models + CommandLightNightly, + CommandLight, + CommandNightly, + Command, + CommandR, + CommandRPlus, + + // Gemini Models + Gemini15FlashLatest, + Gemini10Pro, + Gemini15Flash, + Gemini15Pro, + + // Groq Models + LLAMA708192, + LLAMA38192, + LlamaGroq8b8192ToolUsePreview, + LlamaGroq70b8192ToolUsePreview, + Gemma29bIt, + Gemma7bIt, + Mixtral8x7b32768, + Llama8bInstant, + Llama70bVersatile, + Llama405bReasoning, +} + +impl Model { + fn model_hashmap() -> &'static HashMap<&'static str, Model> { + static MAP: once_cell::sync::Lazy> = + once_cell::sync::Lazy::new(|| { + let mut map = HashMap::new(); + + // OpenAI Models + map.insert("gp-3.5-turbo", Model::Gpt3_5Turbo); + map.insert("gpt-4", Model::GPT4); + map.insert("gpt-4-turbo", Model::Gpt4Turbo); + map.insert("gpt-4o-mini", Model::Gpt4oMini); + map.insert("gpt-4o", Model::GPT4O); + + // Ollama Models + map.insert("gemma:2b", Model::Gemma2B); + map.insert("gemma2", Model::Gemma2); + map.insert("llama3.1", Model::Llama3_1); + + // Anthropic Models + map.insert("claude-3-haiku-20240307", Model::Claude3Haiku20240307); + map.insert("claude-3-sonnet-20240229", Model::Claude3Sonnet20240229); + map.insert("claude-3-opus-20240229", Model::Claude3Opus20240229); + map.insert("claude-3-5-sonnet-20240620", Model::Claude35Sonnet20240620); + + // Cohere Models + map.insert("command-light-nightly", Model::CommandLightNightly); + map.insert("command-light", Model::CommandLight); + map.insert("command-nightly", Model::CommandNightly); + map.insert("command", Model::Command); + map.insert("command-r", Model::CommandR); + map.insert("command-r-plus", Model::CommandRPlus); + + // Gemini Models + map.insert("gemini-1.5-flash-latest", Model::Gemini15FlashLatest); + map.insert("gemini-1.0-pro", Model::Gemini10Pro); + map.insert("gemini-1.5-flash", Model::Gemini15Flash); + map.insert("gemini-1.5-pro", Model::Gemini15Pro); + + // Groq Models + map.insert("llama3-70b-8192", Model::LLAMA708192); + map.insert("llama3-8b-8192", Model::LLAMA38192); + map.insert( + "llama3-groq-8b-8192-tool-use-preview", + Model::LlamaGroq8b8192ToolUsePreview, + ); + map.insert( + "llama3-groq-70b-8192-tool-use-preview", + Model::LlamaGroq70b8192ToolUsePreview, + ); + map.insert("gemma2-9b-it", Model::Gemma29bIt); + map.insert("gemma-7b-it", Model::Gemma7bIt); + map.insert("mixtral-8x7b-32768", Model::Mixtral8x7b32768); + map.insert("llama-3.1-8b-instant", Model::Llama8bInstant); + map.insert("llama-3.1-70b-versatile", Model::Llama70bVersatile); + map.insert("llama-3.1-405b-reasoning", Model::Llama405bReasoning); + + map + }); + &MAP + } + + pub fn from_str(model_name: &str) -> Option { + Self::model_hashmap().get(model_name).copied() + } + + pub fn as_str(&self) -> &'static str { + Self::model_hashmap() + .iter() + .find_map(|(&k, &v)| if v == *self { Some(k) } else { None }) + .unwrap_or_default() + } +} + +impl Serialize for Model { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(self.as_str()) + } +} + +impl<'de> Deserialize<'de> for Model { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + Model::from_str(&s).ok_or_else(|| de::Error::unknown_variant(&s, &[])) + } +} diff --git a/src/cli/llm/wizard.rs b/src/cli/llm/wizard.rs index f90deca70dd..0c9291e4b0d 100644 --- a/src/cli/llm/wizard.rs +++ b/src/cli/llm/wizard.rs @@ -4,20 +4,21 @@ use genai::chat::{ChatOptions, ChatRequest, ChatResponse}; use genai::resolver::AuthResolver; use genai::Client; +use super::model::Model; use super::Result; #[derive(Setters, Clone)] pub struct Wizard { client: Client, - model: String, + model: Model, _q: std::marker::PhantomData, _a: std::marker::PhantomData, } impl Wizard { - pub fn new(model: String, api_key: Option) -> Self { + pub fn new(model: Model, secret: Option) -> Self { let mut config = genai::adapter::AdapterConfig::default(); - if let Some(key) = api_key { + if let Some(key) = secret { config = config.with_auth_resolver(AuthResolver::from_key_value(key)); }