Skip to content

Commit

Permalink
[cli] Rewind messages past last user text message, ignore user tool r…
Browse files Browse the repository at this point in the history
…esult message
  • Loading branch information
jsibbison-square committed Nov 24, 2024
1 parent 470efe5 commit b75919d
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 48 deletions.
6 changes: 2 additions & 4 deletions crates/goose-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ mod commands {
mod inputs;
mod profile;
mod prompt;
mod session {
pub mod session;
pub mod session_file;
}
pub mod session;

mod systems;

use anyhow::Result;
Expand Down
46 changes: 46 additions & 0 deletions crates/goose-cli/src/session/mock_provider.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use anyhow::Result;
use async_trait::async_trait;
use std::sync::Arc;
use std::sync::Mutex;

use goose::models::message::Message;
use goose::models::tool::Tool;
use goose::providers::base::{Provider, Usage};

///
/// This is a copy of crates/goose/src/providers/mock.rs that I can't use as its configured out in that crate.
/// I need to use this in the test module of crates/goose-cli/src/session/session.rs but really what I need is to
/// mock the agent. But that requires a bit of refactor to use an agent trait which I don't want to do in this PR.
/// Therefore its a TODO to create a mock agent and remove this mock provider.
/// A mock provider that returns pre-configured responses for testing
pub struct MockProvider {
responses: Arc<Mutex<Vec<Message>>>,
}

impl MockProvider {
/// Create a new mock provider with a sequence of responses
pub fn new(responses: Vec<Message>) -> Self {
Self {
responses: Arc::new(Mutex::new(responses)),
}
}
}

#[async_trait]
impl Provider for MockProvider {
async fn complete(
&self,
_system_prompt: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<(Message, Usage)> {
let mut responses = self.responses.lock().unwrap();
if responses.is_empty() {
// Return empty response if no more pre-configured responses
Ok((Message::assistant().with_text(""), Usage::default()))
} else {
Ok((responses.remove(0), Usage::default()))
}
}
}
4 changes: 3 additions & 1 deletion crates/goose-cli/src/session/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub mod session;
pub mod session_file;
pub mod session_manager;

#[cfg(test)]
pub mod mock_provider;
117 changes: 74 additions & 43 deletions crates/goose-cli/src/session/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,6 @@ impl<'a> Session<'a> {
}
}

pub fn rewind_messages(&mut self) {
if self.messages.is_empty() {
return;
}

// Remove messages until we find the last user text message
while let Some(message) = self.messages.last() {
if message.role == Role::User && message.content.iter().any(|c| matches!(c, MessageContent::Text(_))) {
break;
}
self.messages.pop();
}

// Remove the last user text message we found
if !self.messages.is_empty() {
self.messages.pop();
}
}

pub async fn start(&mut self) -> Result<(), Box<dyn std::error::Error>> {
self.setup_session();

Expand Down Expand Up @@ -129,19 +110,39 @@ impl<'a> Session<'a> {
}
_ = tokio::signal::ctrl_c() => {
drop(stream);
println!("Popping...");
println!("Old Messages: {:?}", self.messages);
// Pop all 'messages' from the assistant and the most recent user message. Resets the interaction to before the interrupted user request.
self.rewind_messages();
println!("New Messages: {:?}", self.messages);

self.prompt.render(raw_message(" Interrupt: Resetting conversation to before the last sent message...\n"));
break;
}
}
}
}

/// Rewind the messages to before the last user message (they have cancelled it).
pub fn rewind_messages(&mut self) {
if self.messages.is_empty() {
return;
}

// Remove messages until we find the last user 'Text' message (not a tool response).
while let Some(message) = self.messages.last() {
if message.role == Role::User
&& message
.content
.iter()
.any(|c| matches!(c, MessageContent::Text(_)))
{
break;
}
self.messages.pop();
}

// Remove the last user text message we found.
if !self.messages.is_empty() {
self.messages.pop();
}
}

