Skip to content

Commit

Permalink
new: runtime statistics with --stats
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket committed Jun 24, 2024
1 parent b64a1b6 commit fbf7a4a
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 24 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,10 @@ The main idea is giving the model a set of functions to perform operations and a
If you want to observe this (basically the debug mode of Nerve), run your tasklet by adding the following additional arguments:

```sh
nerve -G ... -T whatever-tasklet --save-to state.txt --full-dump
nerve -G ... -T whatever-tasklet --save-to state.txt --full-dump --stats
```

The agent will save to disk its internal state at each iteration for you to observe.
The agent will report more runtime statistics and save to disk its internal state at each iteration for you to observe.

## Installing from Crates.io

Expand Down
18 changes: 12 additions & 6 deletions src/agent/generator/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::fmt::Display;

use anyhow::Result;
use async_trait::async_trait;

Expand Down Expand Up @@ -33,12 +35,16 @@ pub enum Message {
Feedback(String, Option<Invocation>),
}

impl Message {
pub fn to_string(&self) -> String {
match self {
Message::Agent(data, _) => format!("[agent]\n\n{}\n", data),
Message::Feedback(data, _) => format!("[feedback]\n\n{}\n", data),
}
impl Display for Message {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
Message::Agent(data, _) => format!("[agent]\n\n{}\n", data),
Message::Feedback(data, _) => format!("[feedback]\n\n{}\n", data),
}
)
}
}

Expand Down
15 changes: 14 additions & 1 deletion src/agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub struct AgentOptions {
pub max_iterations: usize,
pub save_to: Option<String>,
pub full_dump: bool,
pub with_stats: bool,
}

