From e17a3f24468e9f98df8ea19d2c94cb35f98b3d88 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Fri, 20 Dec 2024 14:45:30 -0500 Subject: [PATCH 01/10] add instructions to InitializeRequest & Router trait --- crates/mcp-core/src/protocol.rs | 2 ++ crates/mcp-server/src/main.rs | 4 ++++ crates/mcp-server/src/router.rs | 3 +++ 3 files changed, 9 insertions(+) diff --git a/crates/mcp-core/src/protocol.rs b/crates/mcp-core/src/protocol.rs index a98d8d794..91d76888a 100644 --- a/crates/mcp-core/src/protocol.rs +++ b/crates/mcp-core/src/protocol.rs @@ -147,6 +147,8 @@ pub struct InitializeResult { pub protocol_version: String, pub capabilities: ServerCapabilities, pub server_info: Implementation, + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] diff --git a/crates/mcp-server/src/main.rs b/crates/mcp-server/src/main.rs index b8911b369..5ea232bb9 100644 --- a/crates/mcp-server/src/main.rs +++ b/crates/mcp-server/src/main.rs @@ -44,6 +44,10 @@ impl CounterRouter { } impl Router for CounterRouter { + fn instructions(&self) -> String { + "This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string() + } + fn capabilities(&self) -> ServerCapabilities { CapabilitiesBuilder::new().with_tools(true).build() } diff --git a/crates/mcp-server/src/router.rs b/crates/mcp-server/src/router.rs index d2d8a9ddf..07deca32b 100644 --- a/crates/mcp-server/src/router.rs +++ b/crates/mcp-server/src/router.rs @@ -78,6 +78,8 @@ impl CapabilitiesBuilder { } pub trait Router: Send + Sync + 'static { + // in the protocol, instructions are optional but we make it required + fn instructions(&self) -> String; fn capabilities(&self) -> ServerCapabilities; fn list_tools(&self) -> Vec; fn call_tool( @@ -113,6 +115,7 @@ pub trait Router: Send + Sync + 'static { name: "mcp-server".to_string(), version: env!("CARGO_PKG_VERSION").to_string(), }, + instructions: Some(self.instructions()), }; let mut response = self.create_response(req.id); From 90038126a1068b391a5f811b2ded715bd2bad1af Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Fri, 20 Dec 2024 16:33:48 -0500 Subject: [PATCH 02/10] add developer crate, which is a MCP server --- crates/developer/Cargo.toml | 29 ++ crates/developer/README.md | 7 + crates/developer/src/errors.rs | 69 +++ crates/developer/src/lang.rs | 34 ++ crates/developer/src/main.rs | 584 ++++++++++++++++++++++++++ crates/developer/src/process_store.rs | 163 +++++++ 6 files changed, 886 insertions(+) create mode 100644 crates/developer/Cargo.toml create mode 100644 crates/developer/README.md create mode 100644 crates/developer/src/errors.rs create mode 100644 crates/developer/src/lang.rs create mode 100644 crates/developer/src/main.rs create mode 100644 crates/developer/src/process_store.rs diff --git a/crates/developer/Cargo.toml b/crates/developer/Cargo.toml new file mode 100644 index 000000000..aefade2a7 --- /dev/null +++ b/crates/developer/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "developer" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +description.workspace = true + +[dependencies] +mcp-core = { path = "../mcp-core" } +mcp-server = { path = "../mcp-server" } +anyhow = "1.0.94" +tokio = { version = "1", features = ["full"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +tracing-appender = "0.2" +url = "2.5" +thiserror = "1.0" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +serde_urlencoded = "0.7" +lazy_static = "1.5" +kill_tree = "0.2.4" +shellexpand = "3.1.0" +indoc = "2.0.5" + +[dev-dependencies] +sysinfo = "0.32.1" diff --git a/crates/developer/README.md b/crates/developer/README.md new file mode 100644 index 000000000..e74452199 --- /dev/null +++ b/crates/developer/README.md @@ -0,0 +1,7 @@ +### Test with MCP Inspector + +```bash +npx @modelcontextprotocol/inspector cargo run -p developer +``` + +Then visit the Inspector in the browser window and test the different endpoints. diff --git a/crates/developer/src/errors.rs b/crates/developer/src/errors.rs new file mode 100644 index 000000000..00ac10b8b --- /dev/null +++ b/crates/developer/src/errors.rs @@ -0,0 +1,69 @@ +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use mcp_core::handler::{ToolError, ResourceError}; +use mcp_server::RouterError; + +#[non_exhaustive] +#[derive(Error, Debug, Clone, Deserialize, Serialize, PartialEq)] +pub enum AgentError { + #[error("Tool not found: {0}")] + ToolNotFound(String), + + #[error("The parameters to the tool call were invalid: {0}")] + InvalidParameters(String), + + #[error("The tool failed during execution with the following output: \n{0}")] + ExecutionError(String), + + #[error("Internal error: {0}")] + Internal(String), + + #[error("Invalid tool name: {0}")] + InvalidToolName(String), +} + +pub type AgentResult = Result; + + +impl From for ToolError { + fn from(err: AgentError) -> Self { + match err { + AgentError::InvalidParameters(msg) => ToolError::InvalidParameters(msg), + AgentError::InvalidToolName(msg) => ToolError::InvalidParameters(msg), + AgentError::ToolNotFound(msg) => ToolError::NotFound(msg), + AgentError::ExecutionError(msg) => ToolError::ExecutionError(msg), + AgentError::Internal(msg) => ToolError::ExecutionError(msg), + } + } +} + +impl From for ResourceError { + fn from(err: AgentError) -> Self { + match err { + AgentError::InvalidParameters(msg) => ResourceError::NotFound(msg), + _ => ResourceError::NotFound(err.to_string()), + } + } +} + + +impl From for RouterError { + fn from(err: AgentError) -> Self { + match err { + AgentError::ToolNotFound(msg) => RouterError::ToolNotFound(msg), + AgentError::InvalidParameters(msg) => RouterError::InvalidParams(msg), + AgentError::ExecutionError(msg) => RouterError::Internal(msg), + AgentError::Internal(msg) => RouterError::Internal(msg), + AgentError::InvalidToolName(msg) => RouterError::ToolNotFound(msg), + } + } +} + +impl From for AgentError { + fn from(err: ResourceError) -> Self { + match err { + ResourceError::NotFound(msg) => AgentError::InvalidParameters(format!("Resource not found: {}", msg)), + ResourceError::ExecutionError(msg) => AgentError::ExecutionError(msg), + } + } +} diff --git a/crates/developer/src/lang.rs b/crates/developer/src/lang.rs new file mode 100644 index 000000000..4d5609de2 --- /dev/null +++ b/crates/developer/src/lang.rs @@ -0,0 +1,34 @@ +use std::path::Path; + +/// Get the markdown language identifier for a file extension +pub fn get_language_identifier(path: &Path) -> &'static str { + match path.extension().and_then(|ext| ext.to_str()) { + Some("rs") => "rust", + Some("py") => "python", + Some("js") => "javascript", + Some("ts") => "typescript", + Some("json") => "json", + Some("toml") => "toml", + Some("yaml") | Some("yml") => "yaml", + Some("sh") => "bash", + Some("go") => "go", + Some("md") => "markdown", + Some("html") => "html", + Some("css") => "css", + Some("sql") => "sql", + Some("java") => "java", + Some("cpp") | Some("cc") | Some("cxx") => "cpp", + Some("c") => "c", + Some("h") | Some("hpp") => "cpp", + Some("rb") => "ruby", + Some("php") => "php", + Some("swift") => "swift", + Some("kt") | Some("kts") => "kotlin", + Some("scala") => "scala", + Some("r") => "r", + Some("m") => "matlab", + Some("pl") => "perl", + Some("dockerfile") => "dockerfile", + _ => "", + } +} diff --git a/crates/developer/src/main.rs b/crates/developer/src/main.rs new file mode 100644 index 000000000..7a8d5eb06 --- /dev/null +++ b/crates/developer/src/main.rs @@ -0,0 +1,584 @@ +mod process_store; +mod errors; +mod lang; + +use indoc::formatdoc; +use anyhow::Result; +use serde_json::{json, Value}; +use std::{collections::HashMap, future::Future, path::{Path, PathBuf}, pin::Pin}; +use tokio::process::Command; +use url::Url; + +use mcp_core::{handler::{ToolError, ResourceError}, protocol::ServerCapabilities, resource::Resource, tool::Tool}; +use mcp_server::router::{CapabilitiesBuilder, RouterService}; + +use mcp_core::role::Role; +use mcp_core::content::Content; +use crate::errors::{AgentError, AgentResult}; + +use tracing_appender::rolling::{RollingFileAppender, Rotation}; +use tracing_subscriber::{self, EnvFilter}; +use mcp_server::{ByteTransport, Router, Server}; +use std::sync::Mutex; +use std::process::Stdio; +use tokio::io::{stdin, stdout}; + +pub struct DeveloperRouter { + tools: Vec, + cwd: Mutex, + active_resources: Mutex>, + file_history: Mutex>>, + instructions: String, +} + +impl DeveloperRouter { + pub fn new() -> Self { + let bash_tool = Tool::new( + "bash".to_string(), + "Run a bash command in the shell in the current working directory".to_string(), + json!({ + "type": "object", + "required": ["command"], + "properties": { + "command": {"type": "string"} + } + }), + ); + + let text_editor_tool = Tool::new( + "text_editor".to_string(), + "Perform text editing operations on files.".to_string(), + json!({ + "type": "object", + "required": ["command", "path"], + "properties": { + "path": {"type": "string"}, + "command": { + "type": "string", + "enum": ["view", "write", "str_replace", "undo_edit"] + }, + "new_str": {"type": "string"}, + "old_str": {"type": "string"}, + "file_text": {"type": "string"} + } + }), + ); + + let instructions = "Developer instructions...".to_string(); // Reuse from original code + + let cwd = std::env::current_dir().unwrap(); + let mut resources = HashMap::new(); + let uri = format!("str:///{}", cwd.display()); + let resource = Resource::new(uri.clone(), Some("text".to_string()), Some("cwd".to_string())).unwrap(); + resources.insert(uri, resource); + + Self { + tools: vec![bash_tool, text_editor_tool], + cwd: Mutex::new(cwd), + active_resources: Mutex::new(resources), + file_history: Mutex::new(HashMap::new()), + instructions, + } + } + + // Example utility function to call the underlying logic + async fn call_bash(&self, args: Value) -> Result { + let result = self.bash(args).await; // adapt your logic from DeveloperSystem + self.map_agent_result_to_value(result) + } + + async fn call_text_editor(&self, args: Value) -> Result { + let result = self.text_editor(args).await; // adapt from DeveloperSystem + self.map_agent_result_to_value(result) + } + + // Convert AgentResult> to Result + fn map_agent_result_to_value(&self, result: AgentResult>) -> Result { + match result { + Ok(contents) => { + let messages: Vec = contents.iter().map(|c| { + json!({ + "text": c.as_text().unwrap_or(""), + "audience": c.audience(), + "priority": c.priority() + }) + }).collect(); + Ok(json!({"messages": messages})) + }, + Err(e) => Err(e.into()) + } + } + + // Helper method to resolve a path relative to cwd + fn resolve_path(&self, path_str: &str) -> AgentResult { + let cwd = self.cwd.lock().unwrap(); + let expanded = shellexpand::tilde(path_str); + let path = Path::new(expanded.as_ref()); + let resolved_path = if path.is_absolute() { + path.to_path_buf() + } else { + cwd.join(path) + }; + + Ok(resolved_path) + } + + // Implement bash tool functionality + async fn bash(&self, params: Value) -> AgentResult> { + let command = + params + .get("command") + .and_then(|v| v.as_str()) + .ok_or(AgentError::InvalidParameters( + "The command string is required".into(), + ))?; + + // Disallow commands that should use other tools + if command.trim_start().starts_with("cat") { + return Err(AgentError::InvalidParameters( + "Do not use `cat` to read files, use the view mode on the text editor tool" + .to_string(), + )); + } + // TODO consider enforcing ripgrep over find? + + // Redirect stderr to stdout to interleave outputs + let cmd_with_redirect = format!("{} 2>&1", command); + + // Execute the command + let child = Command::new("bash") + .stdout(Stdio::piped()) // These two pipes required to capture output later. + .stderr(Stdio::piped()) + .kill_on_drop(true) // Critical so that the command is killed when the agent.reply stream is interrupted. + .arg("-c") + .arg(cmd_with_redirect) + .spawn() + .map_err(|e| AgentError::ExecutionError(e.to_string()))?; + + // Store the process ID with the command as the key + let pid: Option = child.id(); + if let Some(pid) = pid { + crate::process_store::store_process(pid); + } + + // Wait for the command to complete and get output + let output = child + .wait_with_output() + .await + .map_err(|e| AgentError::ExecutionError(e.to_string()))?; + + // Remove the process ID from the store + if let Some(pid) = pid { + crate::process_store::remove_process(pid); + } + + let output_str = format!( + "Finished with Status Code: {}\nOutput:\n{}", + output.status, + String::from_utf8_lossy(&output.stdout) + ); + Ok(vec![ + Content::text(output_str.clone()).with_audience(vec![Role::Assistant]), + Content::text(output_str) + .with_audience(vec![Role::User]) + .with_priority(0.0), + ]) + } + + + + + // Implement text_editor tool functionality + async fn text_editor(&self, params: Value) -> AgentResult> { + let command = params + .get("command") + .and_then(|v| v.as_str()) + .ok_or_else(|| AgentError::InvalidParameters("Missing 'command' parameter".into()))?; + + let path_str = params + .get("path") + .and_then(|v| v.as_str()) + .ok_or_else(|| AgentError::InvalidParameters("Missing 'path' parameter".into()))?; + + let path = self.resolve_path(path_str)?; + + match command { + "view" => self.text_editor_view(&path).await, + "write" => { + let file_text = params + .get("file_text") + .and_then(|v| v.as_str()) + .ok_or_else(|| { + AgentError::InvalidParameters("Missing 'file_text' parameter".into()) + })?; + + self.text_editor_write(&path, file_text).await + } + "str_replace" => { + let old_str = params + .get("old_str") + .and_then(|v| v.as_str()) + .ok_or_else(|| { + AgentError::InvalidParameters("Missing 'old_str' parameter".into()) + })?; + let new_str = params + .get("new_str") + .and_then(|v| v.as_str()) + .ok_or_else(|| { + AgentError::InvalidParameters("Missing 'new_str' parameter".into()) + })?; + + self.text_editor_replace(&path, old_str, new_str).await + } + "undo_edit" => self.text_editor_undo(&path).await, + _ => Err(AgentError::InvalidParameters(format!( + "Unknown command '{}'", + command + ))), + } + } + + async fn text_editor_view(&self, path: &PathBuf) -> AgentResult> { + if path.is_file() { + // Check file size first (2MB limit) + const MAX_FILE_SIZE: u64 = 2 * 1024 * 1024; // 2MB in bytes + const MAX_CHAR_COUNT: usize = 1 << 20; // 2^20 characters (1,048,576) + + let file_size = std::fs::metadata(path) + .map_err(|e| { + AgentError::ExecutionError(format!("Failed to get file metadata: {}", e)) + })? + .len(); + + if file_size > MAX_FILE_SIZE { + return Err(AgentError::ExecutionError(format!( + "File '{}' is too large ({:.2}MB). Maximum size is 2MB to prevent memory issues.", + path.display(), + file_size as f64 / 1024.0 / 1024.0 + ))); + } + + // Create a new resource and add it to active_resources + let uri = Url::from_file_path(path) + .map_err(|_| AgentError::ExecutionError("Invalid file path".into()))? + .to_string(); + + // Read the content once + let content = std::fs::read_to_string(path) + .map_err(|e| AgentError::ExecutionError(format!("Failed to read file: {}", e)))?; + + let char_count = content.chars().count(); + if char_count > MAX_CHAR_COUNT { + return Err(AgentError::ExecutionError(format!( + "File '{}' has too many characters ({}). Maximum character count is {}.", + path.display(), + char_count, + MAX_CHAR_COUNT + ))); + } + + // Create and store the resource + let resource = + Resource::new(uri.clone(), Some("text".to_string()), None).map_err(|e| { + AgentError::ExecutionError(format!("Failed to create resource: {}", e)) + })?; + + self.active_resources.lock().unwrap().insert(uri, resource); + + let language = lang::get_language_identifier(path); + let formatted = formatdoc! {" + ### {path} + ```{language} + {content} + ``` + ", + path=path.display(), + language=language, + content=content, + }; + + // The LLM gets just a quick update as we expect the file to view in the status + // but we send a low priority message for the human + Ok(vec![ + Content::text(format!( + "The file content for {} is now available in the system status.", + path.display() + )) + .with_audience(vec![Role::Assistant]), + Content::text(formatted) + .with_audience(vec![Role::User]) + .with_priority(0.0), + ]) + } else { + Err(AgentError::ExecutionError(format!( + "The path '{}' does not exist or is not a file.", + path.display() + ))) + } + } + + async fn text_editor_write( + &self, + path: &PathBuf, + file_text: &str, + ) -> AgentResult> { + // Get the URI for the file + let uri = Url::from_file_path(path) + .map_err(|_| AgentError::ExecutionError("Invalid file path".into()))? + .to_string(); + + // Check if file already exists and is active + if path.exists() && !self.active_resources.lock().unwrap().contains_key(&uri) { + return Err(AgentError::InvalidParameters(format!( + "File '{}' exists but is not active. View it first before overwriting.", + path.display() + ))); + } + + // Save history for undo + self.save_file_history(path)?; + + // Write to the file + std::fs::write(path, file_text) + .map_err(|e| AgentError::ExecutionError(format!("Failed to write file: {}", e)))?; + + // Create and store resource + + let resource = Resource::new(uri.clone(), Some("text".to_string()), None) + .map_err(|e| AgentError::ExecutionError(e.to_string()))?; + self.active_resources.lock().unwrap().insert(uri, resource); + + // Try to detect the language from the file extension + let language = path.extension().and_then(|ext| ext.to_str()).unwrap_or(""); + + Ok(vec![ + Content::text(format!("Successfully wrote to {}", path.display())) + .with_audience(vec![Role::Assistant]), + Content::text(formatdoc! {r#" + ### {path} + ```{language} + {content} + ``` + "#, + path=path.display(), + language=language, + content=file_text, + }) + .with_audience(vec![Role::User]) + .with_priority(0.2), + ]) + } + + async fn text_editor_replace( + &self, + path: &PathBuf, + old_str: &str, + new_str: &str, + ) -> AgentResult> { + // Get the URI for the file + let uri = Url::from_file_path(path) + .map_err(|_| AgentError::ExecutionError("Invalid file path".into()))? + .to_string(); + + // Check if file exists and is active + if !path.exists() { + return Err(AgentError::InvalidParameters(format!( + "File '{}' does not exist", + path.display() + ))); + } + if !self.active_resources.lock().unwrap().contains_key(&uri) { + return Err(AgentError::InvalidParameters(format!( + "You must view '{}' before editing it", + path.display() + ))); + } + + // Read content + let content = std::fs::read_to_string(path) + .map_err(|e| AgentError::ExecutionError(format!("Failed to read file: {}", e)))?; + + // Ensure 'old_str' appears exactly once + if content.matches(old_str).count() > 1 { + return Err(AgentError::InvalidParameters( + "'old_str' must appear exactly once in the file, but it appears multiple times" + .into(), + )); + } + if content.matches(old_str).count() == 0 { + return Err(AgentError::InvalidParameters( + "'old_str' must appear exactly once in the file, but it does not appear in the file. Make sure the string exactly matches existing file content, including spacing.".into(), + )); + } + + // Save history for undo + self.save_file_history(path)?; + + // Replace and write back + let new_content = content.replace(old_str, new_str); + std::fs::write(path, &new_content) + .map_err(|e| AgentError::ExecutionError(format!("Failed to write file: {}", e)))?; + + // Update resource + if let Some(resource) = self.active_resources.lock().unwrap().get_mut(&uri) { + resource.update_timestamp(); + } + + // Try to detect the language from the file extension + let language = path.extension().and_then(|ext| ext.to_str()).unwrap_or(""); + + Ok(vec![ + Content::text("Successfully replaced text").with_audience(vec![Role::Assistant]), + Content::text(formatdoc! {r#" + ### {path} + + *Before*: + ```{language} + {old_str} + ``` + + *After*: + ```{language} + {new_str} + ``` + "#, + path=path.display(), + language=language, + old_str=old_str, + new_str=new_str, + }) + .with_audience(vec![Role::User]) + .with_priority(0.2), + ]) + } + + async fn text_editor_undo(&self, path: &PathBuf) -> AgentResult> { + let mut history = self.file_history.lock().unwrap(); + if let Some(contents) = history.get_mut(path) { + if let Some(previous_content) = contents.pop() { + // Write previous content back to file + std::fs::write(path, previous_content).map_err(|e| { + AgentError::ExecutionError(format!("Failed to write file: {}", e)) + })?; + Ok(vec![Content::text("Undid the last edit")]) + } else { + Err(AgentError::InvalidParameters( + "No edit history available to undo".into(), + )) + } + } else { + Err(AgentError::InvalidParameters( + "No edit history available to undo".into(), + )) + } + } + + fn save_file_history(&self, path: &PathBuf) -> AgentResult<()> { + let mut history = self.file_history.lock().unwrap(); + let content = if path.exists() { + std::fs::read_to_string(path) + .map_err(|e| AgentError::ExecutionError(format!("Failed to read file: {}", e)))? + } else { + String::new() + }; + history.entry(path.clone()).or_default().push(content); + Ok(()) + } + + + async fn do_read_resource(&self, uri: &str) -> AgentResult { + let content = self.read_resource(uri).await.map_err(AgentError::from)?; + Ok(content) + } +} + +impl Router for DeveloperRouter { + fn instructions(&self) -> String { + self.instructions.clone() + } + + fn capabilities(&self) -> ServerCapabilities { + CapabilitiesBuilder::new().with_tools(true).build() + } + + fn list_tools(&self) -> Vec { + self.tools.clone() + } + + fn call_tool( + &self, + tool_name: &str, + arguments: Value, + ) -> Pin> + Send + 'static>> { + let this = self.clone(); + let tool_name = tool_name.to_string(); + Box::pin(async move { + match tool_name.as_str() { + "bash" => this.call_bash(arguments).await, + "text_editor" => this.call_text_editor(arguments).await, + _ => Err(ToolError::NotFound(format!("Tool {} not found", tool_name))), + } + }) + } + + fn list_resources(&self) -> Vec { + self.active_resources.lock().unwrap().values().cloned().collect() + } + + fn read_resource( + &self, + uri: &str, + ) -> Pin> + Send + 'static>> { + let this = self.clone(); + let uri = uri.to_string(); + Box::pin(async move { + match this.do_read_resource(&uri).await { + Ok(content) => Ok(content), + Err(e) => Err(e.into()), + } + }) + } +} + +impl Clone for DeveloperRouter { + fn clone(&self) -> Self { + Self { + tools: self.tools.clone(), + cwd: Mutex::new(self.cwd.lock().unwrap().clone()), + active_resources: Mutex::new(self.active_resources.lock().unwrap().clone()), + file_history: Mutex::new(self.file_history.lock().unwrap().clone()), + instructions: self.instructions.clone(), + } + } +} + + + + +#[tokio::main] +async fn main() -> Result<()> { + // Set up file appender for logging + let file_appender = RollingFileAppender::new(Rotation::DAILY, "logs", "mcp-server.log"); + + // Initialize the tracing subscriber with file and stdout logging + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::INFO.into())) + .with_writer(file_appender) + .with_target(false) + .with_thread_ids(true) + .with_file(true) + .with_line_number(true) + .init(); + + tracing::info!("Starting MCP server"); + + // Create an instance of our counter router + let router = RouterService(DeveloperRouter::new()); + + // Create and run the server + let server = Server::new(router); + let transport = ByteTransport::new(stdin(), stdout()); + + tracing::info!("Server initialized and ready to handle requests"); + Ok(server.run(transport).await?) +} diff --git a/crates/developer/src/process_store.rs b/crates/developer/src/process_store.rs new file mode 100644 index 000000000..efa38629a --- /dev/null +++ b/crates/developer/src/process_store.rs @@ -0,0 +1,163 @@ +use kill_tree::{blocking::kill_tree_with_config, Config}; +use lazy_static::lazy_static; +use std::sync::Mutex; + +// Singleton that will store process IDs for spawned child processes implementing agent tasks. +lazy_static! { + static ref PROCESS_STORE: Mutex> = Mutex::new(Vec::new()); +} + +pub fn store_process(pid: u32) { + let mut store = PROCESS_STORE.lock().unwrap(); + store.push(pid); +} + +// This removes the record of a process from the store, it does not kill it or check that it is dead. +pub fn remove_process(pid: u32) -> bool { + let mut store = PROCESS_STORE.lock().unwrap(); + if let Some(index) = store.iter().position(|&x| x == pid) { + store.remove(index); + true + } else { + false + } +} + +/// Kill all stored processes +pub fn kill_processes() { + let mut killed_processes = Vec::new(); + { + let store = PROCESS_STORE.lock().unwrap(); + for &pid in store.iter() { + let config = Config { + signal: "SIGKILL".to_string(), + ..Default::default() + }; + let outputs = match kill_tree_with_config(pid, &config) { + Ok(outputs) => outputs, + Err(e) => { + eprintln!("Failed to kill process {}: {}", pid, e); + continue; + } + }; + for output in outputs { + match output { + kill_tree::Output::Killed { process_id, .. } => { + killed_processes.push(process_id); + } + kill_tree::Output::MaybeAlreadyTerminated { process_id, .. } => { + killed_processes.push(process_id); + } + } + } + } + } + // Clean up the store + for pid in killed_processes { + remove_process(pid); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::os::unix::fs::PermissionsExt; + use std::time::Duration; + use std::{fs, thread}; + use sysinfo::{Pid, ProcessesToUpdate, System}; + use tokio::process::Command; + + #[tokio::test] + async fn test_kill_processes_with_children() { + // Create a temporary script that spawns a child process + let temp_dir = std::env::temp_dir(); + let script_path = temp_dir.join("test_script.sh"); + let script_content = r#"#!/bin/bash + # Sleep in the parent process + sleep 300 + "#; + + fs::write(&script_path, script_content).unwrap(); + fs::set_permissions(&script_path, std::fs::Permissions::from_mode(0o755)).unwrap(); + + // Start the parent process which will spawn a child + let parent = Command::new("bash") + .arg("-c") + .arg(script_path.to_str().unwrap()) + .spawn() + .expect("Failed to start parent process"); + + let parent_pid = parent.id().unwrap() as u32; + + // Store the parent process ID + store_process(parent_pid); + + let mut attempt = 0; + let mut child_pids: Vec; + loop { + attempt += 1; + if attempt > 100 { + panic!("Failed to start processes for process kill test"); + } + + // Give processes time to start + thread::sleep(Duration::from_millis(10)); + + // Get the child process ID using pgrep + let child_pids_cmd = Command::new("pgrep") + .arg("-P") + .arg(parent_pid.to_string()) + .output() + .await + .expect("Failed to get child PIDs"); + + let child_pid_str = String::from_utf8_lossy(&child_pids_cmd.stdout); + child_pids = child_pid_str + .lines() + .filter_map(|line| line.trim().parse::().ok()) + .collect::>(); + if child_pids.len() != 1 { + print!("Waiting for 1 child_pids. Got {:?}", child_pids.len()); + continue; + } + + if is_process_running(parent_pid).await && is_process_running(child_pids[0]).await { + break; + } + } + + kill_processes(); + + // Wait until processes are killed + let mut attempts = 0; + while attempts < 100 { + if !is_process_running(parent_pid).await && !is_process_running(child_pids[0]).await { + break; + } + thread::sleep(Duration::from_millis(10)); + attempts += 1; + } + + // Verify processes are dead + assert!(!is_process_running(parent_pid).await); + assert!(!is_process_running(child_pids[0]).await); + + // Clean up the temporary script + fs::remove_file(script_path).unwrap(); + } + + async fn is_process_running(pid: u32) -> bool { + let mut system = System::new_all(); + system.refresh_processes(ProcessesToUpdate::All, true); + + match system.process(Pid::from_u32(pid)) { + Some(process) => !matches!( + process.status(), + sysinfo::ProcessStatus::Stop + | sysinfo::ProcessStatus::Zombie + | sysinfo::ProcessStatus::Dead + ), + None => false, + } + } +} From a712d659dc51f49c5d080f2ed1bfaf54305ae707 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Fri, 20 Dec 2024 16:38:36 -0500 Subject: [PATCH 03/10] fmt --- crates/developer/src/errors.rs | 10 ++--- crates/developer/src/main.rs | 73 +++++++++++++++++++++------------- 2 files changed, 51 insertions(+), 32 deletions(-) diff --git a/crates/developer/src/errors.rs b/crates/developer/src/errors.rs index 00ac10b8b..6025e626c 100644 --- a/crates/developer/src/errors.rs +++ b/crates/developer/src/errors.rs @@ -1,7 +1,7 @@ +use mcp_core::handler::{ResourceError, ToolError}; +use mcp_server::RouterError; use serde::{Deserialize, Serialize}; use thiserror::Error; -use mcp_core::handler::{ToolError, ResourceError}; -use mcp_server::RouterError; #[non_exhaustive] #[derive(Error, Debug, Clone, Deserialize, Serialize, PartialEq)] @@ -24,7 +24,6 @@ pub enum AgentError { pub type AgentResult = Result; - impl From for ToolError { fn from(err: AgentError) -> Self { match err { @@ -46,7 +45,6 @@ impl From for ResourceError { } } - impl From for RouterError { fn from(err: AgentError) -> Self { match err { @@ -62,7 +60,9 @@ impl From for RouterError { impl From for AgentError { fn from(err: ResourceError) -> Self { match err { - ResourceError::NotFound(msg) => AgentError::InvalidParameters(format!("Resource not found: {}", msg)), + ResourceError::NotFound(msg) => { + AgentError::InvalidParameters(format!("Resource not found: {}", msg)) + } ResourceError::ExecutionError(msg) => AgentError::ExecutionError(msg), } } diff --git a/crates/developer/src/main.rs b/crates/developer/src/main.rs index 7a8d5eb06..8b44d33d4 100644 --- a/crates/developer/src/main.rs +++ b/crates/developer/src/main.rs @@ -1,27 +1,37 @@ -mod process_store; mod errors; mod lang; +mod process_store; -use indoc::formatdoc; use anyhow::Result; +use indoc::formatdoc; use serde_json::{json, Value}; -use std::{collections::HashMap, future::Future, path::{Path, PathBuf}, pin::Pin}; +use std::{ + collections::HashMap, + future::Future, + path::{Path, PathBuf}, + pin::Pin, +}; use tokio::process::Command; use url::Url; -use mcp_core::{handler::{ToolError, ResourceError}, protocol::ServerCapabilities, resource::Resource, tool::Tool}; +use mcp_core::{ + handler::{ResourceError, ToolError}, + protocol::ServerCapabilities, + resource::Resource, + tool::Tool, +}; use mcp_server::router::{CapabilitiesBuilder, RouterService}; -use mcp_core::role::Role; -use mcp_core::content::Content; use crate::errors::{AgentError, AgentResult}; +use mcp_core::content::Content; +use mcp_core::role::Role; -use tracing_appender::rolling::{RollingFileAppender, Rotation}; -use tracing_subscriber::{self, EnvFilter}; use mcp_server::{ByteTransport, Router, Server}; -use std::sync::Mutex; use std::process::Stdio; +use std::sync::Mutex; use tokio::io::{stdin, stdout}; +use tracing_appender::rolling::{RollingFileAppender, Rotation}; +use tracing_subscriber::{self, EnvFilter}; pub struct DeveloperRouter { tools: Vec, @@ -69,7 +79,12 @@ impl DeveloperRouter { let cwd = std::env::current_dir().unwrap(); let mut resources = HashMap::new(); let uri = format!("str:///{}", cwd.display()); - let resource = Resource::new(uri.clone(), Some("text".to_string()), Some("cwd".to_string())).unwrap(); + let resource = Resource::new( + uri.clone(), + Some("text".to_string()), + Some("cwd".to_string()), + ) + .unwrap(); resources.insert(uri, resource); Self { @@ -93,19 +108,25 @@ impl DeveloperRouter { } // Convert AgentResult> to Result - fn map_agent_result_to_value(&self, result: AgentResult>) -> Result { + fn map_agent_result_to_value( + &self, + result: AgentResult>, + ) -> Result { match result { Ok(contents) => { - let messages: Vec = contents.iter().map(|c| { - json!({ - "text": c.as_text().unwrap_or(""), - "audience": c.audience(), - "priority": c.priority() + let messages: Vec = contents + .iter() + .map(|c| { + json!({ + "text": c.as_text().unwrap_or(""), + "audience": c.audience(), + "priority": c.priority() + }) }) - }).collect(); + .collect(); Ok(json!({"messages": messages})) - }, - Err(e) => Err(e.into()) + } + Err(e) => Err(e.into()), } } @@ -185,9 +206,6 @@ impl DeveloperRouter { ]) } - - - // Implement text_editor tool functionality async fn text_editor(&self, params: Value) -> AgentResult> { let command = params @@ -485,7 +503,6 @@ impl DeveloperRouter { Ok(()) } - async fn do_read_resource(&self, uri: &str) -> AgentResult { let content = self.read_resource(uri).await.map_err(AgentError::from)?; Ok(content) @@ -522,7 +539,12 @@ impl Router for DeveloperRouter { } fn list_resources(&self) -> Vec { - self.active_resources.lock().unwrap().values().cloned().collect() + self.active_resources + .lock() + .unwrap() + .values() + .cloned() + .collect() } fn read_resource( @@ -552,9 +574,6 @@ impl Clone for DeveloperRouter { } } - - - #[tokio::main] async fn main() -> Result<()> { // Set up file appender for logging From f7a438682549851d9607af656c9c216a3f0cc52d Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Mon, 23 Dec 2024 11:41:01 -0500 Subject: [PATCH 04/10] fix bug in read_resource_internal for developer --- crates/developer/Cargo.toml | 2 + crates/developer/src/main.rs | 82 +++++++++++++++++-- crates/mcp-client/examples/sse.rs | 12 ++- .../mcp-client/examples/stdio_integration.rs | 8 ++ crates/mcp-server/src/main.rs | 18 ++-- 5 files changed, 108 insertions(+), 14 deletions(-) diff --git a/crates/developer/Cargo.toml b/crates/developer/Cargo.toml index aefade2a7..89458267e 100644 --- a/crates/developer/Cargo.toml +++ b/crates/developer/Cargo.toml @@ -16,6 +16,8 @@ tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-appender = "0.2" url = "2.5" +urlencoding = "2.1.3" +base64 = "0.21" thiserror = "1.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/crates/developer/src/main.rs b/crates/developer/src/main.rs index 8b44d33d4..140086854 100644 --- a/crates/developer/src/main.rs +++ b/crates/developer/src/main.rs @@ -3,10 +3,12 @@ mod lang; mod process_store; use anyhow::Result; +use base64::Engine; use indoc::formatdoc; use serde_json::{json, Value}; use std::{ collections::HashMap, + fs, future::Future, path::{Path, PathBuf}, pin::Pin, @@ -30,6 +32,7 @@ use mcp_server::{ByteTransport, Router, Server}; use std::process::Stdio; use std::sync::Mutex; use tokio::io::{stdin, stdout}; +use tracing::info; use tracing_appender::rolling::{RollingFileAppender, Rotation}; use tracing_subscriber::{self, EnvFilter}; @@ -503,9 +506,72 @@ impl DeveloperRouter { Ok(()) } - async fn do_read_resource(&self, uri: &str) -> AgentResult { - let content = self.read_resource(uri).await.map_err(AgentError::from)?; - Ok(content) + async fn read_resource_internal(&self, uri: &str) -> AgentResult { + // Ensure the resource exists in the active resources map + let active_resources = self.active_resources.lock().unwrap(); + let resource = active_resources + .get(uri) + .ok_or_else(|| AgentError::ToolNotFound(format!("Resource '{}' not found", uri)))?; + + let url = Url::parse(uri) + .map_err(|e| AgentError::InvalidParameters(format!("Invalid URI: {}", e)))?; + + // Read content based on scheme and mime_type + match url.scheme() { + "file" => { + let path = url.to_file_path().map_err(|_| { + AgentError::InvalidParameters("Invalid file path in URI".into()) + })?; + + // Ensure file exists + if !path.exists() { + return Err(AgentError::ExecutionError(format!( + "File does not exist: {}", + path.display() + ))); + } + + match resource.mime_type.as_str() { + "text" => { + // Read the file as UTF-8 text + fs::read_to_string(&path).map_err(|e| { + AgentError::ExecutionError(format!("Failed to read file: {}", e)) + }) + } + "blob" => { + // Read as bytes, base64 encode + let bytes = fs::read(&path).map_err(|e| { + AgentError::ExecutionError(format!("Failed to read file: {}", e)) + })?; + Ok(base64::prelude::BASE64_STANDARD.encode(bytes)) + } + mime_type => Err(AgentError::InvalidParameters(format!( + "Unsupported mime type: {}", + mime_type + ))), + } + } + "str" => { + // For str:// URIs, we only support text + if resource.mime_type != "text" { + return Err(AgentError::InvalidParameters(format!( + "str:// URI only supports text mime type, got {}", + resource.mime_type + ))); + } + + // The `Url::path()` gives us the portion after `str:///` + let content_encoded = url.path().trim_start_matches('/'); + let decoded = urlencoding::decode(content_encoded).map_err(|e| { + AgentError::ExecutionError(format!("Failed to decode str:// content: {}", e)) + })?; + Ok(decoded.into_owned()) + } + scheme => Err(AgentError::InvalidParameters(format!( + "Unsupported URI scheme: {}", + scheme + ))), + } } } @@ -539,12 +605,15 @@ impl Router for DeveloperRouter { } fn list_resources(&self) -> Vec { - self.active_resources + let resources = self + .active_resources .lock() .unwrap() .values() .cloned() - .collect() + .collect(); + info!("Listing resources: {:?}", resources); + resources } fn read_resource( @@ -553,8 +622,9 @@ impl Router for DeveloperRouter { ) -> Pin> + Send + 'static>> { let this = self.clone(); let uri = uri.to_string(); + info!("Reading resource: {}", uri); Box::pin(async move { - match this.do_read_resource(&uri).await { + match this.read_resource_internal(&uri).await { Ok(content) => Ok(content), Err(e) => Err(e.into()), } diff --git a/crates/mcp-client/examples/sse.rs b/crates/mcp-client/examples/sse.rs index 961aecafb..86826e690 100644 --- a/crates/mcp-client/examples/sse.rs +++ b/crates/mcp-client/examples/sse.rs @@ -12,7 +12,7 @@ async fn main() -> Result<()> { .with_env_filter( EnvFilter::from_default_env() .add_directive("mcp_client=debug".parse().unwrap()) - .add_directive("eventsource_client=debug".parse().unwrap()), + .add_directive("eventsource_client=info".parse().unwrap()), ) .init(); @@ -53,7 +53,15 @@ async fn main() -> Result<()> { serde_json::json!({ "message": "Client with SSE transport - calling a tool" }), ) .await?; - println!("Tool result: {tool_result:?}"); + println!("Tool result: {tool_result:?}\n"); + + // List resources + let resources = client.list_resources().await?; + println!("Resources: {resources:?}\n"); + + // // Read resource + // let resource = client.read_resource("echo://fixedresource").await?; + // println!("Resource: {resource:?}\n"); Ok(()) } diff --git a/crates/mcp-client/examples/stdio_integration.rs b/crates/mcp-client/examples/stdio_integration.rs index 582267113..544337af6 100644 --- a/crates/mcp-client/examples/stdio_integration.rs +++ b/crates/mcp-client/examples/stdio_integration.rs @@ -69,5 +69,13 @@ async fn main() -> Result<(), ClientError> { let get_value_result = client.call_tool("get_value", serde_json::json!({})).await?; println!("Tool result for 'get_value': {get_value_result:?}\n"); + // List resources + let resources = client.list_resources().await?; + println!("Resources: {resources:?}\n"); + + // Read resource + let resource = client.read_resource("memo://insights").await?; + println!("Resource: {resource:?}\n"); + Ok(()) } diff --git a/crates/mcp-server/src/main.rs b/crates/mcp-server/src/main.rs index 5ea232bb9..6282aff1a 100644 --- a/crates/mcp-server/src/main.rs +++ b/crates/mcp-server/src/main.rs @@ -41,6 +41,10 @@ impl CounterRouter { let counter = self.counter.lock().await; Ok(*counter) } + + fn _create_resource_text(&self, uri: &str, name: &str) -> Resource { + Resource::new(uri, Some("text/plain".to_string()), Some(name.to_string())).unwrap() + } } impl Router for CounterRouter { @@ -112,12 +116,10 @@ impl Router for CounterRouter { } fn list_resources(&self) -> Vec { - vec![Resource::new( - "memo://insights", - Some("text/plain".to_string()), - Some("memo-resource".to_string()), - ) - .unwrap()] + vec![ + self._create_resource_text("str:////Users/to/some/path/", "cwd"), + self._create_resource_text("memo://insights", "memo-name"), + ] } fn read_resource( @@ -127,6 +129,10 @@ impl Router for CounterRouter { let uri = uri.to_string(); Box::pin(async move { match uri.as_str() { + "str:////Users/to/some/path/" => { + let cwd = "/Users/to/some/path/"; + Ok(cwd.to_string()) + } "memo://insights" => { let memo = "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ..."; From f77a9cafa4e78052534b28e3317ed2293ccaf32e Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Mon, 23 Dec 2024 12:23:53 -0500 Subject: [PATCH 05/10] Share cwd, active_resources, and file_history in DeveloperRouter with an Arc --- crates/developer/src/main.rs | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/crates/developer/src/main.rs b/crates/developer/src/main.rs index 140086854..2b9b688ae 100644 --- a/crates/developer/src/main.rs +++ b/crates/developer/src/main.rs @@ -30,7 +30,7 @@ use mcp_core::role::Role; use mcp_server::{ByteTransport, Router, Server}; use std::process::Stdio; -use std::sync::Mutex; +use std::sync::{Arc, Mutex}; use tokio::io::{stdin, stdout}; use tracing::info; use tracing_appender::rolling::{RollingFileAppender, Rotation}; @@ -38,9 +38,11 @@ use tracing_subscriber::{self, EnvFilter}; pub struct DeveloperRouter { tools: Vec, - cwd: Mutex, - active_resources: Mutex>, - file_history: Mutex>>, + // The cwd, active_resources, and file_history are shared across threads + // so we need to use an Arc to ensure thread safety + cwd: Arc>, + active_resources: Arc>>, + file_history: Arc>>>, instructions: String, } @@ -92,9 +94,9 @@ impl DeveloperRouter { Self { tools: vec![bash_tool, text_editor_tool], - cwd: Mutex::new(cwd), - active_resources: Mutex::new(resources), - file_history: Mutex::new(HashMap::new()), + cwd: Arc::new(Mutex::new(cwd)), + active_resources: Arc::new(Mutex::new(resources)), + file_history: Arc::new(Mutex::new(HashMap::new())), instructions, } } @@ -636,9 +638,9 @@ impl Clone for DeveloperRouter { fn clone(&self) -> Self { Self { tools: self.tools.clone(), - cwd: Mutex::new(self.cwd.lock().unwrap().clone()), - active_resources: Mutex::new(self.active_resources.lock().unwrap().clone()), - file_history: Mutex::new(self.file_history.lock().unwrap().clone()), + cwd: Arc::clone(&self.cwd), + active_resources: Arc::clone(&self.active_resources), + file_history: Arc::clone(&self.file_history), instructions: self.instructions.clone(), } } From afd059f6ca719d62285602cdf64c1f09ed4d9066 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Mon, 23 Dec 2024 12:52:52 -0500 Subject: [PATCH 06/10] add unit tests for developer crate --- crates/developer/Cargo.toml | 1 + crates/developer/src/main.rs | 434 +++++++++++++++++++++++++++++++++++ 2 files changed, 435 insertions(+) diff --git a/crates/developer/Cargo.toml b/crates/developer/Cargo.toml index 89458267e..5f7a40c24 100644 --- a/crates/developer/Cargo.toml +++ b/crates/developer/Cargo.toml @@ -29,3 +29,4 @@ indoc = "2.0.5" [dev-dependencies] sysinfo = "0.32.1" +tempfile = "3.8" diff --git a/crates/developer/src/main.rs b/crates/developer/src/main.rs index 2b9b688ae..10a64a62f 100644 --- a/crates/developer/src/main.rs +++ b/crates/developer/src/main.rs @@ -673,3 +673,437 @@ async fn main() -> Result<()> { tracing::info!("Server initialized and ready to handle requests"); Ok(server.run(transport).await?) } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + use tokio::sync::OnceCell; + + static DEV_ROUTER: OnceCell = OnceCell::const_new(); + + async fn get_router() -> &'static DeveloperRouter { + DEV_ROUTER + .get_or_init(|| async { DeveloperRouter::new() }) + .await + } + + #[tokio::test] + async fn test_bash_missing_parameters() { + let router = get_router().await; + let result = router.call_tool("bash", json!({})).await; + + assert!(result.is_err()); + let err = result.err().unwrap(); + assert!(matches!(err, ToolError::InvalidParameters(_))); + } + + #[tokio::test] + async fn test_bash_change_directory() { + let router = get_router().await; + let result = router + .call_tool("bash", json!({ "working_dir": ".", "command": "pwd" })) + .await; + assert!(result.is_ok()); + let output = result.unwrap(); + + // Check that the output contains the current directory + assert!(output.get("messages").unwrap().as_array().unwrap().len() > 0); + let messages = output.get("messages").unwrap().as_array().unwrap(); + let message = messages.first().unwrap(); + let text = message.get("text").unwrap().as_str().unwrap(); + assert!(text.contains(&std::env::current_dir().unwrap().display().to_string())); + } + + #[tokio::test] + async fn test_bash_invalid_directory() { + let router = get_router().await; + let result = router + .call_tool("bash", json!({ "working_dir": "non_existent_dir" })) + .await; + assert!(result.is_err()); + let err = result.err().unwrap(); + assert!(matches!(err, ToolError::InvalidParameters(_))); + } + + #[tokio::test] + async fn test_text_editor_size_limits() { + let router = get_router().await; + let temp_dir = tempfile::tempdir().unwrap(); + + // Test file size limit + { + let large_file_path = temp_dir.path().join("large.txt"); + let large_file_str = large_file_path.to_str().unwrap(); + + // Create a file larger than 2MB + let content = "x".repeat(3 * 1024 * 1024); // 3MB + std::fs::write(&large_file_path, content).unwrap(); + + let result = router + .call_tool( + "text_editor", + json!({ + "command": "view", + "path": large_file_str + }), + ) + .await; + + assert!(result.is_err()); + let err = result.err().unwrap(); + assert!(matches!(err, ToolError::ExecutionError(_))); + assert!(err.to_string().contains("too large")); + } + + // Test character count limit + { + let many_chars_path = temp_dir.path().join("many_chars.txt"); + let many_chars_str = many_chars_path.to_str().unwrap(); + + // Create a file with more than 2^20 characters but less than 2MB + let content = "x".repeat((1 << 20) + 1); // 2^20 + 1 characters + std::fs::write(&many_chars_path, content).unwrap(); + + let result = router + .call_tool( + "text_editor", + json!({ + "command": "view", + "path": many_chars_str + }), + ) + .await; + + assert!(result.is_err()); + let err = result.err().unwrap(); + assert!(matches!(err, ToolError::ExecutionError(_))); + assert!(err.to_string().contains("too many characters")); + } + + temp_dir.close().unwrap(); + } + + #[tokio::test] + async fn test_text_editor_write_and_view_file() { + let router = get_router().await; + + let temp_dir = tempfile::tempdir().unwrap(); + let file_path = temp_dir.path().join("test.txt"); + let file_path_str = file_path.to_str().unwrap(); + + // Create a new file + router + .call_tool( + "text_editor", + json!({ + "command": "write", + "path": file_path_str, + "file_text": "Hello, world!" + }), + ) + .await + .unwrap(); + + // View the file + let view_result = router + .call_tool( + "text_editor", + json!({ + "command": "view", + "path": file_path_str + }), + ) + .await + .unwrap(); + + assert!( + view_result + .get("messages") + .unwrap() + .as_array() + .unwrap() + .len() + > 0 + ); + let messages = view_result.get("messages").unwrap().as_array().unwrap(); + let message = messages.first().unwrap(); + let text = message.get("text").unwrap().as_str().unwrap(); + assert!(text.contains("The file content for")); + + temp_dir.close().unwrap(); + } + + #[tokio::test] + async fn test_text_editor_str_replace() { + let router = get_router().await; + + let temp_dir = tempfile::tempdir().unwrap(); + let file_path = temp_dir.path().join("test.txt"); + let file_path_str = file_path.to_str().unwrap(); + + // Create a new file + router + .call_tool( + "text_editor", + json!({ + "command": "write", + "path": file_path_str, + "file_text": "Hello, world!" + }), + ) + .await + .unwrap(); + + // View the file to make it active + router + .call_tool( + "text_editor", + json!({ + "command": "view", + "path": file_path_str + }), + ) + .await + .unwrap(); + + // Replace string + let replace_result = router + .call_tool( + "text_editor", + json!({ + "command": "str_replace", + "path": file_path_str, + "old_str": "world", + "new_str": "Rust" + }), + ) + .await + .unwrap(); + + let messages = replace_result.get("messages").unwrap().as_array().unwrap(); + let message = messages.first().unwrap(); + assert!(message + .get("text") + .unwrap() + .as_str() + .unwrap() + .contains("Successfully replaced text")); + + // View the file again + let view_result = router + .call_tool( + "text_editor", + json!({ + "command": "view", + "path": file_path_str + }), + ) + .await + .unwrap(); + + let messages = view_result.get("messages").unwrap().as_array().unwrap(); + let message = messages.first().unwrap(); + assert!(message + .get("text") + .unwrap() + .as_str() + .unwrap() + .contains("The file content for")); + + temp_dir.close().unwrap(); + } + + #[tokio::test] + async fn test_read_resource() { + let router = get_router().await; + + let temp_dir = tempfile::tempdir().unwrap(); + let file_path = temp_dir.path().join("test.txt"); + let test_content = "Hello, world!"; + std::fs::write(&file_path, test_content).unwrap(); + + let uri = Url::from_file_path(&file_path).unwrap().to_string(); + + // Test text mime type with file:// URI + { + let mut active_resources = router.active_resources.lock().unwrap(); + let resource = Resource::new(uri.clone(), Some("text".to_string()), None).unwrap(); + active_resources.insert(uri.clone(), resource); + } + let content = router.read_resource(&uri).await.unwrap(); + assert_eq!(content, test_content); + + // Test blob mime type with file:// URI + let blob_path = temp_dir.path().join("test.bin"); + let blob_content = b"Binary content"; + std::fs::write(&blob_path, blob_content).unwrap(); + let blob_uri = Url::from_file_path(&blob_path).unwrap().to_string(); + { + let mut active_resources = router.active_resources.lock().unwrap(); + let resource = Resource::new(blob_uri.clone(), Some("blob".to_string()), None).unwrap(); + active_resources.insert(blob_uri.clone(), resource); + } + let encoded_content = router.read_resource(&blob_uri).await.unwrap(); + assert_eq!( + base64::prelude::BASE64_STANDARD + .decode(encoded_content) + .unwrap(), + blob_content + ); + + // Test str:// URI with text mime type + let str_uri = format!("str:///{}", test_content); + { + let mut active_resources = router.active_resources.lock().unwrap(); + let resource = Resource::new(str_uri.clone(), Some("text".to_string()), None).unwrap(); + active_resources.insert(str_uri.clone(), resource); + } + let str_content = router.read_resource(&str_uri).await.unwrap(); + assert_eq!(str_content, test_content); + + // Test str:// URI with blob mime type (should fail) + let str_blob_uri = format!("str:///{}", test_content); + { + let mut active_resources = router.active_resources.lock().unwrap(); + let resource = + Resource::new(str_blob_uri.clone(), Some("blob".to_string()), None).unwrap(); + active_resources.insert(str_blob_uri.clone(), resource); + } + let error = router.read_resource(&str_blob_uri).await.unwrap_err(); + assert!(matches!(error, ResourceError::NotFound(_))); + assert!(error.to_string().contains("only supports text mime type")); + + // Test invalid URI + let error = router.read_resource("invalid://uri").await.unwrap_err(); + assert!(matches!(error, ResourceError::NotFound(_))); + + // Test file:// URI without registration + let non_registered = Url::from_file_path(temp_dir.path().join("not_registered.txt")) + .unwrap() + .to_string(); + let error = router.read_resource(&non_registered).await.unwrap_err(); + assert!(matches!(error, ResourceError::NotFound(_))); + + // Test file:// URI with non-existent file but registered + let non_existent = Url::from_file_path(temp_dir.path().join("non_existent.txt")) + .unwrap() + .to_string(); + { + let mut active_resources = router.active_resources.lock().unwrap(); + let resource = + Resource::new(non_existent.clone(), Some("text".to_string()), None).unwrap(); + active_resources.insert(non_existent.clone(), resource); + } + let error = router.read_resource(&non_existent).await.unwrap_err(); + assert!(matches!(error, ResourceError::NotFound(_))); + assert!(error.to_string().contains("does not exist")); + + // Test invalid mime type + let invalid_mime = Url::from_file_path(&file_path).unwrap().to_string(); + { + let mut active_resources = router.active_resources.lock().unwrap(); + let mut resource = + Resource::new(invalid_mime.clone(), Some("text".to_string()), None).unwrap(); + resource.mime_type = "invalid".to_string(); + active_resources.insert(invalid_mime.clone(), resource); + } + let error = router.read_resource(&invalid_mime).await.unwrap_err(); + assert!(matches!(error, ResourceError::NotFound(_))); + assert!(error.to_string().contains("Unsupported mime type")); + + temp_dir.close().unwrap(); + } + + #[tokio::test] + async fn test_text_editor_undo_edit() { + let router = get_router().await; + + let temp_dir = tempfile::tempdir().unwrap(); + let file_path = temp_dir.path().join("test.txt"); + let file_path_str = file_path.to_str().unwrap(); + + // Create a new file + router + .call_tool( + "text_editor", + json!({ + "command": "write", + "path": file_path_str, + "file_text": "First line" + }), + ) + .await + .unwrap(); + + // View the file to make it active + router + .call_tool( + "text_editor", + json!({ + "command": "view", + "path": file_path_str + }), + ) + .await + .unwrap(); + + // Replace string + router + .call_tool( + "text_editor", + json!({ + "command": "str_replace", + "path": file_path_str, + "old_str": "First line", + "new_str": "Second line" + }), + ) + .await + .unwrap(); + + // Undo the edit + let undo_result = router + .call_tool( + "text_editor", + json!({ + "command": "undo_edit", + "path": file_path_str + }), + ) + .await + .unwrap(); + + let messages = undo_result.get("messages").unwrap().as_array().unwrap(); + let message = messages.first().unwrap(); + assert!(message + .get("text") + .unwrap() + .as_str() + .unwrap() + .contains("Undid the last edit")); + + // View the file again + let view_result = router + .call_tool( + "text_editor", + json!({ + "command": "view", + "path": file_path_str + }), + ) + .await + .unwrap(); + + let messages = view_result.get("messages").unwrap().as_array().unwrap(); + let message = messages.first().unwrap(); + assert!(message + .get("text") + .unwrap() + .as_str() + .unwrap() + .contains("The file content for")); + + temp_dir.close().unwrap(); + } +} From 16fa6a0c3665ce6d5a6c9a729619fe9a843c3853 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Mon, 23 Dec 2024 12:56:59 -0500 Subject: [PATCH 07/10] use a helper to get first message text --- crates/developer/src/main.rs | 55 +++++++++++------------------------- 1 file changed, 16 insertions(+), 39 deletions(-) diff --git a/crates/developer/src/main.rs b/crates/developer/src/main.rs index 10a64a62f..1706506a8 100644 --- a/crates/developer/src/main.rs +++ b/crates/developer/src/main.rs @@ -682,6 +682,12 @@ mod tests { static DEV_ROUTER: OnceCell = OnceCell::const_new(); + fn get_first_message_text(value: &Value) -> &str { + let messages = value.get("messages").unwrap().as_array().unwrap(); + let first = messages.first().unwrap(); + first.get("text").unwrap().as_str().unwrap() + } + async fn get_router() -> &'static DeveloperRouter { DEV_ROUTER .get_or_init(|| async { DeveloperRouter::new() }) @@ -706,12 +712,9 @@ mod tests { .await; assert!(result.is_ok()); let output = result.unwrap(); - // Check that the output contains the current directory assert!(output.get("messages").unwrap().as_array().unwrap().len() > 0); - let messages = output.get("messages").unwrap().as_array().unwrap(); - let message = messages.first().unwrap(); - let text = message.get("text").unwrap().as_str().unwrap(); + let text = get_first_message_text(&output); assert!(text.contains(&std::env::current_dir().unwrap().display().to_string())); } @@ -826,9 +829,7 @@ mod tests { .len() > 0 ); - let messages = view_result.get("messages").unwrap().as_array().unwrap(); - let message = messages.first().unwrap(); - let text = message.get("text").unwrap().as_str().unwrap(); + let text = get_first_message_text(&view_result); assert!(text.contains("The file content for")); temp_dir.close().unwrap(); @@ -881,14 +882,8 @@ mod tests { .await .unwrap(); - let messages = replace_result.get("messages").unwrap().as_array().unwrap(); - let message = messages.first().unwrap(); - assert!(message - .get("text") - .unwrap() - .as_str() - .unwrap() - .contains("Successfully replaced text")); + let text = get_first_message_text(&replace_result); + assert!(text.contains("Successfully replaced text")); // View the file again let view_result = router @@ -902,14 +897,8 @@ mod tests { .await .unwrap(); - let messages = view_result.get("messages").unwrap().as_array().unwrap(); - let message = messages.first().unwrap(); - assert!(message - .get("text") - .unwrap() - .as_str() - .unwrap() - .contains("The file content for")); + let text = get_first_message_text(&view_result); + assert!(text.contains("The file content for")); temp_dir.close().unwrap(); } @@ -1074,14 +1063,8 @@ mod tests { .await .unwrap(); - let messages = undo_result.get("messages").unwrap().as_array().unwrap(); - let message = messages.first().unwrap(); - assert!(message - .get("text") - .unwrap() - .as_str() - .unwrap() - .contains("Undid the last edit")); + let text = get_first_message_text(&undo_result); + assert!(text.contains("Undid the last edit")); // View the file again let view_result = router @@ -1095,14 +1078,8 @@ mod tests { .await .unwrap(); - let messages = view_result.get("messages").unwrap().as_array().unwrap(); - let message = messages.first().unwrap(); - assert!(message - .get("text") - .unwrap() - .as_str() - .unwrap() - .contains("The file content for")); + let text = get_first_message_text(&view_result); + assert!(text.contains("The file content for")); temp_dir.close().unwrap(); } From 6187430362b5b965c9a3f000bdcb17fae1e8e394 Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Fri, 27 Dec 2024 21:25:49 -0800 Subject: [PATCH 08/10] parity --- crates/developer/Cargo.toml | 1 + crates/developer/src/main.rs | 138 ++++++++++++++++++++++++++++++++++- 2 files changed, 138 insertions(+), 1 deletion(-) diff --git a/crates/developer/Cargo.toml b/crates/developer/Cargo.toml index 5f7a40c24..deac6cf79 100644 --- a/crates/developer/Cargo.toml +++ b/crates/developer/Cargo.toml @@ -26,6 +26,7 @@ lazy_static = "1.5" kill_tree = "0.2.4" shellexpand = "3.1.0" indoc = "2.0.5" +xcap = "0.0.14" [dev-dependencies] sysinfo = "0.32.1" diff --git a/crates/developer/src/main.rs b/crates/developer/src/main.rs index 1706506a8..bf03648a6 100644 --- a/crates/developer/src/main.rs +++ b/crates/developer/src/main.rs @@ -10,6 +10,7 @@ use std::{ collections::HashMap, fs, future::Future, + io::Cursor, path::{Path, PathBuf}, pin::Pin, }; @@ -35,6 +36,7 @@ use tokio::io::{stdin, stdout}; use tracing::info; use tracing_appender::rolling::{RollingFileAppender, Rotation}; use tracing_subscriber::{self, EnvFilter}; +use xcap::{Monitor, Window}; pub struct DeveloperRouter { tools: Vec, @@ -79,6 +81,37 @@ impl DeveloperRouter { }), ); + let list_windows_tool = Tool::new( + "list_windows".to_string(), + "List all open windows".to_string(), + json!({ + "type": "object", + "required": [], + "properties": {} + }), + ); + + let screen_capture_tool = Tool::new( + "screen_capture".to_string(), + "Capture a screenshot of a specified display or window.\nYou can capture either:\n1. A full display (monitor) using the display parameter\n2. A specific window by its title using the window_title parameter\n\nOnly one of display or window_title should be specified.".to_string(), + json!({ + "type": "object", + "required": [], + "properties": { + "display": { + "type": "integer", + "default": 0, + "description": "The display number to capture (0 is main display)" + }, + "window_title": { + "type": "string", + "default": null, + "description": "Optional: the exact title of the window to capture. use the list_windows tool to find the available windows." + } + } + }), + ); + let instructions = "Developer instructions...".to_string(); // Reuse from original code let cwd = std::env::current_dir().unwrap(); @@ -93,7 +126,7 @@ impl DeveloperRouter { resources.insert(uri, resource); Self { - tools: vec![bash_tool, text_editor_tool], + tools: vec![bash_tool, text_editor_tool, list_windows_tool, screen_capture_tool], cwd: Arc::new(Mutex::new(cwd)), active_resources: Arc::new(Mutex::new(resources)), file_history: Arc::new(Mutex::new(HashMap::new())), @@ -508,6 +541,96 @@ impl DeveloperRouter { Ok(()) } + // Implement window listing functionality + async fn list_windows(&self, _params: Value) -> AgentResult> { + let windows = Window::all() + .map_err(|_| AgentError::ExecutionError("Failed to list windows".into()))?; + + let window_titles: Vec = windows + .into_iter() + .map(|w| w.title().to_string()) + .collect(); + + Ok(vec![ + Content::text("The following windows are available.").with_audience(vec![Role::Assistant]), + Content::text(format!("Available windows:\n{}", window_titles.join("\n"))) + .with_audience(vec![Role::User]) + .with_priority(0.0), + ]) + } + + async fn screen_capture(&self, params: Value) -> AgentResult> { + let mut image = if let Some(window_title) = params.get("window_title").and_then(|v| v.as_str()) { + // Try to find and capture the specified window + let windows = Window::all() + .map_err(|_| AgentError::ExecutionError("Failed to list windows".into()))?; + + let window = windows + .into_iter() + .find(|w| w.title() == window_title) + .ok_or_else(|| { + AgentError::ExecutionError(format!( + "No window found with title '{}'", + window_title + )) + })?; + + window.capture_image().map_err(|e| { + AgentError::ExecutionError(format!( + "Failed to capture window '{}': {}", + window_title, e + )) + })? + } else { + // Default to display capture if no window title is specified + let display = params.get("display").and_then(|v| v.as_u64()).unwrap_or(0) as usize; + + let monitors = Monitor::all() + .map_err(|_| AgentError::ExecutionError("Failed to access monitors".into()))?; + let monitor = monitors.get(display).ok_or_else(|| { + AgentError::ExecutionError(format!( + "{} was not an available monitor, {} found.", + display, + monitors.len() + )) + })?; + + monitor.capture_image().map_err(|e| { + AgentError::ExecutionError(format!("Failed to capture display {}: {}", display, e)) + })? + }; + + // Resize the image to a reasonable width while maintaining aspect ratio + let max_width = 768; + if image.width() > max_width { + let scale = max_width as f32 / image.width() as f32; + let new_height = (image.height() as f32 * scale) as u32; + image = xcap::image::imageops::resize( + &image, + max_width, + new_height, + xcap::image::imageops::FilterType::Lanczos3, + ) + }; + + let mut bytes: Vec = Vec::new(); + image + .write_to(&mut Cursor::new(&mut bytes), xcap::image::ImageFormat::Png) + .map_err(|e| { + AgentError::ExecutionError(format!("Failed to write image buffer {}", e)) + })?; + + // Convert to base64 + let data = base64::prelude::BASE64_STANDARD.encode(bytes); + + Ok(vec![ + Content::text("Screenshot captured").with_audience(vec![Role::Assistant]), + Content::image(data, "image/png") + .with_audience(vec![Role::User]) + .with_priority(0.0) + ]) + } + async fn read_resource_internal(&self, uri: &str) -> AgentResult { // Ensure the resource exists in the active resources map let active_resources = self.active_resources.lock().unwrap(); @@ -575,6 +698,17 @@ impl DeveloperRouter { ))), } } + + // Add this helper function similar to the other tool calls + async fn call_list_windows(&self, args: Value) -> Result { + let result = self.list_windows(args).await; + self.map_agent_result_to_value(result) + } + + async fn call_screen_capture(&self, args: Value) -> Result { + let result = self.screen_capture(args).await; + self.map_agent_result_to_value(result) + } } impl Router for DeveloperRouter { @@ -601,6 +735,8 @@ impl Router for DeveloperRouter { match tool_name.as_str() { "bash" => this.call_bash(arguments).await, "text_editor" => this.call_text_editor(arguments).await, + "list_windows" => this.call_list_windows(arguments).await, + "screen_capture" => this.call_screen_capture(arguments).await, _ => Err(ToolError::NotFound(format!("Tool {} not found", tool_name))), } }) From b5f2a596183391e4d5298901dfbb359ee5299689 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Mon, 30 Dec 2024 17:51:09 -0500 Subject: [PATCH 09/10] remove errors.rs, return ToolError or ResourceError directly --- crates/developer/src/errors.rs | 69 ------------- crates/developer/src/main.rs | 137 +++++++++++++------------- crates/developer/src/process_store.rs | 1 + 3 files changed, 69 insertions(+), 138 deletions(-) delete mode 100644 crates/developer/src/errors.rs diff --git a/crates/developer/src/errors.rs b/crates/developer/src/errors.rs deleted file mode 100644 index 6025e626c..000000000 --- a/crates/developer/src/errors.rs +++ /dev/null @@ -1,69 +0,0 @@ -use mcp_core::handler::{ResourceError, ToolError}; -use mcp_server::RouterError; -use serde::{Deserialize, Serialize}; -use thiserror::Error; - -#[non_exhaustive] -#[derive(Error, Debug, Clone, Deserialize, Serialize, PartialEq)] -pub enum AgentError { - #[error("Tool not found: {0}")] - ToolNotFound(String), - - #[error("The parameters to the tool call were invalid: {0}")] - InvalidParameters(String), - - #[error("The tool failed during execution with the following output: \n{0}")] - ExecutionError(String), - - #[error("Internal error: {0}")] - Internal(String), - - #[error("Invalid tool name: {0}")] - InvalidToolName(String), -} - -pub type AgentResult = Result; - -impl From for ToolError { - fn from(err: AgentError) -> Self { - match err { - AgentError::InvalidParameters(msg) => ToolError::InvalidParameters(msg), - AgentError::InvalidToolName(msg) => ToolError::InvalidParameters(msg), - AgentError::ToolNotFound(msg) => ToolError::NotFound(msg), - AgentError::ExecutionError(msg) => ToolError::ExecutionError(msg), - AgentError::Internal(msg) => ToolError::ExecutionError(msg), - } - } -} - -impl From for ResourceError { - fn from(err: AgentError) -> Self { - match err { - AgentError::InvalidParameters(msg) => ResourceError::NotFound(msg), - _ => ResourceError::NotFound(err.to_string()), - } - } -} - -impl From for RouterError { - fn from(err: AgentError) -> Self { - match err { - AgentError::ToolNotFound(msg) => RouterError::ToolNotFound(msg), - AgentError::InvalidParameters(msg) => RouterError::InvalidParams(msg), - AgentError::ExecutionError(msg) => RouterError::Internal(msg), - AgentError::Internal(msg) => RouterError::Internal(msg), - AgentError::InvalidToolName(msg) => RouterError::ToolNotFound(msg), - } - } -} - -impl From for AgentError { - fn from(err: ResourceError) -> Self { - match err { - ResourceError::NotFound(msg) => { - AgentError::InvalidParameters(format!("Resource not found: {}", msg)) - } - ResourceError::ExecutionError(msg) => AgentError::ExecutionError(msg), - } - } -} diff --git a/crates/developer/src/main.rs b/crates/developer/src/main.rs index 1706506a8..ebf9c6b82 100644 --- a/crates/developer/src/main.rs +++ b/crates/developer/src/main.rs @@ -1,4 +1,4 @@ -mod errors; + mod lang; mod process_store; @@ -24,7 +24,7 @@ use mcp_core::{ }; use mcp_server::router::{CapabilitiesBuilder, RouterService}; -use crate::errors::{AgentError, AgentResult}; + use mcp_core::content::Content; use mcp_core::role::Role; @@ -101,21 +101,10 @@ impl DeveloperRouter { } } - // Example utility function to call the underlying logic - async fn call_bash(&self, args: Value) -> Result { - let result = self.bash(args).await; // adapt your logic from DeveloperSystem - self.map_agent_result_to_value(result) - } - - async fn call_text_editor(&self, args: Value) -> Result { - let result = self.text_editor(args).await; // adapt from DeveloperSystem - self.map_agent_result_to_value(result) - } - - // Convert AgentResult> to Result - fn map_agent_result_to_value( + /// Helper method to map the result of a tool call to a JSON value + fn map_result_to_value( &self, - result: AgentResult>, + result: Result, ToolError>, ) -> Result { match result { Ok(contents) => { @@ -131,12 +120,22 @@ impl DeveloperRouter { .collect(); Ok(json!({"messages": messages})) } - Err(e) => Err(e.into()), + Err(e) => Err(e), } } + async fn call_bash(&self, args: Value) -> Result { + let result = self.bash(args).await; + self.map_result_to_value(result) + } + + async fn call_text_editor(&self, args: Value) -> Result { + let result = self.text_editor(args).await; + self.map_result_to_value(result) + } + // Helper method to resolve a path relative to cwd - fn resolve_path(&self, path_str: &str) -> AgentResult { + fn resolve_path(&self, path_str: &str) -> Result { let cwd = self.cwd.lock().unwrap(); let expanded = shellexpand::tilde(path_str); let path = Path::new(expanded.as_ref()); @@ -150,18 +149,18 @@ impl DeveloperRouter { } // Implement bash tool functionality - async fn bash(&self, params: Value) -> AgentResult> { + async fn bash(&self, params: Value) -> Result, ToolError> { let command = params .get("command") .and_then(|v| v.as_str()) - .ok_or(AgentError::InvalidParameters( - "The command string is required".into(), + .ok_or(ToolError::InvalidParameters( + "The command string is required".to_string(), ))?; // Disallow commands that should use other tools if command.trim_start().starts_with("cat") { - return Err(AgentError::InvalidParameters( + return Err(ToolError::InvalidParameters( "Do not use `cat` to read files, use the view mode on the text editor tool" .to_string(), )); @@ -179,7 +178,7 @@ impl DeveloperRouter { .arg("-c") .arg(cmd_with_redirect) .spawn() - .map_err(|e| AgentError::ExecutionError(e.to_string()))?; + .map_err(|e| ToolError::ExecutionError(e.to_string()))?; // Store the process ID with the command as the key let pid: Option = child.id(); @@ -191,7 +190,7 @@ impl DeveloperRouter { let output = child .wait_with_output() .await - .map_err(|e| AgentError::ExecutionError(e.to_string()))?; + .map_err(|e| ToolError::ExecutionError(e.to_string()))?; // Remove the process ID from the store if let Some(pid) = pid { @@ -212,16 +211,16 @@ impl DeveloperRouter { } // Implement text_editor tool functionality - async fn text_editor(&self, params: Value) -> AgentResult> { + async fn text_editor(&self, params: Value) -> Result, ToolError> { let command = params .get("command") .and_then(|v| v.as_str()) - .ok_or_else(|| AgentError::InvalidParameters("Missing 'command' parameter".into()))?; + .ok_or_else(|| ToolError::InvalidParameters("Missing 'command' parameter".to_string()))?; let path_str = params .get("path") .and_then(|v| v.as_str()) - .ok_or_else(|| AgentError::InvalidParameters("Missing 'path' parameter".into()))?; + .ok_or_else(|| ToolError::InvalidParameters("Missing 'path' parameter".into()))?; let path = self.resolve_path(path_str)?; @@ -232,7 +231,7 @@ impl DeveloperRouter { .get("file_text") .and_then(|v| v.as_str()) .ok_or_else(|| { - AgentError::InvalidParameters("Missing 'file_text' parameter".into()) + ToolError::InvalidParameters("Missing 'file_text' parameter".into()) })?; self.text_editor_write(&path, file_text).await @@ -242,26 +241,26 @@ impl DeveloperRouter { .get("old_str") .and_then(|v| v.as_str()) .ok_or_else(|| { - AgentError::InvalidParameters("Missing 'old_str' parameter".into()) + ToolError::InvalidParameters("Missing 'old_str' parameter".into()) })?; let new_str = params .get("new_str") .and_then(|v| v.as_str()) .ok_or_else(|| { - AgentError::InvalidParameters("Missing 'new_str' parameter".into()) + ToolError::InvalidParameters("Missing 'new_str' parameter".into()) })?; self.text_editor_replace(&path, old_str, new_str).await } "undo_edit" => self.text_editor_undo(&path).await, - _ => Err(AgentError::InvalidParameters(format!( + _ => Err(ToolError::InvalidParameters(format!( "Unknown command '{}'", command ))), } } - async fn text_editor_view(&self, path: &PathBuf) -> AgentResult> { + async fn text_editor_view(&self, path: &PathBuf) -> Result, ToolError> { if path.is_file() { // Check file size first (2MB limit) const MAX_FILE_SIZE: u64 = 2 * 1024 * 1024; // 2MB in bytes @@ -269,12 +268,12 @@ impl DeveloperRouter { let file_size = std::fs::metadata(path) .map_err(|e| { - AgentError::ExecutionError(format!("Failed to get file metadata: {}", e)) + ToolError::ExecutionError(format!("Failed to get file metadata: {}", e)) })? .len(); if file_size > MAX_FILE_SIZE { - return Err(AgentError::ExecutionError(format!( + return Err(ToolError::ExecutionError(format!( "File '{}' is too large ({:.2}MB). Maximum size is 2MB to prevent memory issues.", path.display(), file_size as f64 / 1024.0 / 1024.0 @@ -283,16 +282,16 @@ impl DeveloperRouter { // Create a new resource and add it to active_resources let uri = Url::from_file_path(path) - .map_err(|_| AgentError::ExecutionError("Invalid file path".into()))? + .map_err(|_| ToolError::ExecutionError("Invalid file path".into()))? .to_string(); // Read the content once let content = std::fs::read_to_string(path) - .map_err(|e| AgentError::ExecutionError(format!("Failed to read file: {}", e)))?; + .map_err(|e| ToolError::ExecutionError(format!("Failed to read file: {}", e)))?; let char_count = content.chars().count(); if char_count > MAX_CHAR_COUNT { - return Err(AgentError::ExecutionError(format!( + return Err(ToolError::ExecutionError(format!( "File '{}' has too many characters ({}). Maximum character count is {}.", path.display(), char_count, @@ -303,7 +302,7 @@ impl DeveloperRouter { // Create and store the resource let resource = Resource::new(uri.clone(), Some("text".to_string()), None).map_err(|e| { - AgentError::ExecutionError(format!("Failed to create resource: {}", e)) + ToolError::ExecutionError(format!("Failed to create resource: {}", e)) })?; self.active_resources.lock().unwrap().insert(uri, resource); @@ -333,7 +332,7 @@ impl DeveloperRouter { .with_priority(0.0), ]) } else { - Err(AgentError::ExecutionError(format!( + Err(ToolError::ExecutionError(format!( "The path '{}' does not exist or is not a file.", path.display() ))) @@ -344,15 +343,15 @@ impl DeveloperRouter { &self, path: &PathBuf, file_text: &str, - ) -> AgentResult> { + ) -> Result, ToolError> { // Get the URI for the file let uri = Url::from_file_path(path) - .map_err(|_| AgentError::ExecutionError("Invalid file path".into()))? + .map_err(|_| ToolError::ExecutionError("Invalid file path".into()))? .to_string(); // Check if file already exists and is active if path.exists() && !self.active_resources.lock().unwrap().contains_key(&uri) { - return Err(AgentError::InvalidParameters(format!( + return Err(ToolError::InvalidParameters(format!( "File '{}' exists but is not active. View it first before overwriting.", path.display() ))); @@ -363,12 +362,12 @@ impl DeveloperRouter { // Write to the file std::fs::write(path, file_text) - .map_err(|e| AgentError::ExecutionError(format!("Failed to write file: {}", e)))?; + .map_err(|e| ToolError::ExecutionError(format!("Failed to write file: {}", e)))?; // Create and store resource let resource = Resource::new(uri.clone(), Some("text".to_string()), None) - .map_err(|e| AgentError::ExecutionError(e.to_string()))?; + .map_err(|e| ToolError::ExecutionError(e.to_string()))?; self.active_resources.lock().unwrap().insert(uri, resource); // Try to detect the language from the file extension @@ -397,21 +396,21 @@ impl DeveloperRouter { path: &PathBuf, old_str: &str, new_str: &str, - ) -> AgentResult> { + ) -> Result, ToolError> { // Get the URI for the file let uri = Url::from_file_path(path) - .map_err(|_| AgentError::ExecutionError("Invalid file path".into()))? + .map_err(|_| ToolError::ExecutionError("Invalid file path".into()))? .to_string(); // Check if file exists and is active if !path.exists() { - return Err(AgentError::InvalidParameters(format!( + return Err(ToolError::InvalidParameters(format!( "File '{}' does not exist", path.display() ))); } if !self.active_resources.lock().unwrap().contains_key(&uri) { - return Err(AgentError::InvalidParameters(format!( + return Err(ToolError::InvalidParameters(format!( "You must view '{}' before editing it", path.display() ))); @@ -419,17 +418,17 @@ impl DeveloperRouter { // Read content let content = std::fs::read_to_string(path) - .map_err(|e| AgentError::ExecutionError(format!("Failed to read file: {}", e)))?; + .map_err(|e| ToolError::ExecutionError(format!("Failed to read file: {}", e)))?; // Ensure 'old_str' appears exactly once if content.matches(old_str).count() > 1 { - return Err(AgentError::InvalidParameters( + return Err(ToolError::InvalidParameters( "'old_str' must appear exactly once in the file, but it appears multiple times" .into(), )); } if content.matches(old_str).count() == 0 { - return Err(AgentError::InvalidParameters( + return Err(ToolError::InvalidParameters( "'old_str' must appear exactly once in the file, but it does not appear in the file. Make sure the string exactly matches existing file content, including spacing.".into(), )); } @@ -440,7 +439,7 @@ impl DeveloperRouter { // Replace and write back let new_content = content.replace(old_str, new_str); std::fs::write(path, &new_content) - .map_err(|e| AgentError::ExecutionError(format!("Failed to write file: {}", e)))?; + .map_err(|e| ToolError::ExecutionError(format!("Failed to write file: {}", e)))?; // Update resource if let Some(resource) = self.active_resources.lock().unwrap().get_mut(&uri) { @@ -475,32 +474,32 @@ impl DeveloperRouter { ]) } - async fn text_editor_undo(&self, path: &PathBuf) -> AgentResult> { + async fn text_editor_undo(&self, path: &PathBuf) -> Result, ToolError> { let mut history = self.file_history.lock().unwrap(); if let Some(contents) = history.get_mut(path) { if let Some(previous_content) = contents.pop() { // Write previous content back to file std::fs::write(path, previous_content).map_err(|e| { - AgentError::ExecutionError(format!("Failed to write file: {}", e)) + ToolError::ExecutionError(format!("Failed to write file: {}", e)) })?; Ok(vec![Content::text("Undid the last edit")]) } else { - Err(AgentError::InvalidParameters( + Err(ToolError::InvalidParameters( "No edit history available to undo".into(), )) } } else { - Err(AgentError::InvalidParameters( + Err(ToolError::InvalidParameters( "No edit history available to undo".into(), )) } } - fn save_file_history(&self, path: &PathBuf) -> AgentResult<()> { + fn save_file_history(&self, path: &PathBuf) -> Result<(), ToolError> { let mut history = self.file_history.lock().unwrap(); let content = if path.exists() { std::fs::read_to_string(path) - .map_err(|e| AgentError::ExecutionError(format!("Failed to read file: {}", e)))? + .map_err(|e| ToolError::ExecutionError(format!("Failed to read file: {}", e)))? } else { String::new() }; @@ -508,26 +507,26 @@ impl DeveloperRouter { Ok(()) } - async fn read_resource_internal(&self, uri: &str) -> AgentResult { + async fn read_resource_internal(&self, uri: &str) -> Result { // Ensure the resource exists in the active resources map let active_resources = self.active_resources.lock().unwrap(); let resource = active_resources .get(uri) - .ok_or_else(|| AgentError::ToolNotFound(format!("Resource '{}' not found", uri)))?; + .ok_or_else(|| ResourceError::NotFound(format!("Resource '{}' not found", uri)))?; let url = Url::parse(uri) - .map_err(|e| AgentError::InvalidParameters(format!("Invalid URI: {}", e)))?; + .map_err(|e| ResourceError::NotFound(format!("Invalid URI: {}", e)))?; // Read content based on scheme and mime_type match url.scheme() { "file" => { let path = url.to_file_path().map_err(|_| { - AgentError::InvalidParameters("Invalid file path in URI".into()) + ResourceError::NotFound("Invalid file path in URI".into()) })?; // Ensure file exists if !path.exists() { - return Err(AgentError::ExecutionError(format!( + return Err(ResourceError::NotFound(format!( "File does not exist: {}", path.display() ))); @@ -537,17 +536,17 @@ impl DeveloperRouter { "text" => { // Read the file as UTF-8 text fs::read_to_string(&path).map_err(|e| { - AgentError::ExecutionError(format!("Failed to read file: {}", e)) + ResourceError::ExecutionError(format!("Failed to read file: {}", e)) }) } "blob" => { // Read as bytes, base64 encode let bytes = fs::read(&path).map_err(|e| { - AgentError::ExecutionError(format!("Failed to read file: {}", e)) + ResourceError::ExecutionError(format!("Failed to read file: {}", e)) })?; Ok(base64::prelude::BASE64_STANDARD.encode(bytes)) } - mime_type => Err(AgentError::InvalidParameters(format!( + mime_type => Err(ResourceError::ExecutionError(format!( "Unsupported mime type: {}", mime_type ))), @@ -556,7 +555,7 @@ impl DeveloperRouter { "str" => { // For str:// URIs, we only support text if resource.mime_type != "text" { - return Err(AgentError::InvalidParameters(format!( + return Err(ResourceError::ExecutionError(format!( "str:// URI only supports text mime type, got {}", resource.mime_type ))); @@ -565,11 +564,11 @@ impl DeveloperRouter { // The `Url::path()` gives us the portion after `str:///` let content_encoded = url.path().trim_start_matches('/'); let decoded = urlencoding::decode(content_encoded).map_err(|e| { - AgentError::ExecutionError(format!("Failed to decode str:// content: {}", e)) + ResourceError::ExecutionError(format!("Failed to decode str:// content: {}", e)) })?; Ok(decoded.into_owned()) } - scheme => Err(AgentError::InvalidParameters(format!( + scheme => Err(ResourceError::NotFound(format!( "Unsupported URI scheme: {}", scheme ))), diff --git a/crates/developer/src/process_store.rs b/crates/developer/src/process_store.rs index efa38629a..2e5d75a4a 100644 --- a/crates/developer/src/process_store.rs +++ b/crates/developer/src/process_store.rs @@ -24,6 +24,7 @@ pub fn remove_process(pid: u32) -> bool { } /// Kill all stored processes +#[allow(dead_code)] pub fn kill_processes() { let mut killed_processes = Vec::new(); { From e52d4b1ced557cab2ffba1b163332d67774f7c1e Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Tue, 31 Dec 2024 01:20:29 -0800 Subject: [PATCH 10/10] change AgentResult to Result and AgentError to ToolError --- crates/developer/src/main.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/crates/developer/src/main.rs b/crates/developer/src/main.rs index 311364ac1..128a439ba 100644 --- a/crates/developer/src/main.rs +++ b/crates/developer/src/main.rs @@ -541,9 +541,9 @@ impl DeveloperRouter { } // Implement window listing functionality - async fn list_windows(&self, _params: Value) -> AgentResult> { + async fn list_windows(&self, _params: Value) -> Result, ToolError> { let windows = Window::all() - .map_err(|_| AgentError::ExecutionError("Failed to list windows".into()))?; + .map_err(|_| ToolError::ExecutionError("Failed to list windows".into()))?; let window_titles: Vec = windows .into_iter() @@ -558,24 +558,24 @@ impl DeveloperRouter { ]) } - async fn screen_capture(&self, params: Value) -> AgentResult> { + async fn screen_capture(&self, params: Value) -> Result, ToolError> { let mut image = if let Some(window_title) = params.get("window_title").and_then(|v| v.as_str()) { // Try to find and capture the specified window let windows = Window::all() - .map_err(|_| AgentError::ExecutionError("Failed to list windows".into()))?; + .map_err(|_| ToolError::ExecutionError("Failed to list windows".into()))?; let window = windows .into_iter() .find(|w| w.title() == window_title) .ok_or_else(|| { - AgentError::ExecutionError(format!( + ToolError::ExecutionError(format!( "No window found with title '{}'", window_title )) })?; window.capture_image().map_err(|e| { - AgentError::ExecutionError(format!( + ToolError::ExecutionError(format!( "Failed to capture window '{}': {}", window_title, e )) @@ -585,9 +585,9 @@ impl DeveloperRouter { let display = params.get("display").and_then(|v| v.as_u64()).unwrap_or(0) as usize; let monitors = Monitor::all() - .map_err(|_| AgentError::ExecutionError("Failed to access monitors".into()))?; + .map_err(|_| ToolError::ExecutionError("Failed to access monitors".into()))?; let monitor = monitors.get(display).ok_or_else(|| { - AgentError::ExecutionError(format!( + ToolError::ExecutionError(format!( "{} was not an available monitor, {} found.", display, monitors.len() @@ -595,7 +595,7 @@ impl DeveloperRouter { })?; monitor.capture_image().map_err(|e| { - AgentError::ExecutionError(format!("Failed to capture display {}: {}", display, e)) + ToolError::ExecutionError(format!("Failed to capture display {}: {}", display, e)) })? }; @@ -616,7 +616,7 @@ impl DeveloperRouter { image .write_to(&mut Cursor::new(&mut bytes), xcap::image::ImageFormat::Png) .map_err(|e| { - AgentError::ExecutionError(format!("Failed to write image buffer {}", e)) + ToolError::ExecutionError(format!("Failed to write image buffer {}", e)) })?; // Convert to base64 @@ -701,12 +701,12 @@ impl DeveloperRouter { // Add this helper function similar to the other tool calls async fn call_list_windows(&self, args: Value) -> Result { let result = self.list_windows(args).await; - self.map_agent_result_to_value(result) + self.map_result_to_value(result) } async fn call_screen_capture(&self, args: Value) -> Result { let result = self.screen_capture(args).await; - self.map_agent_result_to_value(result) + self.map_result_to_value(result) } }