Skip to content

Commit

Permalink
fix: check server capability when client sends requests (#558)
Browse files Browse the repository at this point in the history
  • Loading branch information
salman1993 authored Jan 8, 2025
1 parent bb706fb commit 640c38b
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 11 deletions.
2 changes: 1 addition & 1 deletion crates/goose/src/agents/capabilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl Capabilities {
/// Add a new MCP system based on the provided client type
// TODO IMPORTANT need to ensure this times out if the system command is broken!
pub async fn add_system(&mut self, config: SystemConfig) -> SystemResult<()> {
let client: McpClient = match config {
let mut client: McpClient = match config {
SystemConfig::Sse { ref uri } => {
let transport = SseTransport::new(uri);
McpClient::new(transport.start().await?)
Expand Down
10 changes: 7 additions & 3 deletions crates/goose/src/agents/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ impl DefaultAgent {
&resources,
Some(model_name),
);

let mut status_content: Vec<String> = Vec::new();

if approx_count > target_limit {
Expand Down Expand Up @@ -217,15 +216,20 @@ impl Agent for DefaultAgent {
}

// Update conversation history for the start of the reply
let resources = capabilities.get_resources().await?;
let mut messages = self
.prepare_inference(
&system_prompt,
&tools,
messages,
&Vec::new(),
estimated_limit,
&capabilities.provider().get_model_config().model_name,
&capabilities.get_resources().await?,
&capabilities
.provider()
.get_model_config()
.model_name
.clone(),
&resources,
)
.await?;

Expand Down
2 changes: 1 addition & 1 deletion crates/mcp-client/examples/sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async fn main() -> Result<()> {
let handle = transport.start().await?;

// Create client
let client = McpClient::new(handle);
let mut client = McpClient::new(handle);
println!("Client created\n");

// Initialize
Expand Down
6 changes: 5 additions & 1 deletion crates/mcp-client/examples/stdio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async fn main() -> Result<(), ClientError> {
let transport_handle = transport.start().await?;

// 3) Create the client
let client = McpClient::new(transport_handle);
let mut client = McpClient::new(transport_handle);

// Initialize
let server_info = client
Expand All @@ -45,5 +45,9 @@ async fn main() -> Result<(), ClientError> {
.await?;
println!("Tool result: {tool_result:?}\n");

// List resources
let resources = client.list_resources().await?;
println!("Available resources: {resources:?}\n");

Ok(())
}
2 changes: 1 addition & 1 deletion crates/mcp-client/examples/stdio_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async fn main() -> Result<(), ClientError> {
let transport_handle = transport.start().await.unwrap();

// Create client
let client = McpClient::new(transport_handle);
let mut client = McpClient::new(transport_handle);

// Initialize
let server_info = client
Expand Down
69 changes: 65 additions & 4 deletions crates/mcp-client/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use std::sync::atomic::{AtomicU64, Ordering};

use crate::transport::TransportHandle;
use mcp_core::protocol::{
CallToolResult, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcNotification,
JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult, ReadResourceResult,
ServerCapabilities, METHOD_NOT_FOUND,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
use tokio::sync::Mutex;
use tower::{Service, ServiceExt};

use crate::transport::TransportHandle; // for Service::ready()
use tower::{Service, ServiceExt}; // for Service::ready()

/// Error type for MCP client operations.
#[derive(Debug, Error)]
Expand All @@ -27,6 +27,9 @@ pub enum Error {
#[error("Unexpected response from server")]
UnexpectedResponse,

#[error("Not initialized")]
NotInitialized,

#[error("Timeout or service not ready")]
NotReady,
}
Expand Down Expand Up @@ -55,6 +58,7 @@ pub struct InitializeParams {
pub struct McpClient {
service: Mutex<TransportHandle>,
next_id: AtomicU64,
server_capabilities: Option<ServerCapabilities>,
}

impl McpClient {
Expand All @@ -63,6 +67,7 @@ impl McpClient {
Self {
service: Mutex::new(transport_handle),
next_id: AtomicU64::new(1),
server_capabilities: None, // set during initialization
}
}

Expand Down Expand Up @@ -135,7 +140,7 @@ impl McpClient {
}

pub async fn initialize(
&self,
&mut self,
info: ClientInfo,
capabilities: ClientCapabilities,
) -> Result<InitializeResult, Error> {
Expand All @@ -151,24 +156,80 @@ impl McpClient {
self.send_notification("notifications/initialized", serde_json::json!({}))
.await?;

self.server_capabilities = Some(result.capabilities.clone());

Ok(result)
}

fn completed_initialization(&self) -> bool {
self.server_capabilities.is_some()
}

pub async fn list_resources(&self) -> Result<ListResourcesResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
// If resources is not supported, return an empty list
if self
.server_capabilities
.as_ref()
.unwrap()
.resources
.is_none()
{
return Ok(ListResourcesResult { resources: vec![] });
}

self.send_request("resources/list", serde_json::json!({}))
.await
}

pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
// If resources is not supported, return an error
if self
.server_capabilities
.as_ref()
.unwrap()
.resources
.is_none()
{
return Err(Error::RpcError {
code: METHOD_NOT_FOUND,
message: "Server does not support 'resources' capability".to_string(),
});
}

let params = serde_json::json!({ "uri": uri });
self.send_request("resources/read", params).await
}

pub async fn list_tools(&self) -> Result<ListToolsResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
// If tools is not supported, return an empty list
if self.server_capabilities.as_ref().unwrap().tools.is_none() {
return Ok(ListToolsResult { tools: vec![] });
}

self.send_request("tools/list", serde_json::json!({})).await
}

pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
// If tools is not supported, return an error
if self.server_capabilities.as_ref().unwrap().tools.is_none() {
return Err(Error::RpcError {
code: METHOD_NOT_FOUND,
message: "Server does not support 'tools' capability".to_string(),
});
}

let params = serde_json::json!({ "name": name, "arguments": arguments });
self.send_request("tools/call", params).await
}
Expand Down

0 comments on commit 640c38b

Please sign in to comment.