diff --git a/src/cli/llm/infer_type_name.rs b/src/cli/llm/infer_type_name.rs index 61ae8bb360..6bcaa9d9df 100644 --- a/src/cli/llm/infer_type_name.rs +++ b/src/cli/llm/infer_type_name.rs @@ -12,6 +12,7 @@ use crate::core::Mustache; #[derive(Default)] pub struct InferTypeName { secret: Option, + wizard: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -75,16 +76,50 @@ impl TryInto for Question { impl InferTypeName { pub fn new(secret: Option) -> InferTypeName { - Self { secret } + Self { secret, wizard: None } } + + fn create_system_messages() -> Vec { + vec![ + ChatMessage::system( + "Given the sample schema of a GraphQL type suggest 5 meaningful names for it.", + ), + ChatMessage::system("The name should be concise and preferably a single word"), + ChatMessage::system("Example Input:"), + ChatMessage::system( + serde_json::to_string_pretty(&Question { + fields: vec![ + ("id".to_string(), "String".to_string()), + ("name".to_string(), "String".to_string()), + ("age".to_string(), "Int".to_string()), + ], + }) + .unwrap(), + ), + ChatMessage::system("Example Output:"), + ChatMessage::system( + serde_json::to_string_pretty(&Answer { + suggestions: vec![ + "Person".into(), + "Profile".into(), + "Member".into(), + "Individual".into(), + "Contact".into(), + ], + }) + .unwrap(), + ), + ChatMessage::system("Ensure the output is in valid JSON format"), + ChatMessage::system("Do not add any additional text before or after the json"), + ] + } + pub async fn generate(&mut self, config: &Config) -> Result> { let secret = self.secret.as_ref().map(|s| s.to_owned()); - - let wizard: Wizard = Wizard::new(groq::LLAMA38192, secret); + self.wizard = Some(Wizard::new(groq::LLAMA38192, secret)); let mut new_name_mappings: HashMap = HashMap::new(); - // removed root type from types. let types_to_be_processed = config .types .iter() @@ -92,8 +127,9 @@ impl InferTypeName { .collect::>(); let total = types_to_be_processed.len(); + let system_messages = Self::create_system_messages(); + for (i, (type_name, type_)) in types_to_be_processed.into_iter().enumerate() { - // convert type to sdl format. let question = Question { fields: type_ .fields @@ -104,7 +140,7 @@ impl InferTypeName { let mut delay = 3; loop { - let answer = wizard.ask(question.clone()).await; + let answer = self.ask_with_context(&system_messages, &question).await; match answer { Ok(answer) => { let name = &answer.suggestions.join(", "); @@ -125,15 +161,10 @@ impl InferTypeName { total ); - // TODO: case where suggested names are already used, then extend the base - // question with `suggest different names, we have already used following - // names: [names list]` break; } Err(e) => { - // TODO: log errors after certain number of retries. if let Error::GenAI(_) = e { - // TODO: retry only when it's required. tracing::warn!( "Unable to retrieve a name for the type '{}'. Retrying in {}s", type_name, @@ -149,6 +180,15 @@ impl InferTypeName { Ok(new_name_mappings.into_iter().map(|(k, v)| (v, k)).collect()) } + + async fn ask_with_context( + &self, + system_messages: &[ChatMessage], + question: &Question, + ) -> Result { + // Remove the conversion to ChatRequest + self.wizard.as_ref().unwrap().ask(question.clone()).await + } } #[cfg(test)]