Skip to content

Commit

Permalink
make model enum and rename to secret
Browse files Browse the repository at this point in the history
Signed-off-by: Sahil Yeole <[email protected]>
  • Loading branch information
beelchester committed Aug 17, 2024
1 parent a6cac2d commit 6ec6291
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 18 deletions.
17 changes: 9 additions & 8 deletions src/cli/generator/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use url::Url;

use crate::cli::llm::model::Model;
use crate::core::config::transformer::Preset;
use crate::core::config::{self, ConfigReaderContext};
use crate::core::http::Method;
Expand All @@ -32,9 +33,9 @@ pub struct Config<Status = UnResolved> {
#[serde(rename_all = "camelCase")]
#[serde(deny_unknown_fields)]
pub struct LLMConfig {
pub model: String,
pub model: Model,
#[serde(skip_serializing_if = "TemplateString::is_empty")]
pub api_key: TemplateString,
pub secret: TemplateString,
}

#[derive(Clone, Deserialize, Serialize, Debug, Default)]
Expand Down Expand Up @@ -281,8 +282,8 @@ impl Config {
.collect::<anyhow::Result<Vec<Input<Resolved>>>>()?;

let output = self.output.resolve(parent_dir)?;
let llm_api_key = self.llm.api_key.resolve(&reader_context);
let llm = LLMConfig { model: self.llm.model, api_key: llm_api_key };
let llm_api_key = self.llm.secret.resolve(&reader_context);
let llm = LLMConfig { model: self.llm.model, secret: llm_api_key };

Ok(Config {
inputs,
Expand Down Expand Up @@ -520,15 +521,15 @@ mod tests {
};

let config = Config::default().llm(LLMConfig {
model: "gpt-3.5-turbo".to_string(),
api_key: TemplateString::parse("{{.env.TAILCALL_LLM_API_KEY}}").unwrap(),
model: Model::Gpt3_5Turbo,
secret: TemplateString::parse("{{.env.TAILCALL_LLM_API_KEY}}").unwrap(),
});
let resolved_config = config.into_resolved("", reader_ctx).unwrap();

let actual = resolved_config.llm;
let expected = LLMConfig {
model: "gpt-3.5-turbo".to_string(),
api_key: TemplateString::try_from(token).unwrap(),
model: Model::Gpt3_5Turbo,
secret: TemplateString::try_from(token).unwrap(),
};

assert_eq!(actual, expected);
Expand Down
15 changes: 8 additions & 7 deletions src/cli/llm/infer_type_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ use std::collections::HashMap;
use genai::chat::{ChatMessage, ChatRequest, ChatResponse};
use serde::{Deserialize, Serialize};

use super::model::Model;
use super::{Error, Result, Wizard};
use crate::cli::generator::config::LLMConfig;
use crate::core::config::Config;

#[derive(Default)]
pub struct InferTypeName {
model: String,
api_key: Option<String>,
model: Model,
secret: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -76,17 +77,17 @@ impl TryInto<ChatRequest> for Question {

impl InferTypeName {
pub fn new(llm_config: LLMConfig) -> InferTypeName {
let api_key = if !llm_config.api_key.is_empty() {
Some(llm_config.api_key.to_string())
let secret = if !llm_config.secret.is_empty() {
Some(llm_config.secret.to_string())
} else {
None
};
Self { model: llm_config.model, api_key }
Self { model: llm_config.model, secret }
}
pub async fn generate(&mut self, config: &Config) -> Result<HashMap<String, String>> {
let api_key = self.api_key.as_ref().map(|s| s.to_owned());
let secret = self.secret.as_ref().map(|s| s.to_owned());

let wizard: Wizard<Question, Answer> = Wizard::new(self.model.clone(), api_key);
let wizard: Wizard<Question, Answer> = Wizard::new(self.model, secret);

let mut new_name_mappings: HashMap<String, String> = HashMap::new();

Expand Down
1 change: 1 addition & 0 deletions src/cli/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod infer_type_name;
pub use error::Error;
use error::Result;
pub use infer_type_name::InferTypeName;
pub mod model;
mod wizard;

pub use wizard::Wizard;
145 changes: 145 additions & 0 deletions src/cli/llm/model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
use std::collections::HashMap;

use serde::de::{self, Deserializer};
use serde::ser::Serializer;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum Model {
// OpenAI Models
#[default]
Gpt3_5Turbo,
GPT4,
Gpt4Turbo,
Gpt4oMini,
GPT4O,

// Ollama Models
Gemma2B,
Gemma2,
Llama3_1,

// Anthropic Models
Claude3Haiku20240307,
Claude3Sonnet20240229,
Claude3Opus20240229,
Claude35Sonnet20240620,

// Cohere Models
CommandLightNightly,
CommandLight,
CommandNightly,
Command,
CommandR,
CommandRPlus,

// Gemini Models
Gemini15FlashLatest,
Gemini10Pro,
Gemini15Flash,
Gemini15Pro,

// Groq Models
LLAMA708192,
LLAMA38192,
LlamaGroq8b8192ToolUsePreview,
LlamaGroq70b8192ToolUsePreview,
Gemma29bIt,
Gemma7bIt,
Mixtral8x7b32768,
Llama8bInstant,
Llama70bVersatile,
Llama405bReasoning,
}

impl Model {
fn model_hashmap() -> &'static HashMap<&'static str, Model> {
static MAP: once_cell::sync::Lazy<HashMap<&'static str, Model>> =
once_cell::sync::Lazy::new(|| {
let mut map = HashMap::new();

// OpenAI Models
map.insert("gp-3.5-turbo", Model::Gpt3_5Turbo);
map.insert("gpt-4", Model::GPT4);
map.insert("gpt-4-turbo", Model::Gpt4Turbo);
map.insert("gpt-4o-mini", Model::Gpt4oMini);
map.insert("gpt-4o", Model::GPT4O);

// Ollama Models
map.insert("gemma:2b", Model::Gemma2B);
map.insert("gemma2", Model::Gemma2);
map.insert("llama3.1", Model::Llama3_1);

// Anthropic Models
map.insert("claude-3-haiku-20240307", Model::Claude3Haiku20240307);
map.insert("claude-3-sonnet-20240229", Model::Claude3Sonnet20240229);
map.insert("claude-3-opus-20240229", Model::Claude3Opus20240229);
map.insert("claude-3-5-sonnet-20240620", Model::Claude35Sonnet20240620);

// Cohere Models
map.insert("command-light-nightly", Model::CommandLightNightly);
map.insert("command-light", Model::CommandLight);
map.insert("command-nightly", Model::CommandNightly);
map.insert("command", Model::Command);
map.insert("command-r", Model::CommandR);
map.insert("command-r-plus", Model::CommandRPlus);

// Gemini Models
map.insert("gemini-1.5-flash-latest", Model::Gemini15FlashLatest);
map.insert("gemini-1.0-pro", Model::Gemini10Pro);
map.insert("gemini-1.5-flash", Model::Gemini15Flash);
map.insert("gemini-1.5-pro", Model::Gemini15Pro);

// Groq Models
map.insert("llama3-70b-8192", Model::LLAMA708192);
map.insert("llama3-8b-8192", Model::LLAMA38192);
map.insert(
"llama3-groq-8b-8192-tool-use-preview",
Model::LlamaGroq8b8192ToolUsePreview,
);
map.insert(
"llama3-groq-70b-8192-tool-use-preview",
Model::LlamaGroq70b8192ToolUsePreview,
);
map.insert("gemma2-9b-it", Model::Gemma29bIt);
map.insert("gemma-7b-it", Model::Gemma7bIt);
map.insert("mixtral-8x7b-32768", Model::Mixtral8x7b32768);
map.insert("llama-3.1-8b-instant", Model::Llama8bInstant);
map.insert("llama-3.1-70b-versatile", Model::Llama70bVersatile);
map.insert("llama-3.1-405b-reasoning", Model::Llama405bReasoning);

map
});
&MAP
}

pub fn from_str(model_name: &str) -> Option<Self> {
Self::model_hashmap().get(model_name).copied()
}

pub fn as_str(&self) -> &'static str {
Self::model_hashmap()
.iter()
.find_map(|(&k, &v)| if v == *self { Some(k) } else { None })
.unwrap_or_default()
}
}

impl Serialize for Model {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.as_str())
}
}

impl<'de> Deserialize<'de> for Model {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Model::from_str(&s).ok_or_else(|| de::Error::unknown_variant(&s, &[]))
}
}
7 changes: 4 additions & 3 deletions src/cli/llm/wizard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@ use genai::chat::{ChatOptions, ChatRequest, ChatResponse};
use genai::resolver::AuthResolver;
use genai::Client;

use super::model::Model;
use super::Result;

#[derive(Setters, Clone)]
pub struct Wizard<Q, A> {
client: Client,
model: String,
model: Model,
_q: std::marker::PhantomData<Q>,
_a: std::marker::PhantomData<A>,
}

impl<Q, A> Wizard<Q, A> {
pub fn new(model: String, api_key: Option<String>) -> Self {
pub fn new(model: Model, secret: Option<String>) -> Self {
let mut config = genai::adapter::AdapterConfig::default();
if let Some(key) = api_key {
if let Some(key) = secret {
config = config.with_auth_resolver(AuthResolver::from_key_value(key));
}

Expand Down

0 comments on commit 6ec6291

Please sign in to comment.