diff --git a/crates/goose-server/Cargo.toml b/crates/goose-server/Cargo.toml index 3c5c7451c..9b089cd54 100644 --- a/crates/goose-server/Cargo.toml +++ b/crates/goose-server/Cargo.toml @@ -32,3 +32,5 @@ path = "src/main.rs" [dev-dependencies] serial_test = "3.2.0" +tower = "0.5" +async-trait = "0.1" diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 70be62bd1..37c1367ce 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -8,11 +8,7 @@ use axum::{ }; use bytes::Bytes; use futures::{stream::StreamExt, Stream}; -use goose::{ - models::content::Content, - models::message::{Message, MessageContent}, - models::role::Role, -}; +use goose::models::{content::Content, message::{Message, MessageContent}, role::Role}; use serde::Deserialize; use serde_json::{json, Value}; use std::{ @@ -275,6 +271,7 @@ async fn handler( // Get a lock on the shared agent let agent = state.agent.clone(); + // Spawn task to handle streaming tokio::spawn(async move { @@ -324,7 +321,7 @@ async fn handler( Ok(SseResponse::new(stream)) } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, serde::Serialize)] struct AskRequest { prompt: String, } @@ -386,3 +383,154 @@ pub fn routes(state: AppState) -> Router { .route("/ask", post(ask_handler)) .with_state(state) } + +#[cfg(test)] +mod tests { + use super::*; + use goose::{ + providers::{ + base::Provider, + configs::OpenAiProviderConfig, + }, + models::tool::Tool, + agent::Agent, + }; + + // Mock Provider implementation for testing + #[derive(Clone)] + struct MockProvider; + + #[async_trait::async_trait] + impl Provider for MockProvider { + async fn complete( + &self, + _system_prompt: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result<(Message, goose::providers::base::Usage), anyhow::Error> { + Ok((Message::assistant().with_text("Mock response"), goose::providers::base::Usage::default())) + } + } + + #[test] + fn test_convert_messages_user_only() { + let incoming = vec![IncomingMessage { + role: "user".to_string(), + content: "Hello".to_string(), + tool_invocations: vec![], + }]; + + let messages = convert_messages(incoming); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].role, Role::User); + assert!(matches!(&messages[0].content[0], MessageContent::Text(text) if text.text == "Hello")); + } + + #[test] + fn test_convert_messages_with_tool_invocation() { + let tool_result = vec![Content::text("tool response").with_priority(0.0)]; + let incoming = vec![IncomingMessage { + role: "assistant".to_string(), + content: "".to_string(), + tool_invocations: vec![ToolInvocation { + state: "result".to_string(), + tool_call_id: "123".to_string(), + tool_name: "test_tool".to_string(), + args: json!({"key": "value"}), + result: Some(tool_result.clone()), + }], + }]; + + let messages = convert_messages(incoming); + assert_eq!(messages.len(), 2); // Tool request and response + + // Check tool request + assert_eq!(messages[0].role, Role::Assistant); + assert!(matches!(&messages[0].content[0], MessageContent::ToolRequest(req) if req.id == "123")); + + // Check tool response + assert_eq!(messages[1].role, Role::User); + assert!(matches!(&messages[1].content[0], MessageContent::ToolResponse(resp) if resp.id == "123")); + } + + #[test] + fn test_protocol_formatter() { + // Test text formatting + let text = "Hello world"; + let formatted = ProtocolFormatter::format_text(text); + assert_eq!(formatted, "0:\"Hello world\"\n"); + + // Test tool call formatting + let formatted = ProtocolFormatter::format_tool_call( + "123", + "test_tool", + &json!({"key": "value"}), + ); + assert!(formatted.starts_with("9:")); + assert!(formatted.contains("\"toolCallId\":\"123\"")); + assert!(formatted.contains("\"toolName\":\"test_tool\"")); + + // Test tool response formatting + let result = vec![Content::text("response").with_priority(0.0)]; + let formatted = ProtocolFormatter::format_tool_response("123", &result); + assert!(formatted.starts_with("a:")); + assert!(formatted.contains("\"toolCallId\":\"123\"")); + + // Test finish formatting + let formatted = ProtocolFormatter::format_finish("stop"); + assert!(formatted.starts_with("d:")); + assert!(formatted.contains("\"finishReason\":\"stop\"")); + } + + mod integration_tests { + use super::*; + use axum::{ + http::Request, + body::Body, + }; + use tower::ServiceExt; + use std::sync::Arc; + use tokio::sync::Mutex; + use goose::providers::configs::ProviderConfig; + + // This test requires tokio runtime + #[tokio::test] + async fn test_ask_endpoint() { + // Create a mock app state with mock provider + let mock_provider = Box::new(MockProvider); + let agent = Agent::new(mock_provider); + let state = AppState { + agent: Arc::new(Mutex::new(agent)), + provider_config: ProviderConfig::OpenAi(OpenAiProviderConfig { + host: "https://api.openai.com".to_string(), + api_key: "test-key".to_string(), + model: "test-model".to_string(), + temperature: None, + max_tokens: None, + }), + }; + + // Build router + let app = routes(state); + + // Create request + let request = Request::builder() + .uri("/ask") + .method("POST") + .header("content-type", "application/json") + .body(Body::from( + serde_json::to_string(&AskRequest { + prompt: "test prompt".to_string(), + }) + .unwrap(), + )) + .unwrap(); + + // Send request + let response = app.oneshot(request).await.unwrap(); + + // Assert response status + assert_eq!(response.status(), StatusCode::OK); + } + } +} diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index ebd027a31..67d318a6f 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -2,7 +2,7 @@ use anyhow::Result; use goose::{ agent::Agent, developer::DeveloperSystem, - providers::{configs::ProviderConfig, factory}, + providers::{configs::ProviderConfig, factory}, systems::goose_hints::GooseHintsSystem, }; use std::sync::Arc; use tokio::sync::Mutex; @@ -18,6 +18,8 @@ impl AppState { let provider = factory::get_provider(provider_config.clone())?; let mut agent = Agent::new(provider); agent.add_system(Box::new(DeveloperSystem::new())); + let goosehints_system = Box::new(GooseHintsSystem::new()); + agent.add_system(goosehints_system); Ok(Self { provider_config, diff --git a/ui/desktop/src/bin/goosed b/ui/desktop/src/bin/goosed index b0703a1ee..bf305713f 100755 Binary files a/ui/desktop/src/bin/goosed and b/ui/desktop/src/bin/goosed differ