Skip to content

Commit

Permalink
Merge branch 'v1.0' into micn/use-rust-tls
Browse files Browse the repository at this point in the history
* v1.0:
  fix: Ldelalande/fix scroll (#504)
  feat: MCP server sdk (simple version first) (#499)
  • Loading branch information
michaelneale committed Dec 19, 2024
2 parents c7d1d7a + 2cd1314 commit 577ba13
Show file tree
Hide file tree
Showing 17 changed files with 1,262 additions and 169 deletions.
129 changes: 62 additions & 67 deletions crates/goose/src/token_counter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ impl TokenCounter {
#[cfg(test)]
mod tests {
use super::*;
use crate::message::MessageContent;
use mcp_core::role::Role;
use serde_json::json;

#[test]
fn test_add_tokenizer_and_count_tokens() {
Expand Down Expand Up @@ -242,73 +245,65 @@ mod tests {
assert_eq!(count, 3);
}

#[cfg(test)]
mod tests {
use super::*;
use crate::message::MessageContent;
use mcp_core::role::Role;
use serde_json::json;

#[test]
fn test_count_chat_tokens() {
let token_counter = TokenCounter::new();

let system_prompt =
"You are a helpful assistant that can answer questions about the weather.";

let messages = vec![
Message {
role: Role::User,
created: 0,
content: vec![MessageContent::text(
"What's the weather like in San Francisco?",
)],
},
Message {
role: Role::Assistant,
created: 1,
content: vec![MessageContent::text(
"Looks like it's 60 degrees Fahrenheit in San Francisco.",
)],
},
Message {
role: Role::User,
created: 2,
content: vec![MessageContent::text("How about New York?")],
},
];

let tools = vec![Tool {
name: "get_current_weather".to_string(),
description: "Get the current weather in a given location".to_string(),
input_schema: json!({
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"description": "The unit of temperature to return",
"enum": ["celsius", "fahrenheit"]
}
#[test]
fn test_count_chat_tokens() {
let token_counter = TokenCounter::new();

let system_prompt =
"You are a helpful assistant that can answer questions about the weather.";

let messages = vec![
Message {
role: Role::User,
created: 0,
content: vec![MessageContent::text(
"What's the weather like in San Francisco?",
)],
},
Message {
role: Role::Assistant,
created: 1,
content: vec![MessageContent::text(
"Looks like it's 60 degrees Fahrenheit in San Francisco.",
)],
},
Message {
role: Role::User,
created: 2,
content: vec![MessageContent::text("How about New York?")],
},
];

let tools = vec![Tool {
name: "get_current_weather".to_string(),
description: "Get the current weather in a given location".to_string(),
input_schema: json!({
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"required": ["location"]
}),
}];

let token_count_without_tools =
token_counter.count_chat_tokens(system_prompt, &messages, &vec![], Some("gpt-4o"));
println!("Total tokens without tools: {}", token_count_without_tools);

let token_count_with_tools =
token_counter.count_chat_tokens(system_prompt, &messages, &tools, Some("gpt-4o"));
println!("Total tokens with tools: {}", token_count_with_tools);

// The token count for messages without tools is calculated using the tokenizer - https://tiktokenizer.vercel.app/
// The token count for messages with tools is taken from tiktoken github repo example (notebook)
assert_eq!(token_count_without_tools, 56);
assert_eq!(token_count_with_tools, 124);
}
"unit": {
"type": "string",
"description": "The unit of temperature to return",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}),
}];

let token_count_without_tools =
token_counter.count_chat_tokens(system_prompt, &messages, &[], Some("gpt-4o"));
println!("Total tokens without tools: {}", token_count_without_tools);

let token_count_with_tools =
token_counter.count_chat_tokens(system_prompt, &messages, &tools, Some("gpt-4o"));
println!("Total tokens with tools: {}", token_count_with_tools);

// The token count for messages without tools is calculated using the tokenizer - https://tiktokenizer.vercel.app/
// The token count for messages with tools is taken from tiktoken github repo example (notebook)
assert_eq!(token_count_without_tools, 56);
assert_eq!(token_count_with_tools, 124);
}
}
50 changes: 0 additions & 50 deletions crates/mcp-client/src/stdio_transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,59 +119,9 @@ impl Transport for StdioTransport {
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::time::Duration;
use tokio::time::timeout;

#[tokio::test]
async fn test_stdio_transport() {
let transport = StdioTransport {
params: StdioServerParams {
command: "tee".to_string(), // tee will echo back what it receives
args: vec![],
env: None,
},
};

let (mut rx, tx) = transport.connect().await.unwrap();

// Create test messages
let request = JsonRpcMessage::Request(JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(1),
method: "ping".to_string(),
params: None,
});

let response = JsonRpcMessage::Response(JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: Some(2),
result: Some(json!({})),
error: None,
});

// Send messages
tx.send(request.clone()).await.unwrap();
tx.send(response.clone()).await.unwrap();

// Receive and verify messages
let mut read_messages = Vec::new();

// Use timeout to avoid hanging if messages aren't received
for _ in 0..2 {
match timeout(Duration::from_secs(1), rx.recv()).await {
Ok(Some(Ok(msg))) => read_messages.push(msg),
Ok(Some(Err(e))) => panic!("Received error: {}", e),
Ok(None) => break,
Err(_) => panic!("Timeout waiting for message"),
}
}

assert_eq!(read_messages.len(), 2, "Expected 2 messages");
assert_eq!(read_messages[0], request);
assert_eq!(read_messages[1], response);
}

#[tokio::test]
async fn test_process_termination() {
let transport = StdioTransport {
Expand Down
20 changes: 15 additions & 5 deletions crates/mcp-core/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,39 @@ pub enum ToolError {
SerializationError(#[from] serde_json::Error),
#[error("Schema error: {0}")]
SchemaError(String),
#[error("Tool not found: {0}")]
NotFound(String),
}

#[derive(Error, Debug)]
pub enum ResourceError {
#[error("Execution failed: {0}")]
ExecutionError(String),
#[error("Resource not found: {0}")]
NotFound(String),
}

pub type Result<T> = std::result::Result<T, ToolError>;

/// Trait for implementing MCP tools
#[async_trait]
pub trait Tool: Send + Sync + 'static {
pub trait ToolHandler: Send + Sync + 'static {
/// The name of the tool
fn name() -> &'static str;
fn name(&self) -> &'static str;

/// A description of what the tool does
fn description() -> &'static str;
fn description(&self) -> &'static str;

/// JSON schema describing the tool's parameters
fn schema() -> Value;
fn schema(&self) -> Value;

/// Execute the tool with the given parameters
async fn call(&self, params: Value) -> Result<Value>;
}

/// Trait for implementing MCP resources
#[async_trait]
pub trait Resource: Send + Sync + 'static {
pub trait ResourceTemplateHandler: Send + Sync + 'static {
/// The URL template for this resource
fn template() -> &'static str;

Expand Down
68 changes: 67 additions & 1 deletion crates/mcp-core/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,105 @@ use serde_json::Value;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<u64>,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<ErrorData>,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct JsonRpcNotification {
pub jsonrpc: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct JsonRpcError {
pub jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<u64>,
pub error: ErrorData,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(untagged)]
#[serde(untagged, try_from = "JsonRpcRaw")]
pub enum JsonRpcMessage {
Request(JsonRpcRequest),
Response(JsonRpcResponse),
Notification(JsonRpcNotification),
Error(JsonRpcError),
}

#[derive(Debug, Serialize, Deserialize)]
struct JsonRpcRaw {
jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<u64>,
method: String,
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<ErrorData>,
}

impl TryFrom<JsonRpcRaw> for JsonRpcMessage {
type Error = String;

fn try_from(raw: JsonRpcRaw) -> Result<Self, <Self as TryFrom<JsonRpcRaw>>::Error> {
// If it has an error field, it's an error response
if raw.error.is_some() {
return Ok(JsonRpcMessage::Error(JsonRpcError {
jsonrpc: raw.jsonrpc,
id: raw.id,
error: raw.error.unwrap(),
}));
}

// If it has a result field, it's a response
if raw.result.is_some() {
return Ok(JsonRpcMessage::Response(JsonRpcResponse {
jsonrpc: raw.jsonrpc,
id: raw.id,
result: raw.result,
error: None,
}));
}

// If the method starts with "notifications/", it's a notification
if raw.method.starts_with("notifications/") {
return Ok(JsonRpcMessage::Notification(JsonRpcNotification {
jsonrpc: raw.jsonrpc,
method: raw.method,
params: raw.params,
}));
}

// Otherwise it's a request
Ok(JsonRpcMessage::Request(JsonRpcRequest {
jsonrpc: raw.jsonrpc,
id: raw.id,
method: raw.method,
params: raw.params,
}))
}
}

// Standard JSON-RPC error codes
pub const PARSE_ERROR: i32 = -32700;
pub const INVALID_REQUEST: i32 = -32600;
Expand Down Expand Up @@ -80,8 +143,11 @@ pub struct Implementation {

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct ServerCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub prompts: Option<PromptsCapability>,
#[serde(skip_serializing_if = "Option::is_none")]
pub resources: Option<ResourcesCapability>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<ToolsCapability>,
// Add other capabilities as needed
}
Expand Down
2 changes: 2 additions & 0 deletions crates/mcp-core/src/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use serde_json::Value;

/// A tool that can be used by a model.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
/// The name of the tool
pub name: String,
Expand All @@ -31,6 +32,7 @@ impl Tool {

/// A tool call request that a system can execute
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolCall {
/// The name of the tool to execute
pub name: String,
Expand Down
Loading

0 comments on commit 577ba13

Please sign in to comment.