From 5d55bc3486f977764a49b9f4f00be22d37b22b4e Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Mon, 16 Dec 2024 23:26:38 -0500 Subject: [PATCH 01/11] mcp-client using tower service trait --- crates/mcp-client/Cargo.toml | 3 + crates/mcp-client/src/client.rs | 163 +++++++ crates/mcp-client/src/lib.rs | 5 +- crates/mcp-client/src/main.rs | 117 ++--- crates/mcp-client/src/service.rs | 74 +++ crates/mcp-client/src/session.rs | 544 ----------------------- crates/mcp-client/src/sse_transport.rs | 229 ---------- crates/mcp-client/src/stdio_transport.rs | 198 --------- crates/mcp-client/src/transport.rs | 14 - crates/mcp-client/src/transport/mod.rs | 42 ++ crates/mcp-client/src/transport/stdio.rs | 94 ++++ crates/mcp-core/src/tool.rs | 1 + 12 files changed, 426 insertions(+), 1058 deletions(-) create mode 100644 crates/mcp-client/src/client.rs create mode 100644 crates/mcp-client/src/service.rs delete mode 100644 crates/mcp-client/src/session.rs delete mode 100644 crates/mcp-client/src/sse_transport.rs delete mode 100644 crates/mcp-client/src/stdio_transport.rs delete mode 100644 crates/mcp-client/src/transport.rs create mode 100644 crates/mcp-client/src/transport/mod.rs create mode 100644 crates/mcp-client/src/transport/stdio.rs diff --git a/crates/mcp-client/Cargo.toml b/crates/mcp-client/Cargo.toml index 6de3d3903..a666e6522 100644 --- a/crates/mcp-client/Cargo.toml +++ b/crates/mcp-client/Cargo.toml @@ -14,10 +14,13 @@ serde_json = "1.0" clap = { version = "4.5", features = ["derive"] } async-trait = "0.1.83" url = "2.5.4" +thiserror = "1.0" anyhow = "1.0" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tokio-retry = "0.3" +tower = { version = "0.4", features = ["timeout", "util"] } +tower-service = "0.3" [dev-dependencies] warp = "0.3" diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs new file mode 100644 index 000000000..0cdab4391 --- /dev/null +++ b/crates/mcp-client/src/client.rs @@ -0,0 +1,163 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use thiserror::Error; +use tower::ServiceExt; // for Service::ready() + +use mcp_core::protocol::{ + InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, + ListResourcesResult, ReadResourceResult, +}; + +/// Error type for MCP client operations. +#[derive(Debug, Error)] +pub enum Error { + #[error("Service error: {0}")] + Service(#[from] super::service::ServiceError), + + #[error("RPC error: code={code}, message={message}")] + RpcError { code: i32, message: String }, + + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::Error), + + #[error("Unexpected response from server")] + UnexpectedResponse, + + #[error("Timeout or service not ready")] + NotReady, +} + +#[derive(Serialize, Deserialize)] +pub struct ClientInfo { + pub name: String, + pub version: String, +} + +#[derive(Serialize, Deserialize, Default)] +pub struct ClientCapabilities { + // Add fields as needed. For now, empty capabilities are fine. +} + +#[derive(Serialize, Deserialize)] +pub struct InitializeParams { + #[serde(rename = "protocolVersion")] + pub protocol_version: String, + pub capabilities: ClientCapabilities, + #[serde(rename = "clientInfo")] + pub client_info: ClientInfo, +} + +/// The MCP client that sends requests via the provided service. +pub struct McpClient { + service: S, + next_id: u64, +} + +impl McpClient +where + S: tower::Service< + JsonRpcRequest, + Response = JsonRpcMessage, + Error = super::service::ServiceError, + > + Send, + S::Future: Send, +{ + pub fn new(service: S) -> Self { + Self { + service, + next_id: 1, + } + } + + /// Send a JSON-RPC request and wait for a response. + async fn send_message(&mut self, method: &str, params: Value) -> Result + where + T: for<'de> Deserialize<'de>, + { + self.service.ready().await.map_err(|_| Error::NotReady)?; + + let request = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: Some(self.next_id), + method: method.to_string(), + params: Some(params), + }; + + self.next_id += 1; + + let response_msg = self.service.call(request).await?; + + match response_msg { + JsonRpcMessage::Response(JsonRpcResponse { + id, result, error, .. + }) => { + // Verify id matches + if id != Some(self.next_id - 1) { + return Err(Error::UnexpectedResponse); + } + if let Some(err) = error { + Err(Error::RpcError { + code: err.code, + message: err.message, + }) + } else if let Some(r) = result { + Ok(serde_json::from_value(r)?) + } else { + Err(Error::UnexpectedResponse) + } + } + JsonRpcMessage::Error(JsonRpcError { id, error, .. }) => { + if id != Some(self.next_id - 1) { + return Err(Error::UnexpectedResponse); + } + Err(Error::RpcError { + code: error.code, + message: error.message, + }) + } + _ => { + // Requests/notifications not expected as a response + Err(Error::UnexpectedResponse) + } + } + } + + // /// Send a JSON-RPC notification. + // pub async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> { + // let notification = mcp_core::protocol::JsonRpcNotification { + // jsonrpc: "2.0".to_string(), + // method: method.to_string(), + // params: Some(params), + // }; + // let msg = serde_json::to_string(¬ification)?; + // let mut transport = self.transport.lock().await; + // transport.send(msg).await + // } + + /// Initialize the connection with the server. + pub async fn initialize( + &mut self, + info: ClientInfo, + capabilities: ClientCapabilities, + ) -> Result { + let params = InitializeParams { + protocol_version: "1.0.0".into(), + client_info: info, + capabilities, + }; + self.send_message("initialize", serde_json::to_value(params)?) + .await + } + + /// List available resources. + pub async fn list_resources(&mut self) -> Result { + self.send_message("resources/list", serde_json::json!({})) + .await + } + + /// Read a resource's content. + pub async fn read_resource(&mut self, uri: &str) -> Result { + let params = serde_json::json!({ "uri": uri }); + self.send_message("resources/read", params).await + } +} diff --git a/crates/mcp-client/src/lib.rs b/crates/mcp-client/src/lib.rs index 3172f2944..c2ab27500 100644 --- a/crates/mcp-client/src/lib.rs +++ b/crates/mcp-client/src/lib.rs @@ -1,4 +1,3 @@ -pub mod session; -pub mod sse_transport; -pub mod stdio_transport; +pub mod client; +pub mod service; pub mod transport; diff --git a/crates/mcp-client/src/main.rs b/crates/mcp-client/src/main.rs index 70495fffd..bd4490fd0 100644 --- a/crates/mcp-client/src/main.rs +++ b/crates/mcp-client/src/main.rs @@ -1,24 +1,31 @@ -use anyhow::{anyhow, Result}; -use clap::Parser; -use mcp_client::{ - session::Session, - sse_transport::{SseTransport, SseTransportParams}, - stdio_transport::{StdioServerParams, StdioTransport}, - transport::Transport, -}; -use serde_json::json; +use anyhow::Result; +use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClient}; +use mcp_client::{service::TransportService, transport::StdioTransport}; +use tower::ServiceBuilder; use tracing_subscriber::EnvFilter; -#[derive(Parser)] -#[command(author, version, about, long_about = None)] -struct Args { - /// Mode to run in: "git" or "echo" - #[arg(short, long, default_value = "git")] - mode: String, -} +// use mcp_client::{ +// service::{ServiceError}, +// transport::{Error as TransportError}, +// }; +// use std::time::Duration; +// use tower::timeout::error::Elapsed; + +// fn convert_box_error(err: Box) -> ServiceError { +// if let Some(elapsed) = err.downcast_ref::() { +// ServiceError::Transport(TransportError::Io( +// std::io::Error::new( +// std::io::ErrorKind::TimedOut, +// format!("Timeout elapsed: {}", elapsed), +// ), +// )) +// } else { +// ServiceError::Other(err.to_string()) +// } +// } #[tokio::main] -async fn main() -> Result<()> { +async fn main() -> Result<(), ClientError> { // Initialize logging tracing_subscriber::fmt() .with_env_filter( @@ -28,64 +35,34 @@ async fn main() -> Result<()> { ) .init(); - let args = Args::parse(); - println!("Args - mode: {}", args.mode); - - // Create session based on mode - let transport: Box = match args.mode.as_str() { - "git" => Box::new(StdioTransport { - params: StdioServerParams { - command: "uvx".into(), - args: vec!["mcp-server-git".into()], - env: None, - }, - }), - "echo" => Box::new(SseTransport { - params: SseTransportParams { - url: "http://localhost:8000/sse".into(), - headers: None, - }, - }), - _ => return Err(anyhow!("Invalid mode. Use 'git' or 'echo'")), - }; + // Create the base transport + let transport = StdioTransport::new("uvx", ["mcp-server-git"]); - let (read_stream, write_stream) = transport.connect().await?; - let mut session = Session::new(read_stream, write_stream).await?; + // Build service with middleware + let service = ServiceBuilder::new().service(TransportService::new(transport)); - // Initialize the connection - let init_result = session.initialize().await?; - println!("Initialized: {:?}", init_result); + // Create client + let mut client = McpClient::new(service); - // List tools - let tools = session.list_tools().await?; - println!("Tools: {:?}", tools); - - if args.mode == "echo" { - // Call a tool (replace with actual tool name and arguments) - let call_result = session - .call_tool("echo_tool", Some(json!({"message": "Hello, world!"}))) - .await?; - println!("Call tool result: {:?}", call_result); - - // List available resources - let resources = session.list_resources().await?; - println!("Resources: {:?}", resources); + // Initialize + let server_info = client + .initialize( + ClientInfo { + name: "test-client".into(), + version: "1.0.0".into(), + }, + ClientCapabilities::default(), + ) + .await?; + println!("Connected to server: {server_info:?}"); - // Read a resource (replace with actual URI) - if let Some(resource) = resources.resources.first() { - let read_result = session.read_resource(&resource.uri).await?; - println!("Read resource result: {:?}", read_result); - } - } else { - // Call a tool (replace with actual tool name and arguments) - let call_result = session - .call_tool("git_status", Some(json!({"repo_path": "."}))) - .await?; - println!("Call tool result: {:?}", call_result); - } + // List resources + let resources = client.list_resources().await?; + println!("Available resources: {resources:?}"); - session.shutdown().await?; - println!("Done!"); + // Read a resource + let content = client.read_resource("file:///example.txt".into()).await?; + println!("Content: {content:?}"); Ok(()) } diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs new file mode 100644 index 000000000..f422bc924 --- /dev/null +++ b/crates/mcp-client/src/service.rs @@ -0,0 +1,74 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::Mutex; +use tower::Service; + +use crate::transport::{Error as TransportError, Transport}; +use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest}; + +#[derive(Debug, thiserror::Error)] +pub enum ServiceError { + #[error("Transport error: {0}")] + Transport(#[from] TransportError), + + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::Error), + + #[error("Other error: {0}")] + Other(String), + + #[error("Unexpected server response")] + UnexpectedResponse, +} + +/// A Tower `Service` implementation that uses a `Transport` to send/receive JsonRpcRequests and JsonRpcMessages. +pub struct TransportService { + transport: Arc>, + initialized: AtomicBool, +} + +impl TransportService { + pub fn new(transport: T) -> Self { + Self { + transport: Arc::new(Mutex::new(transport)), + initialized: AtomicBool::new(false), + } + } +} + +impl Service for TransportService { + type Response = JsonRpcMessage; + type Error = ServiceError; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + // Always ready. We do on-demand initialization in call(). + Poll::Ready(Ok(())) + } + + fn call(&mut self, request: JsonRpcRequest) -> Self::Future { + let transport = Arc::clone(&self.transport); + let started = self.initialized.load(Ordering::SeqCst); + + Box::pin(async move { + let mut transport = transport.lock().await; + + // Initialize (start) transport on the first call. + if !started { + transport.start().await?; + } + + // Serialize request to JSON line + let msg = serde_json::to_string(&request)?; + transport.send(msg).await?; + + let line = transport.receive().await?; + let response_msg: JsonRpcMessage = serde_json::from_str(&line)?; + + Ok(response_msg) + }) + } +} diff --git a/crates/mcp-client/src/session.rs b/crates/mcp-client/src/session.rs deleted file mode 100644 index 1946cea36..000000000 --- a/crates/mcp-client/src/session.rs +++ /dev/null @@ -1,544 +0,0 @@ -use crate::transport::{ReadStream, WriteStream}; -use anyhow::{anyhow, Context, Result}; -use mcp_core::protocol::*; -use serde::de::DeserializeOwned; -use serde_json::{json, Value}; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::Arc; -use tokio::sync::mpsc; -use tokio::sync::Mutex; -use tracing::debug; - -struct OutgoingMessage { - message: JsonRpcMessage, - response_tx: mpsc::Sender>>, -} - -pub struct Session { - request_tx: mpsc::Sender, - id_counter: AtomicU64, - shutdown_tx: mpsc::Sender<()>, - background_task: Arc>>>, - is_closed: Arc, -} - -impl Session { - pub async fn new(read_stream: ReadStream, write_stream: WriteStream) -> Result { - let (request_tx, mut request_rx) = mpsc::channel::(32); - let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); - let is_closed = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let is_closed_clone = is_closed.clone(); - - // Spawn the background task - let background_task = Arc::new(Mutex::new(Some(tokio::spawn({ - async move { - let mut pending_requests: Vec<( - u64, - mpsc::Sender>>, - )> = Vec::new(); - let mut read_stream = read_stream; - let write_stream = write_stream; - - loop { - tokio::select! { - // Handle shutdown signal - Some(()) = shutdown_rx.recv() => { - // Notify all pending requests of shutdown - for (_, tx) in pending_requests { - let _ = tx.send(Err(anyhow!("Session shutdown"))).await; - } - break; - } - - // Handle outgoing messages - Some(outgoing) = request_rx.recv() => { - // If session is closed, reject new messages - if is_closed_clone.load(Ordering::SeqCst) { - let _ = outgoing.response_tx.send(Err(anyhow!("Session is closed"))).await; - continue; - } - - // Send the message - if let Err(e) = write_stream.send(outgoing.message.clone()).await { - debug!("Write error occurred: {}", e); - // let _ = outgoing.response_tx.send(Err(e.into())).await; - // On write error, mark session as closed - is_closed_clone.store(true, Ordering::SeqCst); - break; - } - - // For requests, store the response channel for later - if let JsonRpcMessage::Request(request) = outgoing.message { - if let Some(id) = request.id { - pending_requests.push((id, outgoing.response_tx)); - } - } else { - // For notifications, just confirm success - let _ = outgoing.response_tx.send(Ok(None)).await; - } - } - - // Handle incoming messages - Some(message_result) = read_stream.recv() => { - match message_result { - Ok(JsonRpcMessage::Response(response)) => { - if let Some(id) = response.id { - if let Some(pos) = pending_requests.iter().position(|(req_id, _)| *req_id == id) { - let (_, tx) = pending_requests.remove(pos); - let _ = tx.send(Ok(Some(response))).await; - } - } - } - Ok(JsonRpcMessage::Notification(_)) => { - // Handle incoming notifications if needed - } - Ok(_) => { - eprintln!("Unexpected message type"); - } - Err(e) => { - // On transport error, notify all pending requests and shutdown - eprintln!("Transport error: {}", e); - for (_, tx) in pending_requests { - let _ = tx.send(Err(anyhow!("{}", e))).await; - } - - // Mark session as closed - is_closed_clone.store(true, Ordering::SeqCst); - break; - } - } - } - } - } - } - })))); - - Ok(Self { - request_tx, - id_counter: AtomicU64::new(1), - shutdown_tx, - background_task, - is_closed, - }) - } - - pub async fn shutdown(&self) -> Result<()> { - // Mark session as closed - self.is_closed.store(true, Ordering::SeqCst); - - // Send shutdown signal - self.shutdown_tx - .send(()) - .await - .map_err(|e| anyhow!("Failed to shutdown session: {}", e))?; - - // Wait for background task to complete - if let Some(task) = self.background_task.lock().await.take() { - task.await - .map_err(|e| anyhow!("Background task failed: {}", e))?; - } - - Ok(()) - } - - async fn send_message(&self, message: JsonRpcMessage) -> Result> { - // Check if session is closed - if self.is_closed.load(Ordering::SeqCst) { - return Err(anyhow!("Session is closed")); - } - - let (response_tx, mut response_rx) = mpsc::channel(1); - - self.request_tx - .send(OutgoingMessage { - message, - response_tx, - }) - .await - .context("Failed to send message")?; - - response_rx - .recv() - .await - .context("Failed to receive response")? - } - - async fn rpc_call( - &self, - method: &str, - params: Option, - ) -> Result { - // Check if session is closed - if self.is_closed.load(Ordering::SeqCst) { - return Err(anyhow!("Session is closed")); - } - - let id = self.id_counter.fetch_add(1, Ordering::SeqCst); - let request = JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: Some(id), - method: method.to_string(), - params, - }; - - let response = self - .send_message(JsonRpcMessage::Request(request)) - .await? - .context("Expected response for request")?; - - match (response.error, response.result) { - (Some(error), _) => Err(anyhow!("RPC Error {}: {}", error.code, error.message)), - (_, Some(result)) => { - serde_json::from_value(result).context("Failed to deserialize result") - } - (None, None) => Err(anyhow!("No result in response")), - } - } - - async fn send_notification(&self, method: &str, params: Option) -> Result<()> { - // Check if session is closed - if self.is_closed.load(Ordering::SeqCst) { - return Err(anyhow!("Session is closed")); - } - - let notification = JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: method.to_string(), - params, - }; - - self.send_message(JsonRpcMessage::Notification(notification)) - .await?; - - Ok(()) - } - - pub async fn initialize(&mut self) -> Result { - let params = json!({ - "protocolVersion": "2024-11-05", - "capabilities": { - "sampling": null, - "experimental": null, - "roots": { - "listChanged": true - } - }, - "clientInfo": { - "name": "RustMCPClient", - "version": "0.1.0" - } - }); - - let result: InitializeResult = self.rpc_call("initialize", Some(params)).await?; - self.send_notification("notifications/initialized", None) - .await?; - Ok(result) - } - - pub async fn list_resources(&self) -> Result { - self.rpc_call("resources/list", Some(json!({}))).await - } - - pub async fn read_resource(&self, uri: &str) -> Result { - self.rpc_call("resources/read", Some(json!({ "uri": uri }))) - .await - } - - pub async fn list_tools(&self) -> Result { - self.rpc_call("tools/list", Some(json!({}))).await - } - - pub async fn call_tool(&self, name: &str, arguments: Option) -> Result { - self.rpc_call( - "tools/call", - Some(json!({ - "name": name, - "arguments": arguments.unwrap_or_else(|| json!({})), - })), - ) - .await - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::transport::{ReadStream, Transport, WriteStream}; - use anyhow::{anyhow, Result}; - use async_trait::async_trait; - use std::sync::atomic::Ordering; - use std::time::Duration; - use tokio::sync::mpsc; - use tokio::time::timeout; - - // Mock transport that simulates errors - struct MockTransport { - error_mode: ErrorMode, - } - - #[derive(Clone)] - enum ErrorMode { - ReadError, - WriteError, - ProcessTermination, - Nil, - } - - #[async_trait] - impl Transport for MockTransport { - async fn connect(&self) -> Result<(ReadStream, WriteStream)> { - let (tx_read, rx_read) = mpsc::channel(100); - let (tx_write, mut rx_write) = mpsc::channel(100); - - let error_mode = self.error_mode.clone(); - - tokio::spawn(async move { - // For WriteError, don't wait for any writes, just drop the receiver to force an immediate failure. - // This ensures that the first attempt to send by the Session fails. - match error_mode { - ErrorMode::ReadError => { - // Wait a bit for the request to be sent and then send the error - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - let _ = tx_read.send(Err(anyhow!("Simulated read error"))).await; - } - ErrorMode::WriteError => { - // Immediately drop the rx_write side - drop(rx_write); - } - ErrorMode::ProcessTermination => { - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - let _ = tx_read.send(Err(anyhow!("Child process terminated"))).await; - } - ErrorMode::Nil => { - // Test with initialize and then list_resources - while let Some(message) = rx_write.recv().await { - match message { - JsonRpcMessage::Request(req) => { - // Send a successful response for initialization or other calls - if req.method == "initialize" { - let response = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: req.id, - result: Some(json!({ - "protocolVersion": "2024-11-05", - "capabilities": { "resources": { "listChanged": false } }, - "serverInfo": { "name": "MockServer", "version": "1.0.0" } - })), - error: None, - }); - let _ = tx_read.send(Ok(response)).await; - } else if req.method == "resources/list" { - let response = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: req.id, - result: Some( - json!({ "resources": [{ "uri": "file://res1", "name": "res1" }, { "uri": "file://res2", "name": "res2" }] }), - ), - error: None, - }); - let _ = tx_read.send(Ok(response)).await; - } else { - // Default success for other calls - let response = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: req.id, - result: Some(json!({ "ok": true })), - error: None, - }); - let _ = tx_read.send(Ok(response)).await; - } - } - JsonRpcMessage::Notification(_notif) => { - // For notifications, no response is required. - } - _ => {} - } - } - } - } - }); - - Ok((rx_read, tx_write)) - } - } - - #[tokio::test] - async fn test_session_can_initialize_and_list_resources() -> Result<()> { - let transport = MockTransport { - error_mode: ErrorMode::Nil, - }; - - let (read_stream, write_stream) = transport.connect().await?; - let mut session = Session::new(read_stream, write_stream).await?; - - // Initialize the session - let init_result = session.initialize().await?; - assert_eq!(init_result.protocol_version, "2024-11-05"); - assert_eq!( - init_result.capabilities.resources.unwrap().list_changed, - Some(false) - ); - - // Now list resources - let list_result = session.list_resources().await?; - assert_eq!( - list_result - .resources - .iter() - .map(|r| &r.name) - .collect::>(), - vec!["res1", "res2"] - ); - - // Make another call - just to verify multiple calls work fine - let _: serde_json::Value = session.rpc_call("someMethod", Some(json!({}))).await?; - Ok(()) - } - - #[tokio::test] - async fn test_read_error_terminates_session() { - let transport = MockTransport { - error_mode: ErrorMode::ReadError, - }; - - let (read_stream, write_stream) = transport.connect().await.unwrap(); - let session = Session::new(read_stream, write_stream).await.unwrap(); - - // // Introduce a brief delay to ensure the request is fully sent and pending before the error occurs - // tokio::time::sleep(std::time::Duration::from_millis(20)).await; - - // Try to make an RPC call - should fail due to transport error - let result = session.list_resources().await; - assert!(result.is_err()); - - // Print the actual error message for debugging - let err = result.unwrap_err(); - println!("Actual error: {}", err); - assert!(err.to_string().contains("Simulated read error")); - - // Verify session is marked as closed - assert!( - session.is_closed.load(Ordering::SeqCst), - "Session did not close after error" - ); - - // Subsequent calls should fail immediately - let result = session.list_tools().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Session is closed")); - } - - #[tokio::test] - async fn test_write_error_terminates_session() { - let transport = MockTransport { - error_mode: ErrorMode::WriteError, - }; - - let (read_stream, write_stream) = transport.connect().await.unwrap(); - let session = Session::new(read_stream, write_stream).await.unwrap(); - - // Try to make an RPC call - should fail due to transport error - let result = session.list_resources().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Failed to receive response")); - - // Verify session is marked as closed - assert!(session.is_closed.load(Ordering::SeqCst)); - println!("First call made"); - - // Subsequent calls should fail immediately - let result = session.list_tools().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Session is closed")); - } - - #[tokio::test] - async fn test_process_termination_terminates_session() { - let transport = MockTransport { - error_mode: ErrorMode::ProcessTermination, - }; - - let (read_stream, write_stream) = transport.connect().await.unwrap(); - let session = Session::new(read_stream, write_stream).await.unwrap(); - - // Try to make an RPC call - should fail due to process termination - let result = session.list_resources().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Child process terminated")); - - // Verify session is marked as closed - assert!(session.is_closed.load(Ordering::SeqCst)); - - // Subsequent calls should fail immediately - let result = session.list_tools().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Session is closed")); - } - - #[tokio::test] - async fn test_session_cleanup_on_drop() { - let transport = MockTransport { - error_mode: ErrorMode::ProcessTermination, - }; - - let (read_stream, write_stream) = transport.connect().await.unwrap(); - let session = Session::new(read_stream, write_stream).await.unwrap(); - - // Get a clone of the background task handle - let background_task = session.background_task.clone(); - - // Drop the session - drop(session); - - // Verify that the background task completes - let timeout_result = timeout(Duration::from_secs(1), async { - if let Some(task) = background_task.lock().await.take() { - task.await.unwrap(); - } - }) - .await; - - assert!(timeout_result.is_ok(), "Background task did not complete"); - } - - #[tokio::test] - async fn test_explicit_shutdown() -> Result<()> { - let transport = MockTransport { - error_mode: ErrorMode::Nil, - }; - - let (read_stream, write_stream) = transport.connect().await?; - let session = Session::new(read_stream, write_stream).await?; - - // Verify we can make calls before shutdown - let _: serde_json::Value = session.rpc_call("someMethod", Some(json!({}))).await?; - - // Shutdown the session - session.shutdown().await?; - - // Verify calls fail after shutdown - let result = session.list_resources().await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Session is closed")); - - Ok(()) - } -} diff --git a/crates/mcp-client/src/sse_transport.rs b/crates/mcp-client/src/sse_transport.rs deleted file mode 100644 index bc5ea8852..000000000 --- a/crates/mcp-client/src/sse_transport.rs +++ /dev/null @@ -1,229 +0,0 @@ -use crate::transport::{ReadStream, Transport, WriteStream}; -use anyhow::{anyhow, Context, Result}; -use async_trait::async_trait; -use futures_util::StreamExt; -use mcp_core::protocol::JsonRpcMessage; -use reqwest::{Client, Url}; -use reqwest_eventsource::{Event, EventSource}; -use std::sync::Arc; -use tokio::sync::{mpsc, Mutex}; -use tokio_retry::{ - strategy::{jitter, ExponentialBackoff}, - Retry, -}; -use tracing::{debug, error, info, warn}; - -pub struct SseTransportParams { - pub url: String, - pub headers: Option, -} - -pub struct SseTransport { - pub params: SseTransportParams, -} - -// Helper function to send a POST request with retry logic -async fn send_with_retry( - client: &Client, - endpoint: &str, - json: serde_json::Value, -) -> Result { - // Create retry strategy with exponential backoff - let retry_strategy = ExponentialBackoff::from_millis(100) // Start with 100ms - .factor(2) // Double the delay each time - .map(jitter) // Add randomness to prevent thundering herd - .take(3); // Maximum of 3 retries (4 attempts total) - - Retry::spawn(retry_strategy, || async { - let response = client.post(endpoint).json(&json).send().await?; - - // If we get a 5xx error or specific connection errors, we should retry - if response.status().is_server_error() - || matches!(response.error_for_status_ref(), Err(e) if e.is_connect()) - { - return Err(anyhow!("Server error: {}", response.status())); - } - - Ok(response) - }) - .await -} - -#[async_trait] -impl Transport for SseTransport { - async fn connect(&self) -> Result<(ReadStream, WriteStream)> { - info!("Connecting to SSE endpoint: {}", self.params.url); - let (tx_read, rx_read) = mpsc::channel(100); - let (tx_write, mut rx_write) = mpsc::channel(100); - - let client = Client::new(); - let base_url = Url::parse(&self.params.url).context("Failed to parse SSE URL")?; - - // Create the event source request - let mut request_builder = client.get(base_url.clone()); - if let Some(headers) = &self.params.headers { - request_builder = headers - .iter() - .fold(request_builder, |req, (key, value)| req.header(key, value)); - } - - let event_source = EventSource::new(request_builder)?; - let client_for_post = client.clone(); - - // Shared state for the endpoint URL - let endpoint_url = Arc::new(Mutex::new(None::)); - let endpoint_url_reader = endpoint_url.clone(); - - // Spawn the SSE reader task - tokio::spawn({ - let tx_read = tx_read.clone(); - let base_url = base_url.clone(); - async move { - info!("Starting SSE reader task"); - let mut stream = event_source; - let mut got_endpoint = false; - - while let Some(event) = stream.next().await { - match event { - Ok(Event::Open) => { - info!("SSE connection opened"); - } - Ok(Event::Message(message)) => { - debug!("Received SSE event: {} - {}", message.event, message.data); - match message.event.as_str() { - "endpoint" => { - // Handle endpoint event - let endpoint = message.data; - info!("Received endpoint URL: {}", endpoint); - - // Join with base URL if relative - let endpoint_url_full = if endpoint.starts_with('/') { - match base_url.join(&endpoint) { - Ok(url) => url, - Err(e) => { - error!("Failed to join endpoint URL: {}", e); - let _ = tx_read.send(Err(e.into())).await; - break; - } - } - } else { - match Url::parse(&endpoint) { - Ok(url) => url, - Err(e) => { - error!("Failed to parse endpoint URL: {}", e); - let _ = tx_read.send(Err(e.into())).await; - break; - } - } - }; - - // Validate endpoint URL has same origin (scheme and host) - if base_url.scheme() != endpoint_url_full.scheme() - || base_url.host_str() != endpoint_url_full.host_str() - || base_url.port() != endpoint_url_full.port() - { - let error = format!( - "Endpoint origin does not match connection origin: {}", - endpoint_url_full - ); - error!("{}", error); - let _ = tx_read.send(Err(anyhow!(error))).await; - break; - } - - let endpoint_str = endpoint_url_full.to_string(); - info!("Using full endpoint URL: {}", endpoint_str); - let mut endpoint_guard = endpoint_url.lock().await; - *endpoint_guard = Some(endpoint_str); - got_endpoint = true; - debug!("Endpoint URL set successfully"); - } - "message" => { - if !got_endpoint { - warn!("Received message before endpoint URL"); - continue; - } - // Handle message event - match serde_json::from_str::(&message.data) { - Ok(msg) => { - debug!("Received server message: {:?}", msg); - if tx_read.send(Ok(msg)).await.is_err() { - error!("Failed to send message to read channel"); - break; - } - } - Err(e) => { - error!("Error parsing server message: {}", e); - if tx_read.send(Err(e.into())).await.is_err() { - error!("Failed to send error to read channel"); - break; - } - } - } - } - _ => { - debug!("Ignoring unknown event type: {}", message.event); - } - } - } - Err(e) => { - error!("SSE error: {}", e); - let _ = tx_read.send(Err(e.into())).await; - break; - } - } - } - info!("SSE reader task ended"); - } - }); - - // Spawn the writer task - tokio::spawn(async move { - info!("Starting writer task"); - // Wait for the endpoint URL before processing messages - let mut endpoint = None; - while endpoint.is_none() { - let guard = endpoint_url_reader.lock().await; - if let Some(url) = guard.as_ref() { - endpoint = Some(url.clone()); - break; - } - drop(guard); - debug!("Waiting for endpoint URL..."); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - } - - let endpoint = endpoint.unwrap(); - info!("Starting post writer with endpoint URL: {}", endpoint); - - while let Some(message) = rx_write.recv().await { - match serde_json::to_value(&message) { - Ok(json) => { - debug!("Sending client message: {:?}", json); - match send_with_retry(&client_for_post, &endpoint, json).await { - Ok(response) => { - if !response.status().is_success() { - let status = response.status(); - let text = response.text().await.unwrap_or_default(); - error!("Server returned error status {}: {}", status, text); - } else { - debug!("Message sent successfully: {}", response.status()); - } - } - Err(e) => { - error!("Failed to send message after retries: {}", e); - } - } - } - Err(e) => { - error!("Failed to serialize message: {}", e); - } - } - } - info!("Writer task ended"); - }); - - info!("SSE transport connected"); - Ok((rx_read, tx_write)) - } -} diff --git a/crates/mcp-client/src/stdio_transport.rs b/crates/mcp-client/src/stdio_transport.rs deleted file mode 100644 index 67f0fbf12..000000000 --- a/crates/mcp-client/src/stdio_transport.rs +++ /dev/null @@ -1,198 +0,0 @@ -use crate::transport::{ReadStream, Transport, WriteStream}; -use anyhow::{anyhow, Context, Result}; -use async_trait::async_trait; -use mcp_core::protocol::*; -use std::process::Stdio; -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; -use tokio::process::{Child, Command}; -use tokio::sync::mpsc; - -pub struct StdioServerParams { - pub command: String, - pub args: Vec, - pub env: Option>, -} - -pub struct StdioTransport { - pub params: StdioServerParams, -} - -impl StdioTransport { - fn get_default_environment() -> std::collections::HashMap { - let default_vars = if cfg!(windows) { - vec!["APPDATA", "PATH", "TEMP", "USERNAME"] // Simplified list - } else { - vec!["HOME", "PATH", "SHELL", "USER"] // Simplified list - }; - - std::env::vars() - .filter(|(key, value)| default_vars.contains(&key.as_str()) && !value.starts_with("()")) - .collect() - } - - async fn monitor_child(mut child: Child, tx_read: mpsc::Sender>) { - match child.wait().await { - Ok(status) => { - let msg = if status.success() { - format!("Child process terminated normally with status: {}", status) - } else { - format!("Child process terminated with error status: {}", status) - }; - let _ = tx_read.send(Err(anyhow!(msg))).await; - } - Err(e) => { - let _ = tx_read - .send(Err(anyhow!("Child process error: {}", e))) - .await; - } - } - } -} - -#[async_trait] -impl Transport for StdioTransport { - async fn connect(&self) -> Result<(ReadStream, WriteStream)> { - let mut child = Command::new(&self.params.command) - .args(&self.params.args) - .env_clear() - .envs( - self.params - .env - .clone() - .unwrap_or_else(Self::get_default_environment), - ) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::inherit()) - .spawn() - .context("Failed to spawn child process")?; - - let stdin = child.stdin.take().context("Failed to get stdin handle")?; - let stdout = child.stdout.take().context("Failed to get stdout handle")?; - - let (tx_read, rx_read) = mpsc::channel(100); - let (tx_write, mut rx_write) = mpsc::channel(100); - - // Clone tx_read for the child monitor - let tx_read_monitor = tx_read.clone(); - - // Spawn child process monitor - tokio::spawn(Self::monitor_child(child, tx_read_monitor)); - - // Spawn stdout reader task - let stdout_reader = BufReader::new(stdout); - tokio::spawn(async move { - let mut lines = stdout_reader.lines(); - while let Ok(Some(line)) = lines.next_line().await { - match serde_json::from_str::(&line) { - Ok(msg) => { - if tx_read.send(Ok(msg)).await.is_err() { - break; - } - } - Err(e) => { - let _ = tx_read.send(Err(e.into())).await; - } - } - } - }); - - // Spawn stdin writer task - let mut stdin = stdin; - tokio::spawn(async move { - while let Some(message) = rx_write.recv().await { - let json = serde_json::to_string(&message).expect("Failed to serialize message"); - if stdin - .write_all(format!("{}\n", json).as_bytes()) - .await - .is_err() - { - break; - } - } - }); - - Ok((rx_read, tx_write)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - use std::time::Duration; - use tokio::time::timeout; - - #[tokio::test] - async fn test_stdio_transport() { - let transport = StdioTransport { - params: StdioServerParams { - command: "tee".to_string(), // tee will echo back what it receives - args: vec![], - env: None, - }, - }; - - let (mut rx, tx) = transport.connect().await.unwrap(); - - // Create test messages - let request = JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: Some(1), - method: "ping".to_string(), - params: None, - }); - - let response = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: Some(2), - result: Some(json!({})), - error: None, - }); - - // Send messages - tx.send(request.clone()).await.unwrap(); - tx.send(response.clone()).await.unwrap(); - - // Receive and verify messages - let mut read_messages = Vec::new(); - - // Use timeout to avoid hanging if messages aren't received - for _ in 0..2 { - match timeout(Duration::from_secs(1), rx.recv()).await { - Ok(Some(Ok(msg))) => read_messages.push(msg), - Ok(Some(Err(e))) => panic!("Received error: {}", e), - Ok(None) => break, - Err(_) => panic!("Timeout waiting for message"), - } - } - - assert_eq!(read_messages.len(), 2, "Expected 2 messages"); - assert_eq!(read_messages[0], request); - assert_eq!(read_messages[1], response); - } - - #[tokio::test] - async fn test_process_termination() { - let transport = StdioTransport { - params: StdioServerParams { - command: "sleep".to_string(), - args: vec!["0.3".to_string()], - env: None, - }, - }; - let (mut rx, _tx) = transport.connect().await.unwrap(); - - // Try to receive a message - should get an error about process termination - match timeout(Duration::from_secs(1), rx.recv()).await { - Ok(Some(Err(e))) => { - assert!( - e.to_string().contains("Child process terminated normally"), - "Expected process termination error, got: {}", - e - ); - } - _ => panic!("Expected error, got a different message"), - } - } -} diff --git a/crates/mcp-client/src/transport.rs b/crates/mcp-client/src/transport.rs deleted file mode 100644 index 2ccca05e9..000000000 --- a/crates/mcp-client/src/transport.rs +++ /dev/null @@ -1,14 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; -use mcp_core::protocol::JsonRpcMessage; -use tokio::sync::mpsc::{Receiver, Sender}; - -// Stream types for consistent interface -pub type ReadStream = Receiver>; -pub type WriteStream = Sender; - -// Common trait for transport implementations -#[async_trait] -pub trait Transport { - async fn connect(&self) -> Result<(ReadStream, WriteStream)>; -} diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs new file mode 100644 index 000000000..55e17d10a --- /dev/null +++ b/crates/mcp-client/src/transport/mod.rs @@ -0,0 +1,42 @@ +use async_trait::async_trait; +use thiserror::Error; + +/// A generic error type for transport operations. +#[derive(Debug, Error)] +pub enum Error { + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + + #[error("Transport was not connected or is already closed")] + NotConnected, + + #[error("Unexpected transport error: {0}")] + Other(String), +} + +/// A generic asynchronous transport trait. +/// +/// Implementations are expected to handle: +/// - starting the underlying communication channel (e.g., launching a child process, connecting a socket) +/// - sending JSON-RPC messages as strings +/// - receiving JSON-RPC messages as strings +/// - closing the transport cleanly +#[async_trait] +pub trait Transport: Send + 'static { + /// Start the transport and establish the underlying connection. + async fn start(&mut self) -> Result<(), Error>; + + /// Send a raw JSON-encoded message through the transport. + async fn send(&mut self, msg: String) -> Result<(), Error>; + + /// Receive a raw JSON-encoded message from the transport. + /// + /// This should return a single line representing one JSON message. + async fn receive(&mut self) -> Result; + + /// Close the transport and free any resources. + async fn close(&mut self) -> Result<(), Error>; +} + +pub mod stdio; +pub use stdio::StdioTransport; diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs new file mode 100644 index 000000000..cb223b4f6 --- /dev/null +++ b/crates/mcp-client/src/transport/stdio.rs @@ -0,0 +1,94 @@ +use super::{Error, Transport}; +use async_trait::async_trait; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::process::{Child, ChildStdin, ChildStdout, Command}; + +/// A `StdioTransport` uses a child process’s stdin/stdout as a communication channel. +/// +/// It starts the specified command with arguments and uses its stdin/stdout to send/receive +/// JSON-RPC messages line by line. This is useful for running MCP servers as subprocesses. +pub struct StdioTransport { + command: String, + args: Vec, + child: Option, + stdin: Option, + stdout: Option>, +} + +impl StdioTransport { + /// Create a new `StdioTransport` configured to run the given command with arguments. + /// + /// The transport will not start until `start()` is called. + pub fn new(command: S, args: I) -> Self + where + S: Into, + I: IntoIterator, + { + Self { + command: command.into(), + args: args.into_iter().map(Into::into).collect(), + child: None, + stdin: None, + stdout: None, + } + } +} + +#[async_trait] +impl Transport for StdioTransport { + async fn start(&mut self) -> Result<(), Error> { + if self.child.is_some() { + return Ok(()); // Already started + } + + let mut cmd = Command::new(&self.command); + cmd.args(&self.args) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()); + + let mut child = cmd.spawn()?; + + let stdin = child.stdin.take().ok_or(Error::NotConnected)?; + let stdout = child.stdout.take().ok_or(Error::NotConnected)?; + + self.stdin = Some(stdin); + self.stdout = Some(BufReader::new(stdout)); + self.child = Some(child); + + Ok(()) + } + + async fn send(&mut self, msg: String) -> Result<(), Error> { + let stdin = self.stdin.as_mut().ok_or(Error::NotConnected)?; + // Write the message followed by a newline + stdin.write_all(msg.as_bytes()).await?; + stdin.write_all(b"\n").await?; + stdin.flush().await?; + Ok(()) + } + + async fn receive(&mut self) -> Result { + let stdout = self.stdout.as_mut().ok_or(Error::NotConnected)?; + let mut line = String::new(); + let n = stdout.read_line(&mut line).await?; + if n == 0 { + // End of stream + return Err(Error::NotConnected); + } + Ok(line) + } + + async fn close(&mut self) -> Result<(), Error> { + // Drop stdin to signal EOF + self.stdin.take(); + self.stdout.take(); + + if let Some(mut child) = self.child.take() { + // Wait for child to exit + let _status = child.wait().await?; + } + + Ok(()) + } +} diff --git a/crates/mcp-core/src/tool.rs b/crates/mcp-core/src/tool.rs index 6401b9632..adb99ce12 100644 --- a/crates/mcp-core/src/tool.rs +++ b/crates/mcp-core/src/tool.rs @@ -5,6 +5,7 @@ use serde_json::Value; /// A tool that can be used by a model. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] pub struct Tool { /// The name of the tool pub name: String, From 9ac0f5ee42f7c022edccde4adfb1f2f03d5f1c38 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Tue, 17 Dec 2024 00:15:22 -0500 Subject: [PATCH 02/11] working: send initialized notification during initialization --- crates/mcp-client/src/client.rs | 60 ++++++++++++++++-------- crates/mcp-client/src/main.rs | 42 ++++------------- crates/mcp-client/src/service.rs | 7 ++- crates/mcp-client/src/transport/mod.rs | 33 +++++++++++-- crates/mcp-client/src/transport/stdio.rs | 47 +++++++++++-------- crates/mcp-core/src/protocol.rs | 1 + 6 files changed, 114 insertions(+), 76 deletions(-) diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 0cdab4391..3cb33277c 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -5,8 +5,12 @@ use tower::ServiceExt; // for Service::ready() use mcp_core::protocol::{ InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, - ListResourcesResult, ReadResourceResult, + ListResourcesResult, ListToolsResult, ReadResourceResult, }; +use std::sync::Arc; +use tokio::sync::Mutex; + +use crate::transport::{Error as TransportError, Transport}; /// Error type for MCP client operations. #[derive(Debug, Error)] @@ -48,12 +52,13 @@ pub struct InitializeParams { } /// The MCP client that sends requests via the provided service. -pub struct McpClient { +pub struct McpClient { service: S, + transport: Arc>, next_id: u64, } -impl McpClient +impl McpClient where S: tower::Service< JsonRpcRequest, @@ -61,18 +66,20 @@ where Error = super::service::ServiceError, > + Send, S::Future: Send, + T: Transport, { - pub fn new(service: S) -> Self { + pub fn new(service: S, transport: Arc>) -> Self { Self { service, + transport, next_id: 1, } } /// Send a JSON-RPC request and wait for a response. - async fn send_message(&mut self, method: &str, params: Value) -> Result + async fn send_message(&mut self, method: &str, params: Value) -> Result where - T: for<'de> Deserialize<'de>, + R: for<'de> Deserialize<'de>, { self.service.ready().await.map_err(|_| Error::NotReady)?; @@ -122,17 +129,21 @@ where } } - // /// Send a JSON-RPC notification. - // pub async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> { - // let notification = mcp_core::protocol::JsonRpcNotification { - // jsonrpc: "2.0".to_string(), - // method: method.to_string(), - // params: Some(params), - // }; - // let msg = serde_json::to_string(¬ification)?; - // let mut transport = self.transport.lock().await; - // transport.send(msg).await - // } + /// Send a JSON-RPC notification. + pub async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> { + let notification = mcp_core::protocol::JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: method.to_string(), + params: Some(params), + }; + let msg = serde_json::to_string(¬ification)?; + let transport = self.transport.lock().await; + // transport.send(msg).await + transport + .send(msg) + .await + .map_err(|e: TransportError| Error::Service(e.into())) + } /// Initialize the connection with the server. pub async fn initialize( @@ -145,8 +156,14 @@ where client_info: info, capabilities, }; - self.send_message("initialize", serde_json::to_value(params)?) - .await + let result: InitializeResult = self + .send_message("initialize", serde_json::to_value(params)?) + .await?; + + self.send_notification("notifications/initialized", serde_json::json!({})) + .await?; + + Ok(result) } /// List available resources. @@ -160,4 +177,9 @@ where let params = serde_json::json!({ "uri": uri }); self.send_message("resources/read", params).await } + + /// List tools + pub async fn list_tools(&mut self) -> Result { + self.send_message("tools/list", serde_json::json!({})).await + } } diff --git a/crates/mcp-client/src/main.rs b/crates/mcp-client/src/main.rs index bd4490fd0..cdf86fe9b 100644 --- a/crates/mcp-client/src/main.rs +++ b/crates/mcp-client/src/main.rs @@ -1,29 +1,11 @@ use anyhow::Result; use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClient}; use mcp_client::{service::TransportService, transport::StdioTransport}; +use std::sync::Arc; +use tokio::sync::Mutex; use tower::ServiceBuilder; use tracing_subscriber::EnvFilter; -// use mcp_client::{ -// service::{ServiceError}, -// transport::{Error as TransportError}, -// }; -// use std::time::Duration; -// use tower::timeout::error::Elapsed; - -// fn convert_box_error(err: Box) -> ServiceError { -// if let Some(elapsed) = err.downcast_ref::() { -// ServiceError::Transport(TransportError::Io( -// std::io::Error::new( -// std::io::ErrorKind::TimedOut, -// format!("Timeout elapsed: {}", elapsed), -// ), -// )) -// } else { -// ServiceError::Other(err.to_string()) -// } -// } - #[tokio::main] async fn main() -> Result<(), ClientError> { // Initialize logging @@ -35,14 +17,14 @@ async fn main() -> Result<(), ClientError> { ) .init(); - // Create the base transport - let transport = StdioTransport::new("uvx", ["mcp-server-git"]); + // Create the base transport as Arc> + let transport = Arc::new(Mutex::new(StdioTransport::new("uvx", ["mcp-server-git"]))); // Build service with middleware - let service = ServiceBuilder::new().service(TransportService::new(transport)); + let service = ServiceBuilder::new().service(TransportService::new(Arc::clone(&transport))); // Create client - let mut client = McpClient::new(service); + let mut client = McpClient::new(service, Arc::clone(&transport)); // Initialize let server_info = client @@ -54,15 +36,11 @@ async fn main() -> Result<(), ClientError> { ClientCapabilities::default(), ) .await?; - println!("Connected to server: {server_info:?}"); - - // List resources - let resources = client.list_resources().await?; - println!("Available resources: {resources:?}"); + println!("Connected to server: {server_info:?}\n"); - // Read a resource - let content = client.read_resource("file:///example.txt".into()).await?; - println!("Content: {content:?}"); + // List tools + let tools = client.list_tools().await?; + println!("Available tools: {tools:?}"); Ok(()) } diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs index f422bc924..6748f896a 100644 --- a/crates/mcp-client/src/service.rs +++ b/crates/mcp-client/src/service.rs @@ -37,6 +37,11 @@ impl TransportService { initialized: AtomicBool::new(false), } } + + /// Provides a clone of the transport handle for external access (e.g., for sending notifications). + pub fn get_transport_handle(&self) -> Arc> { + Arc::clone(&self.transport) + } } impl Service for TransportService { @@ -54,7 +59,7 @@ impl Service for TransportService { let started = self.initialized.load(Ordering::SeqCst); Box::pin(async move { - let mut transport = transport.lock().await; + let transport = transport.lock().await; // Initialize (start) transport on the first call. if !started { diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs index 55e17d10a..f346af833 100644 --- a/crates/mcp-client/src/transport/mod.rs +++ b/crates/mcp-client/src/transport/mod.rs @@ -1,5 +1,7 @@ use async_trait::async_trait; +use std::sync::Arc; use thiserror::Error; +use tokio::sync::Mutex; /// A generic error type for transport operations. #[derive(Debug, Error)] @@ -24,18 +26,41 @@ pub enum Error { #[async_trait] pub trait Transport: Send + 'static { /// Start the transport and establish the underlying connection. - async fn start(&mut self) -> Result<(), Error>; + async fn start(&self) -> Result<(), Error>; /// Send a raw JSON-encoded message through the transport. - async fn send(&mut self, msg: String) -> Result<(), Error>; + async fn send(&self, msg: String) -> Result<(), Error>; /// Receive a raw JSON-encoded message from the transport. /// /// This should return a single line representing one JSON message. - async fn receive(&mut self) -> Result; + async fn receive(&self) -> Result; /// Close the transport and free any resources. - async fn close(&mut self) -> Result<(), Error>; + async fn close(&self) -> Result<(), Error>; +} + +#[async_trait] +impl Transport for Arc> { + async fn start(&self) -> Result<(), Error> { + let transport = self.lock().await; + transport.start().await + } + + async fn send(&self, msg: String) -> Result<(), Error> { + let transport = self.lock().await; + transport.send(msg).await + } + + async fn receive(&self) -> Result { + let transport = self.lock().await; + transport.receive().await + } + + async fn close(&self) -> Result<(), Error> { + let transport = self.lock().await; + transport.close().await + } } pub mod stdio; diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index cb223b4f6..cac927574 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -2,6 +2,7 @@ use super::{Error, Transport}; use async_trait::async_trait; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::process::{Child, ChildStdin, ChildStdout, Command}; +use tokio::sync::Mutex; /// A `StdioTransport` uses a child process’s stdin/stdout as a communication channel. /// @@ -10,9 +11,9 @@ use tokio::process::{Child, ChildStdin, ChildStdout, Command}; pub struct StdioTransport { command: String, args: Vec, - child: Option, - stdin: Option, - stdout: Option>, + child: Mutex>, + stdin: Mutex>, + stdout: Mutex>>, } impl StdioTransport { @@ -27,17 +28,17 @@ impl StdioTransport { Self { command: command.into(), args: args.into_iter().map(Into::into).collect(), - child: None, - stdin: None, - stdout: None, + child: Mutex::new(None), + stdin: Mutex::new(None), + stdout: Mutex::new(None), } } } #[async_trait] impl Transport for StdioTransport { - async fn start(&mut self) -> Result<(), Error> { - if self.child.is_some() { + async fn start(&self) -> Result<(), Error> { + if self.child.lock().await.is_some() { return Ok(()); // Already started } @@ -52,15 +53,16 @@ impl Transport for StdioTransport { let stdin = child.stdin.take().ok_or(Error::NotConnected)?; let stdout = child.stdout.take().ok_or(Error::NotConnected)?; - self.stdin = Some(stdin); - self.stdout = Some(BufReader::new(stdout)); - self.child = Some(child); + *self.stdin.lock().await = Some(stdin); + *self.stdout.lock().await = Some(BufReader::new(stdout)); + *self.child.lock().await = Some(child); Ok(()) } - async fn send(&mut self, msg: String) -> Result<(), Error> { - let stdin = self.stdin.as_mut().ok_or(Error::NotConnected)?; + async fn send(&self, msg: String) -> Result<(), Error> { + let mut stdin = self.stdin.lock().await; + let stdin = stdin.as_mut().ok_or(Error::NotConnected)?; // Write the message followed by a newline stdin.write_all(msg.as_bytes()).await?; stdin.write_all(b"\n").await?; @@ -68,8 +70,9 @@ impl Transport for StdioTransport { Ok(()) } - async fn receive(&mut self) -> Result { - let stdout = self.stdout.as_mut().ok_or(Error::NotConnected)?; + async fn receive(&self) -> Result { + let mut stdout = self.stdout.lock().await; + let stdout = stdout.as_mut().ok_or(Error::NotConnected)?; let mut line = String::new(); let n = stdout.read_line(&mut line).await?; if n == 0 { @@ -79,14 +82,18 @@ impl Transport for StdioTransport { Ok(line) } - async fn close(&mut self) -> Result<(), Error> { + async fn close(&self) -> Result<(), Error> { + let mut child = self.child.lock().await; + let mut stdin = self.stdin.lock().await; + let mut stdout = self.stdout.lock().await; + // Drop stdin to signal EOF - self.stdin.take(); - self.stdout.take(); + *stdin = None; + *stdout = None; - if let Some(mut child) = self.child.take() { + if let Some(mut c) = child.take() { // Wait for child to exit - let _status = child.wait().await?; + let _status = c.wait().await?; } Ok(()) diff --git a/crates/mcp-core/src/protocol.rs b/crates/mcp-core/src/protocol.rs index 259050a20..4402ff135 100644 --- a/crates/mcp-core/src/protocol.rs +++ b/crates/mcp-core/src/protocol.rs @@ -23,6 +23,7 @@ pub struct JsonRpcResponse { pub struct JsonRpcNotification { pub jsonrpc: String, pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] pub params: Option, } From a9c51dbfc98abf315f8262e304d34f8670e8ceac Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Tue, 17 Dec 2024 00:16:07 -0500 Subject: [PATCH 03/11] update README --- crates/mcp-client/README.md | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/crates/mcp-client/README.md b/crates/mcp-client/README.md index 32e0c4c32..395b4ba93 100644 --- a/crates/mcp-client/README.md +++ b/crates/mcp-client/README.md @@ -1,13 +1,5 @@ ## Testing stdio ```bash -cargo run -p mcp_client -- --mode git -cargo run -p mcp_client -- --mode echo - -cargo run -p mcp_client --bin stdio +cargo run -p mcp-client ``` - -## Testing SSE - -1. Start the MCP server: `fastmcp run -t sse echo.py` -2. Run the client: `cargo run -p mcp_client --bin sse` From 58abcf5b73b7e845541c7f8b744fb70597c60987 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Tue, 17 Dec 2024 08:23:22 -0500 Subject: [PATCH 04/11] implement Drop trait to close transport when its out of scope --- crates/mcp-client/src/service.rs | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs index 6748f896a..98e3bb847 100644 --- a/crates/mcp-client/src/service.rs +++ b/crates/mcp-client/src/service.rs @@ -25,7 +25,7 @@ pub enum ServiceError { } /// A Tower `Service` implementation that uses a `Transport` to send/receive JsonRpcRequests and JsonRpcMessages. -pub struct TransportService { +pub struct TransportService { transport: Arc>, initialized: AtomicBool, } @@ -71,9 +71,20 @@ impl Service for TransportService { transport.send(msg).await?; let line = transport.receive().await?; - let response_msg: JsonRpcMessage = serde_json::from_str(&line)?; + let response: JsonRpcMessage = serde_json::from_str(&line)?; - Ok(response_msg) + Ok(response) }) } } + +impl Drop for TransportService { + fn drop(&mut self) { + if self.initialized.load(Ordering::SeqCst) { + // Create a new runtime for cleanup if needed + let rt = tokio::runtime::Runtime::new().unwrap(); + let transport = rt.block_on(self.transport.lock()); + let _ = rt.block_on(transport.close()); + } + } +} From 6ee7c8216911f696390bd6f3f5bcf73101d2c953 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Tue, 17 Dec 2024 08:50:17 -0500 Subject: [PATCH 05/11] add timeout middleware to the service --- crates/mcp-client/src/client.rs | 1 - crates/mcp-client/src/main.rs | 22 ++++++++++++++++++---- crates/mcp-client/src/service.rs | 3 +++ 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 3cb33277c..d568cf80f 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -138,7 +138,6 @@ where }; let msg = serde_json::to_string(¬ification)?; let transport = self.transport.lock().await; - // transport.send(msg).await transport .send(msg) .await diff --git a/crates/mcp-client/src/main.rs b/crates/mcp-client/src/main.rs index cdf86fe9b..904083505 100644 --- a/crates/mcp-client/src/main.rs +++ b/crates/mcp-client/src/main.rs @@ -1,9 +1,14 @@ use anyhow::Result; use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClient}; -use mcp_client::{service::TransportService, transport::StdioTransport}; +use mcp_client::{ + service::{ServiceError, TransportService}, + transport::StdioTransport, +}; use std::sync::Arc; +use std::time::Duration; use tokio::sync::Mutex; -use tower::ServiceBuilder; +use tower::timeout::TimeoutLayer; +use tower::{ServiceBuilder, ServiceExt}; use tracing_subscriber::EnvFilter; #[tokio::main] @@ -20,8 +25,17 @@ async fn main() -> Result<(), ClientError> { // Create the base transport as Arc> let transport = Arc::new(Mutex::new(StdioTransport::new("uvx", ["mcp-server-git"]))); - // Build service with middleware - let service = ServiceBuilder::new().service(TransportService::new(Arc::clone(&transport))); + // Build service with middleware including timeout + let service = ServiceBuilder::new() + .layer(TimeoutLayer::new(Duration::from_secs(30))) + .service(TransportService::new(Arc::clone(&transport))) + .map_err(|e: Box| { + if e.is::() { + ServiceError::Timeout(tower::timeout::error::Elapsed::new()) + } else { + ServiceError::Other(e.to_string()) + } + }); // Create client let mut client = McpClient::new(service, Arc::clone(&transport)); diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs index 98e3bb847..83a9dd67c 100644 --- a/crates/mcp-client/src/service.rs +++ b/crates/mcp-client/src/service.rs @@ -17,6 +17,9 @@ pub enum ServiceError { #[error("Serialization error: {0}")] Serialization(#[from] serde_json::Error), + #[error("Request timed out")] + Timeout(#[from] tower::timeout::error::Elapsed), + #[error("Other error: {0}")] Other(String), From 1a7dde6f14592047ff3d1f2498a816128f02a468 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Tue, 17 Dec 2024 10:45:30 -0500 Subject: [PATCH 06/11] add call_tool method, move to example dir --- crates/mcp-client/README.md | 4 ++-- crates/mcp-client/{src/main.rs => examples/stdio.rs} | 8 +++++++- crates/mcp-client/src/client.rs | 9 +++++++-- 3 files changed, 16 insertions(+), 5 deletions(-) rename crates/mcp-client/{src/main.rs => examples/stdio.rs} (87%) diff --git a/crates/mcp-client/README.md b/crates/mcp-client/README.md index 395b4ba93..05559abf7 100644 --- a/crates/mcp-client/README.md +++ b/crates/mcp-client/README.md @@ -1,5 +1,5 @@ -## Testing stdio +## Testing ```bash -cargo run -p mcp-client +cargo run -p mcp-client --example stdio ``` diff --git a/crates/mcp-client/src/main.rs b/crates/mcp-client/examples/stdio.rs similarity index 87% rename from crates/mcp-client/src/main.rs rename to crates/mcp-client/examples/stdio.rs index 904083505..04f6d7f65 100644 --- a/crates/mcp-client/src/main.rs +++ b/crates/mcp-client/examples/stdio.rs @@ -54,7 +54,13 @@ async fn main() -> Result<(), ClientError> { // List tools let tools = client.list_tools().await?; - println!("Available tools: {tools:?}"); + println!("Available tools: {tools:?}\n"); + + // Call tool 'git_status' wtih arguments = {"repo_path": "."} + let tool_result = client + .call_tool("git_status", serde_json::json!({ "repo_path": "." })) + .await?; + println!("Tool result: {tool_result:?}"); Ok(()) } diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index d568cf80f..e864e140a 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -4,8 +4,7 @@ use thiserror::Error; use tower::ServiceExt; // for Service::ready() use mcp_core::protocol::{ - InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, - ListResourcesResult, ListToolsResult, ReadResourceResult, + CallToolResult, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult, ReadResourceResult }; use std::sync::Arc; use tokio::sync::Mutex; @@ -181,4 +180,10 @@ where pub async fn list_tools(&mut self) -> Result { self.send_message("tools/list", serde_json::json!({})).await } + + // Call tool + pub async fn call_tool(&mut self, name: &str, arguments: Value) -> Result { + let params = serde_json::json!({ "name": name, "arguments": arguments }); + self.send_message("tools/call", params).await + } } From 0faf107780e65411b45a10a736ad8ec8808bae8e Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Tue, 17 Dec 2024 10:56:05 -0500 Subject: [PATCH 07/11] working: add SSE transport and example --- crates/mcp-client/README.md | 8 +- crates/mcp-client/examples/sse.rs | 73 ++++++++++++ crates/mcp-client/src/client.rs | 9 +- crates/mcp-client/src/transport/mod.rs | 12 ++ crates/mcp-client/src/transport/sse.rs | 155 +++++++++++++++++++++++++ 5 files changed, 254 insertions(+), 3 deletions(-) create mode 100644 crates/mcp-client/examples/sse.rs create mode 100644 crates/mcp-client/src/transport/sse.rs diff --git a/crates/mcp-client/README.md b/crates/mcp-client/README.md index 05559abf7..a43c4c210 100644 --- a/crates/mcp-client/README.md +++ b/crates/mcp-client/README.md @@ -1,5 +1,11 @@ -## Testing +## Testing stdio transport ```bash cargo run -p mcp-client --example stdio ``` + +## Testing SSE transport + +1. Start the MCP server in one terminal: `fastmcp run -t sse echo.py` +2. Run the client example in new terminal: `cargo run -p mcp-client --example sse` + diff --git a/crates/mcp-client/examples/sse.rs b/crates/mcp-client/examples/sse.rs new file mode 100644 index 000000000..d26aeb665 --- /dev/null +++ b/crates/mcp-client/examples/sse.rs @@ -0,0 +1,73 @@ +use anyhow::Result; +use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient}; +use mcp_client::{ + service::{ServiceError, TransportService}, + transport::SseTransport, +}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; +use tower::timeout::TimeoutLayer; +use tower::{ServiceBuilder, ServiceExt}; +use tracing_subscriber::EnvFilter; + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::from_default_env() + .add_directive("mcp_client=debug".parse().unwrap()) + .add_directive("reqwest_eventsource=debug".parse().unwrap()), + ) + .init(); + + // Create the base transport as Arc> + let transport = Arc::new(Mutex::new(SseTransport::new("http://localhost:8000/sse")?)); + + // Build service with middleware including timeout + let service = ServiceBuilder::new() + .layer(TimeoutLayer::new(Duration::from_secs(30))) + .service(TransportService::new(Arc::clone(&transport))) + .map_err(|e: Box| { + if e.is::() { + ServiceError::Timeout(tower::timeout::error::Elapsed::new()) + } else { + ServiceError::Other(e.to_string()) + } + }); + + // Create client + let mut client = McpClient::new(service, Arc::clone(&transport)); + println!("Client created\n"); + + // Initialize + let server_info = client + .initialize( + ClientInfo { + name: "test-client".into(), + version: "1.0.0".into(), + }, + ClientCapabilities::default(), + ) + .await?; + println!("Connected to server: {server_info:?}\n"); + + // Sleep for 100ms to allow the server to start - surprisingly this is required! + tokio::time::sleep(Duration::from_millis(100)).await; + + // List tools + let tools = client.list_tools().await?; + println!("Available tools: {tools:?}\n"); + + // Call tool + let tool_result = client + .call_tool( + "echo_tool", + serde_json::json!({ "message": "Client with SSE transport - calling a tool" }), + ) + .await?; + println!("Tool result: {tool_result:?}"); + + Ok(()) +} diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index e864e140a..38ecf6c88 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -4,7 +4,8 @@ use thiserror::Error; use tower::ServiceExt; // for Service::ready() use mcp_core::protocol::{ - CallToolResult, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult, ReadResourceResult + CallToolResult, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcRequest, + JsonRpcResponse, ListResourcesResult, ListToolsResult, ReadResourceResult, }; use std::sync::Arc; use tokio::sync::Mutex; @@ -182,7 +183,11 @@ where } // Call tool - pub async fn call_tool(&mut self, name: &str, arguments: Value) -> Result { + pub async fn call_tool( + &mut self, + name: &str, + arguments: Value, + ) -> Result { let params = serde_json::json!({ "name": name, "arguments": arguments }); self.send_message("tools/call", params).await } diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs index f346af833..12da4003a 100644 --- a/crates/mcp-client/src/transport/mod.rs +++ b/crates/mcp-client/src/transport/mod.rs @@ -12,6 +12,15 @@ pub enum Error { #[error("Transport was not connected or is already closed")] NotConnected, + #[error("Invalid URL provided")] + InvalidUrl, + + #[error("Connection timeout")] + Timeout, + + #[error("Failed to send message")] + SendFailed, + #[error("Unexpected transport error: {0}")] Other(String), } @@ -65,3 +74,6 @@ impl Transport for Arc> { pub mod stdio; pub use stdio::StdioTransport; + +pub mod sse; +pub use sse::SseTransport; diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs new file mode 100644 index 000000000..7945197f3 --- /dev/null +++ b/crates/mcp-client/src/transport/sse.rs @@ -0,0 +1,155 @@ +use super::{Error, Transport}; +use async_trait::async_trait; +use futures_util::StreamExt; +use reqwest::{Client, Url}; +use reqwest_eventsource::{Event, EventSource}; +use std::sync::Arc; +use tokio::sync::{mpsc, Mutex}; +use tracing::{debug, error, info}; + +pub struct SseTransport { + connection_url: Url, + endpoint: Arc>>, + http_client: Client, + event_source: Arc>>, + message_rx: Arc>>>, + message_tx: mpsc::Sender, +} + +impl SseTransport { + pub fn new(url: &str) -> Result { + let (message_tx, message_rx) = mpsc::channel(100); + + Ok(Self { + connection_url: Url::parse(url).map_err(|_| Error::InvalidUrl)?, + endpoint: Arc::new(Mutex::new(None)), + http_client: Client::new(), + event_source: Arc::new(Mutex::new(None)), + message_rx: Arc::new(Mutex::new(Some(message_rx))), + message_tx, + }) + } +} + +/// Constructs the endpoint URL by removing "/sse" from the connection URL +/// and appending the given suffix. +fn construct_endpoint_url(base_url: &Url, url_suffix: &str) -> Result { + let trimmed_base = base_url.as_str().trim_end_matches("/sse"); + let trimmed_base = trimmed_base.trim_end_matches('/'); + let trimmed_suffix = url_suffix.trim_start_matches('/'); + let full_url = format!("{}/{}", trimmed_base, trimmed_suffix); + Url::parse(&full_url) +} + +#[async_trait] +impl Transport for SseTransport { + async fn start(&self) -> Result<(), Error> { + if self.event_source.lock().await.is_some() { + return Ok(()); + } + + let event_source = EventSource::get(self.connection_url.as_str()); + let message_tx = self.message_tx.clone(); + let endpoint = self.endpoint.clone(); + + // Store event source + *self.event_source.lock().await = Some(event_source); + + // Create a new event source for the task + let mut stream = EventSource::get(self.connection_url.as_str()); + + let connection_url = self.connection_url.clone(); + let cloned_connection_url = connection_url.clone(); + + // Spawn a task to handle incoming events + tokio::spawn(async move { + while let Some(event) = stream.next().await { + match event { + Ok(Event::Open) => { + // Connection established + info!("\nSSE connection opened"); + } + Ok(Event::Message(message)) => { + debug!("Received SSE event: {} - {}", message.event, message.data); + // Check if this is an endpoint event + if message.event == "endpoint" { + let url_suffix = &message.data; + debug!("Received endpoint URL suffix: {}", url_suffix); + match construct_endpoint_url(&cloned_connection_url, url_suffix) { + Ok(url) => { + info!("Endpoint URL: {}", url); + let mut endpoint_guard = endpoint.lock().await; + *endpoint_guard = Some(url); + } + Err(e) => { + error!("Failed to construct endpoint URL: {}", e); + // Optionally, handle the error (e.g., retry, notify, etc.) + } + } + } else { + // Regular message + // Assuming message.data is the message payload + if let Err(e) = message_tx.send(message.data).await { + error!("Failed to send message: {}", e); + } + } + } + Err(e) => { + error!("EventSource error: {}", e); + break; + } + } + } + }); + + // Wait for endpoint URL: every 100ms, check if the endpoint is set upto 30s timeout + let timeout = tokio::time::sleep(std::time::Duration::from_secs(30)); + tokio::pin!(timeout); + + loop { + tokio::select! { + _ = &mut timeout => { + return Err(Error::Timeout); + } + _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => { + let endpoint_guard = self.endpoint.lock().await; + if endpoint_guard.is_some() { + break; + } + } + } + } + + Ok(()) + } + + async fn send(&self, msg: String) -> Result<(), Error> { + let endpoint = { + let endpoint_guard = self.endpoint.lock().await; + endpoint_guard.as_ref().ok_or(Error::NotConnected)?.clone() + }; + + self.http_client + .post(endpoint) + .header("Content-Type", "application/json") + .body(msg) + .send() + .await + .map_err(|_| Error::SendFailed)?; + + Ok(()) + } + + async fn receive(&self) -> Result { + let mut rx_guard = self.message_rx.lock().await; + let rx = rx_guard.as_mut().ok_or(Error::NotConnected)?; + + rx.recv().await.ok_or(Error::NotConnected) + } + + async fn close(&self) -> Result<(), Error> { + *self.event_source.lock().await = None; + *self.endpoint.lock().await = None; + Ok(()) + } +} From c073d19461cd39768d6dd60c8a611dfb1f71d8d7 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Wed, 18 Dec 2024 15:43:47 -0500 Subject: [PATCH 08/11] Remove transport field in McpClient and let Service handle json rpc msgs --- crates/mcp-client/examples/sse.rs | 2 +- crates/mcp-client/examples/stdio.rs | 2 +- crates/mcp-client/src/client.rs | 40 +++++++++++---------------- crates/mcp-client/src/service.rs | 42 ++++++++++++++++++++--------- 4 files changed, 48 insertions(+), 38 deletions(-) diff --git a/crates/mcp-client/examples/sse.rs b/crates/mcp-client/examples/sse.rs index d26aeb665..89d3560f2 100644 --- a/crates/mcp-client/examples/sse.rs +++ b/crates/mcp-client/examples/sse.rs @@ -38,7 +38,7 @@ async fn main() -> Result<()> { }); // Create client - let mut client = McpClient::new(service, Arc::clone(&transport)); + let mut client = McpClient::new(service); println!("Client created\n"); // Initialize diff --git a/crates/mcp-client/examples/stdio.rs b/crates/mcp-client/examples/stdio.rs index 04f6d7f65..3b26af0d5 100644 --- a/crates/mcp-client/examples/stdio.rs +++ b/crates/mcp-client/examples/stdio.rs @@ -38,7 +38,7 @@ async fn main() -> Result<(), ClientError> { }); // Create client - let mut client = McpClient::new(service, Arc::clone(&transport)); + let mut client = McpClient::new(service); // Initialize let server_info = client diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 38ecf6c88..28a7a506e 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -4,13 +4,9 @@ use thiserror::Error; use tower::ServiceExt; // for Service::ready() use mcp_core::protocol::{ - CallToolResult, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcRequest, - JsonRpcResponse, ListResourcesResult, ListToolsResult, ReadResourceResult, + CallToolResult, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcNotification, + JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult, ReadResourceResult, }; -use std::sync::Arc; -use tokio::sync::Mutex; - -use crate::transport::{Error as TransportError, Transport}; /// Error type for MCP client operations. #[derive(Debug, Error)] @@ -52,26 +48,23 @@ pub struct InitializeParams { } /// The MCP client that sends requests via the provided service. -pub struct McpClient { +pub struct McpClient { service: S, - transport: Arc>, next_id: u64, } -impl McpClient +impl McpClient where S: tower::Service< - JsonRpcRequest, + JsonRpcMessage, Response = JsonRpcMessage, Error = super::service::ServiceError, > + Send, S::Future: Send, - T: Transport, { - pub fn new(service: S, transport: Arc>) -> Self { + pub fn new(service: S) -> Self { Self { service, - transport, next_id: 1, } } @@ -83,12 +76,12 @@ where { self.service.ready().await.map_err(|_| Error::NotReady)?; - let request = JsonRpcRequest { + let request = JsonRpcMessage::Request(JsonRpcRequest { jsonrpc: "2.0".to_string(), id: Some(self.next_id), method: method.to_string(), params: Some(params), - }; + }); self.next_id += 1; @@ -130,18 +123,17 @@ where } /// Send a JSON-RPC notification. - pub async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> { - let notification = mcp_core::protocol::JsonRpcNotification { + pub async fn send_notification(&mut self, method: &str, params: Value) -> Result<(), Error> { + self.service.ready().await.map_err(|_| Error::NotReady)?; + + let notification = JsonRpcMessage::Notification(JsonRpcNotification { jsonrpc: "2.0".to_string(), method: method.to_string(), params: Some(params), - }; - let msg = serde_json::to_string(¬ification)?; - let transport = self.transport.lock().await; - transport - .send(msg) - .await - .map_err(|e: TransportError| Error::Service(e.into())) + }); + + self.service.call(notification).await?; + Ok(()) } /// Initialize the connection with the server. diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs index 83a9dd67c..589f6681c 100644 --- a/crates/mcp-client/src/service.rs +++ b/crates/mcp-client/src/service.rs @@ -7,7 +7,7 @@ use tokio::sync::Mutex; use tower::Service; use crate::transport::{Error as TransportError, Transport}; -use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest}; +use mcp_core::protocol::{JsonRpcMessage, JsonRpcResponse}; #[derive(Debug, thiserror::Error)] pub enum ServiceError { @@ -27,7 +27,7 @@ pub enum ServiceError { UnexpectedResponse, } -/// A Tower `Service` implementation that uses a `Transport` to send/receive JsonRpcRequests and JsonRpcMessages. +/// A Tower `Service` implementation that uses a `Transport` to send/receive JsonRpcMessages and JsonRpcMessages. pub struct TransportService { transport: Arc>, initialized: AtomicBool, @@ -47,7 +47,7 @@ impl TransportService { } } -impl Service for TransportService { +impl Service for TransportService { type Response = JsonRpcMessage; type Error = ServiceError; type Future = Pin> + Send>>; @@ -57,7 +57,7 @@ impl Service for TransportService { Poll::Ready(Ok(())) } - fn call(&mut self, request: JsonRpcRequest) -> Self::Future { + fn call(&mut self, message: JsonRpcMessage) -> Self::Future { let transport = Arc::clone(&self.transport); let started = self.initialized.load(Ordering::SeqCst); @@ -69,14 +69,32 @@ impl Service for TransportService { transport.start().await?; } - // Serialize request to JSON line - let msg = serde_json::to_string(&request)?; - transport.send(msg).await?; - - let line = transport.receive().await?; - let response: JsonRpcMessage = serde_json::from_str(&line)?; - - Ok(response) + match message { + JsonRpcMessage::Notification(notification) => { + // Serialize notification + let msg = serde_json::to_string(¬ification)?; + transport.send(msg).await?; + // For notifications, the protocol does not require a response + // So we return an empty response here and this is not checked upstream + let response: JsonRpcMessage = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: None, + result: None, + error: None, + }); + + Ok(response) + } + JsonRpcMessage::Request(request) => { + // Serialize request & wait for response + let msg = serde_json::to_string(&request)?; + transport.send(msg).await?; + let line = transport.receive().await?; + let response: JsonRpcMessage = serde_json::from_str(&line)?; + Ok(response) + } + _ => return Err(ServiceError::Other("Invalid message type".to_string())), + } }) } } From 0a0efc32474c87c52732ae61e4e7f44105a5cfa3 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Wed, 18 Dec 2024 17:01:58 -0500 Subject: [PATCH 09/11] add example to create collection of clients --- crates/mcp-client/examples/clients.rs | 81 +++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 crates/mcp-client/examples/clients.rs diff --git a/crates/mcp-client/examples/clients.rs b/crates/mcp-client/examples/clients.rs new file mode 100644 index 000000000..d6e2c6afc --- /dev/null +++ b/crates/mcp-client/examples/clients.rs @@ -0,0 +1,81 @@ +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; + +use mcp_client::{ + client::{ClientCapabilities, ClientInfo, McpClient}, + service::{ServiceError, TransportService}, + transport::StdioTransport, +}; +use tower::{ServiceBuilder, ServiceExt}; +use tower::timeout::TimeoutLayer; +use tracing_subscriber::EnvFilter; +use tower::util::BoxService; +use mcp_core::protocol::JsonRpcMessage; + +// Define a type alias for the boxed service using BoxService +type BoxedService = BoxService; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::from_default_env() + .add_directive("mcp_client=debug".parse().unwrap()) + .add_directive("reqwest_eventsource=debug".parse().unwrap()), + ) + .init(); + + // Create two separate clients with stdio transport + let client1 = create_client("client1", "1.0.0")?; + let client2 = create_client("client2", "1.0.0")?; + + // Initialize both clients + let mut clients: Vec> = vec![client1, client2]; + + // Initialize all clients + for (i, client) in clients.iter_mut().enumerate() { + let info = ClientInfo { + name: format!("example-client-{}", i + 1), + version: "1.0.0".to_string(), + }; + let capabilities = ClientCapabilities::default(); + + println!("\nInitializing client {}", i + 1); + let init_result = client.initialize(info, capabilities).await?; + println!("Client {} initialized: {:?}", i + 1, init_result); + } + + // List tools for each client + for (i, client) in clients.iter_mut().enumerate() { + let tools = client.list_tools().await?; + println!("\nClient {} tools: {:?}", i + 1, tools); + } + + Ok(()) +} + +fn create_client( + _name: &str, + _version: &str, +) -> Result>, Box> { + // Create the transport + let transport = Arc::new(Mutex::new(StdioTransport::new("uvx", ["mcp-server-git"]))); + + // Build service with middleware including timeout + let service = ServiceBuilder::new() + .layer(TimeoutLayer::new(Duration::from_secs(30))) + .service(TransportService::new(Arc::clone(&transport))) + .map_err(|e: Box| { + if e.is::() { + ServiceError::Timeout(tower::timeout::error::Elapsed::new()) + } else { + ServiceError::Other(e.to_string()) + } + }) + .boxed(); // Box the service to create a BoxService + + // Create the client + Ok(McpClient::new(service)) +} From 40d5beb53a73e836fcc8d9237e943cb688c0f27c Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Thu, 19 Dec 2024 10:42:16 -0500 Subject: [PATCH 10/11] make McpClient a trait and current version (McpClientImpl) an implementation * checks out the changes from 'kalvin/mcp-client-trait' branch Co-authored-by: kalvinnchau --- crates/mcp-client/examples/clients.rs | 50 +++++++++++++++------ crates/mcp-client/examples/sse.rs | 4 +- crates/mcp-client/examples/stdio.rs | 6 ++- crates/mcp-client/src/client.rs | 62 +++++++++++++++++++-------- 4 files changed, 86 insertions(+), 36 deletions(-) diff --git a/crates/mcp-client/examples/clients.rs b/crates/mcp-client/examples/clients.rs index d6e2c6afc..43d31c4c2 100644 --- a/crates/mcp-client/examples/clients.rs +++ b/crates/mcp-client/examples/clients.rs @@ -3,18 +3,15 @@ use std::time::Duration; use tokio::sync::Mutex; use mcp_client::{ - client::{ClientCapabilities, ClientInfo, McpClient}, + client::{ClientCapabilities, ClientInfo, McpClient, McpClientImpl}, service::{ServiceError, TransportService}, - transport::StdioTransport, + transport::{SseTransport, StdioTransport}, }; -use tower::{ServiceBuilder, ServiceExt}; +use mcp_core::protocol::JsonRpcMessage; use tower::timeout::TimeoutLayer; -use tracing_subscriber::EnvFilter; use tower::util::BoxService; -use mcp_core::protocol::JsonRpcMessage; - -// Define a type alias for the boxed service using BoxService -type BoxedService = BoxService; +use tower::{ServiceBuilder, ServiceExt}; +use tracing_subscriber::EnvFilter; #[tokio::main] async fn main() -> Result<(), Box> { @@ -30,9 +27,13 @@ async fn main() -> Result<(), Box> { // Create two separate clients with stdio transport let client1 = create_client("client1", "1.0.0")?; let client2 = create_client("client2", "1.0.0")?; + let client3 = create_sse_client("client3", "1.0.0")?; // Initialize both clients - let mut clients: Vec> = vec![client1, client2]; + let mut clients: Vec> = Vec::new(); + clients.push(client1); + clients.push(client2); + clients.push(client3); // Initialize all clients for (i, client) in clients.iter_mut().enumerate() { @@ -59,7 +60,7 @@ async fn main() -> Result<(), Box> { fn create_client( _name: &str, _version: &str, -) -> Result>, Box> { +) -> Result, Box> { // Create the transport let transport = Arc::new(Mutex::new(StdioTransport::new("uvx", ["mcp-server-git"]))); @@ -73,9 +74,30 @@ fn create_client( } else { ServiceError::Other(e.to_string()) } - }) - .boxed(); // Box the service to create a BoxService + }); + + Ok(Box::new(McpClientImpl::new(service))) +} + +fn create_sse_client( + _name: &str, + _version: &str, +) -> Result, Box> { + let transport = Arc::new(Mutex::new( + SseTransport::new("http://localhost:8000/sse").unwrap(), + )); + + // Build service with middleware including timeout + let service = ServiceBuilder::new() + .layer(TimeoutLayer::new(Duration::from_secs(30))) + .service(TransportService::new(Arc::clone(&transport))) + .map_err(|e: Box| { + if e.is::() { + ServiceError::Timeout(tower::timeout::error::Elapsed::new()) + } else { + ServiceError::Other(e.to_string()) + } + }); - // Create the client - Ok(McpClient::new(service)) + Ok(Box::new(McpClientImpl::new(service))) } diff --git a/crates/mcp-client/examples/sse.rs b/crates/mcp-client/examples/sse.rs index 89d3560f2..3d7e570d4 100644 --- a/crates/mcp-client/examples/sse.rs +++ b/crates/mcp-client/examples/sse.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient}; +use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientImpl}; use mcp_client::{ service::{ServiceError, TransportService}, transport::SseTransport, @@ -38,7 +38,7 @@ async fn main() -> Result<()> { }); // Create client - let mut client = McpClient::new(service); + let mut client = McpClientImpl::new(service); println!("Client created\n"); // Initialize diff --git a/crates/mcp-client/examples/stdio.rs b/crates/mcp-client/examples/stdio.rs index 3b26af0d5..7300512dd 100644 --- a/crates/mcp-client/examples/stdio.rs +++ b/crates/mcp-client/examples/stdio.rs @@ -1,5 +1,7 @@ use anyhow::Result; -use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClient}; +use mcp_client::client::{ + ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientImpl, +}; use mcp_client::{ service::{ServiceError, TransportService}, transport::StdioTransport, @@ -38,7 +40,7 @@ async fn main() -> Result<(), ClientError> { }); // Create client - let mut client = McpClient::new(service); + let mut client = McpClientImpl::new(service); // Initialize let server_info = client diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 28a7a506e..270385f4e 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -47,13 +47,36 @@ pub struct InitializeParams { pub client_info: ClientInfo, } -/// The MCP client that sends requests via the provided service. -pub struct McpClient { +/// The MCP client trait defining the interface for MCP operations. +#[async_trait::async_trait] +pub trait McpClient { + /// Initialize the connection with the server. + async fn initialize( + &mut self, + info: ClientInfo, + capabilities: ClientCapabilities, + ) -> Result; + + /// List available resources. + async fn list_resources(&mut self) -> Result; + + /// Read a resource's content. + async fn read_resource(&mut self, uri: &str) -> Result; + + /// List available tools. + async fn list_tools(&mut self) -> Result; + + /// Call a specific tool with arguments. + async fn call_tool(&mut self, name: &str, arguments: Value) -> Result; +} + +/// Standard implementation of the MCP client that sends requests via the provided service. +pub struct McpClientImpl { service: S, next_id: u64, } -impl McpClient +impl McpClientImpl where S: tower::Service< JsonRpcMessage, @@ -123,7 +146,7 @@ where } /// Send a JSON-RPC notification. - pub async fn send_notification(&mut self, method: &str, params: Value) -> Result<(), Error> { + async fn send_notification(&mut self, method: &str, params: Value) -> Result<(), Error> { self.service.ready().await.map_err(|_| Error::NotReady)?; let notification = JsonRpcMessage::Notification(JsonRpcNotification { @@ -135,9 +158,20 @@ where self.service.call(notification).await?; Ok(()) } +} - /// Initialize the connection with the server. - pub async fn initialize( +#[async_trait::async_trait] +impl McpClient for McpClientImpl +where + S: tower::Service< + JsonRpcMessage, + Response = JsonRpcMessage, + Error = super::service::ServiceError, + > + Send + + Sync, + S::Future: Send, +{ + async fn initialize( &mut self, info: ClientInfo, capabilities: ClientCapabilities, @@ -157,29 +191,21 @@ where Ok(result) } - /// List available resources. - pub async fn list_resources(&mut self) -> Result { + async fn list_resources(&mut self) -> Result { self.send_message("resources/list", serde_json::json!({})) .await } - /// Read a resource's content. - pub async fn read_resource(&mut self, uri: &str) -> Result { + async fn read_resource(&mut self, uri: &str) -> Result { let params = serde_json::json!({ "uri": uri }); self.send_message("resources/read", params).await } - /// List tools - pub async fn list_tools(&mut self) -> Result { + async fn list_tools(&mut self) -> Result { self.send_message("tools/list", serde_json::json!({})).await } - // Call tool - pub async fn call_tool( - &mut self, - name: &str, - arguments: Value, - ) -> Result { + async fn call_tool(&mut self, name: &str, arguments: Value) -> Result { let params = serde_json::json!({ "name": name, "arguments": arguments }); self.send_message("tools/call", params).await } From 67ea15e85f0b068dbcf687d66bc9496332e8757a Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Thu, 19 Dec 2024 10:46:47 -0500 Subject: [PATCH 11/11] Add JsonRpcMessage::Nil for responding to notifications --- crates/mcp-client/examples/clients.rs | 2 -- crates/mcp-client/src/service.rs | 11 ++--------- crates/mcp-core/src/protocol.rs | 1 + 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/crates/mcp-client/examples/clients.rs b/crates/mcp-client/examples/clients.rs index 43d31c4c2..df8fac9c5 100644 --- a/crates/mcp-client/examples/clients.rs +++ b/crates/mcp-client/examples/clients.rs @@ -7,9 +7,7 @@ use mcp_client::{ service::{ServiceError, TransportService}, transport::{SseTransport, StdioTransport}, }; -use mcp_core::protocol::JsonRpcMessage; use tower::timeout::TimeoutLayer; -use tower::util::BoxService; use tower::{ServiceBuilder, ServiceExt}; use tracing_subscriber::EnvFilter; diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs index 589f6681c..76b081720 100644 --- a/crates/mcp-client/src/service.rs +++ b/crates/mcp-client/src/service.rs @@ -7,7 +7,7 @@ use tokio::sync::Mutex; use tower::Service; use crate::transport::{Error as TransportError, Transport}; -use mcp_core::protocol::{JsonRpcMessage, JsonRpcResponse}; +use mcp_core::protocol::JsonRpcMessage; #[derive(Debug, thiserror::Error)] pub enum ServiceError { @@ -76,14 +76,7 @@ impl Service for TransportService { transport.send(msg).await?; // For notifications, the protocol does not require a response // So we return an empty response here and this is not checked upstream - let response: JsonRpcMessage = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: None, - result: None, - error: None, - }); - - Ok(response) + Ok(JsonRpcMessage::Nil) } JsonRpcMessage::Request(request) => { // Serialize request & wait for response diff --git a/crates/mcp-core/src/protocol.rs b/crates/mcp-core/src/protocol.rs index 4402ff135..87f846fe8 100644 --- a/crates/mcp-core/src/protocol.rs +++ b/crates/mcp-core/src/protocol.rs @@ -41,6 +41,7 @@ pub enum JsonRpcMessage { Response(JsonRpcResponse), Notification(JsonRpcNotification), Error(JsonRpcError), + Nil, // used to respond to notifications } // Standard JSON-RPC error codes