Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(2658): introduce Adapter #2659

Merged
merged 13 commits into from
Aug 12, 2024
7 changes: 6 additions & 1 deletion src/cli/generator/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,12 @@
let mut config = config_gen.generate(true)?;

if infer_type_names {
let mut llm_gen = InferTypeName::default();
let key = self
.runtime
.env
.get("TAILCALL_SECRET")
.map(|s| s.into_owned());
let mut llm_gen = InferTypeName::new(key);

Check warning on line 173 in src/cli/generator/generator.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/generator/generator.rs#L168-L173

Added lines #L168 - L173 were not covered by tests
let suggested_names = llm_gen.generate(config.config()).await?;
let cfg = RenameTypes::new(suggested_names.iter())
.transform(config.config().to_owned())
Expand Down
14 changes: 10 additions & 4 deletions src/cli/llm/infer_type_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
use genai::chat::{ChatMessage, ChatRequest, ChatResponse};
use serde::{Deserialize, Serialize};

use super::model::groq;
use super::{Error, Result, Wizard};
use crate::core::config::Config;

const MODEL: &str = "llama3-8b-8192";

#[derive(Default)]
pub struct InferTypeName {}
pub struct InferTypeName {
secret: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct Answer {
Expand Down Expand Up @@ -73,8 +74,13 @@
}

impl InferTypeName {
pub fn new(secret: Option<String>) -> InferTypeName {
Self { secret }
}

Check warning on line 79 in src/cli/llm/infer_type_name.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/infer_type_name.rs#L77-L79

Added lines #L77 - L79 were not covered by tests
pub async fn generate(&mut self, config: &Config) -> Result<HashMap<String, String>> {
let wizard: Wizard<Question, Answer> = Wizard::new(MODEL.to_string());
let secret = self.secret.as_ref().map(|s| s.to_owned());

let wizard: Wizard<Question, Answer> = Wizard::new(groq::LLAMA38192, secret);

Check warning on line 83 in src/cli/llm/infer_type_name.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/infer_type_name.rs#L81-L83

Added lines #L81 - L83 were not covered by tests

let mut new_name_mappings: HashMap<String, String> = HashMap::new();

Expand Down
2 changes: 2 additions & 0 deletions src/cli/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@ pub mod infer_type_name;
pub use error::Error;
use error::Result;
pub use infer_type_name::InferTypeName;
mod model;
mod wizard;

pub use wizard::Wizard;
73 changes: 73 additions & 0 deletions src/cli/llm/model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#![allow(unused)]

use std::borrow::Cow;
use std::fmt::{Display, Formatter};
use std::marker::PhantomData;

use derive_setters::Setters;
use genai::adapter::AdapterKind;

#[derive(Clone)]
pub struct Model(&'static str);

pub mod open_ai {
use super::*;
pub const GPT3_5_TURBO: Model = Model("gp-3.5-turbo");
pub const GPT4: Model = Model("gpt-4");
pub const GPT4_TURBO: Model = Model("gpt-4-turbo");
pub const GPT4O_MINI: Model = Model("gpt-4o-mini");
pub const GPT4O: Model = Model("gpt-4o");
}

pub mod ollama {
use super::*;
pub const GEMMA2B: Model = Model("gemma:2b");
}

pub mod anthropic {
use super::*;
pub const CLAUDE3_HAIKU_20240307: Model = Model("claude-3-haiku-20240307");
pub const CLAUDE3_SONNET_20240229: Model = Model("claude-3-sonnet-20240229");
pub const CLAUDE3_OPUS_20240229: Model = Model("claude-3-opus-20240229");
pub const CLAUDE35_SONNET_20240620: Model = Model("claude-3-5-sonnet-20240620");
}

pub mod cohere {
use super::*;
pub const COMMAND_LIGHT_NIGHTLY: Model = Model("command-light-nightly");
pub const COMMAND_LIGHT: Model = Model("command-light");
pub const COMMAND_NIGHTLY: Model = Model("command-nightly");
pub const COMMAND: Model = Model("command");
pub const COMMAND_R: Model = Model("command-r");
pub const COMMAND_R_PLUS: Model = Model("command-r-plus");
}

pub mod gemini {
use super::*;
pub const GEMINI15_FLASH_LATEST: Model = Model("gemini-1.5-flash-latest");
pub const GEMINI10_PRO: Model = Model("gemini-1.0-pro");
pub const GEMINI15_FLASH: Model = Model("gemini-1.5-flash");
pub const GEMINI15_PRO: Model = Model("gemini-1.5-pro");
}

pub mod groq {
use super::*;
pub const LLAMA708192: Model = Model("llama3-70b-8192");
pub const LLAMA38192: Model = Model("llama3-8b-8192");
pub const LLAMA_GROQ8B8192_TOOL_USE_PREVIEW: Model =
Model("llama3-groq-8b-8192-tool-use-preview");
pub const LLAMA_GROQ70B8192_TOOL_USE_PREVIEW: Model =
Model("llama3-groq-70b-8192-tool-use-preview");
pub const GEMMA29B_IT: Model = Model("gemma2-9b-it");
pub const GEMMA7B_IT: Model = Model("gemma-7b-it");
pub const MIXTRAL_8X7B32768: Model = Model("mixtral-8x7b-32768");
pub const LLAMA8B_INSTANT: Model = Model("llama-3.1-8b-instant");
pub const LLAMA70B_VERSATILE: Model = Model("llama-3.1-70b-versatile");
pub const LLAMA405B_REASONING: Model = Model("llama-3.1-405b-reasoning");
}

impl Model {
pub fn as_str(&self) -> &'static str {
self.0
}

Check warning on line 72 in src/cli/llm/model.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/model.rs#L70-L72

Added lines #L70 - L72 were not covered by tests
}
15 changes: 12 additions & 3 deletions src/cli/llm/wizard.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,36 @@
use derive_setters::Setters;
use genai::adapter::AdapterKind;
use genai::chat::{ChatOptions, ChatRequest, ChatResponse};
use genai::Client;

use super::Result;
use crate::cli::llm::model::Model;

#[derive(Setters, Clone)]
pub struct Wizard<Q, A> {
client: Client,
// TODO: change model to enum
model: String,
model: Model,
tusharmath marked this conversation as resolved.
Show resolved Hide resolved
_q: std::marker::PhantomData<Q>,
_a: std::marker::PhantomData<A>,
}

impl<Q, A> Wizard<Q, A> {
pub fn new(model: String) -> Self {
pub fn new(model: Model, secret: Option<String>) -> Self {
let mut config = genai::adapter::AdapterConfig::default();
if let Some(key) = secret {
config = config.with_auth_env_name(key);
}

Check warning on line 22 in src/cli/llm/wizard.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/wizard.rs#L18-L22

Added lines #L18 - L22 were not covered by tests

let adapter = AdapterKind::from_model(model.as_str()).unwrap_or(AdapterKind::Ollama);

Check warning on line 25 in src/cli/llm/wizard.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/wizard.rs#L24-L25

Added lines #L24 - L25 were not covered by tests
Self {
client: Client::builder()
.with_chat_options(
ChatOptions::default()
.with_json_mode(true)
.with_temperature(0.0),
)
.insert_adapter_config(adapter, config)

Check warning on line 33 in src/cli/llm/wizard.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/wizard.rs#L33

Added line #L33 was not covered by tests
.build(),
model,
_q: Default::default(),
Expand Down
Loading