diff --git a/crates/goose-cli/src/agents/agent.rs b/crates/goose-cli/src/agents/agent.rs new file mode 100644 index 000000000..850f7685b --- /dev/null +++ b/crates/goose-cli/src/agents/agent.rs @@ -0,0 +1,21 @@ +use anyhow::Result; +use async_trait::async_trait; +use futures::stream::BoxStream; +use goose::{agent::Agent as GooseAgent, models::message::Message, systems::System}; + +#[async_trait] +pub trait Agent { + fn add_system(&mut self, system: Box); + async fn reply(&self, messages: &[Message]) -> Result>>; +} + +#[async_trait] +impl Agent for GooseAgent { + fn add_system(&mut self, system: Box) { + self.add_system(system); + } + + async fn reply(&self, messages: &[Message]) -> Result>> { + self.reply(messages).await + } +} diff --git a/crates/goose-cli/src/agents/mock_agent.rs b/crates/goose-cli/src/agents/mock_agent.rs new file mode 100644 index 000000000..c844ff94b --- /dev/null +++ b/crates/goose-cli/src/agents/mock_agent.rs @@ -0,0 +1,19 @@ +use anyhow::Result; +use async_trait::async_trait; +use futures::stream::BoxStream; +use goose::{models::message::Message, systems::System}; + +use crate::agents::agent::Agent; + +pub struct MockAgent; + +#[async_trait] +impl Agent for MockAgent { + fn add_system(&mut self, _system: Box) { + (); + } + + async fn reply(&self, _messages: &[Message]) -> Result>> { + Ok(Box::pin(futures::stream::empty())) + } +} diff --git a/crates/goose-cli/src/agents/mod.rs b/crates/goose-cli/src/agents/mod.rs new file mode 100644 index 000000000..14b5f00cd --- /dev/null +++ b/crates/goose-cli/src/agents/mod.rs @@ -0,0 +1,4 @@ +pub mod agent; + +#[cfg(test)] +pub mod mock_agent; diff --git a/crates/goose-cli/src/main.rs b/crates/goose-cli/src/main.rs index d3658d2f4..9e38cc416 100644 --- a/crates/goose-cli/src/main.rs +++ b/crates/goose-cli/src/main.rs @@ -4,13 +4,12 @@ mod commands { pub mod session; pub mod version; } +pub mod agents; mod inputs; mod profile; mod prompt; -mod session { - pub mod session; - pub mod session_file; -} +pub mod session; + mod systems; use anyhow::Result; diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 2c1d08a10..c8cb955fe 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -1,3 +1,2 @@ pub mod session; pub mod session_file; -pub mod session_manager; diff --git a/crates/goose-cli/src/session/session.rs b/crates/goose-cli/src/session/session.rs index feb9c6977..0e0b61f29 100644 --- a/crates/goose-cli/src/session/session.rs +++ b/crates/goose-cli/src/session/session.rs @@ -2,25 +2,25 @@ use anyhow::Result; use futures::StreamExt; use std::path::PathBuf; +use crate::agents::agent::Agent; use crate::prompt::prompt::{InputType, Prompt}; use crate::session::session_file::{persist_messages, readable_session_file}; use crate::systems::goose_hints::GooseHintsSystem; -use goose::agent::Agent; use goose::developer::DeveloperSystem; -use goose::models::message::Message; +use goose::models::message::{Message, MessageContent}; use goose::models::role::Role; use super::session_file::deserialize_messages; pub struct Session<'a> { - agent: Box, + agent: Box, prompt: Box, session_file: PathBuf, messages: Vec, } impl<'a> Session<'a> { - pub fn new(agent: Box, prompt: Box, session_file: PathBuf) -> Self { + pub fn new(agent: Box, prompt: Box, session_file: PathBuf) -> Self { let messages = match readable_session_file(&session_file) { Ok(file) => deserialize_messages(file).unwrap_or_else(|e| { eprintln!( @@ -110,14 +110,7 @@ impl<'a> Session<'a> { } _ = tokio::signal::ctrl_c() => { drop(stream); - // Pop all 'messages' from the assistant and the most recent user message. Resets the interaction to before the interrupted user request. - while let Some(message) = self.messages.pop() { - if message.role == Role::User { - break; - } - // else drop any assistant messages. - } - + self.rewind_messages(); self.prompt.render(raw_message(" Interrupt: Resetting conversation to before the last sent message...\n")); break; } @@ -125,6 +118,31 @@ impl<'a> Session<'a> { } } + /// Rewind the messages to before the last user message (they have cancelled it). + pub fn rewind_messages(&mut self) { + if self.messages.is_empty() { + return; + } + + // Remove messages until we find the last user 'Text' message (not a tool response). + while let Some(message) = self.messages.last() { + if message.role == Role::User + && message + .content + .iter() + .any(|c| matches!(c, MessageContent::Text(_))) + { + break; + } + self.messages.pop(); + } + + // Remove the last user text message we found. + if !self.messages.is_empty() { + self.messages.pop(); + } + } + fn setup_session(&mut self) { let system = Box::new(DeveloperSystem::new()); self.agent.add_system(system); @@ -154,3 +172,118 @@ impl<'a> Session<'a> { fn raw_message(content: &str) -> Box { Box::new(Message::assistant().with_text(content)) } + +#[cfg(test)] +mod tests { + use crate::agents::mock_agent::MockAgent; + use crate::prompt::prompt::{self, Input}; + + use super::*; + use goose::{errors::AgentResult, models::tool::ToolCall}; + use tempfile::NamedTempFile; + + // Helper function to create a test session + fn create_test_session() -> Session<'static> { + let temp_file = NamedTempFile::new().unwrap(); + let agent = Box::new(MockAgent {}); + let prompt = Box::new(MockPrompt {}); + Session::new(agent, prompt, temp_file.path().to_path_buf()) + } + + // Mock prompt implementation for testing + struct MockPrompt {} + impl Prompt for MockPrompt { + fn get_input(&mut self) -> std::result::Result { + Ok(Input { + input_type: InputType::Message, + content: Some("Msg:".to_string()), + }) + } + fn render(&mut self, _: Box) {} + fn show_busy(&mut self) {} + fn hide_busy(&self) {} + fn goose_ready(&self) {} + fn close(&self) {} + } + + #[test] + fn test_rewind_messages_only_user() { + let mut session = create_test_session(); + session.messages.push(Message::user().with_text("Hello")); + + session.rewind_messages(); + assert!(session.messages.is_empty()); + } + + #[test] + fn test_rewind_messages_user_then_assistant() { + let mut session = create_test_session(); + session.messages.push(Message::user().with_text("Hello")); + session + .messages + .push(Message::assistant().with_text("World")); + + session.rewind_messages(); + assert!(session.messages.is_empty()); + } + + #[test] + fn test_rewind_messages_multiple_user_messages() { + let mut session = create_test_session(); + session.messages.push(Message::user().with_text("First")); + session + .messages + .push(Message::assistant().with_text("Response 1")); + session.messages.push(Message::user().with_text("Second")); + session.rewind_messages(); + assert_eq!(session.messages.len(), 2); + assert_eq!(session.messages[0].role, Role::User); + assert_eq!(session.messages[1].role, Role::Assistant); + assert_eq!( + session.messages[0].content[0], + MessageContent::text("First") + ); + assert_eq!( + session.messages[1].content[0], + MessageContent::text("Response 1") + ); + } + + #[test] + fn test_rewind_messages_after_interrupted_tool_request() { + let mut session = create_test_session(); + session.messages.push(Message::user().with_text("First")); + session + .messages + .push(Message::assistant().with_text("Response 1")); + session.messages.push(Message::user().with_text("Use tool")); + + let mut mixed_msg = Message::assistant(); + mixed_msg.content.push(MessageContent::text("Using tool")); + mixed_msg.content.push(MessageContent::tool_request( + "test", + AgentResult::Ok(ToolCall::new("test", "test".into())), + )); + session.messages.push(mixed_msg); + + session.messages.push(Message::user().with_tool_response( + "test", + Err(goose::errors::AgentError::ExecutionError( + "Test".to_string(), + )), + )); + + session.rewind_messages(); + assert_eq!(session.messages.len(), 2); + assert_eq!(session.messages[0].role, Role::User); + assert_eq!(session.messages[1].role, Role::Assistant); + assert_eq!( + session.messages[0].content[0], + MessageContent::text("First") + ); + assert_eq!( + session.messages[1].content[0], + MessageContent::text("Response 1") + ); + } +}