From f043c9732ac0f1fd38173cb556105093e2b2612d Mon Sep 17 00:00:00 2001 From: laststylebender Date: Tue, 3 Sep 2024 20:58:46 +0530 Subject: [PATCH] - [draft] generate argument names with LLM. --- src/cli/generator/generator.rs | 13 +- src/cli/llm/infer_arg_name.rs | 210 ++++++++++++++++++ src/cli/llm/mod.rs | 1 + src/cli/llm/prompts/infer_arg_name.md | 19 ++ src/core/config/transformer/mod.rs | 2 + src/core/config/transformer/rename_args.rs | 159 +++++++++++++ ...rmer__rename_args__tests__rename_args.snap | 7 + .../generator/json/operation_generator.rs | 2 +- 8 files changed, 410 insertions(+), 3 deletions(-) create mode 100644 src/cli/llm/infer_arg_name.rs create mode 100644 src/cli/llm/prompts/infer_arg_name.md create mode 100644 src/core/config/transformer/rename_args.rs create mode 100644 src/core/config/transformer/snapshots/tailcall__core__config__transformer__rename_args__tests__rename_args.snap diff --git a/src/cli/generator/generator.rs b/src/cli/generator/generator.rs index 6da3ac426e..0e2010d98a 100644 --- a/src/cli/generator/generator.rs +++ b/src/cli/generator/generator.rs @@ -8,8 +8,9 @@ use pathdiff::diff_paths; use super::config::{Config, LLMConfig, Resolved, Source}; use super::source::ConfigSource; +use crate::cli::llm::infer_arg_name::InferArgName; use crate::cli::llm::InferTypeName; -use crate::core::config::transformer::{Preset, RenameTypes}; +use crate::core::config::transformer::{Preset, RenameArgs, RenameTypes}; use crate::core::config::{self, ConfigModule, ConfigReaderContext}; use crate::core::generator::{Generator as ConfigGenerator, Input}; use crate::core::proto_reader::ProtoReader; @@ -184,12 +185,20 @@ impl Generator { if infer_type_names { if let Some(LLMConfig { model: Some(model), secret }) = llm { - let mut llm_gen = InferTypeName::new(model, secret.map(|s| s.to_string())); + let mut llm_gen = + InferTypeName::new(model.clone(), secret.clone().map(|s| s.to_string())); let suggested_names = llm_gen.generate(config.config()).await?; let cfg = RenameTypes::new(suggested_names.iter()) .transform(config.config().to_owned()) .to_result()?; + let mut llm_gen = InferArgName::new(model, secret.map(|s| s.to_string())); + let suggested_names = llm_gen.generate(&cfg).await?; + + let cfg = RenameArgs::new(suggested_names) + .transform(cfg) + .to_result()?; + config = ConfigModule::from(cfg); } } diff --git a/src/cli/llm/infer_arg_name.rs b/src/cli/llm/infer_arg_name.rs new file mode 100644 index 0000000000..8b23aa7132 --- /dev/null +++ b/src/cli/llm/infer_arg_name.rs @@ -0,0 +1,210 @@ +use genai::chat::{ChatMessage, ChatRequest, ChatResponse}; +use indexmap::IndexMap; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +use super::{Error, Result, Wizard}; +use crate::core::config::transformer::ArgumentInfo; +use crate::core::config::{Config, Resolver}; +use crate::core::Mustache; + +pub struct InferArgName { + wizard: Wizard, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct TypeInfo { + name: String, + #[serde(rename = "outputType")] + output_type: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct FieldMapping { + argument: TypeInfo, + field: TypeInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct Answer { + suggestions: Vec, +} + +impl TryFrom for Answer { + type Error = Error; + + fn try_from(response: ChatResponse) -> Result { + let message_content = response.content.ok_or(Error::EmptyResponse)?; + let text_content = message_content.text_as_str().ok_or(Error::EmptyResponse)?; + Ok(serde_json::from_str(text_content)?) + } +} + +#[derive(Clone, Serialize)] +struct Question { + args_info: FieldMapping, +} + +impl TryInto for Question { + type Error = Error; + + fn try_into(self) -> Result { + let input2 = FieldMapping { + argument: { + TypeInfo { + name: "input2".to_string(), + output_type: "Article".to_string(), + } + }, + field: { + TypeInfo { + name: "createPost".to_string(), + output_type: "Post".to_string(), + } + }, + }; + + let input = serde_json::to_string_pretty(&Question { args_info: input2 })?; + let output = serde_json::to_string_pretty(&Answer { + suggestions: vec![ + "createPostInput".into(), + "postInput".into(), + "articleInput".into(), + "noteInput".into(), + "messageInput".into(), + ], + })?; + + let template_str = include_str!("prompts/infer_arg_name.md"); + let template = Mustache::parse(template_str); + + let context = json!({ + "input": input, + "output": output, + "count": 5, + }); + + let rendered_prompt = template.render(&context); + + Ok(ChatRequest::new(vec![ + ChatMessage::system(rendered_prompt), + ChatMessage::user(serde_json::to_string(&self)?), + ])) + } +} + +impl InferArgName { + pub fn new(model: String, secret: Option) -> InferArgName { + Self { wizard: Wizard::new(model, secret) } + } + + pub async fn generate(&mut self, config: &Config) -> Result> { + let mut mapping: IndexMap = IndexMap::new(); + + for (type_name, type_) in config.types.iter() { + // collect all the args that's needs to be processed with LLM. + for (field_name, field) in type_.fields.iter() { + if field.args.is_empty() { + continue; + } + // filter out query params as we shouldn't change the names of query params. + for (arg_name, arg) in field.args.iter().filter(|(k, _)| match &field.resolver { + Some(Resolver::Http(http)) => !http.query.iter().any(|q| &q.key == *k), + _ => true, + }) { + let question = FieldMapping { + argument: TypeInfo { + name: arg_name.to_string(), + output_type: arg.type_of.name().to_owned(), + }, + field: TypeInfo { + name: field_name.to_string(), + output_type: field.type_of.name().to_owned(), + }, + }; + + let question = Question { args_info: question }; + + let mut delay = 3; + loop { + let answer = self.wizard.ask(question.clone()).await; + match answer { + Ok(answer) => { + tracing::info!( + "Suggestions for Argument {}: [{:?}]", + arg_name, + answer.suggestions, + ); + mapping.insert( + arg_name.to_owned(), + ArgumentInfo::new( + answer.suggestions, + field_name.to_owned(), + type_name.to_owned(), + ), + ); + break; + } + Err(e) => { + // TODO: log errors after certain number of retries. + if let Error::GenAI(_) = e { + // TODO: retry only when it's required. + tracing::warn!( + "Unable to retrieve a name for the type '{}'. Retrying in {}s", + type_name, + delay + ); + tokio::time::sleep(tokio::time::Duration::from_secs(delay)) + .await; + delay *= std::cmp::min(delay * 2, 60); + } + } + } + } + } + } + } + + Ok(mapping) + } +} + +#[cfg(test)] +mod test { + use genai::chat::{ChatRequest, ChatResponse, MessageContent}; + + use super::{Answer, Question}; + use crate::cli::llm::infer_arg_name::{FieldMapping, TypeInfo}; + use crate::core::config::Config; + use crate::core::valid::Validator; + + #[test] + fn test_to_chat_request_conversion() { + let question = Question { + args_info: FieldMapping { + argument: TypeInfo { + name: "input2".to_string(), + output_type: "Article".to_string(), + }, + field: TypeInfo { + name: "createPost".to_string(), + output_type: "Post".to_string(), + }, + }, + }; + let request: ChatRequest = question.try_into().unwrap(); + insta::assert_debug_snapshot!(request); + } + + #[test] + fn test_chat_response_parse() { + let resp = ChatResponse { + content: Some(MessageContent::Text( + "{\"suggestions\":[\"createPostInput\",\"postInput\",\"articleInput\",\"noteInput\",\"messageInput\"]}".to_owned(), + )), + ..Default::default() + }; + let answer = Answer::try_from(resp).unwrap(); + insta::assert_debug_snapshot!(answer); + } +} diff --git a/src/cli/llm/mod.rs b/src/cli/llm/mod.rs index 40c0dce610..09e9ab4ed7 100644 --- a/src/cli/llm/mod.rs +++ b/src/cli/llm/mod.rs @@ -1,4 +1,5 @@ mod error; +pub mod infer_arg_name; pub mod infer_type_name; pub use error::Error; use error::Result; diff --git a/src/cli/llm/prompts/infer_arg_name.md b/src/cli/llm/prompts/infer_arg_name.md new file mode 100644 index 0000000000..14ac000d93 --- /dev/null +++ b/src/cli/llm/prompts/infer_arg_name.md @@ -0,0 +1,19 @@ +Given the Operation Definition of GraphQL, suggest {{count}} meaningful names for the argument names. +The name should be concise and preferably a single word. + +Example Input: +{ +"argument": { +"name": "Input1", +"outputType: "Article" +}, +"field": { +"name" : "createPost", +"outputType" : "Post" +} +} + +Example Output: +suggestions: ["createPostInput","postInput", "articleInput","noteInput","messageInput"], + +Ensure the output is in valid JSON format. diff --git a/src/core/config/transformer/mod.rs b/src/core/config/transformer/mod.rs index bf22d6730c..5d5607ef6c 100644 --- a/src/core/config/transformer/mod.rs +++ b/src/core/config/transformer/mod.rs @@ -5,6 +5,7 @@ mod improve_type_names; mod merge_types; mod nested_unions; mod preset; +mod rename_args; mod rename_types; mod required; mod tree_shake; @@ -17,6 +18,7 @@ pub use improve_type_names::ImproveTypeNames; pub use merge_types::TypeMerger; pub use nested_unions::NestedUnions; pub use preset::Preset; +pub use rename_args::{ArgumentInfo, RenameArgs}; pub use rename_types::RenameTypes; pub use required::Required; pub use tree_shake::TreeShake; diff --git a/src/core/config/transformer/rename_args.rs b/src/core/config/transformer/rename_args.rs new file mode 100644 index 0000000000..b8181908e7 --- /dev/null +++ b/src/core/config/transformer/rename_args.rs @@ -0,0 +1,159 @@ +use indexmap::IndexMap; + +use crate::core::config::{Config, Resolver}; +use crate::core::valid::{Valid, Validator}; +use crate::core::Transform; + +type FieldName = String; +type TypeName = String; + +#[derive(Clone)] +pub struct ArgumentInfo { + suggested_names: Vec, + field_name: FieldName, + type_name: TypeName, +} + +impl ArgumentInfo { + pub fn new( + suggested_names: Vec, + field_name: FieldName, + type_name: TypeName, + ) -> ArgumentInfo { + Self { suggested_names, field_name, type_name } + } +} + +/// arg_name: { +/// suggested_names: Vec, suggested names for argument. +/// field_name: String, name of the field which requires the argument. +/// type_name: String, name of the type where the argument resides. +/// } +pub struct RenameArgs(IndexMap); + +impl RenameArgs { + pub fn new(suggestions: IndexMap) -> Self { + Self(suggestions) + } +} + +impl Transform for RenameArgs { + type Value = Config; + type Error = String; + + fn transform(&self, mut config: Self::Value) -> Valid { + Valid::from_iter(self.0.iter(), |(existing_name, arg_info)| { + let type_name = &arg_info.type_name; + let field_name = &arg_info.field_name; + config.types.get_mut(type_name) + .and_then(|type_| type_.fields.get_mut(field_name)) + .and_then(|field_| field_.args.shift_remove(existing_name)) + .map_or_else( + || Valid::fail(format!("Argument '{}' not found in type '{}'.", existing_name, type_name)), + |arg| { + let field_ = config.types.get_mut(type_name) + .and_then(|type_| type_.fields.get_mut(field_name)) + .expect("Field should exist"); + + let new_name = arg_info.suggested_names.iter() + .find(|suggested_name| !field_.args.contains_key(*suggested_name)) + .cloned(); + + match new_name { + Some(name) => { + field_.args.insert(name.clone(), arg); + match field_.resolver.as_mut(){ + Some(Resolver::Http(http)) => { + // Note: we shouldn't modify the query params, as modifying them will change the API itself. + http.path = http.path.replace(existing_name, name.as_str()); + if let Some(body) = http.body.as_mut() { + *body = body.replace(existing_name, name.as_str()); + } + } + Some(Resolver::Grpc(grpc)) => { + if let Some(body) = grpc.body.as_mut() { + if let Some(str_val) = body.as_str() { + *body = serde_json::Value::String(str_val.replace(existing_name, &name)); + } + } + } + _ => { + // TODO: handle for other resolvers. + } + } + + Valid::succeed(()) + }, + None => { + field_.args.insert(existing_name.clone(), arg); + Valid::fail(format!( + "Could not rename argument '{}'. All suggested names are already in use.", + existing_name + )) + } + } + } + ) + }) + .map(|_| config) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::valid::ValidationError; + + #[test] + fn test_rename_args() { + let sdl = r#" + type Query { + user(id: ID!, name: String): JSON + } + "#; + let config = Config::from_sdl(sdl).to_result().unwrap(); + + let arg_info1 = + ArgumentInfo::new(vec!["userId".to_string()], "user".into(), "Query".into()); + let arg_info2 = + ArgumentInfo::new(vec!["userName".to_string()], "user".into(), "Query".into()); + + let rename_args = indexmap::indexmap! { + "id".to_string() => arg_info1.clone(), + "name".to_string() => arg_info2.clone(), + }; + + let transformed_config = RenameArgs::new(rename_args) + .transform(config) + .to_result() + .unwrap(); + + insta::assert_snapshot!(transformed_config.to_sdl()); + } + + #[test] + fn test_rename_args_conflict() { + let sdl = r#" + type Query { + user(id: ID!, name: String, userName: String): JSON + } + "#; + let config = Config::from_sdl(sdl).to_result().unwrap(); + + let arg_info = + ArgumentInfo::new(vec!["userName".to_string()], "user".into(), "Query".into()); + + let rename_args = indexmap::indexmap! { + "name".to_string() => arg_info, + }; + + let result = RenameArgs::new(rename_args).transform(config).to_result(); + + let expected_err = ValidationError::new( + "Could not rename argument 'name'. All suggested names are already in use.".to_string(), + ); + + assert!(result.is_err()); + assert_eq!(result.err().unwrap(), expected_err); + } +} diff --git a/src/core/config/transformer/snapshots/tailcall__core__config__transformer__rename_args__tests__rename_args.snap b/src/core/config/transformer/snapshots/tailcall__core__config__transformer__rename_args__tests__rename_args.snap new file mode 100644 index 0000000000..174e43a36e --- /dev/null +++ b/src/core/config/transformer/snapshots/tailcall__core__config__transformer__rename_args__tests__rename_args.snap @@ -0,0 +1,7 @@ +--- +source: src/core/config/transformer/rename_args.rs +expression: transformed_config.to_sdl() +--- +type Query { + user(userId: ID!, userName: String): JSON +} diff --git a/src/core/generator/json/operation_generator.rs b/src/core/generator/json/operation_generator.rs index e73058643a..ba98f385bc 100644 --- a/src/core/generator/json/operation_generator.rs +++ b/src/core/generator/json/operation_generator.rs @@ -41,7 +41,7 @@ impl OperationTypeGenerator { // add input type to field. let name_gen = NameGenerator::new("Input"); - let arg_name = name_gen.next(); + let arg_name = name_gen.next(); if let Some(Resolver::Http(http)) = &mut field.resolver { http.body = Some(format!("{{{{.args.{}}}}}", arg_name)); http.method = request_sample.method.to_owned();