fn setup_session(&mut self) {
let system = Box::new(DeveloperSystem::new());
self.agent.add_system(system);
Expand Down Expand Up @@ -175,9 +176,10 @@ fn raw_message(content: &str) -> Box<Message> {
#[cfg(test)]
mod tests {
use crate::prompt::prompt::{self, Input};
use crate::session::mock_provider::MockProvider;

use super::*;
use goose::{errors::AgentResult, models::tool::{Tool, ToolCall}, providers::{base::Provider, mock::MockProvider}};
use goose::{errors::AgentResult, models::tool::ToolCall, providers::base::Provider};
use tempfile::NamedTempFile;

// Helper function to create a test session
Expand All @@ -192,7 +194,12 @@ mod tests {
// Mock prompt implementation for testing
struct MockPrompt {}
impl Prompt for MockPrompt {
fn get_input(&mut self) -> std::result::Result<prompt::Input, anyhow::Error> { Ok(Input {input_type: InputType::Message, content: Some("Msg:".to_string())}) }
fn get_input(&mut self) -> std::result::Result<prompt::Input, anyhow::Error> {
Ok(Input {
input_type: InputType::Message,
content: Some("Msg:".to_string()),
})
}
fn render(&mut self, _: Box<Message>) {}
fn show_busy(&mut self) {}
fn hide_busy(&self) {}
Expand All @@ -204,7 +211,7 @@ mod tests {
fn test_rewind_messages_only_user() {
let mut session = create_test_session();
session.messages.push(Message::user().with_text("Hello"));

session.rewind_messages();
assert!(session.messages.is_empty());
}
Expand All @@ -213,8 +220,10 @@ mod tests {
fn test_rewind_messages_user_then_assistant() {
let mut session = create_test_session();
session.messages.push(Message::user().with_text("Hello"));
session.messages.push(Message::assistant().with_text("World"));

session
.messages
.push(Message::assistant().with_text("World"));

session.rewind_messages();
assert!(session.messages.is_empty());
}
Expand All @@ -223,37 +232,59 @@ mod tests {
fn test_rewind_messages_multiple_user_messages() {
let mut session = create_test_session();
session.messages.push(Message::user().with_text("First"));
session.messages.push(Message::assistant().with_text("Response 1"));
session
.messages
.push(Message::assistant().with_text("Response 1"));
session.messages.push(Message::user().with_text("Second"));
session.messages.push(Message::assistant().with_text("Response 2"));

session.rewind_messages();
assert_eq!(session.messages.len(), 2);
assert_eq!(session.messages[0].role, Role::User);
assert_eq!(session.messages[1].role, Role::Assistant);
assert_eq!(session.messages[0].content[0], MessageContent::text("First"));
assert_eq!(session.messages[1].content[0], MessageContent::text("Response 1"));
assert_eq!(
session.messages[0].content[0],
MessageContent::text("First")
);
assert_eq!(
session.messages[1].content[0],
MessageContent::text("Response 1")
);
}

#[test]
fn test_rewind_messages_after_interrupted_tool_request() {
let mut session = create_test_session();
session.messages.push(Message::user().with_text("First"));
session.messages.push(Message::assistant().with_text("Response 1"));
session
.messages
.push(Message::assistant().with_text("Response 1"));
session.messages.push(Message::user().with_text("Use tool"));

let mut mixed_msg = Message::assistant();
mixed_msg.content.push(MessageContent::text("Using tool"));
mixed_msg.content.push(MessageContent::tool_request("test", AgentResult::Ok(ToolCall::new("test", "test".into()))));
mixed_msg.content.push(MessageContent::tool_request(
"test",
AgentResult::Ok(ToolCall::new("test", "test".into())),
));
session.messages.push(mixed_msg);

session.messages.push(Message::user().with_tool_response("test", Err(goose::errors::AgentError::ExecutionError("Test".to_string()))));


session.messages.push(Message::user().with_tool_response(
"test",
Err(goose::errors::AgentError::ExecutionError(
"Test".to_string(),
)),
));

session.rewind_messages();
assert_eq!(session.messages.len(), 2);
assert_eq!(session.messages[0].role, Role::User);
assert_eq!(session.messages[1].role, Role::Assistant);
assert_eq!(session.messages[0].content[0], MessageContent::text("First"));
assert_eq!(session.messages[1].content[0], MessageContent::text("Response 1"));
assert_eq!(
session.messages[0].content[0],
MessageContent::text("First")
);
assert_eq!(
session.messages[1].content[0],
MessageContent::text("Response 1")
);
}
}
}

0 comments on commit b75919d

Please sign in to comment.