Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cli] Runs agent as an async spawned task #328

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions crates/goose-cli/src/commands/session.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use rand::{distributions::Alphanumeric, Rng};
use std::path::{Path, PathBuf};

use goose::agent::Agent;
use goose::models::message::Message;
use goose::providers::factory;

use crate::commands::expected_config::get_recommended_models;
use crate::profile::profile::Profile;
Expand Down Expand Up @@ -47,9 +45,6 @@ pub fn build_session<'a>(
let provider_config =
set_provider_config(&loaded_profile.provider, loaded_profile.model.clone());

// TODO: Odd to be prepping the provider rather than having that done in the agent?
let provider = factory::get_provider(provider_config).unwrap();
let agent = Box::new(Agent::new(provider));
let mut prompt = match std::env::var("GOOSE_INPUT") {
Ok(val) => match val.as_str() {
"cliclack" => Box::new(CliclackPrompt::new()) as Box<dyn Prompt>,
Expand All @@ -70,7 +65,7 @@ pub fn build_session<'a>(
session_file.display()
))));

Box::new(Session::new(agent, prompt, session_file))
Box::new(Session::new(provider_config.clone(), prompt, session_file))
}

fn session_path(
Expand Down
119 changes: 79 additions & 40 deletions crates/goose-cli/src/session/session.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,33 @@
use anyhow::Result;
use futures::StreamExt;
use futures::{FutureExt, StreamExt};
use goose::providers::configs::ProviderConfig;
use goose::providers::factory;
use std::path::PathBuf;
use tokio::sync::mpsc;

use crate::agents::agent::Agent;
use crate::prompt::prompt::{InputType, Prompt};
use crate::session::session_file::{persist_messages, readable_session_file};
use crate::systems::goose_hints::GooseHintsSystem;
use goose::agent::Agent;
use goose::developer::DeveloperSystem;
use goose::models::message::{Message, MessageContent};
use goose::models::role::Role;

use super::session_file::deserialize_messages;

pub struct Session<'a> {
agent: Box<dyn Agent>,
provider_config: ProviderConfig,
prompt: Box<dyn Prompt + 'a>,
session_file: PathBuf,
messages: Vec<Message>,
}

impl<'a> Session<'a> {
pub fn new(agent: Box<dyn Agent>, prompt: Box<dyn Prompt + 'a>, session_file: PathBuf) -> Self {
pub fn new(
provider_config: ProviderConfig,
prompt: Box<dyn Prompt + 'a>,
session_file: PathBuf,
) -> Self {
let messages = match readable_session_file(&session_file) {
Ok(file) => deserialize_messages(file).unwrap_or_else(|e| {
eprintln!(
Expand All @@ -36,7 +43,7 @@ impl<'a> Session<'a> {
};

Session {
agent,
provider_config,
prompt,
session_file,
messages,
Expand Down Expand Up @@ -84,36 +91,72 @@ impl<'a> Session<'a> {
}

async fn agent_process_messages(&mut self) {
let mut stream = match self.agent.reply(&self.messages).await {
Ok(stream) => stream,
Err(e) => {
eprintln!("Error starting reply stream: {}", e);
return;
let (tx, mut rx) = mpsc::channel::<Option<Result<Message>>>(1);

let messages = self.messages.clone();
let provider_config = self.provider_config.clone();
let (abort_tx, abort_rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
let abort_rx = abort_rx.fuse();
futures::pin_mut!(abort_rx);
let provider = factory::get_provider(provider_config).unwrap();
let mut agent = Box::new(Agent::new(provider));

let system = Box::new(DeveloperSystem::new());
agent.add_system(system);
let goosehints_system = Box::new(GooseHintsSystem::new());
agent.add_system(goosehints_system);

let mut stream = match agent.reply(&messages).await {
Ok(stream) => stream,
Err(e) => {
eprintln!("Error starting reply stream: {}", e);
return;
}
};
let mut done = false;
loop {
tokio::select! {
response = stream.next() => {
match response {
Some(something)=>{tx.send(Some(something)).await.unwrap();}
None => break
}
}
_ = &mut abort_rx => {
done = true;
eprintln!("Agent thread aborted");
}
}
if done {
drop(stream);
break;
}
}
};
loop {
tokio::select! {
response = stream.next() => {
match response {
});

tokio::select! {
_ = async {
while let Some(res) = rx.recv().await {
match res {
Some(Ok(message)) => {
self.messages.push(message.clone());
persist_messages(&self.session_file, &self.messages).unwrap_or_else(|e| eprintln!("Failed to persist messages: {}", e));
self.prompt.hide_busy();
self.prompt.render(Box::new(message.clone()));
}
self.prompt.show_busy();
},
Some(Err(e)) => {
// TODO: Handle error display through prompt
eprintln!("Error: {}", e);
break;
}
None => break,
},
None => {}
}
}
_ = tokio::signal::ctrl_c() => {
drop(stream);
self.rewind_messages();
self.prompt.render(raw_message(" Interrupt: Resetting conversation to before the last sent message...\n"));
break;
}
} => {}
_ = tokio::signal::ctrl_c() => {
let _ = abort_tx.send(());
self.rewind_messages();
self.prompt.render(raw_message(" Interrupt: Resetting conversation to before the last sent message...\n"));
}
}
}
Expand Down Expand Up @@ -144,16 +187,6 @@ impl<'a> Session<'a> {
}

fn setup_session(&mut self) {
let system = Box::new(DeveloperSystem::new());
self.agent.add_system(system);
self.prompt
.render(raw_message("Connected developer system."));

let goosehints_system = Box::new(GooseHintsSystem::new());
self.agent.add_system(goosehints_system);
self.prompt
.render(raw_message("Connected .goosehints system."));

self.prompt.goose_ready();
}

Expand All @@ -175,19 +208,25 @@ fn raw_message(content: &str) -> Box<Message> {

#[cfg(test)]
mod tests {
use crate::agents::mock_agent::MockAgent;
use crate::prompt::prompt::{self, Input};

use super::*;
use goose::{errors::AgentResult, models::tool::ToolCall};
use goose::{
errors::AgentResult, models::tool::ToolCall, providers::configs::OllamaProviderConfig,
};
use tempfile::NamedTempFile;

// Helper function to create a test session
fn create_test_session() -> Session<'static> {
let temp_file = NamedTempFile::new().unwrap();
let agent = Box::new(MockAgent {});
let prompt = Box::new(MockPrompt {});
Session::new(agent, prompt, temp_file.path().to_path_buf())
let provider_config = ProviderConfig::Ollama(OllamaProviderConfig {
model: "test".to_string(),
host: "".to_string(),
temperature: None,
max_tokens: None,
});
Session::new(provider_config, prompt, temp_file.path().to_path_buf())
}

// Mock prompt implementation for testing
Expand Down
Loading