Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(2658): introduce Adapter #2659

Merged
merged 13 commits into from
Aug 12, 2024
25 changes: 25 additions & 0 deletions src/cli/llm/adapter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use std::fmt::{Display, Formatter};

#[derive(Clone)]
pub enum Adapter {
Groq(GroqModel),
}

#[derive(Clone)]
pub enum GroqModel {
Llama38b8192,
}
ssddOnTop marked this conversation as resolved.
Show resolved Hide resolved

Check warning on line 12 in src/cli/llm/adapter.rs

View workflow job for this annotation

GitHub Actions / Run Formatter and Lint Check

Diff in /home/runner/work/tailcall/tailcall/src/cli/llm/adapter.rs
impl Display for Adapter {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let str = match self { Adapter::Groq(g) => g.to_string() };
write!(f, "{}", str)
}
}

Check warning on line 19 in src/cli/llm/adapter.rs

View workflow job for this annotation

GitHub Actions / Run Formatter and Lint Check

Diff in /home/runner/work/tailcall/tailcall/src/cli/llm/adapter.rs
impl Display for GroqModel {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let str = match self { GroqModel::Llama38b8192 => "Llama3_8b_8192" };
write!(f, "{}", str)
}
}
2 changes: 2 additions & 0 deletions src/cli/llm/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
mod error;
pub mod infer_type_name;
pub use error::Error;

Check warning on line 3 in src/cli/llm/mod.rs

View workflow job for this annotation

GitHub Actions / Run Formatter and Lint Check

Diff in /home/runner/work/tailcall/tailcall/src/cli/llm/mod.rs
use error::Result;
pub use infer_type_name::InferTypeName;
mod wizard;
mod adapter;

pub use wizard::Wizard;
9 changes: 4 additions & 5 deletions src/cli/llm/wizard.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
use derive_setters::Setters;

Check warning on line 1 in src/cli/llm/wizard.rs

View workflow job for this annotation

GitHub Actions / Run Formatter and Lint Check

Diff in /home/runner/work/tailcall/tailcall/src/cli/llm/wizard.rs
use genai::chat::{ChatOptions, ChatRequest, ChatResponse};
use genai::Client;

use crate::cli::llm::adapter::Adapter;
use super::Result;

#[derive(Setters, Clone)]
pub struct Wizard<Q, A> {
client: Client,
// TODO: change model to enum
model: String,
model: Adapter,
_q: std::marker::PhantomData<Q>,
_a: std::marker::PhantomData<A>,
}

impl<Q, A> Wizard<Q, A> {
pub fn new(model: String) -> Self {
pub fn new(model: Adapter) -> Self {
ssddOnTop marked this conversation as resolved.
Show resolved Hide resolved
Self {
client: Client::builder()
.with_chat_options(ChatOptions::default().with_json_mode(true))
Expand All @@ -32,7 +31,7 @@
{
let response = self
.client
.exec_chat(self.model.as_str(), q.try_into()?, None)
.exec_chat(self.model.to_string().as_str(), q.try_into()?, None)
.await?;
A::try_from(response)
}
Expand Down
Loading