Skip to content

Commit

Permalink
misc: small fix or general refactoring i did not bother commenting
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket committed Jun 25, 2024
1 parent 71a7c96 commit 098e4de
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 104 deletions.
145 changes: 79 additions & 66 deletions src/agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,85 +179,97 @@ impl Agent {
}

pub async fn step(&mut self) -> Result<()> {
let mut mut_state = self.state.lock().await;
let (invocations, options) = {
let mut mut_state = self.state.lock().await;

mut_state.on_step()?;
mut_state.on_step()?;

if self.options.with_stats {
println!("\n{}\n", &mut_state.metrics);
}
if self.options.with_stats {
println!("\n{}\n", &mut_state.metrics);
}

let system_prompt = serialization::state_to_system_prompt(&mut_state)?;
let prompt = mut_state.to_prompt()?;
let history = mut_state.to_chat_history(self.max_history as usize)?;
let system_prompt = serialization::state_to_system_prompt(&mut_state)?;
let prompt = mut_state.to_prompt()?;
let history = mut_state.to_chat_history(self.max_history as usize)?;
let options = Options::new(system_prompt, prompt, history);

let options = Options::new(system_prompt, prompt, history);
self.save_if_needed(&options, false).await?;

self.save_if_needed(&options, false).await?;
// run model inference
let response = self.generator.chat(&options).await?.trim().to_string();

// run model inference
let response = self.generator.chat(&options).await?.trim().to_string();
// parse the model response into invocations
let invocations = serialization::xml::parsing::try_parse(&response)?;

// parse the model response into invocations
let invocations = serialization::xml::parsing::try_parse(&response)?;
// nothing parsed, report the problem to the model
if invocations.is_empty() {
if response.is_empty() {
println!(
"{}: agent did not provide valid instructions: empty response",
"WARNING".bold().red(),
);

// nothing parsed, report the problem to the model
if invocations.is_empty() {
if response.is_empty() {
println!(
"{}: agent did not provide valid instructions: empty response",
"WARNING".bold().red(),
);

mut_state.metrics.errors.empty_responses += 1;
mut_state.add_unparsed_response_to_history(
&response,
"Do not return an empty responses.".to_string(),
);
} else {
println!("\n\n{}\n\n", response.dimmed());

mut_state.metrics.errors.unparsed_responses += 1;
mut_state.add_unparsed_response_to_history(
mut_state.metrics.errors.empty_responses += 1;
mut_state.add_unparsed_response_to_history(
&response,
"Do not return an empty responses.".to_string(),
);
} else {
println!(
"{}: agent did not provide valid instructions: \n\n{}\n\n",
"WARNING".bold().red(),
response.dimmed()
);

mut_state.metrics.errors.unparsed_responses += 1;
mut_state.add_unparsed_response_to_history(
&response,
"I could not parse any valid actions from your response, please correct it according to the instructions.".to_string(),
);
}
} else {
mut_state.metrics.valid_responses += 1;
}
} else {
mut_state.metrics.valid_responses += 1;
}

// to avoid dead locks, is this needed?
drop(mut_state);
(invocations, options)
};

// for each parsed invocation
// NOTE: the MutexGuard is purposedly captured in its own scope in order to avoid
// deadlocks and make its lifespan clearer.
for inv in invocations {
// lookup action
let mut mut_state = self.state.lock().await;
let action = mut_state.get_action(&inv.action);

let action = self.state.lock().await.get_action(&inv.action);
if action.is_none() {
mut_state.metrics.errors.unknown_actions += 1;
// tell the model that the action name is wrong
mut_state.add_error_to_history(
inv.clone(),
format!("'{}' is not a valid action name", inv.action),
);
drop(mut_state);
{
let mut mut_state = self.state.lock().await;
mut_state.metrics.errors.unknown_actions += 1;
// tell the model that the action name is wrong
mut_state.add_error_to_history(
inv.clone(),
format!("'{}' is not a valid action name", inv.action),
);
}
} else {
let action = action.unwrap();
// validate prerequisites
if let Err(err) = self.validate(&inv, &action) {
mut_state.metrics.errors.invalid_actions += 1;
mut_state.add_error_to_history(inv.clone(), err.to_string());
drop(mut_state);
} else {
mut_state.metrics.valid_actions += 1;
drop(mut_state);
let do_exec = {
let mut mut_state = self.state.lock().await;

if let Err(err) = self.validate(&inv, &action) {
mut_state.metrics.errors.invalid_actions += 1;
mut_state.add_error_to_history(inv.clone(), err.to_string());
false
} else {
mut_state.metrics.valid_actions += 1;
true
}
};

// TODO: timeout logic
// TODO: timeout logic

// execute
// execute
if do_exec {
let ret = action
.run(
self.state.clone(),
Expand All @@ -266,17 +278,18 @@ impl Agent {
)
.await;

let mut mut_state = self.state.lock().await;
if let Err(error) = ret {
mut_state.metrics.errors.errored_actions += 1;
// tell the model about the error
mut_state.add_error_to_history(inv, error.to_string());
} else {
mut_state.metrics.success_actions += 1;
// tell the model about the output
mut_state.add_success_to_history(inv, ret.unwrap());
{
let mut mut_state = self.state.lock().await;
if let Err(error) = ret {
mut_state.metrics.errors.errored_actions += 1;
// tell the model about the error
mut_state.add_error_to_history(inv, error.to_string());
} else {
mut_state.metrics.success_actions += 1;
// tell the model about the output
mut_state.add_success_to_history(inv, ret.unwrap());
}
}
drop(mut_state);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/agent/namespaces/rag/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl Action for Search {
for (doc, score) in &docs {
println!(" * {} ({})", &doc.name, score);
}
println!("");
println!();

Ok(Some(format!(
"Here is some supporting information:\n\n{}",
Expand Down
2 changes: 0 additions & 2 deletions src/agent/rag/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,5 @@ pub trait VectorStore: Send {
Self: Sized;

async fn add(&mut self, document: Document) -> Result<()>;
async fn delete(&mut self, doc_name: &str) -> Result<()>;
async fn clear(&mut self) -> Result<()>;
async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<(Document, f64)>>;
}
20 changes: 0 additions & 20 deletions src/agent/rag/naive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,26 +82,6 @@ impl VectorStore for NaiveVectorStore {
Ok(())
}

async fn delete(&mut self, doc_name: &str) -> Result<()> {
if self.documents.remove(doc_name).is_some() {
self.embeddings.remove(doc_name);
println!("[rag] removed document '{}'", doc_name);
Ok(())
} else {
Err(anyhow!(
"document with name '{}' not found in the index",
doc_name
))
}
}

async fn clear(&mut self) -> Result<()> {
self.embeddings.clear();
self.documents.clear();
println!("[rag] index cleared");
Ok(())
}

async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<(Document, f64)>> {
println!("[{}] {} (top {})", "rag".bold(), query, top_k);

Expand Down
21 changes: 6 additions & 15 deletions src/agent/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use metrics::Metrics;
use super::{
generator::{Client, Message},
namespaces::{self, Namespace},
rag::{self, naive::NaiveVectorStore, Document, VectorStore},
rag::{naive::NaiveVectorStore, Document, VectorStore},
task::Task,
Invocation,
};
Expand All @@ -18,12 +18,6 @@ mod history;
mod metrics;
pub(crate) mod storage;

#[allow(clippy::upper_case_acronyms)]
struct RAG {
config: rag::Configuration,
store: Box<dyn VectorStore>,
}

pub struct State {
// the task
task: Box<dyn Task>,
Expand All @@ -34,7 +28,7 @@ pub struct State {
// list of executed actions
history: History,
// optional rag engine
rag: Option<RAG>,
rag: Option<Box<dyn VectorStore>>,
// set to true when task is complete
complete: bool,
// runtime metrics
Expand Down Expand Up @@ -89,16 +83,13 @@ impl State {
}

// add RAG namespace
let rag = if let Some(config) = task.get_rag_config() {
let v_store =
let rag: Option<Box<dyn VectorStore>> = if let Some(config) = task.get_rag_config() {
let v_store: NaiveVectorStore =
NaiveVectorStore::from_indexed_path(generator.copy()?, &config.path).await?;

namespaces.push(namespaces::NAMESPACES.get("rag").unwrap()());

Some(RAG {
config: config.clone(),
store: Box::new(v_store),
})
Some(Box::new(v_store))
} else {
None
};
Expand Down Expand Up @@ -156,7 +147,7 @@ impl State {

pub async fn rag_query(&mut self, query: &str, top_k: usize) -> Result<Vec<(Document, f64)>> {
if let Some(rag) = &self.rag {
rag.store.retrieve(query, top_k).await
rag.retrieve(query, top_k).await
} else {
Err(anyhow!("no RAG engine has been configured"))
}
Expand Down

0 comments on commit 098e4de

Please sign in to comment.