From 5aef01ca67dcb46701894094b41bedcb972d1b4f Mon Sep 17 00:00:00 2001 From: Zaki Ali Date: Thu, 26 Dec 2024 13:42:32 -0800 Subject: [PATCH 01/11] Update agents with core traits to enable versioned agents --- crates/goose/src/agent.rs | 781 ------------------------------- crates/goose/src/agents/agent.rs | 396 ++++++++++++++++ crates/goose/src/agents/base.rs | 395 ++++++++++++++++ crates/goose/src/agents/mod.rs | 7 + crates/goose/src/agents/v1.rs | 91 ++++ crates/goose/src/lib.rs | 2 +- 6 files changed, 890 insertions(+), 782 deletions(-) delete mode 100644 crates/goose/src/agent.rs create mode 100644 crates/goose/src/agents/agent.rs create mode 100644 crates/goose/src/agents/base.rs create mode 100644 crates/goose/src/agents/mod.rs create mode 100644 crates/goose/src/agents/v1.rs diff --git a/crates/goose/src/agent.rs b/crates/goose/src/agent.rs deleted file mode 100644 index 5c02bf3be..000000000 --- a/crates/goose/src/agent.rs +++ /dev/null @@ -1,781 +0,0 @@ -use anyhow::Result; -use async_stream; -use futures::stream::BoxStream; -use rust_decimal_macros::dec; -use serde_json::json; -use std::collections::HashMap; -use tokio::sync::Mutex; - -use crate::errors::{AgentError, AgentResult}; -use crate::message::{Message, ToolRequest}; -use crate::prompt_template::load_prompt_file; -use crate::providers::base::{Provider, ProviderUsage}; -use crate::systems::System; -use crate::token_counter::TokenCounter; -use mcp_core::{Content, Resource, Tool, ToolCall}; -use serde::Serialize; - -// used to sort resources by priority within error margin -const PRIORITY_EPSILON: f32 = 0.001; - -#[derive(Clone, Debug, Serialize)] -struct SystemInfo { - name: String, - description: String, - instructions: String, -} - -impl SystemInfo { - fn new(name: &str, description: &str, instructions: &str) -> Self { - Self { - name: name.to_string(), - description: description.to_string(), - instructions: instructions.to_string(), - } - } -} - -#[derive(Clone, Debug, Serialize)] -struct SystemStatus { - name: String, - status: String, -} - -impl SystemStatus { - fn new(name: &str, status: String) -> Self { - Self { - name: name.to_string(), - status, - } - } -} - -/// Agent integrates a foundational LLM with the systems it needs to pilot -pub struct Agent { - systems: Vec>, - provider: Box, - provider_usage: Mutex>, -} - -#[allow(dead_code)] -impl Agent { - /// Create a new Agent with the specified provider - pub fn new(provider: Box) -> Self { - Self { - systems: Vec::new(), - provider, - provider_usage: Mutex::new(Vec::new()), - } - } - - /// Add a system to the agent - pub fn add_system(&mut self, system: Box) { - self.systems.push(system); - } - - /// Get the context limit from the provider's configuration - fn get_context_limit(&self) -> usize { - self.provider.get_model_config().context_limit() - } - - /// Get all tools from all systems with proper system prefixing - fn get_prefixed_tools(&self) -> Vec { - let mut tools = Vec::new(); - for system in &self.systems { - for tool in system.tools() { - tools.push(Tool::new( - format!("{}__{}", system.name(), tool.name), - &tool.description, - tool.input_schema.clone(), - )); - } - } - tools - } - - /// Find the appropriate system for a tool call based on the prefixed name - fn get_system_for_tool(&self, prefixed_name: &str) -> Option<&dyn System> { - let parts: Vec<&str> = prefixed_name.split("__").collect(); - if parts.len() != 2 { - return None; - } - let system_name = parts[0]; - self.systems - .iter() - .find(|sys| sys.name() == system_name) - .map(|v| &**v) - } - - /// Dispatch a single tool call to the appropriate system - async fn dispatch_tool_call( - &self, - tool_call: AgentResult, - ) -> AgentResult> { - let call = tool_call?; - let system = self - .get_system_for_tool(&call.name) - .ok_or_else(|| AgentError::ToolNotFound(call.name.clone()))?; - - let tool_name = call - .name - .split("__") - .nth(1) - .ok_or_else(|| AgentError::InvalidToolName(call.name.clone()))?; - let system_tool_call = ToolCall::new(tool_name, call.arguments); - - system.call(system_tool_call).await - } - - fn get_system_prompt(&self) -> AgentResult { - let mut context = HashMap::new(); - let systems_info: Vec = self - .systems - .iter() - .map(|system| { - SystemInfo::new(system.name(), system.description(), system.instructions()) - }) - .collect(); - - context.insert("systems", systems_info); - load_prompt_file("system.md", &context).map_err(|e| AgentError::Internal(e.to_string())) - } - - async fn get_systems_resources( - &self, - ) -> AgentResult>> { - let mut system_resource_content: HashMap> = - HashMap::new(); - for system in &self.systems { - let system_status = system - .status() - .await - .map_err(|e| AgentError::Internal(e.to_string()))?; - - let mut resource_content: HashMap = HashMap::new(); - for resource in system_status { - if let Ok(content) = system.read_resource(&resource.uri).await { - resource_content.insert(resource.uri.to_string(), (resource, content)); - } - } - system_resource_content.insert(system.name().to_string(), resource_content); - } - Ok(system_resource_content) - } - - /// Setup the next inference by budgeting the context window as well as we can - async fn prepare_inference( - &self, - system_prompt: &str, - tools: &[Tool], - messages: &[Message], - pending: &Vec, - target_limit: usize, - ) -> AgentResult> { - // Prepares the inference by managing context window and token budget. - // This function: - // 1. Retrieves and formats system resources and status - // 2. Trims content if total tokens exceed the model's context limit - // 3. Adds pending messages if any. Pending messages are messages that have been added - // to the conversation but not yet responded to. - // 4. Adds two messages to the conversation: - // - A tool request message for status - // - A tool response message containing the (potentially trimmed) status - // - // Returns the updated message history with status information appended. - // - // Arguments: - // * `system_prompt` - The system prompt to include - // * `tools` - Available tools for the agent - // * `messages` - Current conversation history - // - // Returns: - // * `AgentResult>` - Updated message history with status appended - - let token_counter = TokenCounter::new(); - let resource_content = self.get_systems_resources().await?; - - // Flatten all resource content into a vector of strings - let mut resources = Vec::new(); - for system_resources in resource_content.values() { - for (_, content) in system_resources.values() { - resources.push(content.clone()); - } - } - - let approx_count = token_counter.count_everything( - system_prompt, - messages, - tools, - &resources, - Some(&self.provider.get_model_config().model_name), - ); - - let mut status_content: Vec = Vec::new(); - - if approx_count > target_limit { - println!("[WARNING] Token budget exceeded. Current count: {} \n Difference: {} tokens over buget. Removing context", approx_count, approx_count - target_limit); - - // Get token counts for each resourcee - let mut system_token_counts = HashMap::new(); - - // Iterate through each system and its resources - for (system_name, resources) in &resource_content { - let mut resource_counts = HashMap::new(); - for (uri, (_resource, content)) in resources { - let token_count = token_counter - .count_tokens(&content, Some(&self.provider.get_model_config().model_name)) - as u32; - resource_counts.insert(uri.clone(), token_count); - } - system_token_counts.insert(system_name.clone(), resource_counts); - } - // Sort resources by priority and timestamp and trim to fit context limit - let mut all_resources: Vec<(String, String, Resource, u32)> = Vec::new(); - for (system_name, resources) in &resource_content { - for (uri, (resource, _)) in resources { - if let Some(token_count) = system_token_counts - .get(system_name) - .and_then(|counts| counts.get(uri)) - { - all_resources.push(( - system_name.clone(), - uri.clone(), - resource.clone(), - *token_count, - )); - } - } - } - - // Sort by priority (high to low) and timestamp (newest to oldest) - // since priority is float, we need to sort by priority within error margin - PRIORITY_EPSILON - all_resources.sort_by(|a, b| { - // Compare priorities with epsilon - // Compare priorities with Option handling - default to 0.0 if None - let a_priority = a.2.priority().unwrap_or(0.0); - let b_priority = b.2.priority().unwrap_or(0.0); - if (b_priority - a_priority).abs() < PRIORITY_EPSILON { - // Priorities are "equal" within epsilon, use timestamp as tiebreaker - b.2.timestamp().cmp(&a.2.timestamp()) - } else { - // Priorities are different enough, use priority ordering - b.2.priority() - .partial_cmp(&a.2.priority()) - .unwrap_or(std::cmp::Ordering::Equal) - } - }); - - // Remove resources until we're under target limit - let mut current_tokens = approx_count; - - while current_tokens > target_limit && !all_resources.is_empty() { - if let Some((system_name, uri, _, token_count)) = all_resources.pop() { - if let Some(system_counts) = system_token_counts.get_mut(&system_name) { - system_counts.remove(&uri); - current_tokens -= token_count as usize; - } - } - } - // Create status messages only from resources that remain after token trimming - for (system_name, uri, _, _) in &all_resources { - if let Some(system_resources) = resource_content.get(system_name) { - if let Some((resource, content)) = system_resources.get(uri) { - status_content.push(format!("{}\n```\n{}\n```\n", resource.name, content)); - } - } - } - } else { - // Create status messages from all resources when no trimming needed - for resources in resource_content.values() { - for (resource, content) in resources.values() { - status_content.push(format!("{}\n```\n{}\n```\n", resource.name, content)); - } - } - } - - // Join remaining status content and create status message - let status_str = status_content.join("\n"); - let mut context = HashMap::new(); - let systems_status = vec![SystemStatus::new("system", status_str)]; - context.insert("systems", &systems_status); - - // Load and format the status template with only remaining resources - let status = load_prompt_file("status.md", &context) - .map_err(|e| AgentError::Internal(e.to_string()))?; - - // Create a new messages vector with our changes - let mut new_messages = messages.to_vec(); - - // Add pending messages - for msg in pending { - new_messages.push(msg.clone()); - } - - // Finally add the status messages - let message_use = - Message::assistant().with_tool_request("000", Ok(ToolCall::new("status", json!({})))); - - let message_result = - Message::user().with_tool_response("000", Ok(vec![Content::text(status)])); - - new_messages.push(message_use); - new_messages.push(message_result); - - Ok(new_messages) - } - - /// Create a stream that yields each message as it's generated by the agent. - /// This includes both the assistant's responses and any tool responses. - pub async fn reply(&self, messages: &[Message]) -> Result>> { - let mut messages = messages.to_vec(); - let tools = self.get_prefixed_tools(); - let system_prompt = self.get_system_prompt()?; - let estimated_limit = self.provider.get_model_config().get_estimated_limit(); - - // Update conversation history for the start of the reply - messages = self - .prepare_inference( - &system_prompt, - &tools, - &messages, - &Vec::new(), - estimated_limit, - ) - .await?; - - Ok(Box::pin(async_stream::try_stream! { - loop { - // Get completion from provider - let (response, usage) = self.provider.complete( - &system_prompt, - &messages, - &tools, - ).await?; - self.provider_usage.lock().await.push(usage); - - // The assistant's response is added in rewrite_messages_on_tool_response - // Yield the assistant's response - yield response.clone(); - - // Not sure why this is needed, but this ensures that the above message is yielded - // before the following potentially long-running commands start processing - tokio::task::yield_now().await; - - // First collect any tool requests - let tool_requests: Vec<&ToolRequest> = response.content - .iter() - .filter_map(|content| content.as_tool_request()) - .collect(); - - if tool_requests.is_empty() { - // No more tool calls, end the reply loop - break; - } - - // Then dispatch each in parallel - let futures: Vec<_> = tool_requests - .iter() - .map(|request| self.dispatch_tool_call(request.tool_call.clone())) - .collect(); - - // Process all the futures in parallel but wait until all are finished - let outputs = futures::future::join_all(futures).await; - - // Create a message with the responses - let mut message_tool_response = Message::user(); - // Now combine these into MessageContent::ToolResponse using the original ID - for (request, output) in tool_requests.iter().zip(outputs.into_iter()) { - message_tool_response = message_tool_response.with_tool_response( - request.id.clone(), - output, - ); - } - - yield message_tool_response.clone(); - - // Now we have to remove the previous status tooluse and toolresponse - // before we add pending messages, then the status msgs back again - messages.pop(); - messages.pop(); - - let pending = vec![response, message_tool_response]; - messages = self.prepare_inference(&system_prompt, &tools, &messages, &pending, estimated_limit).await?; - } - })) - } - - pub async fn usage(&self) -> Result> { - let provider_usage = self.provider_usage.lock().await.clone(); - - let mut usage_map: HashMap = HashMap::new(); - provider_usage.iter().for_each(|usage| { - usage_map - .entry(usage.model.clone()) - .and_modify(|e| { - e.usage.input_tokens = Some( - e.usage.input_tokens.unwrap_or(0) + usage.usage.input_tokens.unwrap_or(0), - ); - e.usage.output_tokens = Some( - e.usage.output_tokens.unwrap_or(0) + usage.usage.output_tokens.unwrap_or(0), - ); - e.usage.total_tokens = Some( - e.usage.total_tokens.unwrap_or(0) + usage.usage.total_tokens.unwrap_or(0), - ); - if e.cost.is_none() || usage.cost.is_none() { - e.cost = None; // Pricing is not available for all models - } else { - e.cost = Some(e.cost.unwrap_or(dec!(0)) + usage.cost.unwrap_or(dec!(0))); - } - }) - .or_insert_with(|| usage.clone()); - }); - Ok(usage_map.into_values().collect()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::message::MessageContent; - use crate::providers::configs::ModelConfig; - use crate::providers::mock::MockProvider; - use async_trait::async_trait; - use chrono::Utc; - use futures::TryStreamExt; - use mcp_core::resource::Resource; - use mcp_core::Annotations; - use serde_json::json; - - // Mock system for testing - struct MockSystem { - name: String, - tools: Vec, - resources: Vec, - resource_content: HashMap, - } - - impl MockSystem { - fn new(name: &str) -> Self { - Self { - name: name.to_string(), - tools: vec![Tool::new( - "echo", - "Echoes back the input", - json!({"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}), - )], - resources: Vec::new(), - resource_content: HashMap::new(), - } - } - - fn add_resource(&mut self, name: &str, content: &str, priority: f32) { - let uri = format!("file://{}", name); - let resource = Resource { - name: name.to_string(), - uri: uri.clone(), - annotations: Some(Annotations::for_resource(priority, Utc::now())), - description: Some("A mock resource".to_string()), - mime_type: "text/plain".to_string(), - }; - self.resources.push(resource); - self.resource_content.insert(uri, content.to_string()); - } - } - - #[async_trait] - impl System for MockSystem { - fn name(&self) -> &str { - &self.name - } - - fn description(&self) -> &str { - "A mock system for testing" - } - - fn instructions(&self) -> &str { - "Mock system instructions" - } - - fn tools(&self) -> &[Tool] { - &self.tools - } - - async fn status(&self) -> anyhow::Result> { - Ok(self.resources.clone()) - } - - async fn call(&self, tool_call: ToolCall) -> AgentResult> { - match tool_call.name.as_str() { - "echo" => Ok(vec![Content::text( - tool_call.arguments["message"].as_str().unwrap_or(""), - )]), - _ => Err(AgentError::ToolNotFound(tool_call.name)), - } - } - - async fn read_resource(&self, uri: &str) -> AgentResult { - self.resource_content.get(uri).cloned().ok_or_else(|| { - AgentError::InvalidParameters(format!("Resource {} could not be found", uri)) - }) - } - } - - #[tokio::test] - async fn test_simple_response() -> Result<()> { - let response = Message::assistant().with_text("Hello!"); - let provider = MockProvider::new(vec![response.clone()]); - let agent = Agent::new(Box::new(provider)); - - let initial_message = Message::user().with_text("Hi"); - let initial_messages = vec![initial_message]; - - let mut stream = agent.reply(&initial_messages).await?; - let mut messages = Vec::new(); - while let Some(msg) = stream.try_next().await? { - messages.push(msg); - } - - assert_eq!(messages.len(), 1); - assert_eq!(messages[0], response); - Ok(()) - } - - #[tokio::test] - async fn test_usage_rollup() -> Result<()> { - let response = Message::assistant().with_text("Hello!"); - let provider = MockProvider::new(vec![response.clone()]); - let agent = Agent::new(Box::new(provider)); - - let initial_message = Message::user().with_text("Hi"); - let initial_messages = vec![initial_message]; - - let mut stream = agent.reply(&initial_messages).await?; - while stream.try_next().await?.is_some() {} - - // Second message - let mut stream = agent.reply(&initial_messages).await?; - while stream.try_next().await?.is_some() {} - - let usage = agent.usage().await?; - assert_eq!(usage.len(), 1); // 2 messages rolled up to one usage per model - assert_eq!(usage[0].usage.input_tokens, Some(2)); - assert_eq!(usage[0].usage.output_tokens, Some(2)); - assert_eq!(usage[0].usage.total_tokens, Some(4)); - assert_eq!(usage[0].model, "mock"); - assert_eq!(usage[0].cost, Some(dec!(2))); - Ok(()) - } - - #[tokio::test] - async fn test_tool_call() -> Result<()> { - let mut agent = Agent::new(Box::new(MockProvider::new(vec![ - Message::assistant().with_tool_request( - "1", - Ok(ToolCall::new("test_echo", json!({"message": "test"}))), - ), - Message::assistant().with_text("Done!"), - ]))); - - agent.add_system(Box::new(MockSystem::new("test"))); - - let initial_message = Message::user().with_text("Echo test"); - let initial_messages = vec![initial_message]; - - let mut stream = agent.reply(&initial_messages).await?; - let mut messages = Vec::new(); - while let Some(msg) = stream.try_next().await? { - messages.push(msg); - } - - // Should have three messages: tool request, response, and model text - assert_eq!(messages.len(), 3); - assert!(messages[0] - .content - .iter() - .any(|c| matches!(c, MessageContent::ToolRequest(_)))); - assert_eq!(messages[2].content[0], MessageContent::text("Done!")); - Ok(()) - } - - #[tokio::test] - async fn test_invalid_tool() -> Result<()> { - let mut agent = Agent::new(Box::new(MockProvider::new(vec![ - Message::assistant() - .with_tool_request("1", Ok(ToolCall::new("invalid_tool", json!({})))), - Message::assistant().with_text("Error occurred"), - ]))); - - agent.add_system(Box::new(MockSystem::new("test"))); - - let initial_message = Message::user().with_text("Invalid tool"); - let initial_messages = vec![initial_message]; - - let mut stream = agent.reply(&initial_messages).await?; - let mut messages = Vec::new(); - while let Some(msg) = stream.try_next().await? { - messages.push(msg); - } - - // Should have three messages: failed tool request, fail response, and model text - assert_eq!(messages.len(), 3); - assert!(messages[0] - .content - .iter() - .any(|c| matches!(c, MessageContent::ToolRequest(_)))); - assert_eq!( - messages[2].content[0], - MessageContent::text("Error occurred") - ); - Ok(()) - } - - #[tokio::test] - async fn test_multiple_tool_calls() -> Result<()> { - let mut agent = Agent::new(Box::new(MockProvider::new(vec![ - Message::assistant() - .with_tool_request( - "1", - Ok(ToolCall::new("test_echo", json!({"message": "first"}))), - ) - .with_tool_request( - "2", - Ok(ToolCall::new("test_echo", json!({"message": "second"}))), - ), - Message::assistant().with_text("All done!"), - ]))); - - agent.add_system(Box::new(MockSystem::new("test"))); - - let initial_message = Message::user().with_text("Multiple calls"); - let initial_messages = vec![initial_message]; - - let mut stream = agent.reply(&initial_messages).await?; - let mut messages = Vec::new(); - while let Some(msg) = stream.try_next().await? { - messages.push(msg); - } - - // Should have three messages: tool requests, responses, and model text - assert_eq!(messages.len(), 3); - assert!(messages[0] - .content - .iter() - .any(|c| matches!(c, MessageContent::ToolRequest(_)))); - assert_eq!(messages[2].content[0], MessageContent::text("All done!")); - Ok(()) - } - - #[tokio::test] - async fn test_prepare_inference_trims_resources_when_budget_exceeded() -> Result<()> { - // Create a mock provider - let provider = MockProvider::new(vec![]); - let mut agent = Agent::new(Box::new(provider)); - - // Create a mock system with two resources - let mut system = MockSystem::new("test"); - - // Add two resources with different priorities - let string_10toks = "hello ".repeat(10); - system.add_resource("high_priority", &string_10toks, 0.8); - system.add_resource("low_priority", &string_10toks, 0.1); - - agent.add_system(Box::new(system)); - - // Set up test parameters - // 18 tokens with system + user msg in chat format - let system_prompt = "This is a system prompt"; - let messages = vec![Message::user().with_text("Hi there")]; - let tools = vec![]; - let pending = vec![]; - - // Approx count is 40, so target limit of 35 will force trimming - let target_limit = 35; - - // Call prepare_inference - let result = agent - .prepare_inference(system_prompt, &tools, &messages, &pending, target_limit) - .await?; - - // Get the last message which should be the tool response containing status - let status_message = result.last().unwrap(); - let status_content = status_message - .content - .first() - .and_then(|content| content.as_tool_response_text()) - .unwrap_or_default(); - - // Verify that only the high priority resource is included in the status - assert!(status_content.contains("high_priority")); - assert!(!status_content.contains("low_priority")); - - // Now test with a target limit that allows both resources (no trimming) - let target_limit = 100; - - // Call prepare_inference - let result = agent - .prepare_inference(system_prompt, &tools, &messages, &pending, target_limit) - .await?; - - // Get the last message which should be the tool response containing status - let status_message = result.last().unwrap(); - let status_content = status_message - .content - .first() - .and_then(|content| content.as_tool_response_text()) - .unwrap_or_default(); - - // Verify that only the high priority resource is included in the status - assert!(status_content.contains("high_priority")); - assert!(status_content.contains("low_priority")); - Ok(()) - } - - #[tokio::test] - async fn test_context_trimming_with_custom_model_config() -> Result<()> { - let provider = MockProvider::with_config( - vec![], - ModelConfig::new("test_model".to_string()).with_context_limit(Some(20)), - ); - let mut agent = Agent::new(Box::new(provider)); - - // Create a mock system with a resource that will exceed the context limit - let mut system = MockSystem::new("test"); - - // Add a resource that will exceed our tiny context limit - let hello_1_tokens = "hello ".repeat(1); // 1 tokens - let goodbye_10_tokens = "goodbye ".repeat(10); // 10 tokens - system.add_resource("test_resource_removed", &goodbye_10_tokens, 0.1); - system.add_resource("test_resource_expected", &hello_1_tokens, 0.5); - - agent.add_system(Box::new(system)); - - // Set up test parameters - // 18 tokens with system + user msg in chat format - let system_prompt = "This is a system prompt"; - let messages = vec![Message::user().with_text("Hi there")]; - let tools = vec![]; - let pending = vec![]; - - // Use the context limit from the model config - let target_limit = agent.get_context_limit(); - assert_eq!(target_limit, 20, "Context limit should be 20"); - - // Call prepare_inference - let result = agent - .prepare_inference(system_prompt, &tools, &messages, &pending, target_limit) - .await?; - - // Get the last message which should be the tool response containing status - let status_message = result.last().unwrap(); - let status_content = status_message - .content - .first() - .and_then(|content| content.as_tool_response_text()) - .unwrap_or_default(); - - // verify that "hello" is within the response, should be just under 20 tokens with "hello" - assert!(status_content.contains("hello")); - - Ok(()) - } -} diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs new file mode 100644 index 000000000..d141806ec --- /dev/null +++ b/crates/goose/src/agents/agent.rs @@ -0,0 +1,396 @@ +use anyhow::Result; +use async_stream; +use async_trait::async_trait; +use futures::stream::BoxStream; +use rust_decimal_macros::dec; +use serde_json::json; +use std::collections::HashMap; +use tokio::sync::Mutex; + +use crate::errors::{AgentError, AgentResult}; +use crate::message::{Message, ToolRequest}; +use crate::prompt_template::load_prompt_file; +use crate::providers::base::{Provider, ProviderUsage}; +use crate::systems::System; +use crate::token_counter::TokenCounter; +use mcp_core::{Content, Resource, Tool, ToolCall}; +use serde::Serialize; + +// used to sort resources by priority within error margin +const PRIORITY_EPSILON: f32 = 0.001; + +#[derive(Clone, Debug, Serialize)] +struct SystemInfo { + name: String, + description: String, + instructions: String, +} + +impl SystemInfo { + fn new(name: &str, description: &str, instructions: &str) -> Self { + Self { + name: name.to_string(), + description: description.to_string(), + instructions: instructions.to_string(), + } + } +} + +#[derive(Clone, Debug, Serialize)] +struct SystemStatus { + name: String, + status: String, +} + +impl SystemStatus { + fn new(name: &str, status: String) -> Self { + Self { + name: name.to_string(), + status, + } + } +} + +/// Core trait defining the behavior of an Agent +#[async_trait] +pub trait Agent: Send + Sync { + /// Get all tools from all systems with proper system prefixing + fn get_prefixed_tools(&self) -> Vec { + let mut tools = Vec::new(); + for system in self.get_systems() { + for tool in system.tools() { + tools.push(Tool::new( + format!("{}__{}", system.name(), tool.name), + &tool.description, + tool.input_schema.clone(), + )); + } + } + tools + } + + // add a system to the agent + fn add_system(&mut self, system: Box); + + /// Get the systems this agent has access to + fn get_systems(&self) -> &Vec>; + + /// Get the provider for this agent + fn get_provider(&self) -> &Box; + + /// Get the provider usage statistics + fn get_provider_usage(&self) -> &Mutex>; + + /// Setup the next inference by budgeting the context window + async fn prepare_inference( + &self, + system_prompt: &str, + tools: &[Tool], + messages: &[Message], + pending: &[Message], + target_limit: usize, + ) -> AgentResult> { + // Default implementation for prepare_inference + let token_counter = TokenCounter::new(); + let resource_content = self.get_systems_resources().await?; + + // Flatten all resource content into a vector of strings + let mut resources = Vec::new(); + for system_resources in resource_content.values() { + for (_, content) in system_resources.values() { + resources.push(content.clone()); + } + } + + let approx_count = token_counter.count_everything( + system_prompt, + messages, + tools, + &resources, + Some(&self.get_provider().get_model_config().model_name), + ); + + let mut status_content: Vec = Vec::new(); + + if approx_count > target_limit { + println!("[WARNING] Token budget exceeded. Current count: {} \n Difference: {} tokens over buget. Removing context", approx_count, approx_count - target_limit); + + // Get token counts for each resource + let mut system_token_counts = HashMap::new(); + + // Iterate through each system and its resources + for (system_name, resources) in &resource_content { + let mut resource_counts = HashMap::new(); + for (uri, (_resource, content)) in resources { + let token_count = token_counter + .count_tokens(&content, Some(&self.get_provider().get_model_config().model_name)) + as u32; + resource_counts.insert(uri.clone(), token_count); + } + system_token_counts.insert(system_name.clone(), resource_counts); + } + // Sort resources by priority and timestamp and trim to fit context limit + let mut all_resources: Vec<(String, String, Resource, u32)> = Vec::new(); + for (system_name, resources) in &resource_content { + for (uri, (resource, _)) in resources { + if let Some(token_count) = system_token_counts + .get(system_name) + .and_then(|counts| counts.get(uri)) + { + all_resources.push(( + system_name.clone(), + uri.clone(), + resource.clone(), + *token_count, + )); + } + } + } + + // Sort by priority (high to low) and timestamp (newest to oldest) + all_resources.sort_by(|a, b| { + let a_priority = a.2.priority().unwrap_or(0.0); + let b_priority = b.2.priority().unwrap_or(0.0); + if (b_priority - a_priority).abs() < PRIORITY_EPSILON { + b.2.timestamp().cmp(&a.2.timestamp()) + } else { + b.2.priority() + .partial_cmp(&a.2.priority()) + .unwrap_or(std::cmp::Ordering::Equal) + } + }); + + // Remove resources until we're under target limit + let mut current_tokens = approx_count; + + while current_tokens > target_limit && !all_resources.is_empty() { + if let Some((system_name, uri, _, token_count)) = all_resources.pop() { + if let Some(system_counts) = system_token_counts.get_mut(&system_name) { + system_counts.remove(&uri); + current_tokens -= token_count as usize; + } + } + } + // Create status messages only from resources that remain after token trimming + for (system_name, uri, _, _) in &all_resources { + if let Some(system_resources) = resource_content.get(system_name) { + if let Some((resource, content)) = system_resources.get(uri) { + status_content.push(format!("{}\n```\n{}\n```\n", resource.name, content)); + } + } + } + } else { + // Create status messages from all resources when no trimming needed + for resources in resource_content.values() { + for (resource, content) in resources.values() { + status_content.push(format!("{}\n```\n{}\n```\n", resource.name, content)); + } + } + } + + // Join remaining status content and create status message + let status_str = status_content.join("\n"); + let mut context = HashMap::new(); + let systems_status = vec![SystemStatus::new("system", status_str)]; + context.insert("systems", &systems_status); + + // Load and format the status template with only remaining resources + let status = load_prompt_file("status.md", &context) + .map_err(|e| AgentError::Internal(e.to_string()))?; + + // Create a new messages vector with our changes + let mut new_messages = messages.to_vec(); + + // Add pending messages + for msg in pending { + new_messages.push(msg.clone()); + } + + // Finally add the status messages + let message_use = + Message::assistant().with_tool_request("000", Ok(ToolCall::new("status", json!({})))); + + let message_result = + Message::user().with_tool_response("000", Ok(vec![Content::text(status)])); + + new_messages.push(message_use); + new_messages.push(message_result); + + Ok(new_messages) + } + + /// Create a stream that yields each message as it's generated + async fn reply(&self, messages: &[Message]) -> Result>> { + let mut messages = messages.to_vec(); + let tools = self.get_prefixed_tools(); + let system_prompt = self.get_system_prompt()?; + let estimated_limit = self.get_provider().get_model_config().get_estimated_limit(); + + // Update conversation history for the start of the reply + messages = self + .prepare_inference( + &system_prompt, + &tools, + &messages, + &Vec::new(), + estimated_limit, + ) + .await?; + + Ok(Box::pin(async_stream::try_stream! { + loop { + // Get completion from provider + let (response, usage) = self.get_provider().complete( + &system_prompt, + &messages, + &tools, + ).await?; + self.get_provider_usage().lock().await.push(usage); + + // Yield the assistant's response + yield response.clone(); + + tokio::task::yield_now().await; + + // First collect any tool requests + let tool_requests: Vec<&ToolRequest> = response.content + .iter() + .filter_map(|content| content.as_tool_request()) + .collect(); + + if tool_requests.is_empty() { + break; + } + + // Then dispatch each in parallel + let futures: Vec<_> = tool_requests + .iter() + .map(|request| self.dispatch_tool_call(request.tool_call.clone())) + .collect(); + + // Process all the futures in parallel but wait until all are finished + let outputs = futures::future::join_all(futures).await; + + // Create a message with the responses + let mut message_tool_response = Message::user(); + // Now combine these into MessageContent::ToolResponse using the original ID + for (request, output) in tool_requests.iter().zip(outputs.into_iter()) { + message_tool_response = message_tool_response.with_tool_response( + request.id.clone(), + output, + ); + } + + yield message_tool_response.clone(); + + // Now we have to remove the previous status tooluse and toolresponse + // before we add pending messages, then the status msgs back again + messages.pop(); + messages.pop(); + + let pending = vec![response, message_tool_response]; + messages = self.prepare_inference(&system_prompt, &tools, &messages, &pending, estimated_limit).await?; + } + })) + } + + /// Get usage statistics + async fn usage(&self) -> Result> { + let provider_usage = self.get_provider_usage().lock().await.clone(); + let mut usage_map: HashMap = HashMap::new(); + + provider_usage.iter().for_each(|usage| { + usage_map + .entry(usage.model.clone()) + .and_modify(|e| { + e.usage.input_tokens = Some( + e.usage.input_tokens.unwrap_or(0) + usage.usage.input_tokens.unwrap_or(0), + ); + e.usage.output_tokens = Some( + e.usage.output_tokens.unwrap_or(0) + usage.usage.output_tokens.unwrap_or(0), + ); + e.usage.total_tokens = Some( + e.usage.total_tokens.unwrap_or(0) + usage.usage.total_tokens.unwrap_or(0), + ); + if e.cost.is_none() || usage.cost.is_none() { + e.cost = None; // Pricing is not available for all models + } else { + e.cost = Some(e.cost.unwrap_or(dec!(0)) + usage.cost.unwrap_or(dec!(0))); + } + }) + .or_insert_with(|| usage.clone()); + }); + Ok(usage_map.into_values().collect()) + } + + /// Get system resources and their contents + async fn get_systems_resources( + &self, + ) -> AgentResult>> { + let mut system_resource_content = HashMap::new(); + for system in self.get_systems() { + let system_status = system + .status() + .await + .map_err(|e| AgentError::Internal(e.to_string()))?; + + let mut resource_content = HashMap::new(); + for resource in system_status { + if let Ok(content) = system.read_resource(&resource.uri).await { + resource_content.insert(resource.uri.to_string(), (resource, content)); + } + } + system_resource_content.insert(system.name().to_string(), resource_content); + } + Ok(system_resource_content) + } + + /// Get the system prompt + fn get_system_prompt(&self) -> AgentResult { + let mut context = HashMap::new(); + let systems_info: Vec = self + .get_systems() + .iter() + .map(|system| { + SystemInfo::new(system.name(), system.description(), system.instructions()) + }) + .collect(); + + context.insert("systems", systems_info); + load_prompt_file("system.md", &context) + .map_err(|e| AgentError::Internal(e.to_string())) + } + + /// Find the appropriate system for a tool call based on the prefixed name + fn get_system_for_tool(&self, prefixed_name: &str) -> Option<&dyn System> { + let parts: Vec<&str> = prefixed_name.split("__").collect(); + if parts.len() != 2 { + return None; + } + let system_name = parts[0]; + self.get_systems() + .iter() + .find(|sys| sys.name() == system_name) + .map(|v| &**v) + } + + /// Dispatch a single tool call to the appropriate system + async fn dispatch_tool_call( + &self, + tool_call: AgentResult, + ) -> AgentResult> { + let call = tool_call?; + let system = self + .get_system_for_tool(&call.name) + .ok_or_else(|| AgentError::ToolNotFound(call.name.clone()))?; + + let tool_name = call + .name + .split("__") + .nth(1) + .ok_or_else(|| AgentError::InvalidToolName(call.name.clone()))?; + let system_tool_call = ToolCall::new(tool_name, call.arguments); + + system.call(system_tool_call).await + } +} \ No newline at end of file diff --git a/crates/goose/src/agents/base.rs b/crates/goose/src/agents/base.rs new file mode 100644 index 000000000..c2d9cea39 --- /dev/null +++ b/crates/goose/src/agents/base.rs @@ -0,0 +1,395 @@ +use async_trait::async_trait; +use tokio::sync::Mutex; + +use super::Agent; +use crate::providers::base::{Provider, ProviderUsage}; +use crate::systems::System; + +/// Base implementation of an Agent +pub struct BaseAgent { + systems: Vec>, + provider: Box, + provider_usage: Mutex>, +} + +impl BaseAgent { + pub fn new(provider: Box) -> Self { + Self { + systems: Vec::new(), + provider, + provider_usage: Mutex::new(Vec::new()), + } + } + + pub fn add_system(&mut self, system: Box) { + self.systems.push(system); + } +} + +#[async_trait] +impl Agent for BaseAgent { + fn add_system(&mut self, system: Box) { + self.systems.push(system); + } + + fn get_systems(&self) -> &Vec> { + &self.systems + } + + fn get_provider(&self) -> &Box { + &self.provider + } + + fn get_provider_usage(&self) -> &Mutex> { + &self.provider_usage + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::message::{Message, MessageContent}; + use crate::providers::configs::ModelConfig; + use crate::providers::mock::MockProvider; + use async_trait::async_trait; + use chrono::Utc; + use futures::TryStreamExt; + use mcp_core::resource::Resource; + use mcp_core::{Annotations, Content, Tool, ToolCall}; + use rust_decimal_macros::dec; + use serde_json::json; + use std::collections::HashMap; + + // Mock system for testing + struct MockSystem { + name: String, + tools: Vec, + resources: Vec, + resource_content: HashMap, + } + + impl MockSystem { + fn new(name: &str) -> Self { + Self { + name: name.to_string(), + tools: vec![Tool::new( + "echo", + "Echoes back the input", + json!({"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}), + )], + resources: Vec::new(), + resource_content: HashMap::new(), + } + } + + fn add_resource(&mut self, name: &str, content: &str, priority: f32) { + let uri = format!("file://{}", name); + let resource = Resource { + name: name.to_string(), + uri: uri.clone(), + annotations: Some(Annotations::for_resource(priority, Utc::now())), + description: Some("A mock resource".to_string()), + mime_type: "text/plain".to_string(), + }; + self.resources.push(resource); + self.resource_content.insert(uri, content.to_string()); + } + } + + #[async_trait] + impl System for MockSystem { + fn name(&self) -> &str { + &self.name + } + + fn description(&self) -> &str { + "A mock system for testing" + } + + fn instructions(&self) -> &str { + "Mock system instructions" + } + + fn tools(&self) -> &[Tool] { + &self.tools + } + + async fn status(&self) -> anyhow::Result> { + Ok(self.resources.clone()) + } + + async fn call(&self, tool_call: ToolCall) -> crate::errors::AgentResult> { + match tool_call.name.as_str() { + "echo" => Ok(vec![Content::text( + tool_call.arguments["message"].as_str().unwrap_or(""), + )]), + _ => Err(crate::errors::AgentError::ToolNotFound(tool_call.name)), + } + } + + async fn read_resource(&self, uri: &str) -> crate::errors::AgentResult { + self.resource_content.get(uri).cloned().ok_or_else(|| { + crate::errors::AgentError::InvalidParameters(format!("Resource {} could not be found", uri)) + }) + } + } + + #[tokio::test] + async fn test_simple_response() -> anyhow::Result<()> { + let response = Message::assistant().with_text("Hello!"); + let provider = MockProvider::new(vec![response.clone()]); + let agent = BaseAgent::new(Box::new(provider)); + + let initial_message = Message::user().with_text("Hi"); + let initial_messages = vec![initial_message]; + + let mut stream = agent.reply(&initial_messages).await?; + let mut messages = Vec::new(); + while let Some(msg) = stream.try_next().await? { + messages.push(msg); + } + + assert_eq!(messages.len(), 1); + assert_eq!(messages[0], response); + Ok(()) + } + + #[tokio::test] + async fn test_usage_rollup() -> anyhow::Result<()> { + let response = Message::assistant().with_text("Hello!"); + let provider = MockProvider::new(vec![response.clone()]); + let agent = BaseAgent::new(Box::new(provider)); + + let initial_message = Message::user().with_text("Hi"); + let initial_messages = vec![initial_message]; + + let mut stream = agent.reply(&initial_messages).await?; + while stream.try_next().await?.is_some() {} + + // Second message + let mut stream = agent.reply(&initial_messages).await?; + while stream.try_next().await?.is_some() {} + + let usage = agent.usage().await?; + assert_eq!(usage.len(), 1); // 2 messages rolled up to one usage per model + assert_eq!(usage[0].usage.input_tokens, Some(2)); + assert_eq!(usage[0].usage.output_tokens, Some(2)); + assert_eq!(usage[0].usage.total_tokens, Some(4)); + assert_eq!(usage[0].model, "mock"); + assert_eq!(usage[0].cost, Some(dec!(2))); + Ok(()) + } + + #[tokio::test] + async fn test_tool_call() -> anyhow::Result<()> { + let mut agent = BaseAgent::new(Box::new(MockProvider::new(vec![ + Message::assistant().with_tool_request( + "1", + Ok(ToolCall::new("test_echo", json!({"message": "test"}))), + ), + Message::assistant().with_text("Done!"), + ]))); + + agent.add_system(Box::new(MockSystem::new("test"))); + + let initial_message = Message::user().with_text("Echo test"); + let initial_messages = vec![initial_message]; + + let mut stream = agent.reply(&initial_messages).await?; + let mut messages = Vec::new(); + while let Some(msg) = stream.try_next().await? { + messages.push(msg); + } + + // Should have three messages: tool request, response, and model text + assert_eq!(messages.len(), 3); + assert!(messages[0] + .content + .iter() + .any(|c| matches!(c, MessageContent::ToolRequest(_)))); + assert_eq!(messages[2].content[0], MessageContent::text("Done!")); + Ok(()) + } + + #[tokio::test] + async fn test_invalid_tool() -> anyhow::Result<()> { + let mut agent = BaseAgent::new(Box::new(MockProvider::new(vec![ + Message::assistant() + .with_tool_request("1", Ok(ToolCall::new("invalid_tool", json!({})))), + Message::assistant().with_text("Error occurred"), + ]))); + + agent.add_system(Box::new(MockSystem::new("test"))); + + let initial_message = Message::user().with_text("Invalid tool"); + let initial_messages = vec![initial_message]; + + let mut stream = agent.reply(&initial_messages).await?; + let mut messages = Vec::new(); + while let Some(msg) = stream.try_next().await? { + messages.push(msg); + } + + // Should have three messages: failed tool request, fail response, and model text + assert_eq!(messages.len(), 3); + assert!(messages[0] + .content + .iter() + .any(|c| matches!(c, MessageContent::ToolRequest(_)))); + assert_eq!( + messages[2].content[0], + MessageContent::text("Error occurred") + ); + Ok(()) + } + + #[tokio::test] + async fn test_multiple_tool_calls() -> anyhow::Result<()> { + let mut agent = BaseAgent::new(Box::new(MockProvider::new(vec![ + Message::assistant() + .with_tool_request( + "1", + Ok(ToolCall::new("test_echo", json!({"message": "first"}))), + ) + .with_tool_request( + "2", + Ok(ToolCall::new("test_echo", json!({"message": "second"}))), + ), + Message::assistant().with_text("All done!"), + ]))); + + agent.add_system(Box::new(MockSystem::new("test"))); + + let initial_message = Message::user().with_text("Multiple calls"); + let initial_messages = vec![initial_message]; + + let mut stream = agent.reply(&initial_messages).await?; + let mut messages = Vec::new(); + while let Some(msg) = stream.try_next().await? { + messages.push(msg); + } + + // Should have three messages: tool requests, responses, and model text + assert_eq!(messages.len(), 3); + assert!(messages[0] + .content + .iter() + .any(|c| matches!(c, MessageContent::ToolRequest(_)))); + assert_eq!(messages[2].content[0], MessageContent::text("All done!")); + Ok(()) + } + + #[tokio::test] + async fn test_prepare_inference_trims_resources_when_budget_exceeded() -> anyhow::Result<()> { + // Create a mock provider + let provider = MockProvider::new(vec![]); + let mut agent = BaseAgent::new(Box::new(provider)); + + // Create a mock system with two resources + let mut system = MockSystem::new("test"); + + // Add two resources with different priorities + let string_10toks = "hello ".repeat(10); + system.add_resource("high_priority", &string_10toks, 0.8); + system.add_resource("low_priority", &string_10toks, 0.1); + + agent.add_system(Box::new(system)); + + // Set up test parameters + // 18 tokens with system + user msg in chat format + let system_prompt = "This is a system prompt"; + let messages = vec![Message::user().with_text("Hi there")]; + let tools = vec![]; + let pending = vec![]; + + // Approx count is 40, so target limit of 35 will force trimming + let target_limit = 35; + + // Call prepare_inference + let result = agent + .prepare_inference(system_prompt, &tools, &messages, &pending, target_limit) + .await?; + + // Get the last message which should be the tool response containing status + let status_message = result.last().unwrap(); + let status_content = status_message + .content + .first() + .and_then(|content| content.as_tool_response_text()) + .unwrap_or_default(); + + // Verify that only the high priority resource is included in the status + assert!(status_content.contains("high_priority")); + assert!(!status_content.contains("low_priority")); + + // Now test with a target limit that allows both resources (no trimming) + let target_limit = 100; + + // Call prepare_inference + let result = agent + .prepare_inference(system_prompt, &tools, &messages, &pending, target_limit) + .await?; + + // Get the last message which should be the tool response containing status + let status_message = result.last().unwrap(); + let status_content = status_message + .content + .first() + .and_then(|content| content.as_tool_response_text()) + .unwrap_or_default(); + + // Verify that only the high priority resource is included in the status + assert!(status_content.contains("high_priority")); + assert!(status_content.contains("low_priority")); + Ok(()) + } + + #[tokio::test] + async fn test_context_trimming_with_custom_model_config() -> anyhow::Result<()> { + let provider = MockProvider::with_config( + vec![], + ModelConfig::new("test_model".to_string()).with_context_limit(Some(20)), + ); + let mut agent = BaseAgent::new(Box::new(provider)); + + // Create a mock system with a resource that will exceed the context limit + let mut system = MockSystem::new("test"); + + // Add a resource that will exceed our tiny context limit + let hello_1_tokens = "hello ".repeat(1); // 1 tokens + let goodbye_10_tokens = "goodbye ".repeat(10); // 10 tokens + system.add_resource("test_resource_removed", &goodbye_10_tokens, 0.1); + system.add_resource("test_resource_expected", &hello_1_tokens, 0.5); + + agent.add_system(Box::new(system)); + + // Set up test parameters + // 18 tokens with system + user msg in chat format + let system_prompt = "This is a system prompt"; + let messages = vec![Message::user().with_text("Hi there")]; + let tools = vec![]; + let pending = vec![]; + + // Use the context limit from the model config + let target_limit = agent.get_provider().get_model_config().context_limit(); + assert_eq!(target_limit, 20, "Context limit should be 20"); + + // Call prepare_inference + let result = agent + .prepare_inference(system_prompt, &tools, &messages, &pending, target_limit) + .await?; + + // Get the last message which should be the tool response containing status + let status_message = result.last().unwrap(); + let status_content = status_message + .content + .first() + .and_then(|content| content.as_tool_response_text()) + .unwrap_or_default(); + + // verify that "hello" is within the response, should be just under 20 tokens with "hello" + assert!(status_content.contains("hello")); + + Ok(()) + } +} \ No newline at end of file diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs new file mode 100644 index 000000000..8b3b76e32 --- /dev/null +++ b/crates/goose/src/agents/mod.rs @@ -0,0 +1,7 @@ +mod agent; +mod base; +mod v1; + +pub use agent::Agent; +pub use base::BaseAgent; +pub use v1::AgentV1; \ No newline at end of file diff --git a/crates/goose/src/agents/v1.rs b/crates/goose/src/agents/v1.rs new file mode 100644 index 000000000..83c65643a --- /dev/null +++ b/crates/goose/src/agents/v1.rs @@ -0,0 +1,91 @@ +use async_trait::async_trait; +use tokio::sync::Mutex; +use anyhow::Result; + +use super::Agent; +use crate::errors::AgentResult; +use crate::message::Message; +use crate::providers::base::{Provider, ProviderUsage}; +use crate::systems::System; +use mcp_core::Tool; + +/// A version of the agent that uses a more aggressive context management strategy + +pub struct AgentV1 { + systems: Vec>, + provider: Box, + provider_usage: Mutex>, +} + +impl AgentV1 { + pub fn new(provider: Box) -> Self { + Self { + systems: Vec::new(), + provider, + provider_usage: Mutex::new(Vec::new()), + } + } + + pub fn add_system(&mut self, system: Box) { + self.systems.push(system); + } +} + +#[async_trait] +impl Agent for AgentV1 { + fn add_system(&mut self, system: Box) { + self.systems.push(system); + } + + fn get_systems(&self) -> &Vec> { + &self.systems + } + + fn get_provider(&self) -> &Box { + &self.provider + } + + fn get_provider_usage(&self) -> &Mutex> { + &self.provider_usage + } + async fn prepare_inference( + &self, + system_prompt: &str, + tools: &[Tool], + messages: &[Message], + pending: &[Message], + target_limit: usize, + ) -> AgentResult> { + todo!(); + // return Ok(messages.to_vec()); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::providers::mock::MockProvider; + use futures::TryStreamExt; + + #[tokio::test] + async fn test_v1_agent() -> Result<(), anyhow::Error> { + // Create a mock provider that returns a simple response + let response = Message::assistant().with_text("Hello!"); + let provider = MockProvider::new(vec![response.clone()]); + let agent = AgentV1::new(Box::new(provider)); + + // Test basic reply functionality + let initial_message = Message::user().with_text("Hi"); + let initial_messages = vec![initial_message]; + + let mut stream = agent.reply(&initial_messages).await?; + let mut messages = Vec::new(); + while let Some(msg) = stream.try_next().await? { + messages.push(msg); + } + + assert_eq!(messages.len(), 1); + assert_eq!(messages[0], response); + Ok(()) + } +} \ No newline at end of file diff --git a/crates/goose/src/lib.rs b/crates/goose/src/lib.rs index 394e6450c..cede26b24 100644 --- a/crates/goose/src/lib.rs +++ b/crates/goose/src/lib.rs @@ -1,4 +1,4 @@ -pub mod agent; +pub mod agents; pub mod developer; pub mod errors; pub mod key_manager; From c8bdf839dfdf93c6331734521dc44cd85792397d Mon Sep 17 00:00:00 2001 From: Zaki Ali Date: Thu, 26 Dec 2024 14:46:10 -0800 Subject: [PATCH 02/11] Update cli to use new agent versions --- crates/goose-cli/src/agents/agent.rs | 47 ++++++++++++++++------- crates/goose-cli/src/agents/mock_agent.rs | 38 ++++++++++++++++-- crates/goose-cli/src/commands/session.rs | 4 +- crates/goose-cli/src/session.rs | 7 ++-- crates/goose/src/providers.rs | 3 +- 5 files changed, 75 insertions(+), 24 deletions(-) diff --git a/crates/goose-cli/src/agents/agent.rs b/crates/goose-cli/src/agents/agent.rs index eb0833489..a65085083 100644 --- a/crates/goose-cli/src/agents/agent.rs +++ b/crates/goose-cli/src/agents/agent.rs @@ -1,28 +1,49 @@ -use anyhow::Result; +// use anyhow::Result; use async_trait::async_trait; -use futures::stream::BoxStream; +// use futures::stream::BoxStream; +use tokio::sync::Mutex; use goose::{ - agent::Agent as GooseAgent, message::Message, providers::base::ProviderUsage, systems::System, + agents::Agent, providers::base::ProviderUsage, systems::System, providers::base::Provider }; -#[async_trait] -pub trait Agent { - fn add_system(&mut self, system: Box); - async fn reply(&self, messages: &[Message]) -> Result>>; - async fn usage(&self) -> Result>; +pub struct GooseAgent { + systems: Vec>, + provider: Box, + provider_usage: Mutex>, +} + +impl GooseAgent { + pub fn new(provider: Box) -> Self { + Self { + systems: Vec::new(), + provider, + provider_usage: Mutex::new(Vec::new()), + } + } } #[async_trait] impl Agent for GooseAgent { fn add_system(&mut self, system: Box) { - self.add_system(system); + self.systems.push(system); } - async fn reply(&self, messages: &[Message]) -> Result>> { - self.reply(messages).await + fn get_systems(&self) -> &Vec> { + &self.systems } - async fn usage(&self) -> Result> { - self.usage().await + fn get_provider(&self) -> &Box { + &self.provider } + + fn get_provider_usage(&self) -> &Mutex> { + &self.provider_usage + } + // async fn reply(&self, messages: &[Message]) -> Result>> { + // self.reply(messages).await + // } + + // async fn usage(&self) -> Result> { + // self.usage().await + // } } diff --git a/crates/goose-cli/src/agents/mock_agent.rs b/crates/goose-cli/src/agents/mock_agent.rs index 542431ac3..6625eb400 100644 --- a/crates/goose-cli/src/agents/mock_agent.rs +++ b/crates/goose-cli/src/agents/mock_agent.rs @@ -3,15 +3,45 @@ use std::vec; use anyhow::Result; use async_trait::async_trait; use futures::stream::BoxStream; -use goose::{message::Message, providers::base::ProviderUsage, systems::System}; +use goose::{agents::Agent, message::Message, providers::base::{Provider, ProviderUsage}, systems::System}; +use goose::providers::mock::MockProvider; +use tokio::sync::Mutex; -use crate::agents::agent::Agent; -pub struct MockAgent; + +pub struct MockAgent { + systems: Vec>, + provider: Box, + provider_usage: Mutex>, +} + +impl MockAgent { + pub fn new() -> Self { + Self { + systems: Vec::new(), + provider: Box::new(MockProvider::new(Vec::new())), + provider_usage: Mutex::new(Vec::new()), + } + } +} #[async_trait] impl Agent for MockAgent { - fn add_system(&mut self, _system: Box) {} + fn add_system(&mut self, system: Box) { + self.systems.push(system); + } + + fn get_systems(&self) -> &Vec> { + &self.systems + } + + fn get_provider(&self) -> &Box { + &self.provider + } + + fn get_provider_usage(&self) -> &Mutex> { + &self.provider_usage + } async fn reply(&self, _messages: &[Message]) -> Result>> { Ok(Box::pin(futures::stream::empty())) diff --git a/crates/goose-cli/src/commands/session.rs b/crates/goose-cli/src/commands/session.rs index 0e23dac49..e13b35433 100644 --- a/crates/goose-cli/src/commands/session.rs +++ b/crates/goose-cli/src/commands/session.rs @@ -1,5 +1,4 @@ use console::style; -use goose::agent::Agent; use goose::providers::factory; use rand::{distributions::Alphanumeric, Rng}; use std::path::{Path, PathBuf}; @@ -9,6 +8,7 @@ use crate::profile::{get_provider_config, load_profiles, Profile}; use crate::prompt::rustyline::RustylinePrompt; use crate::prompt::Prompt; use crate::session::{ensure_session_dir, get_most_recent_session, Session}; +use crate::agents::agent::GooseAgent; pub fn build_session<'a>( session: Option, @@ -45,7 +45,7 @@ pub fn build_session<'a>( // TODO: Odd to be prepping the provider rather than having that done in the agent? let provider = factory::get_provider(provider_config).unwrap(); - let agent = Box::new(Agent::new(provider)); + let agent = Box::new(GooseAgent::new(provider)); let prompt = match std::env::var("GOOSE_INPUT") { Ok(val) => match val.as_str() { "rustyline" => Box::new(RustylinePrompt::new()) as Box, diff --git a/crates/goose-cli/src/session.rs b/crates/goose-cli/src/session.rs index 0d0270d3e..f0b9cf4a7 100644 --- a/crates/goose-cli/src/session.rs +++ b/crates/goose-cli/src/session.rs @@ -6,12 +6,13 @@ use std::fs::{self, File}; use std::io::{self, BufRead, Write}; use std::path::PathBuf; -use crate::agents::agent::Agent; +// use crate::agents::agent::Agent; use crate::log_usage::log_usage; use crate::prompt::{InputType, Prompt}; use goose::developer::DeveloperSystem; use goose::message::{Message, MessageContent}; use goose::systems::goose_hints::GooseHintsSystem; +use goose::agents::Agent; use mcp_core::role::Role; // File management functions @@ -361,14 +362,14 @@ mod tests { // Helper function to create a test session fn create_test_session() -> Session<'static> { let temp_file = NamedTempFile::new().unwrap(); - let agent = Box::new(MockAgent {}); + let agent = Box::new(MockAgent::new()); let prompt = Box::new(MockPrompt::new()); Session::new(agent, prompt, temp_file.path().to_path_buf()) } fn create_test_session_with_prompt<'a>(prompt: Box) -> Session<'a> { let temp_file = NamedTempFile::new().unwrap(); - let agent = Box::new(MockAgent {}); + let agent = Box::new(MockAgent::new()); Session::new(agent, prompt, temp_file.path().to_path_buf()) } diff --git a/crates/goose/src/providers.rs b/crates/goose/src/providers.rs index 6d564eeb8..badc74098 100644 --- a/crates/goose/src/providers.rs +++ b/crates/goose/src/providers.rs @@ -3,6 +3,7 @@ pub mod base; pub mod configs; pub mod databricks; pub mod factory; +pub mod mock; pub mod model_pricing; pub mod oauth; pub mod ollama; @@ -13,7 +14,5 @@ pub mod utils; pub mod google; pub mod groq; -#[cfg(test)] -pub mod mock; #[cfg(test)] pub mod mock_server; From 8f86c4257c40ef53bd48703102870e7976bc7502 Mon Sep 17 00:00:00 2001 From: Zaki Ali Date: Thu, 26 Dec 2024 15:45:55 -0800 Subject: [PATCH 03/11] Update goose-server --- crates/goose-server/src/routes/reply.rs | 3 ++- crates/goose-server/src/state.rs | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index c6bd48964..c000301f4 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -21,6 +21,7 @@ use std::{ use tokio::sync::mpsc; use tokio::time::timeout; use tokio_stream::wrappers::ReceiverStream; +use goose::agents::Agent; // Types matching the incoming JSON structure #[derive(Debug, Deserialize)] @@ -390,7 +391,7 @@ pub fn routes(state: AppState) -> Router { mod tests { use super::*; use goose::{ - agent::Agent, + agents::BaseAgent as Agent, providers::{ base::{Provider, ProviderUsage, Usage}, configs::{ModelConfig, OpenAiProviderConfig}, diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index fa430c316..f2ccbdf44 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -1,7 +1,7 @@ use anyhow::Result; use goose::providers::configs::GroqProviderConfig; use goose::{ - agent::Agent, + agents::BaseAgent as Agent, developer::DeveloperSystem, memory::MemorySystem, providers::{configs::ProviderConfig, factory}, From e25d30c7c3335f3e7b4a04e9202a0596e51a0e71 Mon Sep 17 00:00:00 2001 From: Zaki Ali Date: Thu, 26 Dec 2024 15:49:38 -0800 Subject: [PATCH 04/11] Comment out v1 agent test --- crates/goose/src/agents/v1.rs | 48 +++++++++++++++++------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/crates/goose/src/agents/v1.rs b/crates/goose/src/agents/v1.rs index 83c65643a..15a9d947b 100644 --- a/crates/goose/src/agents/v1.rs +++ b/crates/goose/src/agents/v1.rs @@ -61,31 +61,31 @@ impl Agent for AgentV1 { } } -#[cfg(test)] -mod tests { - use super::*; - use crate::providers::mock::MockProvider; - use futures::TryStreamExt; +// #[cfg(test)] +// mod tests { +// use super::*; +// use crate::providers::mock::MockProvider; +// use futures::TryStreamExt; - #[tokio::test] - async fn test_v1_agent() -> Result<(), anyhow::Error> { - // Create a mock provider that returns a simple response - let response = Message::assistant().with_text("Hello!"); - let provider = MockProvider::new(vec![response.clone()]); - let agent = AgentV1::new(Box::new(provider)); +// #[tokio::test] +// async fn test_v1_agent() -> Result<(), anyhow::Error> { +// // Create a mock provider that returns a simple response +// let response = Message::assistant().with_text("Hello!"); +// let provider = MockProvider::new(vec![response.clone()]); +// let agent = AgentV1::new(Box::new(provider)); - // Test basic reply functionality - let initial_message = Message::user().with_text("Hi"); - let initial_messages = vec![initial_message]; +// // Test basic reply functionality +// let initial_message = Message::user().with_text("Hi"); +// let initial_messages = vec![initial_message]; - let mut stream = agent.reply(&initial_messages).await?; - let mut messages = Vec::new(); - while let Some(msg) = stream.try_next().await? { - messages.push(msg); - } +// let mut stream = agent.reply(&initial_messages).await?; +// let mut messages = Vec::new(); +// while let Some(msg) = stream.try_next().await? { +// messages.push(msg); +// } - assert_eq!(messages.len(), 1); - assert_eq!(messages[0], response); - Ok(()) - } -} \ No newline at end of file +// assert_eq!(messages.len(), 1); +// assert_eq!(messages[0], response); +// Ok(()) +// } +// } \ No newline at end of file From 7d8f8b320bd24794af8043b090f268996aadd064 Mon Sep 17 00:00:00 2001 From: Zaki Ali Date: Thu, 26 Dec 2024 17:03:50 -0800 Subject: [PATCH 05/11] cargo-fmt --- crates/goose-cli/src/agents/agent.rs | 4 ++-- crates/goose-cli/src/agents/mock_agent.rs | 9 ++++++--- crates/goose-cli/src/commands/session.rs | 2 +- crates/goose-cli/src/session.rs | 2 +- crates/goose-server/src/routes/reply.rs | 2 +- crates/goose/src/agents/agent.rs | 18 +++++++++--------- crates/goose/src/agents/base.rs | 7 +++++-- crates/goose/src/agents/mod.rs | 2 +- crates/goose/src/agents/v1.rs | 4 ++-- 9 files changed, 28 insertions(+), 22 deletions(-) diff --git a/crates/goose-cli/src/agents/agent.rs b/crates/goose-cli/src/agents/agent.rs index a65085083..5c3b51e45 100644 --- a/crates/goose-cli/src/agents/agent.rs +++ b/crates/goose-cli/src/agents/agent.rs @@ -1,10 +1,10 @@ // use anyhow::Result; use async_trait::async_trait; // use futures::stream::BoxStream; -use tokio::sync::Mutex; use goose::{ - agents::Agent, providers::base::ProviderUsage, systems::System, providers::base::Provider + agents::Agent, providers::base::Provider, providers::base::ProviderUsage, systems::System, }; +use tokio::sync::Mutex; pub struct GooseAgent { systems: Vec>, diff --git a/crates/goose-cli/src/agents/mock_agent.rs b/crates/goose-cli/src/agents/mock_agent.rs index 6625eb400..c8bf46a0c 100644 --- a/crates/goose-cli/src/agents/mock_agent.rs +++ b/crates/goose-cli/src/agents/mock_agent.rs @@ -3,12 +3,15 @@ use std::vec; use anyhow::Result; use async_trait::async_trait; use futures::stream::BoxStream; -use goose::{agents::Agent, message::Message, providers::base::{Provider, ProviderUsage}, systems::System}; use goose::providers::mock::MockProvider; +use goose::{ + agents::Agent, + message::Message, + providers::base::{Provider, ProviderUsage}, + systems::System, +}; use tokio::sync::Mutex; - - pub struct MockAgent { systems: Vec>, provider: Box, diff --git a/crates/goose-cli/src/commands/session.rs b/crates/goose-cli/src/commands/session.rs index e13b35433..2d4f4eee9 100644 --- a/crates/goose-cli/src/commands/session.rs +++ b/crates/goose-cli/src/commands/session.rs @@ -4,11 +4,11 @@ use rand::{distributions::Alphanumeric, Rng}; use std::path::{Path, PathBuf}; use std::process; +use crate::agents::agent::GooseAgent; use crate::profile::{get_provider_config, load_profiles, Profile}; use crate::prompt::rustyline::RustylinePrompt; use crate::prompt::Prompt; use crate::session::{ensure_session_dir, get_most_recent_session, Session}; -use crate::agents::agent::GooseAgent; pub fn build_session<'a>( session: Option, diff --git a/crates/goose-cli/src/session.rs b/crates/goose-cli/src/session.rs index f0b9cf4a7..3483ec5e5 100644 --- a/crates/goose-cli/src/session.rs +++ b/crates/goose-cli/src/session.rs @@ -9,10 +9,10 @@ use std::path::PathBuf; // use crate::agents::agent::Agent; use crate::log_usage::log_usage; use crate::prompt::{InputType, Prompt}; +use goose::agents::Agent; use goose::developer::DeveloperSystem; use goose::message::{Message, MessageContent}; use goose::systems::goose_hints::GooseHintsSystem; -use goose::agents::Agent; use mcp_core::role::Role; // File management functions diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index c000301f4..298410fd4 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -8,6 +8,7 @@ use axum::{ }; use bytes::Bytes; use futures::{stream::StreamExt, Stream}; +use goose::agents::Agent; use goose::message::{Message, MessageContent}; use mcp_core::{content::Content, role::Role}; use serde::Deserialize; @@ -21,7 +22,6 @@ use std::{ use tokio::sync::mpsc; use tokio::time::timeout; use tokio_stream::wrappers::ReceiverStream; -use goose::agents::Agent; // Types matching the incoming JSON structure #[derive(Debug, Deserialize)] diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index d141806ec..aba706db7 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -74,10 +74,10 @@ pub trait Agent: Send + Sync { /// Get the systems this agent has access to fn get_systems(&self) -> &Vec>; - + /// Get the provider for this agent fn get_provider(&self) -> &Box; - + /// Get the provider usage statistics fn get_provider_usage(&self) -> &Mutex>; @@ -122,9 +122,10 @@ pub trait Agent: Send + Sync { for (system_name, resources) in &resource_content { let mut resource_counts = HashMap::new(); for (uri, (_resource, content)) in resources { - let token_count = token_counter - .count_tokens(&content, Some(&self.get_provider().get_model_config().model_name)) - as u32; + let token_count = token_counter.count_tokens( + &content, + Some(&self.get_provider().get_model_config().model_name), + ) as u32; resource_counts.insert(uri.clone(), token_count); } system_token_counts.insert(system_name.clone(), resource_counts); @@ -298,7 +299,7 @@ pub trait Agent: Send + Sync { async fn usage(&self) -> Result> { let provider_usage = self.get_provider_usage().lock().await.clone(); let mut usage_map: HashMap = HashMap::new(); - + provider_usage.iter().for_each(|usage| { usage_map .entry(usage.model.clone()) @@ -357,8 +358,7 @@ pub trait Agent: Send + Sync { .collect(); context.insert("systems", systems_info); - load_prompt_file("system.md", &context) - .map_err(|e| AgentError::Internal(e.to_string())) + load_prompt_file("system.md", &context).map_err(|e| AgentError::Internal(e.to_string())) } /// Find the appropriate system for a tool call based on the prefixed name @@ -393,4 +393,4 @@ pub trait Agent: Send + Sync { system.call(system_tool_call).await } -} \ No newline at end of file +} diff --git a/crates/goose/src/agents/base.rs b/crates/goose/src/agents/base.rs index c2d9cea39..7a4658124 100644 --- a/crates/goose/src/agents/base.rs +++ b/crates/goose/src/agents/base.rs @@ -129,7 +129,10 @@ mod tests { async fn read_resource(&self, uri: &str) -> crate::errors::AgentResult { self.resource_content.get(uri).cloned().ok_or_else(|| { - crate::errors::AgentError::InvalidParameters(format!("Resource {} could not be found", uri)) + crate::errors::AgentError::InvalidParameters(format!( + "Resource {} could not be found", + uri + )) }) } } @@ -392,4 +395,4 @@ mod tests { Ok(()) } -} \ No newline at end of file +} diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index 8b3b76e32..9637af45f 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -4,4 +4,4 @@ mod v1; pub use agent::Agent; pub use base::BaseAgent; -pub use v1::AgentV1; \ No newline at end of file +pub use v1::AgentV1; diff --git a/crates/goose/src/agents/v1.rs b/crates/goose/src/agents/v1.rs index 15a9d947b..aa0204ffa 100644 --- a/crates/goose/src/agents/v1.rs +++ b/crates/goose/src/agents/v1.rs @@ -1,6 +1,6 @@ +use anyhow::Result; use async_trait::async_trait; use tokio::sync::Mutex; -use anyhow::Result; use super::Agent; use crate::errors::AgentResult; @@ -88,4 +88,4 @@ impl Agent for AgentV1 { // assert_eq!(messages[0], response); // Ok(()) // } -// } \ No newline at end of file +// } From 46d14a46aaf802e03e79152927604588bd07e619 Mon Sep 17 00:00:00 2001 From: Zaki Ali Date: Fri, 27 Dec 2024 12:04:50 -0800 Subject: [PATCH 06/11] Implement AgentFactory for registration of named agents --- crates/goose/Cargo.toml | 2 + crates/goose/src/agents/base.rs | 3 + crates/goose/src/agents/factory.rs | 193 +++++++++++++++++++++++++++++ crates/goose/src/agents/mod.rs | 2 + crates/goose/src/agents/v1.rs | 44 ++----- crates/goose/src/errors.rs | 3 + crates/goose/src/providers/mock.rs | 4 +- 7 files changed, 213 insertions(+), 38 deletions(-) create mode 100644 crates/goose/src/agents/factory.rs diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 47131df0d..b3b6e7fd3 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -67,6 +67,8 @@ keyring = { version = "3.6.1", features = [ shellexpand = "3.1.0" rust_decimal = "1.36.0" rust_decimal_macros = "1.36.0" +ctor = "0.2.7" +paste = "1.0" [dev-dependencies] sysinfo = "0.32.1" diff --git a/crates/goose/src/agents/base.rs b/crates/goose/src/agents/base.rs index 7a4658124..ce3a6c94e 100644 --- a/crates/goose/src/agents/base.rs +++ b/crates/goose/src/agents/base.rs @@ -3,6 +3,7 @@ use tokio::sync::Mutex; use super::Agent; use crate::providers::base::{Provider, ProviderUsage}; +use crate::register_agent; use crate::systems::System; /// Base implementation of an Agent @@ -45,6 +46,8 @@ impl Agent for BaseAgent { } } +register_agent!("base", BaseAgent); + #[cfg(test)] mod tests { use super::*; diff --git a/crates/goose/src/agents/factory.rs b/crates/goose/src/agents/factory.rs new file mode 100644 index 000000000..5ab41e745 --- /dev/null +++ b/crates/goose/src/agents/factory.rs @@ -0,0 +1,193 @@ +use std::collections::HashMap; +use std::sync::{OnceLock, RwLock}; + +use super::Agent; +use crate::errors::AgentError; +use crate::providers::base::Provider; + +type AgentConstructor = Box) -> Box + Send + Sync>; + +// Use std::sync::RwLock for interior mutability +static AGENT_REGISTRY: OnceLock>> = OnceLock::new(); + +/// Initialize the registry if it hasn't been initialized +fn registry() -> &'static RwLock> { + AGENT_REGISTRY.get_or_init(|| RwLock::new(HashMap::new())) +} + +/// Register a new agent version +pub fn register_agent( + version: &'static str, + constructor: impl Fn(Box) -> Box + Send + Sync + 'static, +) { + let registry = registry(); + if let Ok(mut map) = registry.write() { + map.insert(version, Box::new(constructor)); + } +} + +pub struct AgentFactory; + +impl AgentFactory { + /// Create a new agent instance of the specified version + pub fn create( + version: &str, + provider: Box, + ) -> Result, AgentError> { + let registry = registry(); + if let Ok(map) = registry.read() { + if let Some(constructor) = map.get(version) { + Ok(constructor(provider)) + } else { + Err(AgentError::VersionNotFound(version.to_string())) + } + } else { + Err(AgentError::Internal( + "Failed to access agent registry".to_string(), + )) + } + } + + /// Get a list of all available agent versions + pub fn available_versions() -> Vec<&'static str> { + registry() + .read() + .map(|map| map.keys().copied().collect()) + .unwrap_or_default() + } + + /// Get the default version name + pub fn default_version() -> &'static str { + "base" + } +} + +/// Macro to help with agent registration +#[macro_export] +macro_rules! register_agent { + ($version:expr, $agent_type:ty) => { + paste::paste! { + #[ctor::ctor] + #[allow(non_snake_case)] + fn [<__register_agent_ $version>]() { + $crate::agents::factory::register_agent($version, |provider| { + Box::new(<$agent_type>::new(provider)) + }); + } + } + }; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::providers::base::ProviderUsage; + use crate::providers::mock::MockProvider; + use crate::systems::System; + use async_trait::async_trait; + use tokio::sync::Mutex; + + // Test agent implementation + struct TestAgent { + systems: Vec>, + provider: Box, + provider_usage: Mutex>, + } + + impl TestAgent { + fn new(provider: Box) -> Self { + Self { + systems: Vec::new(), + provider, + provider_usage: Mutex::new(Vec::new()), + } + } + } + + #[async_trait] + impl Agent for TestAgent { + fn add_system(&mut self, system: Box) { + self.systems.push(system); + } + + fn get_systems(&self) -> &Vec> { + &self.systems + } + + fn get_provider(&self) -> &Box { + &self.provider + } + + fn get_provider_usage(&self) -> &Mutex> { + &self.provider_usage + } + } + + #[test] + fn test_register_and_create_agent() { + register_agent!("test_create", TestAgent); + + // Create a mock provider + let provider = Box::new(MockProvider::new(vec![])); + + // Create an agent instance + let result = AgentFactory::create("test_create", provider); + assert!(result.is_ok()); + } + + #[test] + fn test_version_not_found() { + // Try to create an agent with a non-existent version + let provider = Box::new(MockProvider::new(vec![])); + let result = AgentFactory::create("nonexistent", provider); + + assert!(matches!(result, Err(AgentError::VersionNotFound(_)))); + if let Err(AgentError::VersionNotFound(version)) = result { + assert_eq!(version, "nonexistent"); + } + } + + #[test] + fn test_available_versions() { + register_agent!("test_available_1", TestAgent); + register_agent!("test_available_2", TestAgent); + + // Get available versions + let versions = AgentFactory::available_versions(); + + assert!(versions.contains(&"test_available_1")); + assert!(versions.contains(&"test_available_2")); + } + + #[test] + fn test_default_version() { + assert_eq!(AgentFactory::default_version(), "base"); + } + + #[test] + fn test_multiple_registrations() { + register_agent!("test_duplicate", TestAgent); + register_agent!("test_duplicate_other", TestAgent); + + // Create an agent instance + let provider = Box::new(MockProvider::new(vec![])); + let result = AgentFactory::create("test_duplicate", provider); + + // Should still work, last registration wins + assert!(result.is_ok()); + } + + #[test] + fn test_agent_with_provider() { + register_agent!("test_provider_check", TestAgent); + + // Create a mock provider with specific configuration + let provider = Box::new(MockProvider::new(vec![])); + + // Create an agent instance + let agent = AgentFactory::create("test_provider_check", provider).unwrap(); + + // Verify the provider is correctly passed to the agent + assert_eq!(agent.get_provider().get_model_config().model_name, "mock"); + } +} diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index 9637af45f..620435db4 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -1,7 +1,9 @@ mod agent; mod base; +mod factory; mod v1; pub use agent::Agent; pub use base::BaseAgent; +pub use factory::{register_agent, AgentFactory}; pub use v1::AgentV1; diff --git a/crates/goose/src/agents/v1.rs b/crates/goose/src/agents/v1.rs index aa0204ffa..73977e4e5 100644 --- a/crates/goose/src/agents/v1.rs +++ b/crates/goose/src/agents/v1.rs @@ -1,4 +1,3 @@ -use anyhow::Result; use async_trait::async_trait; use tokio::sync::Mutex; @@ -6,11 +5,11 @@ use super::Agent; use crate::errors::AgentResult; use crate::message::Message; use crate::providers::base::{Provider, ProviderUsage}; +use crate::register_agent; use crate::systems::System; use mcp_core::Tool; /// A version of the agent that uses a more aggressive context management strategy - pub struct AgentV1 { systems: Vec>, provider: Box, @@ -48,44 +47,17 @@ impl Agent for AgentV1 { fn get_provider_usage(&self) -> &Mutex> { &self.provider_usage } + async fn prepare_inference( &self, - system_prompt: &str, - tools: &[Tool], - messages: &[Message], - pending: &[Message], - target_limit: usize, + _system_prompt: &str, + _tools: &[Tool], + _messages: &[Message], + _pending: &[Message], + _target_limit: usize, ) -> AgentResult> { todo!(); - // return Ok(messages.to_vec()); } } -// #[cfg(test)] -// mod tests { -// use super::*; -// use crate::providers::mock::MockProvider; -// use futures::TryStreamExt; - -// #[tokio::test] -// async fn test_v1_agent() -> Result<(), anyhow::Error> { -// // Create a mock provider that returns a simple response -// let response = Message::assistant().with_text("Hello!"); -// let provider = MockProvider::new(vec![response.clone()]); -// let agent = AgentV1::new(Box::new(provider)); - -// // Test basic reply functionality -// let initial_message = Message::user().with_text("Hi"); -// let initial_messages = vec![initial_message]; - -// let mut stream = agent.reply(&initial_messages).await?; -// let mut messages = Vec::new(); -// while let Some(msg) = stream.try_next().await? { -// messages.push(msg); -// } - -// assert_eq!(messages.len(), 1); -// assert_eq!(messages[0], response); -// Ok(()) -// } -// } +register_agent!("v1", AgentV1); diff --git a/crates/goose/src/errors.rs b/crates/goose/src/errors.rs index 2d2497bdf..5a3eddac3 100644 --- a/crates/goose/src/errors.rs +++ b/crates/goose/src/errors.rs @@ -18,6 +18,9 @@ pub enum AgentError { #[error("Invalid tool name: {0}")] InvalidToolName(String), + + #[error("Agent version not found: {0}")] + VersionNotFound(String), } pub type AgentResult = Result; diff --git a/crates/goose/src/providers/mock.rs b/crates/goose/src/providers/mock.rs index fa84a63af..54aed6ad2 100644 --- a/crates/goose/src/providers/mock.rs +++ b/crates/goose/src/providers/mock.rs @@ -21,7 +21,7 @@ impl MockProvider { pub fn new(responses: Vec) -> Self { Self { responses: Arc::new(Mutex::new(responses)), - model_config: ModelConfig::new("mock-model".to_string()), + model_config: ModelConfig::new("mock".to_string()), } } @@ -62,7 +62,7 @@ impl Provider for MockProvider { } } - fn get_usage(&self, data: &Value) -> Result { + fn get_usage(&self, _data: &Value) -> Result { Ok(Usage::new(None, None, None)) } } From 10343f430b627738ab4ffe2dc9b67bec45c873e8 Mon Sep 17 00:00:00 2001 From: Zaki Ali Date: Fri, 27 Dec 2024 12:46:12 -0800 Subject: [PATCH 07/11] cli list agent versions --- .../goose-cli/src/commands/agent_version.rs | 37 ++++++++++++++ crates/goose-cli/src/commands/mod.rs | 2 +- crates/goose-cli/src/main.rs | 50 ++++++++++++++----- 3 files changed, 75 insertions(+), 14 deletions(-) create mode 100644 crates/goose-cli/src/commands/agent_version.rs diff --git a/crates/goose-cli/src/commands/agent_version.rs b/crates/goose-cli/src/commands/agent_version.rs new file mode 100644 index 000000000..7b3c35fb0 --- /dev/null +++ b/crates/goose-cli/src/commands/agent_version.rs @@ -0,0 +1,37 @@ +use anyhow::Result; +use clap::Args; +use goose::agents::AgentFactory; +use std::fmt::Write; + +#[derive(Args)] +pub struct AgentCommand { + /// List available agent versions + #[arg(short, long)] + list: bool, +} + +impl AgentCommand { + pub fn run(&self) -> Result<()> { + if self.list { + let mut output = String::new(); + writeln!(output, "Available agent versions:")?; + + let versions = AgentFactory::available_versions(); + let default_version = AgentFactory::default_version(); + + for version in versions { + if version == default_version { + writeln!(output, "* {} (default)", version)?; + } else { + writeln!(output, " {}", version)?; + } + } + + print!("{}", output); + } else { + // When no flags are provided, show the default version + println!("Default version: {}", AgentFactory::default_version()); + } + Ok(()) + } +} \ No newline at end of file diff --git a/crates/goose-cli/src/commands/mod.rs b/crates/goose-cli/src/commands/mod.rs index 9420b16f7..d4e08c617 100644 --- a/crates/goose-cli/src/commands/mod.rs +++ b/crates/goose-cli/src/commands/mod.rs @@ -1,4 +1,4 @@ pub mod configure; pub mod session; pub mod version; -pub mod expected_config; \ No newline at end of file +pub mod agent_version; \ No newline at end of file diff --git a/crates/goose-cli/src/main.rs b/crates/goose-cli/src/main.rs index e9b1fa6b2..acda7c4b3 100644 --- a/crates/goose-cli/src/main.rs +++ b/crates/goose-cli/src/main.rs @@ -1,25 +1,22 @@ -mod commands { - pub mod configure; - pub mod session; - pub mod version; -} -pub mod agents; +use anyhow::Result; +use clap::{Parser, Subcommand}; +use goose::agents::AgentFactory; + +mod agents; +mod commands; mod profile; mod prompt; -pub mod session; - +mod session; mod systems; +mod log_usage; -use anyhow::Result; -use clap::{Parser, Subcommand}; use commands::configure::handle_configure; use commands::session::build_session; use commands::version::print_version; +use commands::agent_version::AgentCommand; use profile::has_no_profiles; use std::io::{self, Read}; -mod log_usage; - #[cfg(test)] mod test_helpers; @@ -31,6 +28,10 @@ struct Cli { #[arg(short = 'v', long = "version")] version: bool, + /// Agent version to use (e.g., 'base', 'v1') + #[arg(short = 'a', long = "agent", default_value_t = String::from("base"))] + agent: String, + #[command(subcommand)] command: Option, } @@ -161,6 +162,9 @@ enum Command { )] resume: bool, }, + + /// List available agent versions + Agent(AgentCommand), } #[derive(Subcommand)] @@ -202,6 +206,20 @@ async fn main() -> Result<()> { return Ok(()); } + // Validate agent version + if !AgentFactory::available_versions().contains(&cli.agent.as_str()) { + eprintln!("Error: Invalid agent version '{}'", cli.agent); + eprintln!("Available versions:"); + for version in AgentFactory::available_versions() { + if version == AgentFactory::default_version() { + eprintln!("* {} (default)", version); + } else { + eprintln!(" {}", version); + } + } + std::process::exit(1); + } + match cli.command { Some(Command::Configure { profile_name, @@ -227,6 +245,7 @@ async fn main() -> Result<()> { resume, }) => { let mut session = build_session(name, profile, resume); + session.agent_version = cli.agent; let _ = session.start().await; return Ok(()); } @@ -250,9 +269,14 @@ async fn main() -> Result<()> { stdin }; let mut session = build_session(name, profile, resume); + session.agent_version = cli.agent; let _ = session.headless_start(contents.clone()).await; return Ok(()); } + Some(Command::Agent(cmd)) => { + cmd.run()?; + return Ok(()); + } None => { println!("No command provided - Run 'goose help' to see available commands."); if has_no_profiles().unwrap_or(false) { @@ -261,4 +285,4 @@ async fn main() -> Result<()> { } } Ok(()) -} +} \ No newline at end of file From a46fed8af15b25ba49b607e6be08ed9e889c0413 Mon Sep 17 00:00:00 2001 From: Zaki Ali Date: Fri, 27 Dec 2024 13:31:12 -0800 Subject: [PATCH 08/11] Add agent versions to session command --- crates/goose-cli/src/agents/agent.rs | 8 +- .../goose-cli/src/commands/agent_version.rs | 39 ++++----- crates/goose-cli/src/commands/mod.rs | 2 +- crates/goose-cli/src/commands/session.rs | 11 +-- crates/goose-cli/src/main.rs | 84 +++++++++++++------ crates/goose-cli/src/session.rs | 1 + 6 files changed, 81 insertions(+), 64 deletions(-) diff --git a/crates/goose-cli/src/agents/agent.rs b/crates/goose-cli/src/agents/agent.rs index 5c3b51e45..85730ab10 100644 --- a/crates/goose-cli/src/agents/agent.rs +++ b/crates/goose-cli/src/agents/agent.rs @@ -12,6 +12,7 @@ pub struct GooseAgent { provider_usage: Mutex>, } +#[allow(dead_code)] impl GooseAgent { pub fn new(provider: Box) -> Self { Self { @@ -39,11 +40,4 @@ impl Agent for GooseAgent { fn get_provider_usage(&self) -> &Mutex> { &self.provider_usage } - // async fn reply(&self, messages: &[Message]) -> Result>> { - // self.reply(messages).await - // } - - // async fn usage(&self) -> Result> { - // self.usage().await - // } } diff --git a/crates/goose-cli/src/commands/agent_version.rs b/crates/goose-cli/src/commands/agent_version.rs index 7b3c35fb0..2f1628367 100644 --- a/crates/goose-cli/src/commands/agent_version.rs +++ b/crates/goose-cli/src/commands/agent_version.rs @@ -4,34 +4,25 @@ use goose::agents::AgentFactory; use std::fmt::Write; #[derive(Args)] -pub struct AgentCommand { - /// List available agent versions - #[arg(short, long)] - list: bool, -} +pub struct AgentCommand {} impl AgentCommand { pub fn run(&self) -> Result<()> { - if self.list { - let mut output = String::new(); - writeln!(output, "Available agent versions:")?; - - let versions = AgentFactory::available_versions(); - let default_version = AgentFactory::default_version(); - - for version in versions { - if version == default_version { - writeln!(output, "* {} (default)", version)?; - } else { - writeln!(output, " {}", version)?; - } + let mut output = String::new(); + writeln!(output, "Available agent versions:")?; + + let versions = AgentFactory::available_versions(); + let default_version = AgentFactory::default_version(); + + for version in versions { + if version == default_version { + writeln!(output, "* {} (default)", version)?; + } else { + writeln!(output, " {}", version)?; } - - print!("{}", output); - } else { - // When no flags are provided, show the default version - println!("Default version: {}", AgentFactory::default_version()); } + + print!("{}", output); Ok(()) } -} \ No newline at end of file +} diff --git a/crates/goose-cli/src/commands/mod.rs b/crates/goose-cli/src/commands/mod.rs index d4e08c617..b84916a20 100644 --- a/crates/goose-cli/src/commands/mod.rs +++ b/crates/goose-cli/src/commands/mod.rs @@ -1,4 +1,4 @@ +pub mod agent_version; pub mod configure; pub mod session; pub mod version; -pub mod agent_version; \ No newline at end of file diff --git a/crates/goose-cli/src/commands/session.rs b/crates/goose-cli/src/commands/session.rs index 2d4f4eee9..1014c1d75 100644 --- a/crates/goose-cli/src/commands/session.rs +++ b/crates/goose-cli/src/commands/session.rs @@ -1,10 +1,10 @@ use console::style; +use goose::agents::AgentFactory; use goose::providers::factory; use rand::{distributions::Alphanumeric, Rng}; use std::path::{Path, PathBuf}; use std::process; -use crate::agents::agent::GooseAgent; use crate::profile::{get_provider_config, load_profiles, Profile}; use crate::prompt::rustyline::RustylinePrompt; use crate::prompt::Prompt; @@ -13,6 +13,7 @@ use crate::session::{ensure_session_dir, get_most_recent_session, Session}; pub fn build_session<'a>( session: Option, profile: Option, + agent_version: Option, resume: bool, ) -> Box> { let session_dir = ensure_session_dir().expect("Failed to create session directory"); @@ -45,7 +46,7 @@ pub fn build_session<'a>( // TODO: Odd to be prepping the provider rather than having that done in the agent? let provider = factory::get_provider(provider_config).unwrap(); - let agent = Box::new(GooseAgent::new(provider)); + let agent = AgentFactory::create(agent_version.as_deref().unwrap_or("base"), provider).unwrap(); let prompt = match std::env::var("GOOSE_INPUT") { Ok(val) => match val.as_str() { "rustyline" => Box::new(RustylinePrompt::new()) as Box, @@ -173,7 +174,7 @@ mod tests { #[should_panic(expected = "Cannot resume session: file")] fn test_resume_nonexistent_session_panics() { run_with_tmp_dir(|| { - build_session(Some("nonexistent-session".to_string()), None, true); + build_session(Some("nonexistent-session".to_string()), None, None, true); }) } @@ -190,7 +191,7 @@ mod tests { fs::write(&file2_path, "{}")?; // Test resuming without a session name - let session = build_session(None, None, true); + let session = build_session(None, None, None, true); assert_eq!(session.session_file().as_path(), file2_path.as_path()); Ok(()) @@ -201,7 +202,7 @@ mod tests { #[should_panic(expected = "No session files found")] fn test_resume_most_recent_session_no_files() { run_with_tmp_dir(|| { - build_session(None, None, true); + build_session(None, None, None, true); }); } } diff --git a/crates/goose-cli/src/main.rs b/crates/goose-cli/src/main.rs index acda7c4b3..31f7e8b17 100644 --- a/crates/goose-cli/src/main.rs +++ b/crates/goose-cli/src/main.rs @@ -4,16 +4,16 @@ use goose::agents::AgentFactory; mod agents; mod commands; +mod log_usage; mod profile; mod prompt; mod session; mod systems; -mod log_usage; +use commands::agent_version::AgentCommand; use commands::configure::handle_configure; use commands::session::build_session; use commands::version::print_version; -use commands::agent_version::AgentCommand; use profile::has_no_profiles; use std::io::{self, Read}; @@ -28,10 +28,6 @@ struct Cli { #[arg(short = 'v', long = "version")] version: bool, - /// Agent version to use (e.g., 'base', 'v1') - #[arg(short = 'a', long = "agent", default_value_t = String::from("base"))] - agent: String, - #[command(subcommand)] command: Option, } @@ -99,6 +95,15 @@ enum Command { )] profile: Option, + /// Agent version to use (e.g., 'base', 'v1') + #[arg( + short, + long, + help = "Agent version to use (e.g., 'base', 'v1'), defaults to 'base'", + long_help = "Specify which agent version to use for this session." + )] + agent: Option, + /// Resume a previous session #[arg( short, @@ -152,6 +157,15 @@ enum Command { )] name: Option, + /// Agent version to use (e.g., 'base', 'v1') + #[arg( + short, + long, + help = "Agent version to use (e.g., 'base', 'v1')", + long_help = "Specify which agent version to use for this session." + )] + agent: Option, + /// Resume a previous run #[arg( short, @@ -164,7 +178,7 @@ enum Command { }, /// List available agent versions - Agent(AgentCommand), + Agents(AgentCommand), } #[derive(Subcommand)] @@ -206,20 +220,6 @@ async fn main() -> Result<()> { return Ok(()); } - // Validate agent version - if !AgentFactory::available_versions().contains(&cli.agent.as_str()) { - eprintln!("Error: Invalid agent version '{}'", cli.agent); - eprintln!("Available versions:"); - for version in AgentFactory::available_versions() { - if version == AgentFactory::default_version() { - eprintln!("* {} (default)", version); - } else { - eprintln!(" {}", version); - } - } - std::process::exit(1); - } - match cli.command { Some(Command::Configure { profile_name, @@ -242,10 +242,25 @@ async fn main() -> Result<()> { Some(Command::Session { name, profile, + agent, resume, }) => { - let mut session = build_session(name, profile, resume); - session.agent_version = cli.agent; + if let Some(agent_version) = agent.clone() { + if !AgentFactory::available_versions().contains(&agent_version.as_str()) { + eprintln!("Error: Invalid agent version '{}'", agent_version); + eprintln!("Available versions:"); + for version in AgentFactory::available_versions() { + if version == AgentFactory::default_version() { + eprintln!("* {} (default)", version); + } else { + eprintln!(" {}", version); + } + } + std::process::exit(1); + } + } + + let mut session = build_session(name, profile, agent, resume); let _ = session.start().await; return Ok(()); } @@ -254,8 +269,24 @@ async fn main() -> Result<()> { input_text, profile, name, + agent, resume, }) => { + if let Some(agent_version) = agent.clone() { + if !AgentFactory::available_versions().contains(&agent_version.as_str()) { + eprintln!("Error: Invalid agent version '{}'", agent_version); + eprintln!("Available versions:"); + for version in AgentFactory::available_versions() { + if version == AgentFactory::default_version() { + eprintln!("* {} (default)", version); + } else { + eprintln!(" {}", version); + } + } + std::process::exit(1); + } + } + let contents = if let Some(file_name) = instructions { let file_path = std::path::Path::new(&file_name); std::fs::read_to_string(file_path).expect("Failed to read the instruction file") @@ -268,12 +299,11 @@ async fn main() -> Result<()> { .expect("Failed to read from stdin"); stdin }; - let mut session = build_session(name, profile, resume); - session.agent_version = cli.agent; + let mut session = build_session(name, profile, agent, resume); let _ = session.headless_start(contents.clone()).await; return Ok(()); } - Some(Command::Agent(cmd)) => { + Some(Command::Agents(cmd)) => { cmd.run()?; return Ok(()); } @@ -285,4 +315,4 @@ async fn main() -> Result<()> { } } Ok(()) -} \ No newline at end of file +} diff --git a/crates/goose-cli/src/session.rs b/crates/goose-cli/src/session.rs index 3483ec5e5..8454021a7 100644 --- a/crates/goose-cli/src/session.rs +++ b/crates/goose-cli/src/session.rs @@ -102,6 +102,7 @@ pub struct Session<'a> { messages: Vec, } +#[allow(dead_code)] impl<'a> Session<'a> { pub fn new( agent: Box, From 91555577bc4df22e0c07732c541de984059f0912 Mon Sep 17 00:00:00 2001 From: Zaki Ali Date: Fri, 27 Dec 2024 15:36:28 -0800 Subject: [PATCH 09/11] Update server to configure agent versions --- crates/goose-server/src/configuration.rs | 6 +++++ crates/goose-server/src/main.rs | 6 ++++- crates/goose-server/src/routes/agent.rs | 32 ++++++++++++++++++++++++ crates/goose-server/src/routes/mod.rs | 5 +++- crates/goose-server/src/routes/reply.rs | 6 ++--- crates/goose-server/src/state.rs | 13 +++++++--- 6 files changed, 59 insertions(+), 9 deletions(-) create mode 100644 crates/goose-server/src/routes/agent.rs diff --git a/crates/goose-server/src/configuration.rs b/crates/goose-server/src/configuration.rs index bbece1276..05ff555da 100644 --- a/crates/goose-server/src/configuration.rs +++ b/crates/goose-server/src/configuration.rs @@ -219,6 +219,12 @@ pub struct Settings { #[serde(default)] pub server: ServerSettings, pub provider: ProviderSettings, + #[serde(default = "default_agent_version")] + pub agent_version: Option, +} + +fn default_agent_version() -> Option { + None // Will use AgentFactory::default_version() when None } impl Settings { diff --git a/crates/goose-server/src/main.rs b/crates/goose-server/src/main.rs index d575596a7..58dd3b537 100644 --- a/crates/goose-server/src/main.rs +++ b/crates/goose-server/src/main.rs @@ -19,7 +19,11 @@ async fn main() -> anyhow::Result<()> { std::env::var("GOOSE_SERVER__SECRET_KEY").unwrap_or_else(|_| "test".to_string()); // Create app state - let state = state::AppState::new(settings.provider.into_config(), secret_key.clone())?; + let state = state::AppState::new( + settings.provider.into_config(), + secret_key.clone(), + settings.agent_version, + )?; // Create router with CORS support let cors = CorsLayer::new() diff --git a/crates/goose-server/src/routes/agent.rs b/crates/goose-server/src/routes/agent.rs new file mode 100644 index 000000000..36233ecca --- /dev/null +++ b/crates/goose-server/src/routes/agent.rs @@ -0,0 +1,32 @@ +use axum::{ + extract::State, + routing::get, + Json, Router, +}; +use goose::agents::AgentFactory; +use serde::Serialize; +use crate::state::AppState; + +#[derive(Serialize)] +struct VersionsResponse { + current_version: String, + available_versions: Vec, + default_version: String, +} + +async fn get_versions(State(state): State) -> Json { + let versions = AgentFactory::available_versions(); + let default_version = AgentFactory::default_version().to_string(); + + Json(VersionsResponse { + current_version: state.agent_version.clone(), + available_versions: versions.iter().map(|v| v.to_string()).collect(), + default_version, + }) +} + +pub fn routes(state: AppState) -> Router { + Router::new() + .route("/api/agent/versions", get(get_versions)) + .with_state(state) +} \ No newline at end of file diff --git a/crates/goose-server/src/routes/mod.rs b/crates/goose-server/src/routes/mod.rs index 2d798a0da..b430136d0 100644 --- a/crates/goose-server/src/routes/mod.rs +++ b/crates/goose-server/src/routes/mod.rs @@ -1,9 +1,12 @@ // Export route modules +pub mod agent; pub mod reply; use axum::Router; // Function to configure all routes pub fn configure(state: crate::state::AppState) -> Router { - Router::new().merge(reply::routes(state)) + Router::new() + .merge(reply::routes(state.clone())) + .merge(agent::routes(state)) } diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 298410fd4..501260677 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -8,7 +8,6 @@ use axum::{ }; use bytes::Bytes; use futures::{stream::StreamExt, Stream}; -use goose::agents::Agent; use goose::message::{Message, MessageContent}; use mcp_core::{content::Content, role::Role}; use serde::Deserialize; @@ -423,7 +422,7 @@ mod tests { &self.model_config } - fn get_usage(&self, data: &Value) -> anyhow::Result { + fn get_usage(&self, _data: &Value) -> anyhow::Result { Ok(Usage::new(None, None, None)) } } @@ -519,13 +518,14 @@ mod tests { }); let agent = Agent::new(mock_provider); let state = AppState { - agent: Arc::new(Mutex::new(agent)), + agent: Arc::new(Mutex::new(Box::new(agent))), provider_config: ProviderConfig::OpenAi(OpenAiProviderConfig { host: "https://api.openai.com".to_string(), api_key: "test-key".to_string(), model: ModelConfig::new("test-model".to_string()), }), secret_key: "test-secret".to_string(), + agent_version: "test-version".to_string(), }; // Build router diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index f2ccbdf44..8bb6c6576 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -1,7 +1,8 @@ use anyhow::Result; use goose::providers::configs::GroqProviderConfig; use goose::{ - agents::BaseAgent as Agent, + agents::AgentFactory, + agents::Agent, developer::DeveloperSystem, memory::MemorySystem, providers::{configs::ProviderConfig, factory}, @@ -13,14 +14,16 @@ use tokio::sync::Mutex; /// Shared application state pub struct AppState { pub provider_config: ProviderConfig, - pub agent: Arc>, + pub agent: Arc>>, pub secret_key: String, + pub agent_version: String, } impl AppState { - pub fn new(provider_config: ProviderConfig, secret_key: String) -> Result { + pub fn new(provider_config: ProviderConfig, secret_key: String, agent_version: Option) -> Result { let provider = factory::get_provider(provider_config.clone())?; - let mut agent = Agent::new(provider); + let mut agent = AgentFactory::create(agent_version.clone().unwrap_or(AgentFactory::default_version().to_string()).as_str(), provider)?; + agent.add_system(Box::new(DeveloperSystem::new())); // Add memory system only if GOOSE_SERVER__MEMORY is set to "true" @@ -37,6 +40,7 @@ impl AppState { provider_config, agent: Arc::new(Mutex::new(agent)), secret_key, + agent_version: agent_version.clone().unwrap_or(AgentFactory::default_version().to_string()), }) } } @@ -89,6 +93,7 @@ impl Clone for AppState { }, agent: self.agent.clone(), secret_key: self.secret_key.clone(), + agent_version: self.agent_version.clone(), } } } From 35893f691ea54acb6c26107b3863a16f03732fc1 Mon Sep 17 00:00:00 2001 From: Zaki Ali Date: Fri, 27 Dec 2024 16:49:58 -0800 Subject: [PATCH 10/11] formatting --- crates/goose-server/src/configuration.rs | 2 +- crates/goose-server/src/routes/agent.rs | 12 ++++-------- crates/goose-server/src/state.rs | 20 ++++++++++++++++---- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/crates/goose-server/src/configuration.rs b/crates/goose-server/src/configuration.rs index 05ff555da..7d402cec5 100644 --- a/crates/goose-server/src/configuration.rs +++ b/crates/goose-server/src/configuration.rs @@ -224,7 +224,7 @@ pub struct Settings { } fn default_agent_version() -> Option { - None // Will use AgentFactory::default_version() when None + None // Will use AgentFactory::default_version() when None } impl Settings { diff --git a/crates/goose-server/src/routes/agent.rs b/crates/goose-server/src/routes/agent.rs index 36233ecca..4b78a7319 100644 --- a/crates/goose-server/src/routes/agent.rs +++ b/crates/goose-server/src/routes/agent.rs @@ -1,11 +1,7 @@ -use axum::{ - extract::State, - routing::get, - Json, Router, -}; +use crate::state::AppState; +use axum::{extract::State, routing::get, Json, Router}; use goose::agents::AgentFactory; use serde::Serialize; -use crate::state::AppState; #[derive(Serialize)] struct VersionsResponse { @@ -17,7 +13,7 @@ struct VersionsResponse { async fn get_versions(State(state): State) -> Json { let versions = AgentFactory::available_versions(); let default_version = AgentFactory::default_version().to_string(); - + Json(VersionsResponse { current_version: state.agent_version.clone(), available_versions: versions.iter().map(|v| v.to_string()).collect(), @@ -29,4 +25,4 @@ pub fn routes(state: AppState) -> Router { Router::new() .route("/api/agent/versions", get(get_versions)) .with_state(state) -} \ No newline at end of file +} diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index 8bb6c6576..18bf03dba 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -1,8 +1,8 @@ use anyhow::Result; use goose::providers::configs::GroqProviderConfig; use goose::{ - agents::AgentFactory, agents::Agent, + agents::AgentFactory, developer::DeveloperSystem, memory::MemorySystem, providers::{configs::ProviderConfig, factory}, @@ -20,9 +20,19 @@ pub struct AppState { } impl AppState { - pub fn new(provider_config: ProviderConfig, secret_key: String, agent_version: Option) -> Result { + pub fn new( + provider_config: ProviderConfig, + secret_key: String, + agent_version: Option, + ) -> Result { let provider = factory::get_provider(provider_config.clone())?; - let mut agent = AgentFactory::create(agent_version.clone().unwrap_or(AgentFactory::default_version().to_string()).as_str(), provider)?; + let mut agent = AgentFactory::create( + agent_version + .clone() + .unwrap_or(AgentFactory::default_version().to_string()) + .as_str(), + provider, + )?; agent.add_system(Box::new(DeveloperSystem::new())); @@ -40,7 +50,9 @@ impl AppState { provider_config, agent: Arc::new(Mutex::new(agent)), secret_key, - agent_version: agent_version.clone().unwrap_or(AgentFactory::default_version().to_string()), + agent_version: agent_version + .clone() + .unwrap_or(AgentFactory::default_version().to_string()), }) } } From cf1a58c1a468b98e9718e746ad2b7398b54b8c4d Mon Sep 17 00:00:00 2001 From: Zaki Ali Date: Mon, 30 Dec 2024 22:47:37 -0800 Subject: [PATCH 11/11] Refactor Agent interface --- crates/goose-cli/src/agents/agent.rs | 43 -- crates/goose-cli/src/agents/mock_agent.rs | 27 +- crates/goose-cli/src/agents/mod.rs | 2 - crates/goose-cli/src/commands/session.rs | 2 +- crates/goose-cli/src/main.rs | 8 +- crates/goose-cli/src/session.rs | 13 +- crates/goose-server/src/main.rs | 2 +- crates/goose-server/src/routes/reply.rs | 2 +- crates/goose-server/src/state.rs | 8 +- crates/goose/src/agents/agent.rs | 397 +---------------- crates/goose/src/agents/base.rs | 401 ----------------- crates/goose/src/agents/default.rs | 505 ++++++++++++++++++++++ crates/goose/src/agents/factory.rs | 60 ++- crates/goose/src/agents/mcp_manager.rs | 196 +++++++++ crates/goose/src/agents/mod.rs | 8 +- crates/goose/src/agents/v1.rs | 63 --- crates/goose/src/errors.rs | 5 +- 17 files changed, 788 insertions(+), 954 deletions(-) delete mode 100644 crates/goose-cli/src/agents/agent.rs delete mode 100644 crates/goose/src/agents/base.rs create mode 100644 crates/goose/src/agents/default.rs create mode 100644 crates/goose/src/agents/mcp_manager.rs delete mode 100644 crates/goose/src/agents/v1.rs diff --git a/crates/goose-cli/src/agents/agent.rs b/crates/goose-cli/src/agents/agent.rs deleted file mode 100644 index 85730ab10..000000000 --- a/crates/goose-cli/src/agents/agent.rs +++ /dev/null @@ -1,43 +0,0 @@ -// use anyhow::Result; -use async_trait::async_trait; -// use futures::stream::BoxStream; -use goose::{ - agents::Agent, providers::base::Provider, providers::base::ProviderUsage, systems::System, -}; -use tokio::sync::Mutex; - -pub struct GooseAgent { - systems: Vec>, - provider: Box, - provider_usage: Mutex>, -} - -#[allow(dead_code)] -impl GooseAgent { - pub fn new(provider: Box) -> Self { - Self { - systems: Vec::new(), - provider, - provider_usage: Mutex::new(Vec::new()), - } - } -} - -#[async_trait] -impl Agent for GooseAgent { - fn add_system(&mut self, system: Box) { - self.systems.push(system); - } - - fn get_systems(&self) -> &Vec> { - &self.systems - } - - fn get_provider(&self) -> &Box { - &self.provider - } - - fn get_provider_usage(&self) -> &Mutex> { - &self.provider_usage - } -} diff --git a/crates/goose-cli/src/agents/mock_agent.rs b/crates/goose-cli/src/agents/mock_agent.rs index c8bf46a0c..090feabd8 100644 --- a/crates/goose-cli/src/agents/mock_agent.rs +++ b/crates/goose-cli/src/agents/mock_agent.rs @@ -1,15 +1,14 @@ -use std::vec; - -use anyhow::Result; use async_trait::async_trait; use futures::stream::BoxStream; use goose::providers::mock::MockProvider; use goose::{ agents::Agent, + errors::AgentResult, message::Message, providers::base::{Provider, ProviderUsage}, systems::System, }; +use serde_json::Value; use tokio::sync::Mutex; pub struct MockAgent { @@ -30,27 +29,31 @@ impl MockAgent { #[async_trait] impl Agent for MockAgent { - fn add_system(&mut self, system: Box) { + async fn add_system(&mut self, system: Box) -> AgentResult<()> { self.systems.push(system); + Ok(()) } - fn get_systems(&self) -> &Vec> { - &self.systems + async fn remove_system(&mut self, name: &str) -> AgentResult<()> { + self.systems.retain(|s| s.name() != name); + Ok(()) } - fn get_provider(&self) -> &Box { - &self.provider + async fn list_systems(&self) -> AgentResult> { + Ok(self.systems.iter() + .map(|s| (s.name().to_string(), s.description().to_string())) + .collect()) } - fn get_provider_usage(&self) -> &Mutex> { - &self.provider_usage + async fn passthrough(&self, _system: &str, _request: Value) -> AgentResult { + Ok(Value::Null) } - async fn reply(&self, _messages: &[Message]) -> Result>> { + async fn reply(&self, _messages: &[Message]) -> anyhow::Result>> { Ok(Box::pin(futures::stream::empty())) } - async fn usage(&self) -> Result> { + async fn usage(&self) -> AgentResult> { Ok(vec![ProviderUsage::new( "mock".to_string(), Default::default(), diff --git a/crates/goose-cli/src/agents/mod.rs b/crates/goose-cli/src/agents/mod.rs index 14b5f00cd..a1a102c66 100644 --- a/crates/goose-cli/src/agents/mod.rs +++ b/crates/goose-cli/src/agents/mod.rs @@ -1,4 +1,2 @@ -pub mod agent; - #[cfg(test)] pub mod mock_agent; diff --git a/crates/goose-cli/src/commands/session.rs b/crates/goose-cli/src/commands/session.rs index 1014c1d75..88ed4abee 100644 --- a/crates/goose-cli/src/commands/session.rs +++ b/crates/goose-cli/src/commands/session.rs @@ -46,7 +46,7 @@ pub fn build_session<'a>( // TODO: Odd to be prepping the provider rather than having that done in the agent? let provider = factory::get_provider(provider_config).unwrap(); - let agent = AgentFactory::create(agent_version.as_deref().unwrap_or("base"), provider).unwrap(); + let agent = AgentFactory::create(agent_version.as_deref().unwrap_or("default"), provider).unwrap(); let prompt = match std::env::var("GOOSE_INPUT") { Ok(val) => match val.as_str() { "rustyline" => Box::new(RustylinePrompt::new()) as Box, diff --git a/crates/goose-cli/src/main.rs b/crates/goose-cli/src/main.rs index 31f7e8b17..3b3cd852c 100644 --- a/crates/goose-cli/src/main.rs +++ b/crates/goose-cli/src/main.rs @@ -95,11 +95,11 @@ enum Command { )] profile: Option, - /// Agent version to use (e.g., 'base', 'v1') + /// Agent version to use (e.g., 'default', 'v1') #[arg( short, long, - help = "Agent version to use (e.g., 'base', 'v1'), defaults to 'base'", + help = "Agent version to use (e.g., 'default', 'v1'), defaults to 'default'", long_help = "Specify which agent version to use for this session." )] agent: Option, @@ -157,11 +157,11 @@ enum Command { )] name: Option, - /// Agent version to use (e.g., 'base', 'v1') + /// Agent version to use (e.g., 'default', 'v1') #[arg( short, long, - help = "Agent version to use (e.g., 'base', 'v1')", + help = "Agent version to use (e.g., 'default', 'v1')", long_help = "Specify which agent version to use for this session." )] agent: Option, diff --git a/crates/goose-cli/src/session.rs b/crates/goose-cli/src/session.rs index 8454021a7..96d652390 100644 --- a/crates/goose-cli/src/session.rs +++ b/crates/goose-cli/src/session.rs @@ -10,6 +10,7 @@ use std::path::PathBuf; use crate::log_usage::log_usage; use crate::prompt::{InputType, Prompt}; use goose::agents::Agent; +use goose::errors::AgentResult; use goose::developer::DeveloperSystem; use goose::message::{Message, MessageContent}; use goose::systems::goose_hints::GooseHintsSystem; @@ -134,7 +135,7 @@ impl<'a> Session<'a> { } pub async fn start(&mut self) -> Result<(), Box> { - self.setup_session(); + self.setup_session().await?; self.prompt.goose_ready(); loop { @@ -162,7 +163,7 @@ impl<'a> Session<'a> { &mut self, initial_message: String, ) -> Result<(), Box> { - self.setup_session(); + self.setup_session().await?; self.messages .push(Message::user().with_text(initial_message.as_str())); @@ -313,11 +314,12 @@ We've removed the conversation up to the most recent user message } } - fn setup_session(&mut self) { + async fn setup_session(&mut self) -> AgentResult<()> { let system = Box::new(DeveloperSystem::new()); - self.agent.add_system(system); + self.agent.add_system(system).await?; let goosehints_system = Box::new(GooseHintsSystem::new()); - self.agent.add_system(goosehints_system); + self.agent.add_system(goosehints_system).await?; + Ok(()) } async fn close_session(&mut self) { @@ -329,6 +331,7 @@ We've removed the conversation up to the most recent user message .as_str(), )); self.prompt.close(); + match self.agent.usage().await { Ok(usage) => log_usage(self.session_file.to_string_lossy().to_string(), usage), Err(e) => eprintln!("Failed to collect total provider usage: {}", e), diff --git a/crates/goose-server/src/main.rs b/crates/goose-server/src/main.rs index 58dd3b537..4c0eb2b82 100644 --- a/crates/goose-server/src/main.rs +++ b/crates/goose-server/src/main.rs @@ -23,7 +23,7 @@ async fn main() -> anyhow::Result<()> { settings.provider.into_config(), secret_key.clone(), settings.agent_version, - )?; + ).await?; // Create router with CORS support let cors = CorsLayer::new() diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 501260677..b5269aa8b 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -390,7 +390,7 @@ pub fn routes(state: AppState) -> Router { mod tests { use super::*; use goose::{ - agents::BaseAgent as Agent, + agents::DefaultAgent as Agent, providers::{ base::{Provider, ProviderUsage, Usage}, configs::{ModelConfig, OpenAiProviderConfig}, diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index 18bf03dba..f62f96d07 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -20,7 +20,7 @@ pub struct AppState { } impl AppState { - pub fn new( + pub async fn new( provider_config: ProviderConfig, secret_key: String, agent_version: Option, @@ -34,17 +34,17 @@ impl AppState { provider, )?; - agent.add_system(Box::new(DeveloperSystem::new())); + agent.add_system(Box::new(DeveloperSystem::new())).await?; // Add memory system only if GOOSE_SERVER__MEMORY is set to "true" if let Ok(memory_enabled) = env::var("GOOSE_SERVER__MEMORY") { if memory_enabled.to_lowercase() == "true" { - agent.add_system(Box::new(MemorySystem::new())); + agent.add_system(Box::new(MemorySystem::new())).await?; } } let goosehints_system = Box::new(GooseHintsSystem::new()); - agent.add_system(goosehints_system); + agent.add_system(goosehints_system).await?; Ok(Self { provider_config, diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index aba706db7..5da9167d8 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -1,396 +1,31 @@ use anyhow::Result; -use async_stream; use async_trait::async_trait; use futures::stream::BoxStream; -use rust_decimal_macros::dec; -use serde_json::json; -use std::collections::HashMap; -use tokio::sync::Mutex; +use serde_json::Value; -use crate::errors::{AgentError, AgentResult}; -use crate::message::{Message, ToolRequest}; -use crate::prompt_template::load_prompt_file; -use crate::providers::base::{Provider, ProviderUsage}; +use crate::errors::AgentResult; +use crate::message::Message; +use crate::providers::base::ProviderUsage; use crate::systems::System; -use crate::token_counter::TokenCounter; -use mcp_core::{Content, Resource, Tool, ToolCall}; -use serde::Serialize; - -// used to sort resources by priority within error margin -const PRIORITY_EPSILON: f32 = 0.001; - -#[derive(Clone, Debug, Serialize)] -struct SystemInfo { - name: String, - description: String, - instructions: String, -} - -impl SystemInfo { - fn new(name: &str, description: &str, instructions: &str) -> Self { - Self { - name: name.to_string(), - description: description.to_string(), - instructions: instructions.to_string(), - } - } -} - -#[derive(Clone, Debug, Serialize)] -struct SystemStatus { - name: String, - status: String, -} - -impl SystemStatus { - fn new(name: &str, status: String) -> Self { - Self { - name: name.to_string(), - status, - } - } -} /// Core trait defining the behavior of an Agent #[async_trait] pub trait Agent: Send + Sync { - /// Get all tools from all systems with proper system prefixing - fn get_prefixed_tools(&self) -> Vec { - let mut tools = Vec::new(); - for system in self.get_systems() { - for tool in system.tools() { - tools.push(Tool::new( - format!("{}__{}", system.name(), tool.name), - &tool.description, - tool.input_schema.clone(), - )); - } - } - tools - } - - // add a system to the agent - fn add_system(&mut self, system: Box); - - /// Get the systems this agent has access to - fn get_systems(&self) -> &Vec>; - - /// Get the provider for this agent - fn get_provider(&self) -> &Box; - - /// Get the provider usage statistics - fn get_provider_usage(&self) -> &Mutex>; - - /// Setup the next inference by budgeting the context window - async fn prepare_inference( - &self, - system_prompt: &str, - tools: &[Tool], - messages: &[Message], - pending: &[Message], - target_limit: usize, - ) -> AgentResult> { - // Default implementation for prepare_inference - let token_counter = TokenCounter::new(); - let resource_content = self.get_systems_resources().await?; - - // Flatten all resource content into a vector of strings - let mut resources = Vec::new(); - for system_resources in resource_content.values() { - for (_, content) in system_resources.values() { - resources.push(content.clone()); - } - } - - let approx_count = token_counter.count_everything( - system_prompt, - messages, - tools, - &resources, - Some(&self.get_provider().get_model_config().model_name), - ); - - let mut status_content: Vec = Vec::new(); - - if approx_count > target_limit { - println!("[WARNING] Token budget exceeded. Current count: {} \n Difference: {} tokens over buget. Removing context", approx_count, approx_count - target_limit); - - // Get token counts for each resource - let mut system_token_counts = HashMap::new(); - - // Iterate through each system and its resources - for (system_name, resources) in &resource_content { - let mut resource_counts = HashMap::new(); - for (uri, (_resource, content)) in resources { - let token_count = token_counter.count_tokens( - &content, - Some(&self.get_provider().get_model_config().model_name), - ) as u32; - resource_counts.insert(uri.clone(), token_count); - } - system_token_counts.insert(system_name.clone(), resource_counts); - } - // Sort resources by priority and timestamp and trim to fit context limit - let mut all_resources: Vec<(String, String, Resource, u32)> = Vec::new(); - for (system_name, resources) in &resource_content { - for (uri, (resource, _)) in resources { - if let Some(token_count) = system_token_counts - .get(system_name) - .and_then(|counts| counts.get(uri)) - { - all_resources.push(( - system_name.clone(), - uri.clone(), - resource.clone(), - *token_count, - )); - } - } - } - - // Sort by priority (high to low) and timestamp (newest to oldest) - all_resources.sort_by(|a, b| { - let a_priority = a.2.priority().unwrap_or(0.0); - let b_priority = b.2.priority().unwrap_or(0.0); - if (b_priority - a_priority).abs() < PRIORITY_EPSILON { - b.2.timestamp().cmp(&a.2.timestamp()) - } else { - b.2.priority() - .partial_cmp(&a.2.priority()) - .unwrap_or(std::cmp::Ordering::Equal) - } - }); - - // Remove resources until we're under target limit - let mut current_tokens = approx_count; - - while current_tokens > target_limit && !all_resources.is_empty() { - if let Some((system_name, uri, _, token_count)) = all_resources.pop() { - if let Some(system_counts) = system_token_counts.get_mut(&system_name) { - system_counts.remove(&uri); - current_tokens -= token_count as usize; - } - } - } - // Create status messages only from resources that remain after token trimming - for (system_name, uri, _, _) in &all_resources { - if let Some(system_resources) = resource_content.get(system_name) { - if let Some((resource, content)) = system_resources.get(uri) { - status_content.push(format!("{}\n```\n{}\n```\n", resource.name, content)); - } - } - } - } else { - // Create status messages from all resources when no trimming needed - for resources in resource_content.values() { - for (resource, content) in resources.values() { - status_content.push(format!("{}\n```\n{}\n```\n", resource.name, content)); - } - } - } - - // Join remaining status content and create status message - let status_str = status_content.join("\n"); - let mut context = HashMap::new(); - let systems_status = vec![SystemStatus::new("system", status_str)]; - context.insert("systems", &systems_status); - - // Load and format the status template with only remaining resources - let status = load_prompt_file("status.md", &context) - .map_err(|e| AgentError::Internal(e.to_string()))?; - - // Create a new messages vector with our changes - let mut new_messages = messages.to_vec(); - - // Add pending messages - for msg in pending { - new_messages.push(msg.clone()); - } - - // Finally add the status messages - let message_use = - Message::assistant().with_tool_request("000", Ok(ToolCall::new("status", json!({})))); - - let message_result = - Message::user().with_tool_response("000", Ok(vec![Content::text(status)])); - - new_messages.push(message_use); - new_messages.push(message_result); - - Ok(new_messages) - } - - /// Create a stream that yields each message as it's generated - async fn reply(&self, messages: &[Message]) -> Result>> { - let mut messages = messages.to_vec(); - let tools = self.get_prefixed_tools(); - let system_prompt = self.get_system_prompt()?; - let estimated_limit = self.get_provider().get_model_config().get_estimated_limit(); - - // Update conversation history for the start of the reply - messages = self - .prepare_inference( - &system_prompt, - &tools, - &messages, - &Vec::new(), - estimated_limit, - ) - .await?; - - Ok(Box::pin(async_stream::try_stream! { - loop { - // Get completion from provider - let (response, usage) = self.get_provider().complete( - &system_prompt, - &messages, - &tools, - ).await?; - self.get_provider_usage().lock().await.push(usage); - - // Yield the assistant's response - yield response.clone(); - - tokio::task::yield_now().await; - - // First collect any tool requests - let tool_requests: Vec<&ToolRequest> = response.content - .iter() - .filter_map(|content| content.as_tool_request()) - .collect(); - - if tool_requests.is_empty() { - break; - } - - // Then dispatch each in parallel - let futures: Vec<_> = tool_requests - .iter() - .map(|request| self.dispatch_tool_call(request.tool_call.clone())) - .collect(); - - // Process all the futures in parallel but wait until all are finished - let outputs = futures::future::join_all(futures).await; - - // Create a message with the responses - let mut message_tool_response = Message::user(); - // Now combine these into MessageContent::ToolResponse using the original ID - for (request, output) in tool_requests.iter().zip(outputs.into_iter()) { - message_tool_response = message_tool_response.with_tool_response( - request.id.clone(), - output, - ); - } - - yield message_tool_response.clone(); - - // Now we have to remove the previous status tooluse and toolresponse - // before we add pending messages, then the status msgs back again - messages.pop(); - messages.pop(); - - let pending = vec![response, message_tool_response]; - messages = self.prepare_inference(&system_prompt, &tools, &messages, &pending, estimated_limit).await?; - } - })) - } - - /// Get usage statistics - async fn usage(&self) -> Result> { - let provider_usage = self.get_provider_usage().lock().await.clone(); - let mut usage_map: HashMap = HashMap::new(); - - provider_usage.iter().for_each(|usage| { - usage_map - .entry(usage.model.clone()) - .and_modify(|e| { - e.usage.input_tokens = Some( - e.usage.input_tokens.unwrap_or(0) + usage.usage.input_tokens.unwrap_or(0), - ); - e.usage.output_tokens = Some( - e.usage.output_tokens.unwrap_or(0) + usage.usage.output_tokens.unwrap_or(0), - ); - e.usage.total_tokens = Some( - e.usage.total_tokens.unwrap_or(0) + usage.usage.total_tokens.unwrap_or(0), - ); - if e.cost.is_none() || usage.cost.is_none() { - e.cost = None; // Pricing is not available for all models - } else { - e.cost = Some(e.cost.unwrap_or(dec!(0)) + usage.cost.unwrap_or(dec!(0))); - } - }) - .or_insert_with(|| usage.clone()); - }); - Ok(usage_map.into_values().collect()) - } - - /// Get system resources and their contents - async fn get_systems_resources( - &self, - ) -> AgentResult>> { - let mut system_resource_content = HashMap::new(); - for system in self.get_systems() { - let system_status = system - .status() - .await - .map_err(|e| AgentError::Internal(e.to_string()))?; - - let mut resource_content = HashMap::new(); - for resource in system_status { - if let Ok(content) = system.read_resource(&resource.uri).await { - resource_content.insert(resource.uri.to_string(), (resource, content)); - } - } - system_resource_content.insert(system.name().to_string(), resource_content); - } - Ok(system_resource_content) - } - - /// Get the system prompt - fn get_system_prompt(&self) -> AgentResult { - let mut context = HashMap::new(); - let systems_info: Vec = self - .get_systems() - .iter() - .map(|system| { - SystemInfo::new(system.name(), system.description(), system.instructions()) - }) - .collect(); + /// Create a stream that yields each message as it's generated by the agent + async fn reply(&self, messages: &[Message]) -> Result>>; - context.insert("systems", systems_info); - load_prompt_file("system.md", &context).map_err(|e| AgentError::Internal(e.to_string())) - } + /// Add a system to the agent + async fn add_system(&mut self, system: Box) -> AgentResult<()>; - /// Find the appropriate system for a tool call based on the prefixed name - fn get_system_for_tool(&self, prefixed_name: &str) -> Option<&dyn System> { - let parts: Vec<&str> = prefixed_name.split("__").collect(); - if parts.len() != 2 { - return None; - } - let system_name = parts[0]; - self.get_systems() - .iter() - .find(|sys| sys.name() == system_name) - .map(|v| &**v) - } + /// Remove a system by name + async fn remove_system(&mut self, name: &str) -> AgentResult<()>; - /// Dispatch a single tool call to the appropriate system - async fn dispatch_tool_call( - &self, - tool_call: AgentResult, - ) -> AgentResult> { - let call = tool_call?; - let system = self - .get_system_for_tool(&call.name) - .ok_or_else(|| AgentError::ToolNotFound(call.name.clone()))?; + /// List all systems and their status + async fn list_systems(&self) -> AgentResult>; - let tool_name = call - .name - .split("__") - .nth(1) - .ok_or_else(|| AgentError::InvalidToolName(call.name.clone()))?; - let system_tool_call = ToolCall::new(tool_name, call.arguments); + /// Pass through a JSON-RPC request to a specific system + async fn passthrough(&self, _system: &str, _request: Value) -> AgentResult; - system.call(system_tool_call).await - } + /// Get the total usage of the agent + async fn usage(&self) -> AgentResult>; } diff --git a/crates/goose/src/agents/base.rs b/crates/goose/src/agents/base.rs deleted file mode 100644 index ce3a6c94e..000000000 --- a/crates/goose/src/agents/base.rs +++ /dev/null @@ -1,401 +0,0 @@ -use async_trait::async_trait; -use tokio::sync::Mutex; - -use super::Agent; -use crate::providers::base::{Provider, ProviderUsage}; -use crate::register_agent; -use crate::systems::System; - -/// Base implementation of an Agent -pub struct BaseAgent { - systems: Vec>, - provider: Box, - provider_usage: Mutex>, -} - -impl BaseAgent { - pub fn new(provider: Box) -> Self { - Self { - systems: Vec::new(), - provider, - provider_usage: Mutex::new(Vec::new()), - } - } - - pub fn add_system(&mut self, system: Box) { - self.systems.push(system); - } -} - -#[async_trait] -impl Agent for BaseAgent { - fn add_system(&mut self, system: Box) { - self.systems.push(system); - } - - fn get_systems(&self) -> &Vec> { - &self.systems - } - - fn get_provider(&self) -> &Box { - &self.provider - } - - fn get_provider_usage(&self) -> &Mutex> { - &self.provider_usage - } -} - -register_agent!("base", BaseAgent); - -#[cfg(test)] -mod tests { - use super::*; - use crate::message::{Message, MessageContent}; - use crate::providers::configs::ModelConfig; - use crate::providers::mock::MockProvider; - use async_trait::async_trait; - use chrono::Utc; - use futures::TryStreamExt; - use mcp_core::resource::Resource; - use mcp_core::{Annotations, Content, Tool, ToolCall}; - use rust_decimal_macros::dec; - use serde_json::json; - use std::collections::HashMap; - - // Mock system for testing - struct MockSystem { - name: String, - tools: Vec, - resources: Vec, - resource_content: HashMap, - } - - impl MockSystem { - fn new(name: &str) -> Self { - Self { - name: name.to_string(), - tools: vec![Tool::new( - "echo", - "Echoes back the input", - json!({"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}), - )], - resources: Vec::new(), - resource_content: HashMap::new(), - } - } - - fn add_resource(&mut self, name: &str, content: &str, priority: f32) { - let uri = format!("file://{}", name); - let resource = Resource { - name: name.to_string(), - uri: uri.clone(), - annotations: Some(Annotations::for_resource(priority, Utc::now())), - description: Some("A mock resource".to_string()), - mime_type: "text/plain".to_string(), - }; - self.resources.push(resource); - self.resource_content.insert(uri, content.to_string()); - } - } - - #[async_trait] - impl System for MockSystem { - fn name(&self) -> &str { - &self.name - } - - fn description(&self) -> &str { - "A mock system for testing" - } - - fn instructions(&self) -> &str { - "Mock system instructions" - } - - fn tools(&self) -> &[Tool] { - &self.tools - } - - async fn status(&self) -> anyhow::Result> { - Ok(self.resources.clone()) - } - - async fn call(&self, tool_call: ToolCall) -> crate::errors::AgentResult> { - match tool_call.name.as_str() { - "echo" => Ok(vec![Content::text( - tool_call.arguments["message"].as_str().unwrap_or(""), - )]), - _ => Err(crate::errors::AgentError::ToolNotFound(tool_call.name)), - } - } - - async fn read_resource(&self, uri: &str) -> crate::errors::AgentResult { - self.resource_content.get(uri).cloned().ok_or_else(|| { - crate::errors::AgentError::InvalidParameters(format!( - "Resource {} could not be found", - uri - )) - }) - } - } - - #[tokio::test] - async fn test_simple_response() -> anyhow::Result<()> { - let response = Message::assistant().with_text("Hello!"); - let provider = MockProvider::new(vec![response.clone()]); - let agent = BaseAgent::new(Box::new(provider)); - - let initial_message = Message::user().with_text("Hi"); - let initial_messages = vec![initial_message]; - - let mut stream = agent.reply(&initial_messages).await?; - let mut messages = Vec::new(); - while let Some(msg) = stream.try_next().await? { - messages.push(msg); - } - - assert_eq!(messages.len(), 1); - assert_eq!(messages[0], response); - Ok(()) - } - - #[tokio::test] - async fn test_usage_rollup() -> anyhow::Result<()> { - let response = Message::assistant().with_text("Hello!"); - let provider = MockProvider::new(vec![response.clone()]); - let agent = BaseAgent::new(Box::new(provider)); - - let initial_message = Message::user().with_text("Hi"); - let initial_messages = vec![initial_message]; - - let mut stream = agent.reply(&initial_messages).await?; - while stream.try_next().await?.is_some() {} - - // Second message - let mut stream = agent.reply(&initial_messages).await?; - while stream.try_next().await?.is_some() {} - - let usage = agent.usage().await?; - assert_eq!(usage.len(), 1); // 2 messages rolled up to one usage per model - assert_eq!(usage[0].usage.input_tokens, Some(2)); - assert_eq!(usage[0].usage.output_tokens, Some(2)); - assert_eq!(usage[0].usage.total_tokens, Some(4)); - assert_eq!(usage[0].model, "mock"); - assert_eq!(usage[0].cost, Some(dec!(2))); - Ok(()) - } - - #[tokio::test] - async fn test_tool_call() -> anyhow::Result<()> { - let mut agent = BaseAgent::new(Box::new(MockProvider::new(vec![ - Message::assistant().with_tool_request( - "1", - Ok(ToolCall::new("test_echo", json!({"message": "test"}))), - ), - Message::assistant().with_text("Done!"), - ]))); - - agent.add_system(Box::new(MockSystem::new("test"))); - - let initial_message = Message::user().with_text("Echo test"); - let initial_messages = vec![initial_message]; - - let mut stream = agent.reply(&initial_messages).await?; - let mut messages = Vec::new(); - while let Some(msg) = stream.try_next().await? { - messages.push(msg); - } - - // Should have three messages: tool request, response, and model text - assert_eq!(messages.len(), 3); - assert!(messages[0] - .content - .iter() - .any(|c| matches!(c, MessageContent::ToolRequest(_)))); - assert_eq!(messages[2].content[0], MessageContent::text("Done!")); - Ok(()) - } - - #[tokio::test] - async fn test_invalid_tool() -> anyhow::Result<()> { - let mut agent = BaseAgent::new(Box::new(MockProvider::new(vec![ - Message::assistant() - .with_tool_request("1", Ok(ToolCall::new("invalid_tool", json!({})))), - Message::assistant().with_text("Error occurred"), - ]))); - - agent.add_system(Box::new(MockSystem::new("test"))); - - let initial_message = Message::user().with_text("Invalid tool"); - let initial_messages = vec![initial_message]; - - let mut stream = agent.reply(&initial_messages).await?; - let mut messages = Vec::new(); - while let Some(msg) = stream.try_next().await? { - messages.push(msg); - } - - // Should have three messages: failed tool request, fail response, and model text - assert_eq!(messages.len(), 3); - assert!(messages[0] - .content - .iter() - .any(|c| matches!(c, MessageContent::ToolRequest(_)))); - assert_eq!( - messages[2].content[0], - MessageContent::text("Error occurred") - ); - Ok(()) - } - - #[tokio::test] - async fn test_multiple_tool_calls() -> anyhow::Result<()> { - let mut agent = BaseAgent::new(Box::new(MockProvider::new(vec![ - Message::assistant() - .with_tool_request( - "1", - Ok(ToolCall::new("test_echo", json!({"message": "first"}))), - ) - .with_tool_request( - "2", - Ok(ToolCall::new("test_echo", json!({"message": "second"}))), - ), - Message::assistant().with_text("All done!"), - ]))); - - agent.add_system(Box::new(MockSystem::new("test"))); - - let initial_message = Message::user().with_text("Multiple calls"); - let initial_messages = vec![initial_message]; - - let mut stream = agent.reply(&initial_messages).await?; - let mut messages = Vec::new(); - while let Some(msg) = stream.try_next().await? { - messages.push(msg); - } - - // Should have three messages: tool requests, responses, and model text - assert_eq!(messages.len(), 3); - assert!(messages[0] - .content - .iter() - .any(|c| matches!(c, MessageContent::ToolRequest(_)))); - assert_eq!(messages[2].content[0], MessageContent::text("All done!")); - Ok(()) - } - - #[tokio::test] - async fn test_prepare_inference_trims_resources_when_budget_exceeded() -> anyhow::Result<()> { - // Create a mock provider - let provider = MockProvider::new(vec![]); - let mut agent = BaseAgent::new(Box::new(provider)); - - // Create a mock system with two resources - let mut system = MockSystem::new("test"); - - // Add two resources with different priorities - let string_10toks = "hello ".repeat(10); - system.add_resource("high_priority", &string_10toks, 0.8); - system.add_resource("low_priority", &string_10toks, 0.1); - - agent.add_system(Box::new(system)); - - // Set up test parameters - // 18 tokens with system + user msg in chat format - let system_prompt = "This is a system prompt"; - let messages = vec![Message::user().with_text("Hi there")]; - let tools = vec![]; - let pending = vec![]; - - // Approx count is 40, so target limit of 35 will force trimming - let target_limit = 35; - - // Call prepare_inference - let result = agent - .prepare_inference(system_prompt, &tools, &messages, &pending, target_limit) - .await?; - - // Get the last message which should be the tool response containing status - let status_message = result.last().unwrap(); - let status_content = status_message - .content - .first() - .and_then(|content| content.as_tool_response_text()) - .unwrap_or_default(); - - // Verify that only the high priority resource is included in the status - assert!(status_content.contains("high_priority")); - assert!(!status_content.contains("low_priority")); - - // Now test with a target limit that allows both resources (no trimming) - let target_limit = 100; - - // Call prepare_inference - let result = agent - .prepare_inference(system_prompt, &tools, &messages, &pending, target_limit) - .await?; - - // Get the last message which should be the tool response containing status - let status_message = result.last().unwrap(); - let status_content = status_message - .content - .first() - .and_then(|content| content.as_tool_response_text()) - .unwrap_or_default(); - - // Verify that only the high priority resource is included in the status - assert!(status_content.contains("high_priority")); - assert!(status_content.contains("low_priority")); - Ok(()) - } - - #[tokio::test] - async fn test_context_trimming_with_custom_model_config() -> anyhow::Result<()> { - let provider = MockProvider::with_config( - vec![], - ModelConfig::new("test_model".to_string()).with_context_limit(Some(20)), - ); - let mut agent = BaseAgent::new(Box::new(provider)); - - // Create a mock system with a resource that will exceed the context limit - let mut system = MockSystem::new("test"); - - // Add a resource that will exceed our tiny context limit - let hello_1_tokens = "hello ".repeat(1); // 1 tokens - let goodbye_10_tokens = "goodbye ".repeat(10); // 10 tokens - system.add_resource("test_resource_removed", &goodbye_10_tokens, 0.1); - system.add_resource("test_resource_expected", &hello_1_tokens, 0.5); - - agent.add_system(Box::new(system)); - - // Set up test parameters - // 18 tokens with system + user msg in chat format - let system_prompt = "This is a system prompt"; - let messages = vec![Message::user().with_text("Hi there")]; - let tools = vec![]; - let pending = vec![]; - - // Use the context limit from the model config - let target_limit = agent.get_provider().get_model_config().context_limit(); - assert_eq!(target_limit, 20, "Context limit should be 20"); - - // Call prepare_inference - let result = agent - .prepare_inference(system_prompt, &tools, &messages, &pending, target_limit) - .await?; - - // Get the last message which should be the tool response containing status - let status_message = result.last().unwrap(); - let status_content = status_message - .content - .first() - .and_then(|content| content.as_tool_response_text()) - .unwrap_or_default(); - - // verify that "hello" is within the response, should be just under 20 tokens with "hello" - assert!(status_content.contains("hello")); - - Ok(()) - } -} diff --git a/crates/goose/src/agents/default.rs b/crates/goose/src/agents/default.rs new file mode 100644 index 000000000..f5f75d612 --- /dev/null +++ b/crates/goose/src/agents/default.rs @@ -0,0 +1,505 @@ +use async_trait::async_trait; +use futures::stream::BoxStream; +use serde::Serialize; +use serde_json::json; +use tokio::sync::Mutex; +use std::collections::HashMap; + +use super::{Agent, MCPManager}; +use crate::errors::{AgentError, AgentResult}; +use crate::message::{Message, ToolRequest}; +use crate::providers::base::Provider; +use crate::register_agent; +use crate::systems::System; +use crate::token_counter::TokenCounter; +use mcp_core::{Content, Resource, Tool, ToolCall}; +use crate::prompt_template::load_prompt_file; +use crate::providers::base::ProviderUsage; +use serde_json::Value; +// used to sort resources by priority within error margin +const PRIORITY_EPSILON: f32 = 0.001; + +#[derive(Clone, Debug, Serialize)] +struct SystemStatus { + name: String, + status: String, +} + +impl SystemStatus { + fn new(name: &str, status: String) -> Self { + Self { + name: name.to_string(), + status, + } + } +} + +/// Default implementation of an Agent +pub struct DefaultAgent { + mcp_manager: Mutex, +} + +impl DefaultAgent { + pub fn new(provider: Box) -> Self { + Self { + mcp_manager: Mutex::new(MCPManager::new(provider)), + } + } + + /// Setup the next inference by budgeting the context window + async fn prepare_inference( + &self, + system_prompt: &str, + tools: &[Tool], + messages: &[Message], + pending: &[Message], + target_limit: usize, + model_name: &str, + resource_content: &HashMap>, + ) -> AgentResult> { + let token_counter = TokenCounter::new(); + + // Flatten all resource content into a vector of strings + let mut resources = Vec::new(); + for system_resources in resource_content.values() { + for (_, content) in system_resources.values() { + resources.push(content.clone()); + } + } + + let approx_count = token_counter.count_everything( + system_prompt, + messages, + tools, + &resources, + Some(model_name), + ); + + let mut status_content: Vec = Vec::new(); + + if approx_count > target_limit { + println!("[WARNING] Token budget exceeded. Current count: {} \n Difference: {} tokens over buget. Removing context", approx_count, approx_count - target_limit); + + // Get token counts for each resource + let mut system_token_counts = HashMap::new(); + + // Iterate through each system and its resources + for (system_name, resources) in resource_content { + let mut resource_counts = HashMap::new(); + for (uri, (_resource, content)) in resources { + let token_count = token_counter.count_tokens(&content, Some(model_name)) as u32; + resource_counts.insert(uri.clone(), token_count); + } + system_token_counts.insert(system_name.clone(), resource_counts); + } + + // Sort resources by priority and timestamp and trim to fit context limit + let mut all_resources: Vec<(String, String, Resource, u32)> = Vec::new(); + for (system_name, resources) in resource_content { + for (uri, (resource, _)) in resources { + if let Some(token_count) = system_token_counts + .get(system_name) + .and_then(|counts| counts.get(uri)) + { + all_resources.push(( + system_name.clone(), + uri.clone(), + resource.clone(), + *token_count, + )); + } + } + } + + // Sort by priority (high to low) and timestamp (newest to oldest) + all_resources.sort_by(|a, b| { + let a_priority = a.2.priority().unwrap_or(0.0); + let b_priority = b.2.priority().unwrap_or(0.0); + if (b_priority - a_priority).abs() < PRIORITY_EPSILON { + b.2.timestamp().cmp(&a.2.timestamp()) + } else { + b.2.priority() + .partial_cmp(&a.2.priority()) + .unwrap_or(std::cmp::Ordering::Equal) + } + }); + + // Remove resources until we're under target limit + let mut current_tokens = approx_count; + + while current_tokens > target_limit && !all_resources.is_empty() { + if let Some((system_name, uri, _, token_count)) = all_resources.pop() { + if let Some(system_counts) = system_token_counts.get_mut(&system_name) { + system_counts.remove(&uri); + current_tokens -= token_count as usize; + } + } + } + + // Create status messages only from resources that remain after token trimming + for (system_name, uri, _, _) in &all_resources { + if let Some(system_resources) = resource_content.get(system_name) { + if let Some((resource, content)) = system_resources.get(uri) { + status_content.push(format!("{}\n```\n{}\n```\n", resource.name, content)); + } + } + } + } else { + // Create status messages from all resources when no trimming needed + for resources in resource_content.values() { + for (resource, content) in resources.values() { + status_content.push(format!("{}\n```\n{}\n```\n", resource.name, content)); + } + } + } + + // Join remaining status content and create status message + let status_str = status_content.join("\n"); + let mut context = HashMap::new(); + let systems_status = vec![SystemStatus::new("system", status_str)]; + context.insert("systems", &systems_status); + + // Load and format the status template with only remaining resources + let status = load_prompt_file("status.md", &context) + .map_err(|e| AgentError::Internal(e.to_string()))?; + + // Create a new messages vector with our changes + let mut new_messages = messages.to_vec(); + + // Add pending messages + for msg in pending { + new_messages.push(msg.clone()); + } + + // Finally add the status messages + let message_use = + Message::assistant().with_tool_request("000", Ok(ToolCall::new("status", json!({})))); + + let message_result = + Message::user().with_tool_response("000", Ok(vec![Content::text(status)])); + + new_messages.push(message_use); + new_messages.push(message_result); + + Ok(new_messages) + } +} + +#[async_trait] +impl Agent for DefaultAgent { + async fn add_system(&mut self, system: Box) -> AgentResult<()> { + let mut manager = self.mcp_manager.lock().await; + manager.add_system(system); + Ok(()) + } + + async fn remove_system(&mut self, name: &str) -> AgentResult<()> { + let mut manager = self.mcp_manager.lock().await; + manager.remove_system(name) + } + + async fn list_systems(&self) -> AgentResult> { + let manager = self.mcp_manager.lock().await; + manager.list_systems().await + } + + async fn passthrough(&self, _system: &str, _request: Value) -> AgentResult { + Ok(Value::Null) + } + + async fn reply(&self, messages: &[Message]) -> anyhow::Result>> { + let manager = self.mcp_manager.lock().await; + let tools = manager.get_prefixed_tools(); + let system_prompt = manager.get_system_prompt()?; + let estimated_limit = manager.provider().get_model_config().get_estimated_limit(); + + // Update conversation history for the start of the reply + let mut messages = self.prepare_inference( + &system_prompt, + &tools, + messages, + &Vec::new(), + estimated_limit, + &manager.provider().get_model_config().model_name, + &manager.get_systems_resources().await?, + ).await?; + + Ok(Box::pin(async_stream::try_stream! { + loop { + // Get completion from provider + let (response, usage) = manager.provider().complete( + &system_prompt, + &messages, + &tools, + ).await?; + manager.record_usage(usage).await; + + // Yield the assistant's response + yield response.clone(); + + tokio::task::yield_now().await; + + // First collect any tool requests + let tool_requests: Vec<&ToolRequest> = response.content + .iter() + .filter_map(|content| content.as_tool_request()) + .collect(); + + if tool_requests.is_empty() { + break; + } + + // Then dispatch each in parallel + let futures: Vec<_> = tool_requests + .iter() + .map(|request| manager.dispatch_tool_call(request.tool_call.clone())) + .collect(); + + // Process all the futures in parallel but wait until all are finished + let outputs = futures::future::join_all(futures).await; + + // Create a message with the responses + let mut message_tool_response = Message::user(); + // Now combine these into MessageContent::ToolResponse using the original ID + for (request, output) in tool_requests.iter().zip(outputs.into_iter()) { + message_tool_response = message_tool_response.with_tool_response( + request.id.clone(), + output, + ); + } + + yield message_tool_response.clone(); + + // Now we have to remove the previous status tooluse and toolresponse + // before we add pending messages, then the status msgs back again + messages.pop(); + messages.pop(); + + let pending = vec![response, message_tool_response]; + messages = self.prepare_inference(&system_prompt, &tools, &messages, &pending, estimated_limit, &manager.provider().get_model_config().model_name, &manager.get_systems_resources().await?).await?; + } + })) + } + + async fn usage(&self) -> AgentResult> { + let manager = self.mcp_manager.lock().await; + manager.get_usage().await.map_err(|e| AgentError::Internal(e.to_string())) + } +} + +register_agent!("default", DefaultAgent); + +#[cfg(test)] +mod tests { + use super::*; + use crate::message::{Message, MessageContent}; + use crate::providers::configs::ModelConfig; + use crate::providers::mock::MockProvider; + use async_trait::async_trait; + use chrono::Utc; + use futures::TryStreamExt; + use mcp_core::resource::Resource; + use mcp_core::{Annotations, Content, Tool, ToolCall}; + use serde_json::json; + use std::collections::HashMap; + + // Mock system for testing + struct MockSystem { + name: String, + tools: Vec, + resources: Vec, + resource_content: HashMap, + } + + impl MockSystem { + fn new(name: &str) -> Self { + Self { + name: name.to_string(), + tools: vec![Tool::new( + "echo", + "Echoes back the input", + json!({"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}), + )], + resources: Vec::new(), + resource_content: HashMap::new(), + } + } + + fn add_resource(&mut self, name: &str, content: &str, priority: f32) { + let uri = format!("file://{}", name); + let resource = Resource { + name: name.to_string(), + uri: uri.clone(), + annotations: Some(Annotations::for_resource(priority, Utc::now())), + description: Some("A mock resource".to_string()), + mime_type: "text/plain".to_string(), + }; + self.resources.push(resource); + self.resource_content.insert(uri, content.to_string()); + } + } + + #[async_trait] + impl System for MockSystem { + fn name(&self) -> &str { + &self.name + } + + fn description(&self) -> &str { + "A mock system for testing" + } + + fn instructions(&self) -> &str { + "Mock system instructions" + } + + fn tools(&self) -> &[Tool] { + &self.tools + } + + async fn status(&self) -> anyhow::Result> { + Ok(self.resources.clone()) + } + + async fn call(&self, tool_call: ToolCall) -> AgentResult> { + match tool_call.name.as_str() { + "echo" => Ok(vec![Content::text( + tool_call.arguments["message"].as_str().unwrap_or(""), + )]), + _ => Err(AgentError::ToolNotFound(tool_call.name)), + } + } + + async fn read_resource(&self, uri: &str) -> AgentResult { + self.resource_content.get(uri).cloned().ok_or_else(|| { + AgentError::InvalidParameters(format!("Resource {} could not be found", uri)) + }) + } + } + + #[tokio::test(flavor = "current_thread")] + async fn test_simple_response() -> anyhow::Result<()> { + let response = Message::assistant().with_text("Hello!"); + let provider = MockProvider::new(vec![response.clone()]); + let mut agent = DefaultAgent::new(Box::new(provider)); + + // Add a system to test system management + agent.add_system(Box::new(MockSystem::new("test"))).await?; + + let initial_message = Message::user().with_text("Hi"); + let initial_messages = vec![initial_message]; + + let mut stream = agent.reply(&initial_messages).await?; + let mut messages = Vec::new(); + while let Some(msg) = stream.try_next().await? { + messages.push(msg); + } + + assert_eq!(messages.len(), 1); + assert_eq!(messages[0], response); + Ok(()) + } + + #[tokio::test(flavor = "current_thread")] + async fn test_system_management() -> anyhow::Result<()> { + let provider = MockProvider::new(vec![]); + let mut agent = DefaultAgent::new(Box::new(provider)); + + // Add a system + agent.add_system(Box::new(MockSystem::new("test1"))).await?; + agent.add_system(Box::new(MockSystem::new("test2"))).await?; + + // List systems + let systems = agent.list_systems().await?; + assert_eq!(systems.len(), 2); + assert!(systems.iter().any(|(name, _)| name == "test1")); + assert!(systems.iter().any(|(name, _)| name == "test2")); + + // Remove a system + agent.remove_system("test1").await?; + let systems = agent.list_systems().await?; + assert_eq!(systems.len(), 1); + assert_eq!(systems[0].0, "test2"); + + Ok(()) + } + + #[tokio::test] + async fn test_tool_call() -> anyhow::Result<()> { + let mut agent = DefaultAgent::new(Box::new(MockProvider::new(vec![ + Message::assistant().with_tool_request( + "1", + Ok(ToolCall::new("test_echo", json!({"message": "test"}))), + ), + Message::assistant().with_text("Done!"), + ]))); + + agent.add_system(Box::new(MockSystem::new("test"))).await?; + + let initial_message = Message::user().with_text("Echo test"); + let initial_messages = vec![initial_message]; + + let mut stream = agent.reply(&initial_messages).await?; + let mut messages = Vec::new(); + while let Some(msg) = stream.try_next().await? { + messages.push(msg); + } + + // Should have three messages: tool request, response, and model text + assert_eq!(messages.len(), 3); + assert!(messages[0] + .content + .iter() + .any(|c| matches!(c, MessageContent::ToolRequest(_)))); + assert_eq!(messages[2].content[0], MessageContent::text("Done!")); + Ok(()) + } + + #[tokio::test] + async fn test_prepare_inference_trims_resources() -> anyhow::Result<()> { + let provider = MockProvider::with_config( + vec![], + ModelConfig::new("test_model".to_string()).with_context_limit(Some(20)), + ); + let mut agent = DefaultAgent::new(Box::new(provider)); + + // Create a mock system with resources + let mut system = MockSystem::new("test"); + let hello_1_tokens = "hello ".repeat(1); // 1 tokens + let goodbye_10_tokens = "goodbye ".repeat(10); // 10 tokens + system.add_resource("test_resource_removed", &goodbye_10_tokens, 0.1); + system.add_resource("test_resource_expected", &hello_1_tokens, 0.5); + + agent.add_system(Box::new(system)).await?; + + // Set up test parameters + let manager = agent.mcp_manager.lock().await; + + let system_prompt = "This is a system prompt"; + let messages = vec![Message::user().with_text("Hi there")]; + let pending = vec![]; + let tools = vec![]; + let target_limit = manager.provider().get_model_config().context_limit(); + + assert_eq!(target_limit, 20, "Context limit should be 20"); + // Test prepare_inference + let result = agent + .prepare_inference(&system_prompt, &tools, &messages, &pending, target_limit, &manager.provider().get_model_config().model_name, &manager.get_systems_resources().await?) + .await?; + + // Get the last message which should be the tool response containing status + let status_message = result.last().unwrap(); + let status_content = status_message + .content + .first() + .and_then(|content| content.as_tool_response_text()) + .unwrap_or_default(); + + + // Verify that "hello" is within the response, should be just under 20 tokens with "hello" + assert!(status_content.contains("hello")); + assert!(!status_content.contains("goodbye")); + + Ok(()) + } +} \ No newline at end of file diff --git a/crates/goose/src/agents/factory.rs b/crates/goose/src/agents/factory.rs index 5ab41e745..a91b095ab 100644 --- a/crates/goose/src/agents/factory.rs +++ b/crates/goose/src/agents/factory.rs @@ -58,7 +58,7 @@ impl AgentFactory { /// Get the default version name pub fn default_version() -> &'static str { - "base" + "default" } } @@ -81,45 +81,57 @@ macro_rules! register_agent { #[cfg(test)] mod tests { use super::*; - use crate::providers::base::ProviderUsage; + use crate::message::Message; use crate::providers::mock::MockProvider; + use crate::providers::base::ProviderUsage; + use crate::errors::AgentResult; use crate::systems::System; use async_trait::async_trait; + use futures::stream::BoxStream; + use serde_json::Value; use tokio::sync::Mutex; // Test agent implementation struct TestAgent { - systems: Vec>, - provider: Box, - provider_usage: Mutex>, + mcp_manager: Mutex, } impl TestAgent { fn new(provider: Box) -> Self { Self { - systems: Vec::new(), - provider, - provider_usage: Mutex::new(Vec::new()), + mcp_manager: Mutex::new(super::super::MCPManager::new(provider)), } } } #[async_trait] impl Agent for TestAgent { - fn add_system(&mut self, system: Box) { - self.systems.push(system); + async fn add_system(&mut self, system: Box) -> AgentResult<()> { + let mut manager = self.mcp_manager.lock().await; + manager.add_system(system); + Ok(()) + } + + async fn remove_system(&mut self, name: &str) -> AgentResult<()> { + let mut manager = self.mcp_manager.lock().await; + manager.remove_system(name) + } + + async fn list_systems(&self) -> AgentResult> { + let manager = self.mcp_manager.lock().await; + manager.list_systems().await } - fn get_systems(&self) -> &Vec> { - &self.systems + async fn passthrough(&self, _system: &str, _request: Value) -> AgentResult { + Ok(Value::Null) } - fn get_provider(&self) -> &Box { - &self.provider + async fn reply(&self, _messages: &[Message]) -> anyhow::Result>> { + Ok(Box::pin(futures::stream::empty())) } - fn get_provider_usage(&self) -> &Mutex> { - &self.provider_usage + async fn usage(&self) -> AgentResult> { + Ok(vec![]) } } @@ -176,18 +188,4 @@ mod tests { // Should still work, last registration wins assert!(result.is_ok()); } - - #[test] - fn test_agent_with_provider() { - register_agent!("test_provider_check", TestAgent); - - // Create a mock provider with specific configuration - let provider = Box::new(MockProvider::new(vec![])); - - // Create an agent instance - let agent = AgentFactory::create("test_provider_check", provider).unwrap(); - - // Verify the provider is correctly passed to the agent - assert_eq!(agent.get_provider().get_model_config().model_name, "mock"); - } -} +} \ No newline at end of file diff --git a/crates/goose/src/agents/mcp_manager.rs b/crates/goose/src/agents/mcp_manager.rs new file mode 100644 index 000000000..c2dab215f --- /dev/null +++ b/crates/goose/src/agents/mcp_manager.rs @@ -0,0 +1,196 @@ +use std::collections::HashMap; +use tokio::sync::Mutex; +use rust_decimal_macros::dec; + +use crate::errors::{AgentError, AgentResult}; +use crate::prompt_template::load_prompt_file; +use crate::systems::System; +use crate::providers::base::{Provider, ProviderUsage}; +use mcp_core::{Content, Resource, Tool, ToolCall}; +use serde::Serialize; + +#[derive(Clone, Debug, Serialize)] +struct SystemInfo { + name: String, + description: String, + instructions: String, +} + +impl SystemInfo { + fn new(name: &str, description: &str, instructions: &str) -> Self { + Self { + name: name.to_string(), + description: description.to_string(), + instructions: instructions.to_string(), + } + } +} + +/// Manages MCP systems and their interactions +pub struct MCPManager { + systems: Vec>, + provider: Box, + provider_usage: Mutex>, +} + +impl MCPManager { + pub fn new(provider: Box) -> Self { + Self { + systems: Vec::new(), + provider, + provider_usage: Mutex::new(Vec::new()), + } + } + + /// Get a reference to the provider + pub fn provider(&self) -> &Box { + &self.provider + } + + /// Record provider usage + pub async fn record_usage(&self, usage: ProviderUsage) { + self.provider_usage.lock().await.push(usage); + } + + /// Get aggregated usage statistics + pub async fn get_usage(&self) -> anyhow::Result> { + let provider_usage = self.provider_usage.lock().await.clone(); + let mut usage_map: HashMap = HashMap::new(); + + provider_usage.iter().for_each(|usage| { + usage_map + .entry(usage.model.clone()) + .and_modify(|e| { + e.usage.input_tokens = Some( + e.usage.input_tokens.unwrap_or(0) + usage.usage.input_tokens.unwrap_or(0), + ); + e.usage.output_tokens = Some( + e.usage.output_tokens.unwrap_or(0) + usage.usage.output_tokens.unwrap_or(0), + ); + e.usage.total_tokens = Some( + e.usage.total_tokens.unwrap_or(0) + usage.usage.total_tokens.unwrap_or(0), + ); + if e.cost.is_none() || usage.cost.is_none() { + e.cost = None; // Pricing is not available for all models + } else { + e.cost = Some(e.cost.unwrap_or(dec!(0)) + usage.cost.unwrap_or(dec!(0))); + } + }) + .or_insert_with(|| usage.clone()); + }); + Ok(usage_map.into_values().collect()) + } + + /// Add a system to the manager + pub fn add_system(&mut self, system: Box) { + self.systems.push(system); + } + + /// Remove a system by name + pub fn remove_system(&mut self, name: &str) -> AgentResult<()> { + if let Some(pos) = self.systems.iter().position(|sys| sys.name() == name) { + self.systems.remove(pos); + Ok(()) + } else { + Err(AgentError::SystemNotFound(name.to_string())) + } + } + + /// List all systems and their status + pub async fn list_systems(&self) -> AgentResult> { + let mut statuses = Vec::new(); + for system in &self.systems { + let status = system + .status() + .await + .map_err(|e| AgentError::Internal(e.to_string()))?; + statuses.push((system.name().to_string(), format!("{:?}", status))); + } + Ok(statuses) + } + + /// Get all tools from all systems with proper system prefixing + pub fn get_prefixed_tools(&self) -> Vec { + let mut tools = Vec::new(); + for system in &self.systems { + for tool in system.tools() { + tools.push(Tool::new( + format!("{}__{}", system.name(), tool.name), + &tool.description, + tool.input_schema.clone(), + )); + } + } + tools + } + + /// Get system resources and their contents + pub async fn get_systems_resources( + &self, + ) -> AgentResult>> { + let mut system_resource_content = HashMap::new(); + for system in &self.systems { + let system_status = system + .status() + .await + .map_err(|e| AgentError::Internal(e.to_string()))?; + + let mut resource_content = HashMap::new(); + for resource in system_status { + if let Ok(content) = system.read_resource(&resource.uri).await { + resource_content.insert(resource.uri.to_string(), (resource, content)); + } + } + system_resource_content.insert(system.name().to_string(), resource_content); + } + Ok(system_resource_content) + } + + /// Get the system prompt + pub fn get_system_prompt(&self) -> AgentResult { + let mut context = HashMap::new(); + let systems_info: Vec = self + .systems + .iter() + .map(|system| { + SystemInfo::new(system.name(), system.description(), system.instructions()) + }) + .collect(); + + context.insert("systems", systems_info); + load_prompt_file("system.md", &context).map_err(|e| AgentError::Internal(e.to_string())) + } + + /// Find the appropriate system for a tool call based on the prefixed name + pub fn get_system_for_tool(&self, prefixed_name: &str) -> Option<&dyn System> { + let parts: Vec<&str> = prefixed_name.split("__").collect(); + if parts.len() != 2 { + return None; + } + let system_name = parts[0]; + self.systems + .iter() + .find(|sys| sys.name() == system_name) + .map(|v| &**v) + } + + /// Dispatch a single tool call to the appropriate system + pub async fn dispatch_tool_call( + &self, + tool_call: AgentResult, + ) -> AgentResult> { + let call = tool_call?; + let system = self + .get_system_for_tool(&call.name) + .ok_or_else(|| AgentError::ToolNotFound(call.name.clone()))?; + + let tool_name = call + .name + .split("__") + .nth(1) + .ok_or_else(|| AgentError::InvalidToolName(call.name.clone()))?; + let system_tool_call = ToolCall::new(tool_name, call.arguments); + + system.call(system_tool_call).await + } +} \ No newline at end of file diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index 620435db4..5905bd186 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -1,9 +1,9 @@ mod agent; -mod base; +mod default; mod factory; -mod v1; +mod mcp_manager; pub use agent::Agent; -pub use base::BaseAgent; +pub use default::DefaultAgent; pub use factory::{register_agent, AgentFactory}; -pub use v1::AgentV1; +pub use mcp_manager::MCPManager; \ No newline at end of file diff --git a/crates/goose/src/agents/v1.rs b/crates/goose/src/agents/v1.rs deleted file mode 100644 index 73977e4e5..000000000 --- a/crates/goose/src/agents/v1.rs +++ /dev/null @@ -1,63 +0,0 @@ -use async_trait::async_trait; -use tokio::sync::Mutex; - -use super::Agent; -use crate::errors::AgentResult; -use crate::message::Message; -use crate::providers::base::{Provider, ProviderUsage}; -use crate::register_agent; -use crate::systems::System; -use mcp_core::Tool; - -/// A version of the agent that uses a more aggressive context management strategy -pub struct AgentV1 { - systems: Vec>, - provider: Box, - provider_usage: Mutex>, -} - -impl AgentV1 { - pub fn new(provider: Box) -> Self { - Self { - systems: Vec::new(), - provider, - provider_usage: Mutex::new(Vec::new()), - } - } - - pub fn add_system(&mut self, system: Box) { - self.systems.push(system); - } -} - -#[async_trait] -impl Agent for AgentV1 { - fn add_system(&mut self, system: Box) { - self.systems.push(system); - } - - fn get_systems(&self) -> &Vec> { - &self.systems - } - - fn get_provider(&self) -> &Box { - &self.provider - } - - fn get_provider_usage(&self) -> &Mutex> { - &self.provider_usage - } - - async fn prepare_inference( - &self, - _system_prompt: &str, - _tools: &[Tool], - _messages: &[Message], - _pending: &[Message], - _target_limit: usize, - ) -> AgentResult> { - todo!(); - } -} - -register_agent!("v1", AgentV1); diff --git a/crates/goose/src/errors.rs b/crates/goose/src/errors.rs index 5a3eddac3..045ef8fa7 100644 --- a/crates/goose/src/errors.rs +++ b/crates/goose/src/errors.rs @@ -19,8 +19,11 @@ pub enum AgentError { #[error("Invalid tool name: {0}")] InvalidToolName(String), + #[error("System not found: {0}")] + SystemNotFound(String), + #[error("Agent version not found: {0}")] VersionNotFound(String), } -pub type AgentResult = Result; +pub type AgentResult = Result; \ No newline at end of file