Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basis for named agents #525

Merged
merged 11 commits into from
Jan 1, 2025
Merged
41 changes: 28 additions & 13 deletions crates/goose-cli/src/agents/agent.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,43 @@
use anyhow::Result;
// use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
// use futures::stream::BoxStream;
use goose::{
agent::Agent as GooseAgent, message::Message, providers::base::ProviderUsage, systems::System,
agents::Agent, providers::base::Provider, providers::base::ProviderUsage, systems::System,
};
use tokio::sync::Mutex;

#[async_trait]
pub trait Agent {
fn add_system(&mut self, system: Box<dyn System>);
async fn reply(&self, messages: &[Message]) -> Result<BoxStream<'_, Result<Message>>>;
async fn usage(&self) -> Result<Vec<ProviderUsage>>;
pub struct GooseAgent {
systems: Vec<Box<dyn System>>,
provider: Box<dyn Provider>,
provider_usage: Mutex<Vec<ProviderUsage>>,
}

#[allow(dead_code)]
impl GooseAgent {
pub fn new(provider: Box<dyn Provider>) -> Self {
Self {
systems: Vec::new(),
provider,
provider_usage: Mutex::new(Vec::new()),
}
}
}

#[async_trait]
impl Agent for GooseAgent {
fn add_system(&mut self, system: Box<dyn System>) {
self.add_system(system);
self.systems.push(system);
}

fn get_systems(&self) -> &Vec<Box<dyn System>> {
&self.systems
}

async fn reply(&self, messages: &[Message]) -> Result<BoxStream<'_, Result<Message>>> {
self.reply(messages).await
fn get_provider(&self) -> &Box<dyn Provider> {
&self.provider
}

async fn usage(&self) -> Result<Vec<ProviderUsage>> {
self.usage().await
fn get_provider_usage(&self) -> &Mutex<Vec<ProviderUsage>> {
&self.provider_usage
}
}
41 changes: 37 additions & 4 deletions crates/goose-cli/src/agents/mock_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,48 @@ use std::vec;
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use goose::{message::Message, providers::base::ProviderUsage, systems::System};
use goose::providers::mock::MockProvider;
use goose::{
agents::Agent,
message::Message,
providers::base::{Provider, ProviderUsage},
systems::System,
};
use tokio::sync::Mutex;

use crate::agents::agent::Agent;
pub struct MockAgent {
systems: Vec<Box<dyn System>>,
provider: Box<dyn Provider>,
provider_usage: Mutex<Vec<ProviderUsage>>,
}

pub struct MockAgent;
impl MockAgent {
pub fn new() -> Self {
Self {
systems: Vec::new(),
provider: Box::new(MockProvider::new(Vec::new())),
provider_usage: Mutex::new(Vec::new()),
}
}
}

#[async_trait]
impl Agent for MockAgent {
fn add_system(&mut self, _system: Box<dyn System>) {}
fn add_system(&mut self, system: Box<dyn System>) {
self.systems.push(system);
}

fn get_systems(&self) -> &Vec<Box<dyn System>> {
&self.systems
}

fn get_provider(&self) -> &Box<dyn Provider> {
&self.provider
}

fn get_provider_usage(&self) -> &Mutex<Vec<ProviderUsage>> {
&self.provider_usage
}

async fn reply(&self, _messages: &[Message]) -> Result<BoxStream<'_, Result<Message>>> {
Ok(Box::pin(futures::stream::empty()))
Expand Down
28 changes: 28 additions & 0 deletions crates/goose-cli/src/commands/agent_version.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use anyhow::Result;
use clap::Args;
use goose::agents::AgentFactory;
use std::fmt::Write;

#[derive(Args)]
pub struct AgentCommand {}

impl AgentCommand {
pub fn run(&self) -> Result<()> {
let mut output = String::new();
writeln!(output, "Available agent versions:")?;

let versions = AgentFactory::available_versions();
let default_version = AgentFactory::default_version();

for version in versions {
if version == default_version {
writeln!(output, "* {} (default)", version)?;
} else {
writeln!(output, " {}", version)?;
}
}

print!("{}", output);
Ok(())
}
}
2 changes: 1 addition & 1 deletion crates/goose-cli/src/commands/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod agent_version;
pub mod configure;
pub mod session;
pub mod version;
pub mod expected_config;
11 changes: 6 additions & 5 deletions crates/goose-cli/src/commands/session.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use console::style;
use goose::agent::Agent;
use goose::agents::AgentFactory;
use goose::providers::factory;
use rand::{distributions::Alphanumeric, Rng};
use std::path::{Path, PathBuf};
Expand All @@ -13,6 +13,7 @@ use crate::session::{ensure_session_dir, get_most_recent_session, Session};
pub fn build_session<'a>(
session: Option<String>,
profile: Option<String>,
agent_version: Option<String>,
resume: bool,
) -> Box<Session<'a>> {
let session_dir = ensure_session_dir().expect("Failed to create session directory");
Expand Down Expand Up @@ -45,7 +46,7 @@ pub fn build_session<'a>(

// TODO: Odd to be prepping the provider rather than having that done in the agent?
let provider = factory::get_provider(provider_config).unwrap();
let agent = Box::new(Agent::new(provider));
let agent = AgentFactory::create(agent_version.as_deref().unwrap_or("base"), provider).unwrap();
let prompt = match std::env::var("GOOSE_INPUT") {
Ok(val) => match val.as_str() {
"rustyline" => Box::new(RustylinePrompt::new()) as Box<dyn Prompt>,
Expand Down Expand Up @@ -173,7 +174,7 @@ mod tests {
#[should_panic(expected = "Cannot resume session: file")]
fn test_resume_nonexistent_session_panics() {
run_with_tmp_dir(|| {
build_session(Some("nonexistent-session".to_string()), None, true);
build_session(Some("nonexistent-session".to_string()), None, None, true);
})
}

Expand All @@ -190,7 +191,7 @@ mod tests {
fs::write(&file2_path, "{}")?;

// Test resuming without a session name
let session = build_session(None, None, true);
let session = build_session(None, None, None, true);
assert_eq!(session.session_file().as_path(), file2_path.as_path());

Ok(())
Expand All @@ -201,7 +202,7 @@ mod tests {
#[should_panic(expected = "No session files found")]
fn test_resume_most_recent_session_no_files() {
run_with_tmp_dir(|| {
build_session(None, None, true);
build_session(None, None, None, true);
});
}
}
82 changes: 68 additions & 14 deletions crates/goose-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
mod commands {
pub mod configure;
pub mod session;
pub mod version;
}
pub mod agents;
use anyhow::Result;
use clap::{Parser, Subcommand};
use goose::agents::AgentFactory;

mod agents;
mod commands;
mod log_usage;
mod profile;
mod prompt;
pub mod session;

mod session;
mod systems;

use anyhow::Result;
use clap::{Parser, Subcommand};
use commands::agent_version::AgentCommand;
use commands::configure::handle_configure;
use commands::session::build_session;
use commands::version::print_version;
use profile::has_no_profiles;
use std::io::{self, Read};

mod log_usage;

#[cfg(test)]
mod test_helpers;

Expand Down Expand Up @@ -98,6 +95,15 @@ enum Command {
)]
profile: Option<String>,

/// Agent version to use (e.g., 'base', 'v1')
#[arg(
short,
long,
help = "Agent version to use (e.g., 'base', 'v1'), defaults to 'base'",
long_help = "Specify which agent version to use for this session."
)]
agent: Option<String>,

/// Resume a previous session
#[arg(
short,
Expand Down Expand Up @@ -151,6 +157,15 @@ enum Command {
)]
name: Option<String>,

/// Agent version to use (e.g., 'base', 'v1')
#[arg(
short,
long,
help = "Agent version to use (e.g., 'base', 'v1')",
long_help = "Specify which agent version to use for this session."
)]
agent: Option<String>,

