Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(2665): send system message once in InferTypeName generation #2676

Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 55 additions & 11 deletions src/cli/llm/infer_type_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#[derive(Default)]
pub struct InferTypeName {
secret: Option<String>,
wizard: Option<Wizard<Question, Answer>>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why it's optional?

}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -75,25 +76,60 @@

impl InferTypeName {
pub fn new(secret: Option<String>) -> InferTypeName {
Self { secret }
Self { secret, wizard: None }

Check warning on line 79 in src/cli/llm/infer_type_name.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/infer_type_name.rs#L79

Added line #L79 was not covered by tests
}

fn create_system_messages() -> Vec<ChatMessage> {
vec![
ChatMessage::system(
"Given the sample schema of a GraphQL type suggest 5 meaningful names for it.",
),
ChatMessage::system("The name should be concise and preferably a single word"),
ChatMessage::system("Example Input:"),
ChatMessage::system(
serde_json::to_string_pretty(&Question {
fields: vec![
("id".to_string(), "String".to_string()),
("name".to_string(), "String".to_string()),
("age".to_string(), "Int".to_string()),
],
})
.unwrap(),
),
ChatMessage::system("Example Output:"),
ChatMessage::system(
serde_json::to_string_pretty(&Answer {
suggestions: vec![
"Person".into(),
"Profile".into(),
"Member".into(),
"Individual".into(),
"Contact".into(),
],
})
.unwrap(),
),
ChatMessage::system("Ensure the output is in valid JSON format"),
ChatMessage::system("Do not add any additional text before or after the json"),
]
}

Check warning on line 115 in src/cli/llm/infer_type_name.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/infer_type_name.rs#L82-L115

Added lines #L82 - L115 were not covered by tests

pub async fn generate(&mut self, config: &Config) -> Result<HashMap<String, String>> {
let secret = self.secret.as_ref().map(|s| s.to_owned());

let wizard: Wizard<Question, Answer> = Wizard::new(groq::LLAMA38192, secret);
self.wizard = Some(Wizard::new(groq::LLAMA38192, secret));

Check warning on line 119 in src/cli/llm/infer_type_name.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/infer_type_name.rs#L119

Added line #L119 was not covered by tests

let mut new_name_mappings: HashMap<String, String> = HashMap::new();

// removed root type from types.
let types_to_be_processed = config
.types
.iter()
.filter(|(type_name, _)| !config.is_root_operation_type(type_name))
.collect::<Vec<_>>();

let total = types_to_be_processed.len();
let system_messages = Self::create_system_messages();

Check warning on line 130 in src/cli/llm/infer_type_name.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/infer_type_name.rs#L130

Added line #L130 was not covered by tests

for (i, (type_name, type_)) in types_to_be_processed.into_iter().enumerate() {
// convert type to sdl format.
let question = Question {
fields: type_
.fields
Expand All @@ -104,7 +140,7 @@

let mut delay = 3;
loop {
let answer = wizard.ask(question.clone()).await;
let answer = self.ask_with_context(&system_messages, &question).await;

Check warning on line 143 in src/cli/llm/infer_type_name.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/infer_type_name.rs#L143

Added line #L143 was not covered by tests
match answer {
Ok(answer) => {
let name = &answer.suggestions.join(", ");
Expand All @@ -125,15 +161,10 @@
total
);

// TODO: case where suggested names are already used, then extend the base
// question with `suggest different names, we have already used following
// names: [names list]`
break;
}
Err(e) => {
// TODO: log errors after certain number of retries.
if let Error::GenAI(_) = e {
// TODO: retry only when it's required.
tracing::warn!(
"Unable to retrieve a name for the type '{}'. Retrying in {}s",
type_name,
Expand All @@ -149,6 +180,19 @@

Ok(new_name_mappings.into_iter().map(|(k, v)| (v, k)).collect())
}

async fn ask_with_context(
&self,
system_messages: &[ChatMessage],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove system_messages if we are not using it.

question: &Question,
) -> Result<Answer> {
let content = serde_json::to_string(question)?;
let mut messages = system_messages.to_vec();
messages.push(ChatMessage::user(content));

let request = ChatRequest::new(messages);
self.wizard.as_ref().unwrap().ask(request).await
}

Check warning on line 195 in src/cli/llm/infer_type_name.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/infer_type_name.rs#L184-L195

Added lines #L184 - L195 were not covered by tests
}

#[cfg(test)]
Expand Down
5 changes: 2 additions & 3 deletions src/cli/llm/wizard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,13 @@
}
}

pub async fn ask(&self, q: Q) -> Result<A>
pub async fn ask(&self, request: ChatRequest) -> Result<A>

Check warning on line 42 in src/cli/llm/wizard.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/wizard.rs#L42

Added line #L42 was not covered by tests
where
Q: TryInto<ChatRequest, Error = super::Error>,
A: TryFrom<ChatResponse, Error = super::Error>,
{
let response = self
.client
.exec_chat(self.model.as_str(), q.try_into()?, None)
.exec_chat(self.model.as_str(), request, None)

Check warning on line 48 in src/cli/llm/wizard.rs

View check run for this annotation

Codecov / codecov/patch

src/cli/llm/wizard.rs#L48

Added line #L48 was not covered by tests
onyedikachi-david marked this conversation as resolved.
Show resolved Hide resolved
.await?;
A::try_from(response)
}
Expand Down
Loading