Skip to content

Commit

Permalink
Interrupt running shell tool commands (#365)
Browse files Browse the repository at this point in the history
  • Loading branch information
jsibbison-square authored Nov 28, 2024
1 parent d8f43c6 commit ceb80ca
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
35 changes: 25 additions & 10 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ use std::{
convert::Infallible,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::sync::mpsc;
use tokio::time::timeout;
use tokio_stream::wrappers::ReceiverStream;

// Types matching the incoming JSON structure
Expand Down Expand Up @@ -293,18 +295,31 @@ async fn handler(
}
};

while let Some(response) = stream.next().await {
match response {
Ok(message) => {
if let Err(e) = stream_message(message, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
break;
loop {
tokio::select! {
response = timeout(Duration::from_millis(500), stream.next()) => {
match response {
Ok(Some(Ok(message))) => {
if let Err(e) = stream_message(message, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
break;
}
}
Ok(Some(Err(e))) => {
tracing::error!("Error processing message: {}", e);
break;
}
Ok(None) => {
break;
}
Err(_) => { // Heartbeat, used to detect disconnected clients and then end running tools.
if tx.is_closed() {
break;
}
continue;
}
}
}
Err(e) => {
tracing::error!("Error processing message: {}", e);
break;
}
}
}

Expand Down
4 changes: 3 additions & 1 deletion crates/goose/src/developer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use serde_json::{json, Value};
use std::collections::{HashMap, HashSet};
use std::io::Cursor;
use std::path::{Path, PathBuf};
use std::process::Command;
use std::sync::Mutex;
use tokio::process::Command;
use xcap::Monitor;

use crate::errors::{AgentError, AgentResult};
Expand Down Expand Up @@ -192,9 +192,11 @@ impl DeveloperSystem {

// Execute the command
let output = Command::new("bash")
.kill_on_drop(true) // Critical so that the command is killed when the agent.reply stream is interrupted.
.arg("-c")
.arg(cmd_with_redirect)
.output()
.await
.map_err(|e| AgentError::ExecutionError(e.to_string()))?;

let output_str = format!(
Expand Down

0 comments on commit ceb80ca

Please sign in to comment.