/// Resume a previous run
#[arg(
short,
Expand All @@ -161,6 +176,9 @@ enum Command {
)]
resume: bool,
},

/// List available agent versions
Agents(AgentCommand),
}

#[derive(Subcommand)]
Expand Down Expand Up @@ -224,9 +242,25 @@ async fn main() -> Result<()> {
Some(Command::Session {
name,
profile,
agent,
resume,
}) => {
let mut session = build_session(name, profile, resume);
if let Some(agent_version) = agent.clone() {
if !AgentFactory::available_versions().contains(&agent_version.as_str()) {
eprintln!("Error: Invalid agent version '{}'", agent_version);
eprintln!("Available versions:");
for version in AgentFactory::available_versions() {
if version == AgentFactory::default_version() {
eprintln!("* {} (default)", version);
} else {
eprintln!(" {}", version);
}
}
std::process::exit(1);
}
}

let mut session = build_session(name, profile, agent, resume);
let _ = session.start().await;
return Ok(());
}
Expand All @@ -235,8 +269,24 @@ async fn main() -> Result<()> {
input_text,
profile,
name,
agent,
resume,
}) => {
if let Some(agent_version) = agent.clone() {
if !AgentFactory::available_versions().contains(&agent_version.as_str()) {
eprintln!("Error: Invalid agent version '{}'", agent_version);
eprintln!("Available versions:");
for version in AgentFactory::available_versions() {
if version == AgentFactory::default_version() {
eprintln!("* {} (default)", version);
} else {
eprintln!(" {}", version);
}
}
std::process::exit(1);
}
}

let contents = if let Some(file_name) = instructions {
let file_path = std::path::Path::new(&file_name);
std::fs::read_to_string(file_path).expect("Failed to read the instruction file")
Expand All @@ -249,10 +299,14 @@ async fn main() -> Result<()> {
.expect("Failed to read from stdin");
stdin
};
let mut session = build_session(name, profile, resume);
let mut session = build_session(name, profile, agent, resume);
let _ = session.headless_start(contents.clone()).await;
return Ok(());
}
Some(Command::Agents(cmd)) => {
cmd.run()?;
return Ok(());
}
None => {
println!("No command provided - Run 'goose help' to see available commands.");
if has_no_profiles().unwrap_or(false) {
Expand Down
8 changes: 5 additions & 3 deletions crates/goose-cli/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ use std::fs::{self, File};
use std::io::{self, BufRead, Write};
use std::path::PathBuf;

use crate::agents::agent::Agent;
// use crate::agents::agent::Agent;
use crate::log_usage::log_usage;
use crate::prompt::{InputType, Prompt};
use goose::agents::Agent;
use goose::developer::DeveloperSystem;
use goose::message::{Message, MessageContent};
use goose::systems::goose_hints::GooseHintsSystem;
Expand Down Expand Up @@ -101,6 +102,7 @@ pub struct Session<'a> {
messages: Vec<Message>,
}

#[allow(dead_code)]
impl<'a> Session<'a> {
pub fn new(
agent: Box<dyn Agent>,
Expand Down Expand Up @@ -361,14 +363,14 @@ mod tests {
// Helper function to create a test session
fn create_test_session() -> Session<'static> {
let temp_file = NamedTempFile::new().unwrap();
let agent = Box::new(MockAgent {});
let agent = Box::new(MockAgent::new());
let prompt = Box::new(MockPrompt::new());
Session::new(agent, prompt, temp_file.path().to_path_buf())
}

fn create_test_session_with_prompt<'a>(prompt: Box<dyn Prompt + 'a>) -> Session<'a> {
let temp_file = NamedTempFile::new().unwrap();
let agent = Box::new(MockAgent {});
let agent = Box::new(MockAgent::new());
Session::new(agent, prompt, temp_file.path().to_path_buf())
}

Expand Down
6 changes: 6 additions & 0 deletions crates/goose-server/src/configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ pub struct Settings {
#[serde(default)]
pub server: ServerSettings,
pub provider: ProviderSettings,
#[serde(default = "default_agent_version")]
pub agent_version: Option<String>,
}

fn default_agent_version() -> Option<String> {
None // Will use AgentFactory::default_version() when None
}

impl Settings {
Expand Down
Loading
Loading