Skip to content

Commit

Permalink
refactor: improve safety rails speed and prompt
Browse files Browse the repository at this point in the history
- Removes the LLM check for speed, because the regexes have covered
  all of the cases it was covering previously
- Refactors the live display and notifier to allow it to be paused by a
  tool. This fixes the prompting
  • Loading branch information
baxen committed Sep 5, 2024
1 parent 72d927f commit 4aef5cb
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 84 deletions.
44 changes: 26 additions & 18 deletions src/goose/cli/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,20 @@ def load_profile(name: Optional[str]) -> Profile:
class SessionNotifier(Notifier):
def __init__(self, status_indicator: Status) -> None:
self.status_indicator = status_indicator
self.live = Live(self.status_indicator, refresh_per_second=8, transient=True)

def log(self, content: RenderableType) -> None:
print(content)

def status(self, status: str) -> None:
self.status_indicator.update(status)

def start(self) -> None:
self.live.start()

def stop(self) -> None:
self.live.stop()


class Session:
"""A session handler for managing interactions between a user and the Goose exchange
Expand All @@ -87,9 +94,9 @@ def __init__(
) -> None:
self.name = name
self.status_indicator = Status("", spinner="dots")
notifier = SessionNotifier(self.status_indicator)
self.notifier = SessionNotifier(self.status_indicator)

self.exchange = build_exchange(profile=load_profile(profile), notifier=notifier)
self.exchange = build_exchange(profile=load_profile(profile), notifier=self.notifier)

if name is not None and self.session_file_path.exists():
messages = self.load_session()
Expand Down Expand Up @@ -143,22 +150,23 @@ def run(self) -> None:
"""
message = self.process_first_message()
while message: # Loop until no input (empty string).
with Live(self.status_indicator, refresh_per_second=8, transient=True):
try:
self.exchange.add(message)
self.reply() # Process the user message.
except KeyboardInterrupt:
self.interrupt_reply()
except Exception:
print(traceback.format_exc())
if self.exchange.messages:
self.exchange.messages.pop()
print(
"\n[red]The error above was an exception we were not able to handle.\n\n[/]"
+ "These errors are often related to connection or authentication\n"
+ "We've removed your most recent input"
+ " - [yellow]depending on the error you may be able to continue[/]"
)
self.notifier.start()
try:
self.exchange.add(message)
self.reply() # Process the user message.
except KeyboardInterrupt:
self.interrupt_reply()
except Exception:
print(traceback.format_exc())
if self.exchange.messages:
self.exchange.messages.pop()
print(
"\n[red]The error above was an exception we were not able to handle.\n\n[/]"
+ "These errors are often related to connection or authentication\n"
+ "We've removed your most recent input"
+ " - [yellow]depending on the error you may be able to continue[/]"
)
self.notifier.stop()

print() # Print a newline for separation.
user_input = self.prompt_session.get_user_input()
Expand Down
13 changes: 12 additions & 1 deletion src/goose/notifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod

from typing import Optional
from rich.console import RenderableType


Expand All @@ -19,10 +20,20 @@ def log(self, content: RenderableType) -> None:
pass

@abstractmethod
def status(self, status: str) -> None:
def status(self, status: Optional[str]) -> None:
"""Log a status to ephemeral display
Args:
status (str): The status to display
"""
pass

@abstractmethod
def start(self) -> None:
"""Start the display for the notifier"""
pass

@abstractmethod
def stop(self) -> None:
"""Stop the display for the notifier"""
pass
27 changes: 7 additions & 20 deletions src/goose/toolkit/developer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,18 @@
from rich import box
from rich.markdown import Markdown
from rich.panel import Panel
from rich.prompt import Confirm, PromptType
from rich.prompt import Confirm
from rich.table import Table
from rich.text import Text

from goose.toolkit.base import Toolkit, tool
from goose.toolkit.utils import get_language, render_template


