Skip to content

Commit

Permalink
- [draft] generate argument names with LLM.
Browse files Browse the repository at this point in the history
  • Loading branch information
laststylebender14 committed Sep 3, 2024
1 parent 6073342 commit f043c97
Show file tree
Hide file tree
Showing 8 changed files with 410 additions and 3 deletions.
13 changes: 11 additions & 2 deletions src/cli/generator/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ use pathdiff::diff_paths;

use super::config::{Config, LLMConfig, Resolved, Source};
use super::source::ConfigSource;
use crate::cli::llm::infer_arg_name::InferArgName;
use crate::cli::llm::InferTypeName;
use crate::core::config::transformer::{Preset, RenameTypes};
use crate::core::config::transformer::{Preset, RenameArgs, RenameTypes};
use crate::core::config::{self, ConfigModule, ConfigReaderContext};
use crate::core::generator::{Generator as ConfigGenerator, Input};
use crate::core::proto_reader::ProtoReader;
Expand Down Expand Up @@ -184,12 +185,20 @@ impl Generator {

if infer_type_names {
if let Some(LLMConfig { model: Some(model), secret }) = llm {
let mut llm_gen = InferTypeName::new(model, secret.map(|s| s.to_string()));
let mut llm_gen =
InferTypeName::new(model.clone(), secret.clone().map(|s| s.to_string()));
let suggested_names = llm_gen.generate(config.config()).await?;
let cfg = RenameTypes::new(suggested_names.iter())
.transform(config.config().to_owned())
.to_result()?;

let mut llm_gen = InferArgName::new(model, secret.map(|s| s.to_string()));
let suggested_names = llm_gen.generate(&cfg).await?;

let cfg = RenameArgs::new(suggested_names)
.transform(cfg)
.to_result()?;

config = ConfigModule::from(cfg);
}
}
Expand Down
210 changes: 210 additions & 0 deletions src/cli/llm/infer_arg_name.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
use genai::chat::{ChatMessage, ChatRequest, ChatResponse};
use indexmap::IndexMap;
use serde::{Deserialize, Serialize};
use serde_json::json;

use super::{Error, Result, Wizard};
use crate::core::config::transformer::ArgumentInfo;
use crate::core::config::{Config, Resolver};
use crate::core::Mustache;

pub struct InferArgName {
wizard: Wizard<Question, Answer>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct TypeInfo {
name: String,
#[serde(rename = "outputType")]
output_type: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct FieldMapping {
argument: TypeInfo,
field: TypeInfo,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct Answer {
suggestions: Vec<String>,
}

impl TryFrom<ChatResponse> for Answer {
type Error = Error;

fn try_from(response: ChatResponse) -> Result<Self> {
let message_content = response.content.ok_or(Error::EmptyResponse)?;
let text_content = message_content.text_as_str().ok_or(Error::EmptyResponse)?;
Ok(serde_json::from_str(text_content)?)
}
}

#[derive(Clone, Serialize)]
struct Question {
args_info: FieldMapping,
}

impl TryInto<ChatRequest> for Question {
type Error = Error;

fn try_into(self) -> Result<ChatRequest> {
let input2 = FieldMapping {
argument: {
TypeInfo {
name: "input2".to_string(),
output_type: "Article".to_string(),
}
},
field: {
TypeInfo {
name: "createPost".to_string(),
output_type: "Post".to_string(),
}
},
};

let input = serde_json::to_string_pretty(&Question { args_info: input2 })?;
let output = serde_json::to_string_pretty(&Answer {
suggestions: vec![
"createPostInput".into(),
"postInput".into(),
"articleInput".into(),
"noteInput".into(),
"messageInput".into(),
],
})?;

let template_str = include_str!("prompts/infer_arg_name.md");
let template = Mustache::parse(template_str);

let context = json!({
"input": input,
"output": output,
"count": 5,
});

let rendered_prompt = template.render(&context);

Ok(ChatRequest::new(vec![
ChatMessage::system(rendered_prompt),
ChatMessage::user(serde_json::to_string(&self)?),
]))
}
}

impl InferArgName {
pub fn new(model: String, secret: Option<String>) -> InferArgName {
Self { wizard: Wizard::new(model, secret) }
}

pub async fn generate(&mut self, config: &Config) -> Result<IndexMap<String, ArgumentInfo>> {
let mut mapping: IndexMap<String, ArgumentInfo> = IndexMap::new();

for (type_name, type_) in config.types.iter() {
// collect all the args that's needs to be processed with LLM.
for (field_name, field) in type_.fields.iter() {
if field.args.is_empty() {
continue;
}
// filter out query params as we shouldn't change the names of query params.
for (arg_name, arg) in field.args.iter().filter(|(k, _)| match &field.resolver {
Some(Resolver::Http(http)) => !http.query.iter().any(|q| &q.key == *k),
_ => true,
}) {
let question = FieldMapping {
argument: TypeInfo {
name: arg_name.to_string(),
output_type: arg.type_of.name().to_owned(),
},
field: TypeInfo {
name: field_name.to_string(),
output_type: field.type_of.name().to_owned(),
},
};

let question = Question { args_info: question };

let mut delay = 3;
loop {
let answer = self.wizard.ask(question.clone()).await;
match answer {
Ok(answer) => {
tracing::info!(
"Suggestions for Argument {}: [{:?}]",
arg_name,
answer.suggestions,
);
mapping.insert(
arg_name.to_owned(),
ArgumentInfo::new(
answer.suggestions,
field_name.to_owned(),
type_name.to_owned(),
),
);
break;
}
Err(e) => {
// TODO: log errors after certain number of retries.
if let Error::GenAI(_) = e {
// TODO: retry only when it's required.
tracing::warn!(
"Unable to retrieve a name for the type '{}'. Retrying in {}s",
type_name,
delay
);
tokio::time::sleep(tokio::time::Duration::from_secs(delay))
.await;
delay *= std::cmp::min(delay * 2, 60);
}
}
}
}
}
}
}

Ok(mapping)
}
}

#[cfg(test)]
mod test {
use genai::chat::{ChatRequest, ChatResponse, MessageContent};

use super::{Answer, Question};
use crate::cli::llm::infer_arg_name::{FieldMapping, TypeInfo};
use crate::core::config::Config;

Check failure on line 178 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Formatter and Lint Check

unused import: `crate::core::config::Config`

Check failure on line 178 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on linux-x64-gnu

unused import: `crate::core::config::Config`

Check failure on line 178 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on linux-x64-musl

unused import: `crate::core::config::Config`

Check failure on line 178 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on linux-arm64-gnu

unused import: `crate::core::config::Config`

Check failure on line 178 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on linux-arm64-musl

unused import: `crate::core::config::Config`

Check failure on line 178 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on linux-ia32-gnu

unused import: `crate::core::config::Config`

Check failure on line 178 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on darwin-arm64

unused import: `crate::core::config::Config`

Check failure on line 178 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on darwin-x64

unused import: `crate::core::config::Config`

Check failure on line 178 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on win32-x64-msvc

unused import: `crate::core::config::Config`

Check failure on line 178 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on win32-ia32-msvc

unused import: `crate::core::config::Config`
use crate::core::valid::Validator;

Check failure on line 179 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Formatter and Lint Check

unused import: `crate::core::valid::Validator`

Check failure on line 179 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on linux-x64-gnu

unused import: `crate::core::valid::Validator`

Check failure on line 179 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on linux-x64-musl

unused import: `crate::core::valid::Validator`

Check failure on line 179 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on linux-arm64-gnu

unused import: `crate::core::valid::Validator`

Check failure on line 179 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on linux-arm64-musl

unused import: `crate::core::valid::Validator`

Check failure on line 179 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on linux-ia32-gnu

unused import: `crate::core::valid::Validator`

Check failure on line 179 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on darwin-arm64

unused import: `crate::core::valid::Validator`

Check failure on line 179 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on darwin-x64

unused import: `crate::core::valid::Validator`

Check failure on line 179 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on win32-x64-msvc

unused import: `crate::core::valid::Validator`

Check failure on line 179 in src/cli/llm/infer_arg_name.rs

View workflow job for this annotation

GitHub Actions / Run Tests on win32-ia32-msvc

unused import: `crate::core::valid::Validator`

#[test]
fn test_to_chat_request_conversion() {
let question = Question {
args_info: FieldMapping {
argument: TypeInfo {
name: "input2".to_string(),
output_type: "Article".to_string(),
},
field: TypeInfo {
name: "createPost".to_string(),
output_type: "Post".to_string(),
},
},
};
let request: ChatRequest = question.try_into().unwrap();
insta::assert_debug_snapshot!(request);
}

#[test]
fn test_chat_response_parse() {
let resp = ChatResponse {
content: Some(MessageContent::Text(
"{\"suggestions\":[\"createPostInput\",\"postInput\",\"articleInput\",\"noteInput\",\"messageInput\"]}".to_owned(),
)),
..Default::default()
};
let answer = Answer::try_from(resp).unwrap();
insta::assert_debug_snapshot!(answer);
}
}
1 change: 1 addition & 0 deletions src/cli/llm/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod error;
pub mod infer_arg_name;
pub mod infer_type_name;
pub use error::Error;
use error::Result;
Expand Down
19 changes: 19 additions & 0 deletions src/cli/llm/prompts/infer_arg_name.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
Given the Operation Definition of GraphQL, suggest {{count}} meaningful names for the argument names.
The name should be concise and preferably a single word.

Example Input:
{
"argument": {
"name": "Input1",
"outputType: "Article"
},
"field": {
"name" : "createPost",
"outputType" : "Post"
}
}

Example Output:
suggestions: ["createPostInput","postInput", "articleInput","noteInput","messageInput"],

Ensure the output is in valid JSON format.
2 changes: 2 additions & 0 deletions src/core/config/transformer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod improve_type_names;
mod merge_types;
mod nested_unions;
mod preset;
mod rename_args;
mod rename_types;
mod required;
mod tree_shake;
Expand All @@ -17,6 +18,7 @@ pub use improve_type_names::ImproveTypeNames;
pub use merge_types::TypeMerger;
pub use nested_unions::NestedUnions;
pub use preset::Preset;
pub use rename_args::{ArgumentInfo, RenameArgs};
pub use rename_types::RenameTypes;
pub use required::Required;
pub use tree_shake::TreeShake;
Expand Down
Loading

0 comments on commit f043c97

Please sign in to comment.