pub struct Agent {
Expand Down Expand Up @@ -103,7 +104,11 @@ impl Agent {
}

pub async fn step(&mut self) -> Result<()> {
self.state.on_next_iteration()?;
self.state.on_step()?;

if self.options.with_stats {
println!("\n{}\n", &self.state.metrics);
}

let system_prompt = serialization::state_to_system_prompt(&self.state)?;
let prompt = self.state.to_prompt()?;
Expand All @@ -122,11 +127,15 @@ impl Agent {
// nothing parsed, report the problem to the model
if invocations.is_empty() {
if response.is_empty() {
self.state.metrics.errors.empty_responses += 1;

self.state.add_unparsed_response_to_history(
&response,
"Do not return an empty responses.".to_string(),
);
} else {
self.state.metrics.errors.unparsed_responses += 1;

self.state.add_unparsed_response_to_history(
&response,
"I could not parse any valid actions from your response, please correct it according to the instructions.".to_string(),
Expand All @@ -142,13 +151,17 @@ impl Agent {
format!("\n\n{}\n\n", response.dimmed().yellow())
}
);
} else {
self.state.metrics.valid_responses += 1;
}

// for each parsed invocation
for inv in invocations {
// see if valid action and execute
if let Err(e) = self.state.execute(inv.clone()).await {
println!("ERROR: {}", e);
} else {
self.state.metrics.valid_actions += 1;
}

self.save_if_needed(&options, true)?;
Expand Down
6 changes: 3 additions & 3 deletions src/agent/serialization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ pub(crate) fn state_to_system_prompt(state: &State) -> Result<String> {
.join("\n");
let available_actions = state_available_actions(state)?;

let iterations = if state.max_iters > 0 {
let iterations = if state.metrics.max_steps > 0 {
format!(
"You are currently at step {} of a maximum of {}.",
state.curr_iter + 1,
state.max_iters
state.metrics.current_step + 1,
state.metrics.max_steps
)
} else {
"".to_string()
Expand Down
69 changes: 69 additions & 0 deletions src/agent/state/metrics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use std::fmt::Display;

use colored::Colorize;

#[derive(Debug, Default)]
pub struct ErrorMetrics {
pub empty_responses: usize,
pub unparsed_responses: usize,
pub unknown_actions: usize,
pub invalid_actions: usize,
pub errored_actions: usize,
}

impl ErrorMetrics {
fn has_response_errors(&self) -> bool {
self.empty_responses > 0 || self.unparsed_responses > 0
}

fn has_action_errors(&self) -> bool {
self.unknown_actions > 0 || self.invalid_actions > 0 || self.errored_actions > 0
}
}

#[derive(Debug, Default)]
pub struct Metrics {
pub max_steps: usize,
pub current_step: usize,
pub valid_responses: usize,
pub valid_actions: usize,
pub success_actions: usize,
pub errors: ErrorMetrics,
}

impl Display for Metrics {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[{}] steps:", "statistics".bold().blue())?;
if self.max_steps > 0 {
write!(f, "{}/{} ", self.current_step, self.max_steps)?;
} else {
write!(f, "{} ", self.current_step)?;
}

if self.errors.has_response_errors() {
write!(
f,
"responses(valid:{} empty:{} broken:{}) ",
self.valid_responses, self.errors.empty_responses, self.errors.unparsed_responses
)?;
} else if self.valid_responses > 0 {
write!(f, "responses:{} ", self.valid_responses)?;
}

if self.errors.has_action_errors() {
write!(
f,
"actions(valid:{} ok:{} errored:{} unknown:{} invalid:{})",
self.valid_actions,
self.success_actions,
self.errors.errored_actions,
self.errors.unknown_actions,
self.errors.invalid_actions
)?;
} else if self.valid_actions > 0 {
write!(f, "actions:{}", self.valid_actions,)?;
}

Ok(())
}
}
29 changes: 19 additions & 10 deletions src/agent/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::{

use anyhow::Result;
use colored::Colorize;
use metrics::Metrics;

use super::{
generator::Message,
Expand All @@ -19,15 +20,13 @@ use history::{Execution, History};
use storage::Storage;

mod history;
mod metrics;
pub(crate) mod storage;

#[derive(Debug)]
pub struct State {
// the task
task: Box<dyn Task>,
// current iteration and max
pub curr_iter: usize,
pub max_iters: usize,
// model memories, goals and other storages
storages: HashMap<String, Storage>,
// available actions and execution history
Expand All @@ -36,6 +35,8 @@ pub struct State {
history: Mutex<History>,
// set to true when task is complete
complete: AtomicBool,
// runtime metrics
pub metrics: Metrics,
}

impl State {
Expand Down Expand Up @@ -105,21 +106,25 @@ impl State {
goal.set_current(&prompt, false);
}

let metrics = Metrics {
max_steps: max_iterations,
..Default::default()
};

Ok(Self {
task,
storages,
history,
namespaces,
complete,
max_iters: max_iterations,
curr_iter: 0,
metrics,
})
}

pub fn on_next_iteration(&mut self) -> Result<()> {
self.curr_iter += 1;
if self.max_iters > 0 && self.curr_iter >= self.max_iters {
Err(anyhow!("maximum number of iterations reached"))
pub fn on_step(&mut self) -> Result<()> {
self.metrics.current_step += 1;
if self.metrics.max_steps > 0 && self.metrics.current_step >= self.metrics.max_steps {
Err(anyhow!("maximum number of steps reached"))
} else {
Ok(())
}
Expand Down Expand Up @@ -262,10 +267,11 @@ impl State {
true
}

pub async fn execute(&self, invocation: Invocation) -> Result<()> {
pub async fn execute(&mut self, invocation: Invocation) -> Result<()> {
let action = match self.get_action(&invocation.action) {
Some(action) => action,
None => {
self.metrics.errors.unknown_actions += 1;
// tell the model that the action name is wrong
self.add_error_to_history(
invocation.clone(),
Expand All @@ -277,6 +283,7 @@ impl State {

// validate prerequisites
if !self.validate(&invocation, action) {
self.metrics.errors.invalid_actions += 1;
// not a core error, just inform the model and return
return Ok(());
}
Expand All @@ -285,9 +292,11 @@ impl State {
let inv = invocation.clone();
let ret = action.run(self, invocation.attributes, invocation.payload);
if let Err(error) = ret {
self.metrics.errors.errored_actions += 1;
// tell the model about the error
self.add_error_to_history(inv, error.to_string());
} else {
self.metrics.success_actions += 1;
// tell the model about the output
self.add_success_to_history(inv, ret.unwrap());
}
Expand Down
4 changes: 4 additions & 0 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ pub(crate) struct Args {
/// Print the documentation of the available action namespaces.
#[arg(long)]
pub generate_doc: bool,
/// Report runtime statistics.
#[arg(long)]
pub stats: bool,
}

impl Args {
Expand All @@ -61,6 +64,7 @@ impl Args {
max_iterations: self.max_iterations,
save_to: self.save_to.clone(),
full_dump: self.full_dump,
with_stats: self.stats,
}
}

Expand Down
2 changes: 0 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ async fn main() -> Result<()> {

// keep going until the task is complete or a fatal error is reached
while !agent.get_state().is_complete() {
// TODO: collect & report statistics (steps, model errors, etc)

// next step
if let Err(error) = agent.step().await {
println!("{}", error.to_string().bold().red());
Expand Down

0 comments on commit fbf7a4a

Please sign in to comment.