From 6cd157a535ed2f0bb6501064c8a482f0f4ebab81 Mon Sep 17 00:00:00 2001 From: Jarrod Sibbison Date: Mon, 25 Nov 2024 13:28:29 +1100 Subject: [PATCH] [cli] Runs agent as an async spawned task --- crates/goose-cli/src/commands/session.rs | 7 +- crates/goose-cli/src/session/session.rs | 119 +++++++++++++++-------- 2 files changed, 80 insertions(+), 46 deletions(-) diff --git a/crates/goose-cli/src/commands/session.rs b/crates/goose-cli/src/commands/session.rs index 928a85078..c60f61c0e 100644 --- a/crates/goose-cli/src/commands/session.rs +++ b/crates/goose-cli/src/commands/session.rs @@ -1,9 +1,7 @@ use rand::{distributions::Alphanumeric, Rng}; use std::path::{Path, PathBuf}; -use goose::agent::Agent; use goose::models::message::Message; -use goose::providers::factory; use crate::commands::expected_config::get_recommended_models; use crate::profile::profile::Profile; @@ -47,9 +45,6 @@ pub fn build_session<'a>( let provider_config = set_provider_config(&loaded_profile.provider, loaded_profile.model.clone()); - // TODO: Odd to be prepping the provider rather than having that done in the agent? - let provider = factory::get_provider(provider_config).unwrap(); - let agent = Box::new(Agent::new(provider)); let mut prompt = match std::env::var("GOOSE_INPUT") { Ok(val) => match val.as_str() { "cliclack" => Box::new(CliclackPrompt::new()) as Box, @@ -70,7 +65,7 @@ pub fn build_session<'a>( session_file.display() )))); - Box::new(Session::new(agent, prompt, session_file)) + Box::new(Session::new(provider_config.clone(), prompt, session_file)) } fn session_path( diff --git a/crates/goose-cli/src/session/session.rs b/crates/goose-cli/src/session/session.rs index 0e0b61f29..071a09005 100644 --- a/crates/goose-cli/src/session/session.rs +++ b/crates/goose-cli/src/session/session.rs @@ -1,11 +1,14 @@ use anyhow::Result; -use futures::StreamExt; +use futures::{FutureExt, StreamExt}; +use goose::providers::configs::ProviderConfig; +use goose::providers::factory; use std::path::PathBuf; +use tokio::sync::mpsc; -use crate::agents::agent::Agent; use crate::prompt::prompt::{InputType, Prompt}; use crate::session::session_file::{persist_messages, readable_session_file}; use crate::systems::goose_hints::GooseHintsSystem; +use goose::agent::Agent; use goose::developer::DeveloperSystem; use goose::models::message::{Message, MessageContent}; use goose::models::role::Role; @@ -13,14 +16,18 @@ use goose::models::role::Role; use super::session_file::deserialize_messages; pub struct Session<'a> { - agent: Box, + provider_config: ProviderConfig, prompt: Box, session_file: PathBuf, messages: Vec, } impl<'a> Session<'a> { - pub fn new(agent: Box, prompt: Box, session_file: PathBuf) -> Self { + pub fn new( + provider_config: ProviderConfig, + prompt: Box, + session_file: PathBuf, + ) -> Self { let messages = match readable_session_file(&session_file) { Ok(file) => deserialize_messages(file).unwrap_or_else(|e| { eprintln!( @@ -36,7 +43,7 @@ impl<'a> Session<'a> { }; Session { - agent, + provider_config, prompt, session_file, messages, @@ -84,36 +91,72 @@ impl<'a> Session<'a> { } async fn agent_process_messages(&mut self) { - let mut stream = match self.agent.reply(&self.messages).await { - Ok(stream) => stream, - Err(e) => { - eprintln!("Error starting reply stream: {}", e); - return; + let (tx, mut rx) = mpsc::channel::>>(1); + + let messages = self.messages.clone(); + let provider_config = self.provider_config.clone(); + let (abort_tx, abort_rx) = tokio::sync::oneshot::channel(); + tokio::spawn(async move { + let abort_rx = abort_rx.fuse(); + futures::pin_mut!(abort_rx); + let provider = factory::get_provider(provider_config).unwrap(); + let mut agent = Box::new(Agent::new(provider)); + + let system = Box::new(DeveloperSystem::new()); + agent.add_system(system); + let goosehints_system = Box::new(GooseHintsSystem::new()); + agent.add_system(goosehints_system); + + let mut stream = match agent.reply(&messages).await { + Ok(stream) => stream, + Err(e) => { + eprintln!("Error starting reply stream: {}", e); + return; + } + }; + let mut done = false; + loop { + tokio::select! { + response = stream.next() => { + match response { + Some(something)=>{tx.send(Some(something)).await.unwrap();} + None => break + } + } + _ = &mut abort_rx => { + done = true; + eprintln!("Agent thread aborted"); + } + } + if done { + drop(stream); + break; + } } - }; - loop { - tokio::select! { - response = stream.next() => { - match response { + }); + + tokio::select! { + _ = async { + while let Some(res) = rx.recv().await { + match res { Some(Ok(message)) => { self.messages.push(message.clone()); persist_messages(&self.session_file, &self.messages).unwrap_or_else(|e| eprintln!("Failed to persist messages: {}", e)); + self.prompt.hide_busy(); self.prompt.render(Box::new(message.clone())); - } + self.prompt.show_busy(); + }, Some(Err(e)) => { - // TODO: Handle error display through prompt eprintln!("Error: {}", e); - break; - } - None => break, + }, + None => {} } } - _ = tokio::signal::ctrl_c() => { - drop(stream); - self.rewind_messages(); - self.prompt.render(raw_message(" Interrupt: Resetting conversation to before the last sent message...\n")); - break; - } + } => {} + _ = tokio::signal::ctrl_c() => { + let _ = abort_tx.send(()); + self.rewind_messages(); + self.prompt.render(raw_message(" Interrupt: Resetting conversation to before the last sent message...\n")); } } } @@ -144,16 +187,6 @@ impl<'a> Session<'a> { } fn setup_session(&mut self) { - let system = Box::new(DeveloperSystem::new()); - self.agent.add_system(system); - self.prompt - .render(raw_message("Connected developer system.")); - - let goosehints_system = Box::new(GooseHintsSystem::new()); - self.agent.add_system(goosehints_system); - self.prompt - .render(raw_message("Connected .goosehints system.")); - self.prompt.goose_ready(); } @@ -175,19 +208,25 @@ fn raw_message(content: &str) -> Box { #[cfg(test)] mod tests { - use crate::agents::mock_agent::MockAgent; use crate::prompt::prompt::{self, Input}; use super::*; - use goose::{errors::AgentResult, models::tool::ToolCall}; + use goose::{ + errors::AgentResult, models::tool::ToolCall, providers::configs::OllamaProviderConfig, + }; use tempfile::NamedTempFile; // Helper function to create a test session fn create_test_session() -> Session<'static> { let temp_file = NamedTempFile::new().unwrap(); - let agent = Box::new(MockAgent {}); let prompt = Box::new(MockPrompt {}); - Session::new(agent, prompt, temp_file.path().to_path_buf()) + let provider_config = ProviderConfig::Ollama(OllamaProviderConfig { + model: "test".to_string(), + host: "".to_string(), + temperature: None, + max_tokens: None, + }); + Session::new(provider_config, prompt, temp_file.path().to_path_buf()) } // Mock prompt implementation for testing