From c6c7e24ebf2e80268a326858951cba080f7507b7 Mon Sep 17 00:00:00 2001 From: David Anyatonwu Date: Mon, 12 Aug 2024 20:17:47 +0000 Subject: [PATCH 1/2] refactor(2665): send system message once in InferTypeName generation --- src/cli/llm/infer_type_name.rs | 66 ++++++++++++++++++++++++++++------ src/cli/llm/wizard.rs | 5 ++- 2 files changed, 57 insertions(+), 14 deletions(-) diff --git a/src/cli/llm/infer_type_name.rs b/src/cli/llm/infer_type_name.rs index ed0869dfd5..695e6dcb24 100644 --- a/src/cli/llm/infer_type_name.rs +++ b/src/cli/llm/infer_type_name.rs @@ -10,6 +10,7 @@ use crate::core::config::Config; #[derive(Default)] pub struct InferTypeName { secret: Option, + wizard: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -75,16 +76,50 @@ impl TryInto for Question { impl InferTypeName { pub fn new(secret: Option) -> InferTypeName { - Self { secret } + Self { secret, wizard: None } } + + fn create_system_messages() -> Vec { + vec![ + ChatMessage::system( + "Given the sample schema of a GraphQL type suggest 5 meaningful names for it.", + ), + ChatMessage::system("The name should be concise and preferably a single word"), + ChatMessage::system("Example Input:"), + ChatMessage::system( + serde_json::to_string_pretty(&Question { + fields: vec![ + ("id".to_string(), "String".to_string()), + ("name".to_string(), "String".to_string()), + ("age".to_string(), "Int".to_string()), + ], + }) + .unwrap(), + ), + ChatMessage::system("Example Output:"), + ChatMessage::system( + serde_json::to_string_pretty(&Answer { + suggestions: vec![ + "Person".into(), + "Profile".into(), + "Member".into(), + "Individual".into(), + "Contact".into(), + ], + }) + .unwrap(), + ), + ChatMessage::system("Ensure the output is in valid JSON format"), + ChatMessage::system("Do not add any additional text before or after the json"), + ] + } + pub async fn generate(&mut self, config: &Config) -> Result> { let secret = self.secret.as_ref().map(|s| s.to_owned()); - - let wizard: Wizard = Wizard::new(groq::LLAMA38192, secret); + self.wizard = Some(Wizard::new(groq::LLAMA38192, secret)); let mut new_name_mappings: HashMap = HashMap::new(); - // removed root type from types. let types_to_be_processed = config .types .iter() @@ -92,8 +127,9 @@ impl InferTypeName { .collect::>(); let total = types_to_be_processed.len(); + let system_messages = Self::create_system_messages(); + for (i, (type_name, type_)) in types_to_be_processed.into_iter().enumerate() { - // convert type to sdl format. let question = Question { fields: type_ .fields @@ -104,7 +140,7 @@ impl InferTypeName { let mut delay = 3; loop { - let answer = wizard.ask(question.clone()).await; + let answer = self.ask_with_context(&system_messages, &question).await; match answer { Ok(answer) => { let name = &answer.suggestions.join(", "); @@ -125,15 +161,10 @@ impl InferTypeName { total ); - // TODO: case where suggested names are already used, then extend the base - // question with `suggest different names, we have already used following - // names: [names list]` break; } Err(e) => { - // TODO: log errors after certain number of retries. if let Error::GenAI(_) = e { - // TODO: retry only when it's required. tracing::warn!( "Unable to retrieve a name for the type '{}'. Retrying in {}s", type_name, @@ -149,6 +180,19 @@ impl InferTypeName { Ok(new_name_mappings.into_iter().map(|(k, v)| (v, k)).collect()) } + + async fn ask_with_context( + &self, + system_messages: &[ChatMessage], + question: &Question, + ) -> Result { + let content = serde_json::to_string(question)?; + let mut messages = system_messages.to_vec(); + messages.push(ChatMessage::user(content)); + + let request = ChatRequest::new(messages); + self.wizard.as_ref().unwrap().ask(request).await + } } #[cfg(test)] diff --git a/src/cli/llm/wizard.rs b/src/cli/llm/wizard.rs index 1604d7f15f..0945ac55ef 100644 --- a/src/cli/llm/wizard.rs +++ b/src/cli/llm/wizard.rs @@ -39,14 +39,13 @@ impl Wizard { } } - pub async fn ask(&self, q: Q) -> Result + pub async fn ask(&self, request: ChatRequest) -> Result where - Q: TryInto, A: TryFrom, { let response = self .client - .exec_chat(self.model.as_str(), q.try_into()?, None) + .exec_chat(self.model.as_str(), request, None) .await?; A::try_from(response) } From 81bfcce4887f427456c3a94550e9619c93de93a0 Mon Sep 17 00:00:00 2001 From: David Anyatonwu Date: Tue, 13 Aug 2024 12:39:34 +0000 Subject: [PATCH 2/2] fix(llm): correct type mismatch in InferTypeName.ask_with_context maintaining previous API --- src/cli/llm/infer_type_name.rs | 8 ++------ src/cli/llm/wizard.rs | 5 +++-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/cli/llm/infer_type_name.rs b/src/cli/llm/infer_type_name.rs index 695e6dcb24..dd79920837 100644 --- a/src/cli/llm/infer_type_name.rs +++ b/src/cli/llm/infer_type_name.rs @@ -186,12 +186,8 @@ impl InferTypeName { system_messages: &[ChatMessage], question: &Question, ) -> Result { - let content = serde_json::to_string(question)?; - let mut messages = system_messages.to_vec(); - messages.push(ChatMessage::user(content)); - - let request = ChatRequest::new(messages); - self.wizard.as_ref().unwrap().ask(request).await + // Remove the conversion to ChatRequest + self.wizard.as_ref().unwrap().ask(question.clone()).await } } diff --git a/src/cli/llm/wizard.rs b/src/cli/llm/wizard.rs index 0945ac55ef..1604d7f15f 100644 --- a/src/cli/llm/wizard.rs +++ b/src/cli/llm/wizard.rs @@ -39,13 +39,14 @@ impl Wizard { } } - pub async fn ask(&self, request: ChatRequest) -> Result + pub async fn ask(&self, q: Q) -> Result where + Q: TryInto, A: TryFrom, { let response = self .client - .exec_chat(self.model.as_str(), request, None) + .exec_chat(self.model.as_str(), q.try_into()?, None) .await?; A::try_from(response) }