Skip to content

Commit

Permalink
feat: add hints to goosed (#379)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelneale authored Dec 3, 2024
1 parent 5b29e1d commit e2aa23f
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 7 deletions.
2 changes: 2 additions & 0 deletions crates/goose-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@ path = "src/main.rs"

[dev-dependencies]
serial_test = "3.2.0"
tower = "0.5"
async-trait = "0.1"
160 changes: 154 additions & 6 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@ use axum::{
};
use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use goose::{
models::content::Content,
models::message::{Message, MessageContent},
models::role::Role,
};
use goose::models::{content::Content, message::{Message, MessageContent}, role::Role};
use serde::Deserialize;
use serde_json::{json, Value};
use std::{
Expand Down Expand Up @@ -275,6 +271,7 @@ async fn handler(

// Get a lock on the shared agent
let agent = state.agent.clone();


// Spawn task to handle streaming
tokio::spawn(async move {
Expand Down Expand Up @@ -324,7 +321,7 @@ async fn handler(
Ok(SseResponse::new(stream))
}

#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, serde::Serialize)]
struct AskRequest {
prompt: String,
}
Expand Down Expand Up @@ -386,3 +383,154 @@ pub fn routes(state: AppState) -> Router {
.route("/ask", post(ask_handler))
.with_state(state)
}

#[cfg(test)]
mod tests {
use super::*;
use goose::{
providers::{
base::Provider,
configs::OpenAiProviderConfig,
},
models::tool::Tool,
agent::Agent,
};

// Mock Provider implementation for testing
#[derive(Clone)]
struct MockProvider;

#[async_trait::async_trait]
impl Provider for MockProvider {
async fn complete(
&self,
_system_prompt: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<(Message, goose::providers::base::Usage), anyhow::Error> {
Ok((Message::assistant().with_text("Mock response"), goose::providers::base::Usage::default()))
}
}

#[test]
fn test_convert_messages_user_only() {
let incoming = vec![IncomingMessage {
role: "user".to_string(),
content: "Hello".to_string(),
tool_invocations: vec![],
}];

let messages = convert_messages(incoming);
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, Role::User);
assert!(matches!(&messages[0].content[0], MessageContent::Text(text) if text.text == "Hello"));
}

#[test]
fn test_convert_messages_with_tool_invocation() {
let tool_result = vec![Content::text("tool response").with_priority(0.0)];
let incoming = vec![IncomingMessage {
role: "assistant".to_string(),
content: "".to_string(),
tool_invocations: vec![ToolInvocation {
state: "result".to_string(),
tool_call_id: "123".to_string(),
tool_name: "test_tool".to_string(),
args: json!({"key": "value"}),
result: Some(tool_result.clone()),
}],
}];

let messages = convert_messages(incoming);
assert_eq!(messages.len(), 2); // Tool request and response

// Check tool request
assert_eq!(messages[0].role, Role::Assistant);
assert!(matches!(&messages[0].content[0], MessageContent::ToolRequest(req) if req.id == "123"));

// Check tool response
assert_eq!(messages[1].role, Role::User);
assert!(matches!(&messages[1].content[0], MessageContent::ToolResponse(resp) if resp.id == "123"));
}

#[test]
fn test_protocol_formatter() {
// Test text formatting
let text = "Hello world";
let formatted = ProtocolFormatter::format_text(text);
assert_eq!(formatted, "0:\"Hello world\"\n");

// Test tool call formatting
let formatted = ProtocolFormatter::format_tool_call(
"123",
"test_tool",
&json!({"key": "value"}),
);
assert!(formatted.starts_with("9:"));
assert!(formatted.contains("\"toolCallId\":\"123\""));
assert!(formatted.contains("\"toolName\":\"test_tool\""));

// Test tool response formatting
let result = vec![Content::text("response").with_priority(0.0)];
let formatted = ProtocolFormatter::format_tool_response("123", &result);
assert!(formatted.starts_with("a:"));
assert!(formatted.contains("\"toolCallId\":\"123\""));

// Test finish formatting
let formatted = ProtocolFormatter::format_finish("stop");
assert!(formatted.starts_with("d:"));
assert!(formatted.contains("\"finishReason\":\"stop\""));
}

mod integration_tests {
use super::*;
use axum::{
http::Request,
body::Body,
};
use tower::ServiceExt;
use std::sync::Arc;
use tokio::sync::Mutex;
use goose::providers::configs::ProviderConfig;

// This test requires tokio runtime
#[tokio::test]
async fn test_ask_endpoint() {
// Create a mock app state with mock provider
let mock_provider = Box::new(MockProvider);
let agent = Agent::new(mock_provider);
let state = AppState {
agent: Arc::new(Mutex::new(agent)),
provider_config: ProviderConfig::OpenAi(OpenAiProviderConfig {
host: "https://api.openai.com".to_string(),
api_key: "test-key".to_string(),
model: "test-model".to_string(),
temperature: None,
max_tokens: None,
}),
};

// Build router
let app = routes(state);

// Create request
let request = Request::builder()
.uri("/ask")
.method("POST")
.header("content-type", "application/json")
.body(Body::from(
serde_json::to_string(&AskRequest {
prompt: "test prompt".to_string(),
})
.unwrap(),
))
.unwrap();

// Send request
let response = app.oneshot(request).await.unwrap();

// Assert response status
assert_eq!(response.status(), StatusCode::OK);
}
}
}
4 changes: 3 additions & 1 deletion crates/goose-server/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use anyhow::Result;
use goose::{
agent::Agent,
developer::DeveloperSystem,
providers::{configs::ProviderConfig, factory},
providers::{configs::ProviderConfig, factory}, systems::goose_hints::GooseHintsSystem,
};
use std::sync::Arc;
use tokio::sync::Mutex;
Expand All @@ -18,6 +18,8 @@ impl AppState {
let provider = factory::get_provider(provider_config.clone())?;
let mut agent = Agent::new(provider);
agent.add_system(Box::new(DeveloperSystem::new()));
let goosehints_system = Box::new(GooseHintsSystem::new());
agent.add_system(goosehints_system);

Ok(Self {
provider_config,
Expand Down
Binary file modified ui/desktop/src/bin/goosed
Binary file not shown.

0 comments on commit e2aa23f

Please sign in to comment.