Skip to content

Commit

Permalink
refactor(2665): send system message once in InferTypeName generation
Browse files Browse the repository at this point in the history
  • Loading branch information
onyedikachi-david committed Aug 12, 2024
1 parent e7fc1ce commit c6c7e24
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 14 deletions.
66 changes: 55 additions & 11 deletions src/cli/llm/infer_type_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::core::config::Config;
#[derive(Default)]
pub struct InferTypeName {
secret: Option<String>,
wizard: Option<Wizard<Question, Answer>>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -75,25 +76,60 @@ impl TryInto<ChatRequest> for Question {

impl InferTypeName {
pub fn new(secret: Option<String>) -> InferTypeName {
Self { secret }
Self { secret, wizard: None }

Check warning on line 79 in src/cli/llm/infer_type_name.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/infer_type_name.rs#L79

Added line #L79 was not covered by tests
}

fn create_system_messages() -> Vec<ChatMessage> {
vec![
ChatMessage::system(
"Given the sample schema of a GraphQL type suggest 5 meaningful names for it.",
),
ChatMessage::system("The name should be concise and preferably a single word"),
ChatMessage::system("Example Input:"),
ChatMessage::system(
serde_json::to_string_pretty(&Question {
fields: vec![
("id".to_string(), "String".to_string()),
("name".to_string(), "String".to_string()),
("age".to_string(), "Int".to_string()),
],
})
.unwrap(),
),
ChatMessage::system("Example Output:"),
ChatMessage::system(
serde_json::to_string_pretty(&Answer {
suggestions: vec![
"Person".into(),
"Profile".into(),
"Member".into(),
"Individual".into(),
"Contact".into(),
],
})
.unwrap(),
),
ChatMessage::system("Ensure the output is in valid JSON format"),
ChatMessage::system("Do not add any additional text before or after the json"),
]
}

Check warning on line 115 in src/cli/llm/infer_type_name.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/infer_type_name.rs#L82-L115

Added lines #L82 - L115 were not covered by tests

pub async fn generate(&mut self, config: &Config) -> Result<HashMap<String, String>> {
let secret = self.secret.as_ref().map(|s| s.to_owned());

let wizard: Wizard<Question, Answer> = Wizard::new(groq::LLAMA38192, secret);
self.wizard = Some(Wizard::new(groq::LLAMA38192, secret));

Check warning on line 119 in src/cli/llm/infer_type_name.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/infer_type_name.rs#L119

Added line #L119 was not covered by tests

let mut new_name_mappings: HashMap<String, String> = HashMap::new();

// removed root type from types.
let types_to_be_processed = config
.types
.iter()
.filter(|(type_name, _)| !config.is_root_operation_type(type_name))
.collect::<Vec<_>>();

let total = types_to_be_processed.len();
let system_messages = Self::create_system_messages();

Check warning on line 130 in src/cli/llm/infer_type_name.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/infer_type_name.rs#L130

Added line #L130 was not covered by tests

for (i, (type_name, type_)) in types_to_be_processed.into_iter().enumerate() {
// convert type to sdl format.
let question = Question {
fields: type_
.fields
Expand All @@ -104,7 +140,7 @@ impl InferTypeName {

let mut delay = 3;
loop {
let answer = wizard.ask(question.clone()).await;
let answer = self.ask_with_context(&system_messages, &question).await;

Check warning on line 143 in src/cli/llm/infer_type_name.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/infer_type_name.rs#L143

Added line #L143 was not covered by tests
match answer {
Ok(answer) => {
let name = &answer.suggestions.join(", ");
Expand All @@ -125,15 +161,10 @@ impl InferTypeName {
total
);

// TODO: case where suggested names are already used, then extend the base
// question with `suggest different names, we have already used following
// names: [names list]`
break;
}
Err(e) => {
// TODO: log errors after certain number of retries.
if let Error::GenAI(_) = e {
// TODO: retry only when it's required.
tracing::warn!(
"Unable to retrieve a name for the type '{}'. Retrying in {}s",
type_name,
Expand All @@ -149,6 +180,19 @@ impl InferTypeName {

Ok(new_name_mappings.into_iter().map(|(k, v)| (v, k)).collect())
}

async fn ask_with_context(
&self,
system_messages: &[ChatMessage],
question: &Question,
) -> Result<Answer> {
let content = serde_json::to_string(question)?;
let mut messages = system_messages.to_vec();
messages.push(ChatMessage::user(content));

let request = ChatRequest::new(messages);
self.wizard.as_ref().unwrap().ask(request).await
}

Check warning on line 195 in src/cli/llm/infer_type_name.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/infer_type_name.rs#L184-L195

Added lines #L184 - L195 were not covered by tests
}

#[cfg(test)]
Expand Down
5 changes: 2 additions & 3 deletions src/cli/llm/wizard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,13 @@ impl<Q, A> Wizard<Q, A> {
}
}

pub async fn ask(&self, q: Q) -> Result<A>
pub async fn ask(&self, request: ChatRequest) -> Result<A>

Check warning on line 42 in src/cli/llm/wizard.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/wizard.rs#L42

Added line #L42 was not covered by tests
where
Q: TryInto<ChatRequest, Error = super::Error>,
A: TryFrom<ChatResponse, Error = super::Error>,
{
let response = self
.client
.exec_chat(self.model.as_str(), q.try_into()?, None)
.exec_chat(self.model.as_str(), request, None)

Check warning on line 48 in src/cli/llm/wizard.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/wizard.rs#L48

Added line #L48 was not covered by tests
.await?;
A::try_from(response)
}
Expand Down

0 comments on commit c6c7e24

Please sign in to comment.