def keep_unsafe_command_prompt(command: str) -> PromptType:
def keep_unsafe_command_prompt(command: str) -> bool:
command_text = Text(command, style="bold red")
message = (
Text("\nWe flagged the command: ")
+ command_text
+ Text(" as potentially unsafe, do you want to proceed? (yes/no)")
Text("\nWe flagged the command: ") + command_text + Text(" as potentially unsafe, do you want to proceed?")
)
return Confirm.ask(message, default=True)

Expand Down Expand Up @@ -148,26 +146,15 @@ def shell(self, command: str) -> str:
# logging and integrates with the overall UI logging system
self.notifier.log(Panel.fit(Markdown(f"```bash\n{command}\n```"), title="shell"))

safety_rails_exchange = self.exchange_view.processor.replace(
system=Message.load("prompts/safety_rails.jinja").text
)
# remove the previous message which was a tool_use Assistant message
safety_rails_exchange.messages.pop()

safety_rails_exchange.add(Message.assistant(f"Here is the command I'd like to run: `{command}`"))
safety_rails_exchange.add(Message.user("Please provide the danger rating of that command"))
rating = safety_rails_exchange.reply().text

try:
rating = int(rating)
except ValueError:
rating = 5 # if we can't interpret we default to unsafe
if is_dangerous_command(command) or int(rating) > 3:
if is_dangerous_command(command):
# Stop the notifications so we can prompt
self.notifier.stop()
if not keep_unsafe_command_prompt(command):
raise RuntimeError(
f"The command {command} was rejected as dangerous by the user."
+ " Do not proceed further, instead ask for instructions."
)
self.notifier.start()
self.notifier.status("running shell command")
result: CompletedProcess = run(command, shell=True, text=True, capture_output=True, check=False)
if result.returncode == 0:
Expand Down
39 changes: 0 additions & 39 deletions src/goose/toolkit/prompts/safety_rails.jinja

This file was deleted.

8 changes: 4 additions & 4 deletions src/goose/utils/check_shell_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@ def is_dangerous_command(command: str) -> bool:
bool: True if the command is dangerous, False otherwise.
"""
dangerous_patterns = [
# Commands that are generally unsafe
r"\brm\b", # rm command
r"\bgit\s+push\b", # git push command
r"\bsudo\b", # sudo command
# Add more dangerous command patterns here
r"\bmv\b", # mv command
r"\bchmod\b", # chmod command
r"\bchown\b", # chown command
r"\bmkfs\b", # mkfs command
r"\bsystemctl\b", # systemctl command
r"\breboot\b", # reboot command
r"\bshutdown\b", # shutdown command
# Manipulating files in ~/ directly or dot files
r"^~/[^/]+$", # Files directly in home directory
r"/\.[^/]+$", # Dot files
# Target files that are unsafe
r"\b~\/\.|\/\.\w+", # commands that point to files or dirs in home that start with a dot (dotfiles)
]
for pattern in dangerous_patterns:
if re.search(pattern, command):
print(command, pattern)
return True
return False
17 changes: 15 additions & 2 deletions tests/utils/test_check_shell_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,26 @@
"systemctl stop nginx",
"reboot",
"shutdown now",
"echo hello > ~/.bashrc",
"cat ~/.hello.txt",
"cat ~/.config/example.txt",
],
)
def test_dangerous_commands(command):
assert is_dangerous_command(command)


@pytest.mark.parametrize("command", ["ls -la", 'echo "Hello World"', "cp ~/folder/file.txt /tmp/"])
@pytest.mark.parametrize(
"command",
[
"ls -la",
'echo "Hello World"',
"cp ~/folder/file.txt /tmp/",
"echo hello > ~/toplevel/sublevel.txt",
"cat hello.txt",
"cat ~/config/example.txt",
"ls -la path/to/visible/file",
"echo 'file.with.dot.txt'",
],
)
def test_safe_commands(command):
assert not is_dangerous_command(command)

0 comments on commit 4aef5cb

Please sign in to comment.