Skip to content

Commit

Permalink
Merge pull request #8 from mark3labs/implement-notification-handling
Browse files Browse the repository at this point in the history
Implement notification handling
  • Loading branch information
ezynda3 authored Dec 14, 2024
2 parents 48c485e + d1c3cfc commit c61624c
Show file tree
Hide file tree
Showing 10 changed files with 529 additions and 283 deletions.
46 changes: 23 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ func main() {
}
}

func helloHandler(ctx context.Context, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
name, ok := arguments["name"].(string)
func helloHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
name, ok := request.Params.Arguments["name"].(string)
if !ok {
return mcp.NewToolResultError("name must be a string"), nil
}
Expand Down Expand Up @@ -137,10 +137,10 @@ func main() {
)

// Add the calculator handler
s.AddTool(calculatorTool, func(args map[string]interface{}) (*mcp.CallToolResult, error) {
op := args["operation"].(string)
x := args["x"].(float64)
y := args["y"].(float64)
s.AddTool(calculatorTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
op := request.Params.Arguments["operation"].(string)
x := request.Params.Arguments["x"].(float64)
y := request.Params.Arguments["y"].(float64)

var result float64
switch op {
Expand Down Expand Up @@ -223,7 +223,7 @@ resource := mcp.NewResource(
)

// Add resource with its handler
s.AddResource(resource, func(ctx context.Context) ([]interface{}, error) {
s.AddResource(resource, func(ctx context.Context, request mcp.ReadResourceRequest) ([]interface{}, error) {
content, err := os.ReadFile("README.md")
if err != nil {
return nil, err
Expand Down Expand Up @@ -254,8 +254,8 @@ template := mcp.NewResourceTemplate(
)

// Add template with its handler
s.AddResourceTemplate(template, func(ctx context.Context, args map[string]interface{}) ([]interface{}, error) {
userID := args["id"].(string)
s.AddResourceTemplate(template, func(ctx context.Context, request mcp.ReadResourceRequest) ([]interface{}, error) {
userID := request.Params.URI // Extract ID from the full URI

profile, err := getUserProfile(userID) // Your DB/API call here
if err != nil {
Expand Down Expand Up @@ -303,10 +303,10 @@ calculatorTool := mcp.NewTool("calculate",
),
)

s.AddTool(calculatorTool, func(args map[string]interface{}) (*mcp.CallToolResult, error) {
op := args["operation"].(string)
x := args["x"].(float64)
y := args["y"].(float64)
s.AddTool(calculatorTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
op := request.Params.Arguments["operation"].(string)
x := request.Params.Arguments["x"].(float64)
y := request.Params.Arguments["y"].(float64)

var result float64
switch op {
Expand Down Expand Up @@ -346,11 +346,11 @@ httpTool := mcp.NewTool("http_request",
),
)

s.AddTool(httpTool, func(args map[string]interface{}) (*mcp.CallToolResult, error) {
method := args["method"].(string)
url := args["url"].(string)
s.AddTool(httpTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
method := request.Params.Arguments["method"].(string)
url := request.Params.Arguments["url"].(string)
body := ""
if b, ok := args["body"].(string); ok {
if b, ok := request.Params.Arguments["body"].(string); ok {
body = b
}

Expand Down Expand Up @@ -413,8 +413,8 @@ s.AddPrompt(mcp.NewPrompt("greeting",
mcp.WithArgument("name",
mcp.ArgumentDescription("Name of the person to greet"),
),
), func(args map[string]string) (*mcp.GetPromptResult, error) {
name := args["name"]
), func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
name := request.Params.Arguments["name"].(string)
if name == "" {
name = "friend"
}
Expand All @@ -437,8 +437,8 @@ s.AddPrompt(mcp.NewPrompt("code_review",
mcp.ArgumentDescription("Pull request number to review"),
mcp.RequiredArgument(),
),
), func(args map[string]string) (*mcp.GetPromptResult, error) {
prNumber := args["pr_number"]
), func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
prNumber := request.Params.Arguments["pr_number"].(string)
if prNumber == "" {
return nil, fmt.Errorf("pr_number is required")
}
Expand Down Expand Up @@ -468,8 +468,8 @@ s.AddPrompt(mcp.NewPrompt("query_builder",
mcp.ArgumentDescription("Name of the table to query"),
mcp.RequiredArgument(),
),
), func(args map[string]string) (*mcp.GetPromptResult, error) {
tableName := args["table"]
), func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
tableName := request.Params.Arguments["table"].(string)
if tableName == "" {
return nil, fmt.Errorf("table name is required")
}
Expand Down
4 changes: 2 additions & 2 deletions client/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ func TestSSEMCPClient(t *testing.T) {
Type: "object",
Properties: map[string]interface{}{},
},
}, func(arguments map[string]interface{}) (*mcp.CallToolResult, error) {
}, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{}, nil
})

// Create test server
// Initialize
testServer := server.NewTestServer(mcpServer)
defer testServer.Close()

Expand Down
134 changes: 105 additions & 29 deletions examples/everything/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package main

import (
"context"
"flag"
"fmt"
"log"
"time"
Expand Down Expand Up @@ -79,6 +81,12 @@ func NewMCPServer() *MCPServer {
mcp.Required(),
),
), s.handleEchoTool)

s.server.AddTool(
mcp.NewTool("notify"),
s.handleSendNotification,
)

s.server.AddTool(mcp.NewTool(string(ADD),
mcp.WithDescription("Adds two numbers"),
mcp.WithNumber("a",
Expand Down Expand Up @@ -127,7 +135,7 @@ func NewMCPServer() *MCPServer {
mcp.WithDescription("Returns the MCP_TINY_IMAGE"),
), s.handleGetTinyImageTool)

s.server.AddNotificationHandler(s.handleNotification)
s.server.AddNotificationHandler("notification", s.handleNotification)

go s.runUpdateInterval()

Expand Down Expand Up @@ -177,6 +185,7 @@ func (s *MCPServer) runUpdateInterval() {
}

func (s *MCPServer) handleReadResource(
ctx context.Context,
request mcp.ReadResourceRequest,
) ([]interface{}, error) {
return []interface{}{
Expand All @@ -191,6 +200,7 @@ func (s *MCPServer) handleReadResource(
}

func (s *MCPServer) handleResourceTemplate(
ctx context.Context,
request mcp.ReadResourceRequest,
) ([]interface{}, error) {
return []interface{}{
Expand All @@ -205,7 +215,8 @@ func (s *MCPServer) handleResourceTemplate(
}

func (s *MCPServer) handleSimplePrompt(
arguments map[string]string,
ctx context.Context,
request mcp.GetPromptRequest,
) (*mcp.GetPromptResult, error) {
return &mcp.GetPromptResult{
Description: "A simple prompt without arguments",
Expand All @@ -222,8 +233,10 @@ func (s *MCPServer) handleSimplePrompt(
}

func (s *MCPServer) handleComplexPrompt(
arguments map[string]string,
ctx context.Context,
request mcp.GetPromptRequest,
) (*mcp.GetPromptResult, error) {
arguments := request.Params.Arguments
return &mcp.GetPromptResult{
Description: "A complex prompt with arguments",
Messages: []mcp.PromptMessage{
Expand Down Expand Up @@ -258,8 +271,10 @@ func (s *MCPServer) handleComplexPrompt(
}

func (s *MCPServer) handleEchoTool(
arguments map[string]interface{},
ctx context.Context,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, error) {
arguments := request.Params.Arguments
message, ok := arguments["message"].(string)
if !ok {
return nil, fmt.Errorf("invalid message argument")
Expand All @@ -275,8 +290,10 @@ func (s *MCPServer) handleEchoTool(
}

func (s *MCPServer) handleAddTool(
arguments map[string]interface{},
ctx context.Context,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, error) {
arguments := request.Params.Arguments
a, ok1 := arguments["a"].(float64)
b, ok2 := arguments["b"].(float64)
if !ok1 || !ok2 {
Expand All @@ -293,35 +310,65 @@ func (s *MCPServer) handleAddTool(
}, nil
}

func (s *MCPServer) handleSendNotification(
ctx context.Context,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, error) {

server := server.ServerFromContext(ctx)

err := server.SendNotificationToClient(
"notifications/progress",
map[string]interface{}{
"progress": 10,
"total": 10,
"progressToken": 0,
},
)
if err != nil {
return nil, fmt.Errorf("failed to send notification: %w", err)
}

return &mcp.CallToolResult{
Content: []interface{}{
mcp.TextContent{
Type: "text",
Text: "notification sent successfully",
},
},
}, nil
}

func (s *MCPServer) ServeSSE(addr string) *server.SSEServer {
return server.NewSSEServer(s.server, fmt.Sprintf("http://%s", addr))
}

func (s *MCPServer) ServeStdio() error {
return server.ServeStdio(s.server)
}

func (s *MCPServer) handleLongRunningOperationTool(
arguments map[string]interface{},
ctx context.Context,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, error) {
arguments := request.Params.Arguments
progressToken := request.Params.Meta.ProgressToken
duration, _ := arguments["duration"].(float64)
steps, _ := arguments["steps"].(float64)
stepDuration := duration / steps
progressToken, _ := arguments["_meta"].(map[string]interface{})["progressToken"].(mcp.ProgressToken)
server := server.ServerFromContext(ctx)

for i := 1; i < int(steps)+1; i++ {
time.Sleep(time.Duration(stepDuration * float64(time.Second)))
if progressToken != nil {
// s.server.HandleMessage(
// context.Background(),
// mcp.JSONRPCNotification{
// JSONRPC: mcp.JSONRPC_VERSION,
// Notification: mcp.Notification{
// Method: "progress",
// Params: struct {
// Meta map[string]interface{} `json:"_meta,omitempty"`
// }{
// Meta: map[string]interface{}{
// "progress": i,
// "total": int(steps),
// "progressToken": progressToken,
// },
// },
// },
// },
// )
server.SendNotificationToClient(
"notifications/progress",
map[string]interface{}{
"progress": i,
"total": int(steps),
"progressToken": progressToken,
},
)
}
}

Expand Down Expand Up @@ -361,7 +408,8 @@ func (s *MCPServer) handleLongRunningOperationTool(
// }

func (s *MCPServer) handleGetTinyImageTool(
arguments map[string]interface{},
ctx context.Context,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []interface{}{
Expand All @@ -382,7 +430,10 @@ func (s *MCPServer) handleGetTinyImageTool(
}, nil
}

func (s *MCPServer) handleNotification(notification mcp.JSONRPCNotification) {
func (s *MCPServer) handleNotification(
ctx context.Context,
notification mcp.JSONRPCNotification,
) {
log.Printf("Received notification: %s", notification.Method)
}

Expand All @@ -391,9 +442,34 @@ func (s *MCPServer) Serve() error {
}

func main() {
var transport string
flag.StringVar(&transport, "t", "stdio", "Transport type (stdio or sse)")
flag.StringVar(
&transport,
"transport",
"stdio",
"Transport type (stdio or sse)",
)
flag.Parse()

server := NewMCPServer()
if err := server.Serve(); err != nil {
log.Fatalf("Server error: %v", err)

switch transport {
case "stdio":
if err := server.ServeStdio(); err != nil {
log.Fatalf("Server error: %v", err)
}
case "sse":
sseServer := server.ServeSSE("localhost:8080")
log.Printf("SSE server listening on :8080")
if err := sseServer.Start(":8080"); err != nil {
log.Fatalf("Server error: %v", err)
}
default:
log.Fatalf(
"Invalid transport type: %s. Must be 'stdio' or 'sse'",
transport,
)
}
}

Expand Down
Loading

0 comments on commit c61624c

Please sign in to comment.