Skip to content

Commit

Permalink
improve infer arg name
Browse files Browse the repository at this point in the history
Signed-off-by: Sahil Yeole <[email protected]>
  • Loading branch information
beelchester committed Aug 16, 2024
1 parent 284bcc3 commit 859daf2
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions src/cli/llm/infer_arg_name.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::collections::{BTreeMap, HashMap};
use std::collections::HashMap;

use genai::chat::{ChatMessage, ChatRequest, ChatResponse};
use serde::{Deserialize, Serialize};

use super::model::{gemini, groq};
use super::model::gemini;
use super::{Error, Result, Wizard};
use crate::core::config::{Arg, Config};
use crate::core::config::Config;

#[derive(Default)]
pub struct InferArgName {
Expand Down Expand Up @@ -83,16 +83,14 @@ impl InferArgName {

let mut new_name_mappings: HashMap<String, String> = HashMap::new();

// removed root type from types.
let types_to_be_processed = config
let query_type = config
.types
.iter()
.filter(|(type_name, _)| *type_name == "Query")
.collect::<Vec<_>>();
.find(|(type_name, _)| *type_name == "Query")
.map(|(_, type_)| type_);

let total = types_to_be_processed.len();
for (i, (type_name, type_)) in types_to_be_processed.into_iter().enumerate() {
let mut args_vec = Vec::new();
if let Some(type_) = query_type {
let mut args_to_be_processed = HashMap::new();
let fields = &type_.fields.keys().collect::<Vec<_>>();
for key in fields {
if let Some(field) = &type_.fields.get(key.as_str()) {
Expand All @@ -102,28 +100,27 @@ impl InferArgName {
.iter()
.map(|(k, v)| (k.to_string(), v.type_of.clone()))
.collect::<Vec<_>>();
args_vec.push((key.to_string(), args));
args_to_be_processed.insert(key.to_string(), args);
}
}
}
for arg in args_vec {
tracing::info!("arg: {:?}", arg);
let total = args_to_be_processed.len();
for (i, arg) in args_to_be_processed.into_iter().enumerate() {
let question = Question { fields: arg.clone() };

let mut delay = 3;
loop {
// let answer = wizard.ask(question.clone()).await;
let answer = wizard.ask(question.clone()).await;
match answer {
Ok(answer) => {
let name = &answer.suggestions.join(", ");
for name in answer.suggestions {
if config.types.contains_key(&name)
if type_.fields.contains_key(&name)
|| new_name_mappings.contains_key(&name)
{
continue;
}
new_name_mappings.insert(name, type_name.to_owned());
new_name_mappings.insert(name, arg.0.to_owned());
break;
}
tracing::info!(
Expand All @@ -144,8 +141,8 @@ impl InferArgName {
if let Error::GenAI(_) = e {
// TODO: retry only when it's required.
tracing::warn!(
"Unable to retrieve a name for the type '{}'. Retrying in {}s. Error: {}",
type_name,
"Unable to retrieve a name for the argument '{}'. Retrying in {}s. Error: {}",
arg.0,
delay,
e
);
Expand All @@ -156,8 +153,10 @@ impl InferArgName {
}
}
}
}

Ok(new_name_mappings.into_iter().map(|(k, v)| (v, k)).collect())
Ok(new_name_mappings.into_iter().map(|(k, v)| (v, k)).collect())
} else {
Ok(HashMap::new())
}
}
}

0 comments on commit 859daf2

Please sign in to comment.