diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a81c64251..e8bab484a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -55,7 +55,7 @@ in the goose package thanks to [plugin metadata][plugin]!), create a class that import os import platform -from goose.toolkit.base import Toolkit, tool +from goose.toolkit import Toolkit, tool class Demo(Toolkit): diff --git a/pyproject.toml b/pyproject.toml index f16472558..b71711000 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,24 +22,24 @@ packages = ["src/goose"] goose-ai = "goose.module_name" [project.entry-points."goose.toolkit"] -developer = "goose.toolkit.developer:Developer" -github = "goose.toolkit.github:Github" -screen = "goose.toolkit.screen:Screen" -repo_context = "goose.toolkit.repo_context.repo_context:RepoContext" +developer = "goose._internal.toolkit:Developer" +github = "goose._internal.toolkit:Github" +screen = "goose._internal.toolkit:Screen" +repo_context = "goose._internal.toolkit:RepoContext" [project.entry-points."goose.profile"] -default = "goose.profile:default_profile" +default = "goose._internal.profile:default_profile" [project.entry-points."goose.command"] -file = "goose.command.file:FileCommand" +file = "goose._internal.cli.command.file:FileCommand" [project.entry-points."goose.cli.group"] -goose = "goose.cli.main:goose_cli" +goose = "goose._internal.cli.main:goose_cli" [project.entry-points."goose.cli.group_option"] [project.scripts] -goose = "goose.cli.main:cli" +goose = "goose._internal.cli.main:cli" [build-system] requires = ["hatchling"] diff --git a/src/goose/__init__.py b/src/goose/__init__.py index e69de29bb..6812fdbae 100644 --- a/src/goose/__init__.py +++ b/src/goose/__init__.py @@ -0,0 +1,5 @@ +from .pluginbase.notifier import Notifier # noqa: F401 +from .pluginbase.toolkit import Toolkit, tool # noqa: F401 +from .pluginbase.profile import Profile, ToolkitSpec # noqa: F401 +from .pluginbase.command import Command # noqa: F401 +from .pluginbase.utils import converter # noqa: F401 diff --git a/src/goose/cli/__init__.py b/src/goose/_internal/cli/__init__.py similarity index 100% rename from src/goose/cli/__init__.py rename to src/goose/_internal/cli/__init__.py diff --git a/src/goose/command/__init__.py b/src/goose/_internal/cli/command/__init__.py similarity index 76% rename from src/goose/command/__init__.py rename to src/goose/_internal/cli/command/__init__.py index d9fd674a4..423b66b4b 100644 --- a/src/goose/command/__init__.py +++ b/src/goose/_internal/cli/command/__init__.py @@ -1,8 +1,8 @@ from functools import cache from typing import Dict -from goose.command.base import Command -from goose.utils import load_plugins +from goose.pluginbase.command import Command +from ...utils import load_plugins @cache diff --git a/src/goose/command/file.py b/src/goose/_internal/cli/command/file.py similarity index 97% rename from src/goose/command/file.py rename to src/goose/_internal/cli/command/file.py index 7bbf7d9e3..ec061d5fd 100644 --- a/src/goose/command/file.py +++ b/src/goose/_internal/cli/command/file.py @@ -3,7 +3,7 @@ from prompt_toolkit.completion import Completion -from goose.command.base import Command +from goose.pluginbase.command import Command class FileCommand(Command): diff --git a/src/goose/cli/main.py b/src/goose/_internal/cli/main.py similarity index 95% rename from src/goose/cli/main.py rename to src/goose/_internal/cli/main.py index 4ebcc2811..79b2a6947 100644 --- a/src/goose/cli/main.py +++ b/src/goose/_internal/cli/main.py @@ -6,10 +6,10 @@ from rich import print from ruamel.yaml import YAML -from goose.cli.config import SESSIONS_PATH -from goose.cli.session import Session -from goose.utils import load_plugins -from goose.utils.session_file import list_sorted_session_files +from ...config import SESSIONS_PATH +from .session.session import Session +from ..utils import load_plugins +from goose.pluginbase.utils.session_file import list_sorted_session_files @click.group() diff --git a/src/goose/cli/prompt/__init__.py b/src/goose/_internal/cli/prompt/__init__.py similarity index 100% rename from src/goose/cli/prompt/__init__.py rename to src/goose/_internal/cli/prompt/__init__.py diff --git a/src/goose/cli/prompt/completer.py b/src/goose/_internal/cli/prompt/completer.py similarity index 97% rename from src/goose/cli/prompt/completer.py rename to src/goose/_internal/cli/prompt/completer.py index 6739d1530..7b5c3ec71 100644 --- a/src/goose/cli/prompt/completer.py +++ b/src/goose/_internal/cli/prompt/completer.py @@ -4,7 +4,7 @@ from prompt_toolkit.completion import CompleteEvent, Completer, Completion from prompt_toolkit.document import Document -from goose.command.base import Command +from goose.pluginbase.command import Command class GoosePromptCompleter(Completer): diff --git a/src/goose/cli/prompt/create.py b/src/goose/_internal/cli/prompt/create.py similarity index 93% rename from src/goose/cli/prompt/create.py rename to src/goose/_internal/cli/prompt/create.py index 628c86ccb..15245c9f9 100644 --- a/src/goose/cli/prompt/create.py +++ b/src/goose/_internal/cli/prompt/create.py @@ -4,9 +4,9 @@ from prompt_toolkit.keys import Keys from prompt_toolkit.styles import Style -from goose.cli.prompt.completer import GoosePromptCompleter -from goose.cli.prompt.lexer import PromptLexer -from goose.command import get_commands +from .completer import GoosePromptCompleter +from .lexer import PromptLexer +from ..command import get_commands def create_prompt() -> PromptSession: diff --git a/src/goose/cli/prompt/goose_prompt_session.py b/src/goose/_internal/cli/prompt/goose_prompt_session.py similarity index 87% rename from src/goose/cli/prompt/goose_prompt_session.py rename to src/goose/_internal/cli/prompt/goose_prompt_session.py index cfcedd80a..7b399fcfa 100644 --- a/src/goose/cli/prompt/goose_prompt_session.py +++ b/src/goose/_internal/cli/prompt/goose_prompt_session.py @@ -4,9 +4,9 @@ from prompt_toolkit.formatted_text import FormattedText from prompt_toolkit.validation import DummyValidator -from goose.cli.prompt.create import create_prompt -from goose.cli.prompt.prompt_validator import PromptValidator -from goose.cli.prompt.user_input import PromptAction, UserInput +from .create import create_prompt +from .prompt_validator import PromptValidator +from .user_input import PromptAction, UserInput class GoosePromptSession: diff --git a/src/goose/cli/prompt/lexer.py b/src/goose/_internal/cli/prompt/lexer.py similarity index 100% rename from src/goose/cli/prompt/lexer.py rename to src/goose/_internal/cli/prompt/lexer.py diff --git a/src/goose/cli/prompt/prompt_validator.py b/src/goose/_internal/cli/prompt/prompt_validator.py similarity index 100% rename from src/goose/cli/prompt/prompt_validator.py rename to src/goose/_internal/cli/prompt/prompt_validator.py diff --git a/src/goose/cli/prompt/user_input.py b/src/goose/_internal/cli/prompt/user_input.py similarity index 100% rename from src/goose/cli/prompt/user_input.py rename to src/goose/_internal/cli/prompt/user_input.py diff --git a/src/goose/toolkit/repo_context/__init__.py b/src/goose/_internal/cli/session/__init__.py similarity index 100% rename from src/goose/toolkit/repo_context/__init__.py rename to src/goose/_internal/cli/session/__init__.py diff --git a/src/goose/cli/session.py b/src/goose/_internal/cli/session/session.py similarity index 75% rename from src/goose/cli/session.py rename to src/goose/_internal/cli/session/session.py index 713bd1f8c..7f373c127 100644 --- a/src/goose/cli/session.py +++ b/src/goose/_internal/cli/session/session.py @@ -5,70 +5,20 @@ from exchange import Message, ToolResult, ToolUse, Text from prompt_toolkit.shortcuts import confirm from rich import print -from rich.console import RenderableType from rich.live import Live from rich.markdown import Markdown -from rich.panel import Panel from rich.status import Status -from goose.build import build_exchange -from goose.cli.config import ( - default_profiles, - ensure_config, - read_config, - session_path, -) -from goose.cli.prompt.goose_prompt_session import GoosePromptSession -from goose.notifier import Notifier -from goose.profile import Profile -from goose.utils import droid, load_plugins -from goose.utils.session_file import read_from_file, write_to_file +from .session_utils import random_session_name -RESUME_MESSAGE = "I see we were interrupted. How can I help you?" - - -def load_provider() -> str: - # We try to infer a provider, by going in order of what will auth - providers = load_plugins(group="exchange.provider") - for provider, cls in providers.items(): - try: - cls.from_env() - print(Panel(f"[green]Detected an available provider: [/]{provider}")) - return provider - except Exception: - pass - else: - # TODO link to auth docs - print( - Panel( - "[red]Could not authenticate any providers[/]\n" - + "Returning a default pointing to openai, but you will need to set an API token env variable." - ) - ) - return "openai" - - -def load_profile(name: Optional[str]) -> Profile: - if name is None: - name = "default" +from .session_notifier import SessionNotifier - # If the name is one of the default values, we ensure a valid configuration - if name in default_profiles(): - return ensure_config(name) +from ...profile.config import load_profile +from ...exchange.build import build_exchange +from ..prompt.goose_prompt_session import GoosePromptSession +from goose.pluginbase.utils.session_file import read_from_file, write_to_file, session_path - # Otherwise this is a custom config and we return it from the config file - return read_config()[name] - - -class SessionNotifier(Notifier): - def __init__(self, status_indicator: Status) -> None: - self.status_indicator = status_indicator - - def log(self, content: RenderableType) -> None: - print(content) - - def status(self, status: str) -> None: - self.status_indicator.update(status) +RESUME_MESSAGE = "I see we were interrupted. How can I help you?" class Session: @@ -92,7 +42,7 @@ def __init__( self.exchange = build_exchange(profile=load_profile(profile), notifier=notifier) if name is not None and self.session_file_path.exists(): - messages = self.load_session() + messages = self._load_session() if messages and messages[-1].role == "user": if type(messages[-1].content[-1]) is Text: @@ -112,11 +62,11 @@ def __init__( self.exchange.messages.extend(messages) if len(self.exchange.messages) == 0 and plan: - self.setup_plan(plan=plan) + self._setup_plan(plan=plan) self.prompt_session = GoosePromptSession.create_prompt_session() - def setup_plan(self, plan: dict) -> None: + def _setup_plan(self, plan: dict) -> None: if len(self.exchange.messages): raise ValueError("The plan can only be set on an empty session.") self.exchange.messages.append(Message.user(plan["kickoff_message"])) @@ -127,7 +77,7 @@ def setup_plan(self, plan: dict) -> None: plan_tool_use = ToolUse(id="initialplan", name="update_plan", parameters=dict(tasks=tasks)) self.exchange.add_tool_use(plan_tool_use) - def process_first_message(self) -> Optional[Message]: + def _process_first_message(self) -> Optional[Message]: # Get a first input unless it has been specified, such as by a plan if len(self.exchange.messages) == 0 or self.exchange.messages[-1].role == "assistant": user_input = self.prompt_session.get_user_input() @@ -141,14 +91,14 @@ def run(self) -> None: Runs the main loop to handle user inputs and responses. Continues until an empty string is returned from the prompt. """ - message = self.process_first_message() + message = self._process_first_message() while message: # Loop until no input (empty string). with Live(self.status_indicator, refresh_per_second=8, transient=True): try: self.exchange.add(message) - self.reply() # Process the user message. + self._reply() # Process the user message. except KeyboardInterrupt: - self.interrupt_reply() + self._interrupt_reply() except Exception: print(traceback.format_exc()) if self.exchange.messages: @@ -164,9 +114,9 @@ def run(self) -> None: user_input = self.prompt_session.get_user_input() message = Message.user(text=user_input.text) if user_input.to_continue() else None - self.save_session() + self._save_session() - def reply(self) -> None: + def _reply(self) -> None: """Reply to the last user message, calling tools as needed Args: @@ -190,7 +140,7 @@ def reply(self) -> None: if response.text: print(Markdown(response.text)) - def interrupt_reply(self) -> None: + def _interrupt_reply(self) -> None: """Recover from an interruption at an arbitrary state""" # Default recovery message if no user message is pending. recovery = "We interrupted before the next processing started." @@ -224,28 +174,28 @@ def interrupt_reply(self) -> None: def session_file_path(self) -> Path: return session_path(self.name) - def save_session(self) -> None: + def _save_session(self) -> None: """Save the current session to a file in JSON format.""" if self.name is None: - self.generate_session_name() + self._generate_session_name() try: if self.session_file_path.exists(): if not confirm(f"Session {self.name} exists in {self.session_file_path}, overwrite?"): - self.generate_session_name() + self._generate_session_name() write_to_file(self.session_file_path, self.exchange.messages) except PermissionError as e: raise RuntimeError(f"Failed to save session due to permissions: {e}") except (IOError, OSError) as e: raise RuntimeError(f"Failed to save session due to I/O error: {e}") - def load_session(self) -> List[Message]: + def _load_session(self) -> List[Message]: """Load a session from a JSON file.""" return read_from_file(self.session_file_path) - def generate_session_name(self) -> None: + def _generate_session_name(self) -> None: user_entered_session_name = self.prompt_session.get_save_session_name() - self.name = user_entered_session_name if user_entered_session_name else droid() + self.name = user_entered_session_name if user_entered_session_name else random_session_name() print(f"Saving to [bold cyan]{self.session_file_path}[/bold cyan]") diff --git a/src/goose/_internal/cli/session/session_notifier.py b/src/goose/_internal/cli/session/session_notifier.py new file mode 100644 index 000000000..43d2ecb7f --- /dev/null +++ b/src/goose/_internal/cli/session/session_notifier.py @@ -0,0 +1,15 @@ +from goose.pluginbase.notifier import Notifier +from rich.status import Status +from rich.console import RenderableType +from rich import print + + +class SessionNotifier(Notifier): + def __init__(self, status_indicator: Status) -> None: + self.status_indicator = status_indicator + + def log(self, content: RenderableType) -> None: + print(content) + + def status(self, status: str) -> None: + self.status_indicator.update(status) diff --git a/src/goose/_internal/cli/session/session_utils.py b/src/goose/_internal/cli/session/session_utils.py new file mode 100644 index 000000000..d3c295fe1 --- /dev/null +++ b/src/goose/_internal/cli/session/session_utils.py @@ -0,0 +1,12 @@ +import random +import string + +def random_session_name() -> str: + return "".join( + [ + random.choice(string.ascii_lowercase), + random.choice(string.digits), + random.choice(string.ascii_lowercase), + random.choice(string.digits), + ] + ) \ No newline at end of file diff --git a/src/goose/build.py b/src/goose/_internal/exchange/build.py similarity index 91% rename from src/goose/build.py rename to src/goose/_internal/exchange/build.py index dd539a9af..edebf86b5 100644 --- a/src/goose/build.py +++ b/src/goose/_internal/exchange/build.py @@ -4,11 +4,11 @@ from exchange.moderators import get_moderator from exchange.providers import get_provider -from goose.notifier import Notifier -from goose.profile import Profile -from goose.toolkit import get_toolkit -from goose.toolkit.base import Requirements -from goose.view import ExchangeView +from goose.pluginbase.notifier import Notifier +from goose.pluginbase.profile import Profile +from goose.pluginbase.toolkit import Requirements +from ..toolkit import get_toolkit +from .view import ExchangeView def build_exchange(profile: Profile, notifier: Notifier) -> Exchange: diff --git a/src/goose/system.jinja b/src/goose/_internal/exchange/system.jinja similarity index 100% rename from src/goose/system.jinja rename to src/goose/_internal/exchange/system.jinja diff --git a/src/goose/view.py b/src/goose/_internal/exchange/view.py similarity index 100% rename from src/goose/view.py rename to src/goose/_internal/exchange/view.py diff --git a/src/goose/_internal/profile/__init__.py b/src/goose/_internal/profile/__init__.py new file mode 100644 index 000000000..5647a50ac --- /dev/null +++ b/src/goose/_internal/profile/__init__.py @@ -0,0 +1,16 @@ +from typing import Any, Dict +from goose.pluginbase.profile import Profile, ToolkitSpec + + +def default_profile(provider: str, processor: str, accelerator: str, **kwargs: Dict[str, Any]) -> Profile: + """Get the default profile""" + + # TODO consider if the providers should have recommended models + + return Profile( + provider=provider, + processor=processor, + accelerator=accelerator, + moderator="truncate", + toolkits=[ToolkitSpec("developer")], + ) diff --git a/src/goose/cli/config.py b/src/goose/_internal/profile/config.py similarity index 77% rename from src/goose/cli/config.py rename to src/goose/_internal/profile/config.py index f875b49e9..7ce57f57c 100644 --- a/src/goose/cli/config.py +++ b/src/goose/_internal/profile/config.py @@ -1,7 +1,6 @@ from functools import cache from io import StringIO -from pathlib import Path -from typing import Callable, Dict, Mapping, Tuple +from typing import Callable, Mapping, Dict, Optional, Tuple from rich import print from rich.panel import Panel @@ -9,27 +8,18 @@ from rich.text import Text from ruamel.yaml import YAML -from goose.profile import Profile -from goose.utils import load_plugins -from goose.utils.diff import pretty_diff - -GOOSE_GLOBAL_PATH = Path("~/.config/goose").expanduser() -PROFILES_CONFIG_PATH = GOOSE_GLOBAL_PATH.joinpath("profiles.yaml") -SESSIONS_PATH = GOOSE_GLOBAL_PATH.joinpath("sessions") -SESSION_FILE_SUFFIX = ".jsonl" +from goose.config import PROFILES_CONFIG_PATH +from goose.pluginbase.profile import Profile +from ..utils import load_plugins +from .diff import pretty_diff @cache -def default_profiles() -> Mapping[str, Callable]: +def _all_recommended_profiles() -> Mapping[str, Callable]: return load_plugins(group="goose.profile") -def session_path(name: str) -> Path: - SESSIONS_PATH.mkdir(parents=True, exist_ok=True) - return SESSIONS_PATH.joinpath(f"{name}{SESSION_FILE_SUFFIX}") - - -def write_config(profiles: Dict[str, Profile]) -> None: +def _write_config(profiles: Dict[str, Profile]) -> None: """Overwrite the config with the passed profiles""" PROFILES_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True) converted = {name: profile.to_dict() for name, profile in profiles.items()} @@ -38,13 +28,13 @@ def write_config(profiles: Dict[str, Profile]) -> None: yaml.dump(converted, f) -def ensure_config(name: str) -> Profile: +def _ensure_config(name: str) -> Profile: """Ensure that the config exists and has the default section""" # TODO we should copy a templated default config in to better document # but this is complicated a bit by autodetecting the provider - provider, processor, accelerator = default_model_configuration() - profile = default_profiles()[name](provider, processor, accelerator) + provider, processor, accelerator = _default_model_configuration() + profile = _all_recommended_profiles()[name](provider, processor, accelerator) profiles = {} if not PROFILES_CONFIG_PATH.exists(): @@ -57,14 +47,14 @@ def ensure_config(name: str) -> Profile: ) default = profile profiles = {name: default} - write_config(profiles) + _write_config(profiles) return profile - profiles = read_config() + profiles = _read_config() if name not in profiles: print(Panel(f"[yellow]Your configuration doesn't have a profile named '{name}', adding one now[/yellow]")) profiles.update({name: profile}) - write_config(profiles) + _write_config(profiles) elif name in profiles: # if the profile stored differs from the default one, we should prompt the user to see if they want # to update it! we need to recursively compare the two profiles, as object comparison will always return false @@ -93,14 +83,14 @@ def ensure_config(name: str) -> Profile: ) if should_update: profiles[name] = profile - write_config(profiles) + _write_config(profiles) else: profile = profiles[name] return profile -def read_config() -> Dict[str, Profile]: +def _read_config() -> Dict[str, Profile]: """Return config from the configuration file and validates its contents""" yaml = YAML() @@ -110,7 +100,7 @@ def read_config() -> Dict[str, Profile]: return {name: Profile(**profile) for name, profile in data.items()} -def default_model_configuration() -> Tuple[str, str, str]: +def _default_model_configuration() -> Tuple[str, str, str]: providers = load_plugins(group="exchange.provider") for provider, cls in providers.items(): try: @@ -139,3 +129,15 @@ def default_model_configuration() -> Tuple[str, str, str]: } processor, accelerator = recommended.get(provider, ("gpt-4o", "gpt-4o-mini")) return provider, processor, accelerator + + +def load_profile(name: Optional[str]) -> Profile: + if name is None: + name = "default" + + # If the name is one of the default values, we ensure a valid configuration + if name in _all_recommended_profiles(): + return _ensure_config(name) + + # Otherwise this is a custom config and we return it from the config file + return _read_config()[name] diff --git a/src/goose/utils/diff.py b/src/goose/_internal/profile/diff.py similarity index 92% rename from src/goose/utils/diff.py rename to src/goose/_internal/profile/diff.py index e3583be01..e79199c4d 100644 --- a/src/goose/utils/diff.py +++ b/src/goose/_internal/profile/diff.py @@ -3,7 +3,7 @@ from rich.text import Text -def diff(a: str, b: str) -> List[str]: +def _diff(a: str, b: str) -> List[str]: """Returns a string containing the unified diff of two strings.""" import difflib @@ -23,7 +23,7 @@ def diff(a: str, b: str) -> List[str]: def pretty_diff(a: str, b: str) -> Text: """Returns a pretty-printed diff of two strings.""" - diff_lines = diff(a, b) + diff_lines = _diff(a, b) result = Text() for line in diff_lines: if line.startswith("+"): diff --git a/src/goose/_internal/toolkit/__init__.py b/src/goose/_internal/toolkit/__init__.py new file mode 100644 index 000000000..2a0030575 --- /dev/null +++ b/src/goose/_internal/toolkit/__init__.py @@ -0,0 +1,18 @@ +from .developer.developer import Developer # noqa: F401 +from .repo_context.repo_context import RepoContext # noqa: F401 +from .summarization.summarize_repo import SummarizeRepo # noqa +from .summarization.summarize_project import SummarizeProject # noqa +from .summarization.summarize_file import SummarizeFile # noqa: F401 +from .github.github import Github # noqa: F401 +from .screen import Screen # noqa: F401 + + +from functools import cache + +from goose.pluginbase.toolkit import Toolkit +from ..utils import load_plugins + + +@cache +def get_toolkit(name: str) -> type[Toolkit]: + return load_plugins(group="goose.toolkit")[name] diff --git a/tests/toolkit/__init__.py b/src/goose/_internal/toolkit/developer/__init__.py similarity index 100% rename from tests/toolkit/__init__.py rename to src/goose/_internal/toolkit/developer/__init__.py diff --git a/src/goose/toolkit/developer.py b/src/goose/_internal/toolkit/developer/developer.py similarity index 97% rename from src/goose/toolkit/developer.py rename to src/goose/_internal/toolkit/developer/developer.py index 1114df19d..63aca0299 100644 --- a/src/goose/toolkit/developer.py +++ b/src/goose/_internal/toolkit/developer/developer.py @@ -1,7 +1,7 @@ from pathlib import Path from subprocess import CompletedProcess, run from typing import List -from goose.utils.check_shell_command import is_dangerous_command +from .utils.check_shell_command import is_dangerous_command from exchange import Message from rich import box @@ -11,8 +11,8 @@ from rich.table import Table from rich.text import Text -from goose.toolkit.base import Toolkit, tool -from goose.toolkit.utils import get_language +from goose.pluginbase.toolkit import Toolkit, tool +from goose.pluginbase.utils.file_language import get_language def keep_unsafe_command_prompt(command: str) -> PromptType: @@ -34,7 +34,7 @@ class Developer(Toolkit): def system(self) -> str: """Retrieve system configuration details for developer""" - hints_path = Path('.goosehints') + hints_path = Path(".goosehints") system_prompt = Message.load("prompts/developer.jinja").text if hints_path.is_file(): goosehints = hints_path.read_text() diff --git a/src/goose/toolkit/prompts/developer.jinja b/src/goose/_internal/toolkit/developer/prompts/developer.jinja similarity index 100% rename from src/goose/toolkit/prompts/developer.jinja rename to src/goose/_internal/toolkit/developer/prompts/developer.jinja diff --git a/src/goose/toolkit/prompts/safety_rails.jinja b/src/goose/_internal/toolkit/developer/prompts/safety_rails.jinja similarity index 100% rename from src/goose/toolkit/prompts/safety_rails.jinja rename to src/goose/_internal/toolkit/developer/prompts/safety_rails.jinja diff --git a/src/goose/utils/check_shell_command.py b/src/goose/_internal/toolkit/developer/utils/check_shell_command.py similarity index 100% rename from src/goose/utils/check_shell_command.py rename to src/goose/_internal/toolkit/developer/utils/check_shell_command.py diff --git a/src/goose/toolkit/github.py b/src/goose/_internal/toolkit/github/github.py similarity index 87% rename from src/goose/toolkit/github.py rename to src/goose/_internal/toolkit/github/github.py index 4a7025926..102958b83 100644 --- a/src/goose/toolkit/github.py +++ b/src/goose/_internal/toolkit/github/github.py @@ -1,6 +1,6 @@ from exchange import Message -from goose.toolkit.base import Toolkit +from goose.pluginbase.toolkit import Toolkit class Github(Toolkit): diff --git a/src/goose/toolkit/prompts/github.jinja b/src/goose/_internal/toolkit/github/prompts/github.jinja similarity index 100% rename from src/goose/toolkit/prompts/github.jinja rename to src/goose/_internal/toolkit/github/prompts/github.jinja diff --git a/src/goose/_internal/toolkit/repo_context/__init__.py b/src/goose/_internal/toolkit/repo_context/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/goose/toolkit/repo_context/prompts/repo_context.jinja b/src/goose/_internal/toolkit/repo_context/prompts/repo_context.jinja similarity index 100% rename from src/goose/toolkit/repo_context/prompts/repo_context.jinja rename to src/goose/_internal/toolkit/repo_context/prompts/repo_context.jinja diff --git a/src/goose/toolkit/repo_context/repo_context.py b/src/goose/_internal/toolkit/repo_context/repo_context.py similarity index 91% rename from src/goose/toolkit/repo_context/repo_context.py rename to src/goose/_internal/toolkit/repo_context/repo_context.py index 89a01a76f..897f65a89 100644 --- a/src/goose/toolkit/repo_context/repo_context.py +++ b/src/goose/_internal/toolkit/repo_context/repo_context.py @@ -5,12 +5,12 @@ from exchange import Message -from goose.notifier import Notifier -from goose.toolkit import Toolkit -from goose.toolkit.base import Requirements, tool -from goose.toolkit.repo_context.utils import get_repo_size, goose_picks_files -from goose.toolkit.summarization.utils import load_summary_file_if_exists, summarize_files_concurrent -from goose.utils.ask import clear_exchange, replace_prompt +from goose.pluginbase.notifier import Notifier +from goose.pluginbase.toolkit import Toolkit +from goose.pluginbase.toolkit import Requirements, tool +from .repo_context_utils import get_repo_size, goose_picks_files +from goose.pluginbase.utils.summarization import load_summary_file_if_exists, summarize_files_concurrent +from goose.pluginbase.utils.ask import clear_exchange, replace_prompt class RepoContext(Toolkit): diff --git a/src/goose/toolkit/repo_context/utils.py b/src/goose/_internal/toolkit/repo_context/repo_context_utils.py similarity index 87% rename from src/goose/toolkit/repo_context/utils.py rename to src/goose/_internal/toolkit/repo_context/repo_context_utils.py index dca7f04b0..27fa0c55f 100644 --- a/src/goose/toolkit/repo_context/utils.py +++ b/src/goose/_internal/toolkit/repo_context/repo_context_utils.py @@ -6,10 +6,10 @@ from exchange import Exchange -from goose.utils.ask import ask_an_ai +from goose.pluginbase.utils.ask import ask_an_ai -def get_directory_size(directory: str) -> int: +def _get_directory_size(directory: str) -> int: total_size = 0 for dirpath, _, filenames in os.walk(directory): for f in filenames: @@ -23,10 +23,10 @@ def get_directory_size(directory: str) -> int: def get_repo_size(repo_path: str) -> int: """Returns repo size in MB""" git_dir = os.path.join(repo_path, ".git") - return get_directory_size(git_dir) / (1024**2) + return _get_directory_size(git_dir) / (1024**2) -def get_files_and_directories(root_dir: str) -> Dict[str, list]: +def _get_files_and_directories(root_dir: str) -> Dict[str, list]: """Gets file names and directory names. Checks that goose has correctly typed the file and directory names and that the files actually exist (to avoid downstream file read errors). @@ -70,7 +70,7 @@ def goose_picks_files(root: str, exchange: Exchange, max_workers: int = 4) -> Li with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: while queue: current_batch = [queue.popleft() for _ in range(min(max_workers, len(queue)))] - futures = {executor.submit(process_directory, dir, exchange): dir for dir in current_batch} + futures = {executor.submit(_process_directory, dir, exchange): dir for dir in current_batch} for future in concurrent.futures.as_completed(futures): files, next_dirs = future.result() @@ -80,12 +80,12 @@ def goose_picks_files(root: str, exchange: Exchange, max_workers: int = 4) -> Li return all_files -def process_directory(current_dir: str, exchange: Exchange) -> Tuple[List[str], List[str]]: +def _process_directory(current_dir: str, exchange: Exchange) -> Tuple[List[str], List[str]]: """Allows goose to pick files and subdirectories contained in a given directory (current_dir). Get the list of file and directory names in the current folder, then ask Goose to pick which ones to keep. """ - files_and_dirs = get_files_and_directories(current_dir) + files_and_dirs = _get_files_and_directories(current_dir) ai_response = ask_an_ai(str(files_and_dirs), exchange) # FIXME: goose response validation diff --git a/src/goose/toolkit/screen.py b/src/goose/_internal/toolkit/screen.py similarity index 95% rename from src/goose/toolkit/screen.py rename to src/goose/_internal/toolkit/screen.py index ce5524881..9a1bb9367 100644 --- a/src/goose/toolkit/screen.py +++ b/src/goose/_internal/toolkit/screen.py @@ -1,7 +1,7 @@ import subprocess import uuid -from goose.toolkit.base import Toolkit, tool +from goose.pluginbase.toolkit import Toolkit, tool class Screen(Toolkit): diff --git a/src/goose/_internal/toolkit/summarization/__init__.py b/src/goose/_internal/toolkit/summarization/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/src/goose/_internal/toolkit/summarization/__init__.py @@ -0,0 +1 @@ + diff --git a/src/goose/toolkit/summarization/summarize_file.py b/src/goose/_internal/toolkit/summarization/summarize_file.py similarity index 84% rename from src/goose/toolkit/summarization/summarize_file.py rename to src/goose/_internal/toolkit/summarization/summarize_file.py index d4685eec7..5535a9313 100644 --- a/src/goose/toolkit/summarization/summarize_file.py +++ b/src/goose/_internal/toolkit/summarization/summarize_file.py @@ -1,8 +1,7 @@ from typing import Optional -from goose.toolkit import Toolkit -from goose.toolkit.base import tool -from goose.toolkit.summarization.utils import summarize_file +from goose.pluginbase.toolkit import Toolkit, tool +from goose.pluginbase.utils.summarization import summarize_file class SummarizeFile(Toolkit): diff --git a/src/goose/toolkit/summarization/summarize_project.py b/src/goose/_internal/toolkit/summarization/summarize_project.py similarity index 90% rename from src/goose/toolkit/summarization/summarize_project.py rename to src/goose/_internal/toolkit/summarization/summarize_project.py index d910fbc47..06f830a63 100644 --- a/src/goose/toolkit/summarization/summarize_project.py +++ b/src/goose/_internal/toolkit/summarization/summarize_project.py @@ -1,9 +1,8 @@ import os from typing import List, Optional -from goose.toolkit import Toolkit -from goose.toolkit.base import tool -from goose.toolkit.summarization.utils import summarize_directory +from goose.pluginbase.toolkit import Toolkit, tool +from goose.pluginbase.utils.summarization import summarize_directory class SummarizeProject(Toolkit): diff --git a/src/goose/toolkit/summarization/summarize_repo.py b/src/goose/_internal/toolkit/summarization/summarize_repo.py similarity index 91% rename from src/goose/toolkit/summarization/summarize_repo.py rename to src/goose/_internal/toolkit/summarization/summarize_repo.py index 18c7da428..13f2e20d1 100644 --- a/src/goose/toolkit/summarization/summarize_repo.py +++ b/src/goose/_internal/toolkit/summarization/summarize_repo.py @@ -1,8 +1,7 @@ from typing import List, Optional -from goose.toolkit import Toolkit -from goose.toolkit.base import tool -from goose.toolkit.summarization.utils import summarize_repo +from goose.pluginbase.toolkit import Toolkit, tool +from goose.pluginbase.utils.summarization import summarize_repo class SummarizeRepo(Toolkit): diff --git a/src/goose/_internal/utils.py b/src/goose/_internal/utils.py new file mode 100644 index 000000000..29085621b --- /dev/null +++ b/src/goose/_internal/utils.py @@ -0,0 +1,26 @@ +from importlib.metadata import entry_points + + +def load_plugins(group: str) -> dict: + """ + Load plugins based on a specified entry point group. + + This function iterates through all entry points registered under a specified group + + Args: + group (str): The entry point group to load plugins from. This should match the group specified + in the package setup where plugins are defined. + + Returns: + dict: A dictionary where each key is the entry point name, and the value is the loaded plugin object. + + Raises: + Exception: Propagates exceptions raised by entry point loading, which might occur if a plugin + is not found or if there are issues with the plugin's code. + """ + plugins = {} + # Access all entry points for the specified group and load each. + for entrypoint in entry_points(group=group): + plugin = entrypoint.load() # Load the plugin. + plugins[entrypoint.name] = plugin # Store the loaded plugin in the dictionary. + return plugins diff --git a/src/goose/config/__init__.py b/src/goose/config/__init__.py new file mode 100644 index 000000000..963e86c5b --- /dev/null +++ b/src/goose/config/__init__.py @@ -0,0 +1,7 @@ +from pathlib import Path +from typing import Final + +GOOSE_GLOBAL_PATH: Final = Path("~/.config/goose").expanduser() +PROFILES_CONFIG_PATH: Final = GOOSE_GLOBAL_PATH.joinpath("profiles.yaml") +SESSIONS_PATH: Final = GOOSE_GLOBAL_PATH.joinpath("sessions") +SESSION_FILE_SUFFIX: Final = ".jsonl" diff --git a/src/goose/command/base.py b/src/goose/pluginbase/command.py similarity index 100% rename from src/goose/command/base.py rename to src/goose/pluginbase/command.py diff --git a/src/goose/notifier.py b/src/goose/pluginbase/notifier.py similarity index 100% rename from src/goose/notifier.py rename to src/goose/pluginbase/notifier.py diff --git a/src/goose/profile.py b/src/goose/pluginbase/profile.py similarity index 66% rename from src/goose/profile.py rename to src/goose/pluginbase/profile.py index ec0a12a54..2745067b7 100644 --- a/src/goose/profile.py +++ b/src/goose/pluginbase/profile.py @@ -2,7 +2,7 @@ from attrs import asdict, define, field -from goose.utils import ensure_list +from .utils.converter import ensure_list @define @@ -24,7 +24,7 @@ class Profile: toolkits: List[ToolkitSpec] = field(factory=list, converter=ensure_list(ToolkitSpec)) @toolkits.validator - def check_toolkit_requirements(self, _: Type["ToolkitSpec"], toolkits: List[ToolkitSpec]) -> None: + def _check_toolkit_requirements(self, _: Type["ToolkitSpec"], toolkits: List[ToolkitSpec]) -> None: # checks that the list of toolkits in the profile have their requirements installed_toolkits = set([toolkit.name for toolkit in toolkits]) @@ -38,17 +38,3 @@ def check_toolkit_requirements(self, _: Type["ToolkitSpec"], toolkits: List[Tool def to_dict(self) -> Dict[str, Any]: return asdict(self) - - -def default_profile(provider: str, processor: str, accelerator: str, **kwargs: Dict[str, Any]) -> Profile: - """Get the default profile""" - - # TODO consider if the providers should have recommended models - - return Profile( - provider=provider, - processor=processor, - accelerator=accelerator, - moderator="truncate", - toolkits=[ToolkitSpec("developer")], - ) diff --git a/src/goose/toolkit/base.py b/src/goose/pluginbase/toolkit.py similarity index 98% rename from src/goose/toolkit/base.py rename to src/goose/pluginbase/toolkit.py index d26630ca4..0ff856c5f 100644 --- a/src/goose/toolkit/base.py +++ b/src/goose/pluginbase/toolkit.py @@ -5,7 +5,7 @@ from attrs import define, field from exchange import Tool -from goose.notifier import Notifier +from .notifier import Notifier # Create a type variable that can represent any function signature F = TypeVar("F", bound=Callable) diff --git a/src/goose/pluginbase/utils/__init__.py b/src/goose/pluginbase/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/goose/utils/ask.py b/src/goose/pluginbase/utils/ask.py similarity index 100% rename from src/goose/utils/ask.py rename to src/goose/pluginbase/utils/ask.py diff --git a/src/goose/pluginbase/utils/converter.py b/src/goose/pluginbase/utils/converter.py new file mode 100644 index 000000000..5e3a24e50 --- /dev/null +++ b/src/goose/pluginbase/utils/converter.py @@ -0,0 +1,30 @@ +from typing import Any, Callable, Dict, List, Type, TypeVar + +T = TypeVar("T") + +def ensure(cls: Type[T]) -> Callable[[Any], T]: + """Convert dictionary to a class instance""" + + def converter(val: Any) -> T: # noqa: ANN401 + if isinstance(val, cls): + return val + elif isinstance(val, dict): + return cls(**val) + elif isinstance(val, list): + return cls(*val) + else: + return cls(val) + + return converter + + +def ensure_list(cls: Type[T]) -> Callable[[List[Dict[str, Any]]], Type[T]]: + """Convert a list of dictionaries to class instances""" + + def converter(val: List[Dict[str, Any]]) -> List[T]: + output = [] + for entry in val: + output.append(ensure(cls)(entry)) + return output + + return converter diff --git a/src/goose/toolkit/utils.py b/src/goose/pluginbase/utils/file_language.py similarity index 100% rename from src/goose/toolkit/utils.py rename to src/goose/pluginbase/utils/file_language.py diff --git a/src/goose/utils/session_file.py b/src/goose/pluginbase/utils/session_file.py similarity index 71% rename from src/goose/utils/session_file.py rename to src/goose/pluginbase/utils/session_file.py index a47efcb1e..af36fc6a8 100644 --- a/src/goose/utils/session_file.py +++ b/src/goose/pluginbase/utils/session_file.py @@ -4,7 +4,7 @@ from exchange import Message -from goose.cli.config import SESSION_FILE_SUFFIX +from goose.config import SESSIONS_PATH, SESSION_FILE_SUFFIX def write_to_file(file_path: Path, messages: List[Message]) -> None: @@ -25,15 +25,20 @@ def read_from_file(file_path: Path) -> List[Message]: def list_sorted_session_files(session_files_directory: Path) -> Dict[str, Path]: - logs = list_session_files(session_files_directory) + logs = _list_session_files(session_files_directory) return {log.stem: log for log in sorted(logs, key=lambda x: x.stat().st_mtime, reverse=True)} -def list_session_files(session_files_directory: Path) -> Iterator[Path]: +def _list_session_files(session_files_directory: Path) -> Iterator[Path]: return session_files_directory.glob(f"*{SESSION_FILE_SUFFIX}") def session_file_exists(session_files_directory: Path) -> bool: if not session_files_directory.exists(): return False - return any(list_session_files(session_files_directory)) + return any(_list_session_files(session_files_directory)) + + +def session_path(name: str) -> Path: + SESSIONS_PATH.mkdir(parents=True, exist_ok=True) + return SESSIONS_PATH.joinpath(f"{name}{SESSION_FILE_SUFFIX}") diff --git a/src/goose/toolkit/summarization/utils.py b/src/goose/pluginbase/utils/summarization.py similarity index 84% rename from src/goose/toolkit/summarization/utils.py rename to src/goose/pluginbase/utils/summarization.py index d398713cc..a89f4cf1e 100644 --- a/src/goose/toolkit/summarization/utils.py +++ b/src/goose/pluginbase/utils/summarization.py @@ -1,3 +1,4 @@ +import glob import json import subprocess from concurrent.futures import ThreadPoolExecutor, as_completed @@ -7,15 +8,14 @@ from exchange import Exchange from exchange.providers.utils import InitialMessageTooLargeError -from goose.utils.ask import ask_an_ai -from goose.utils.file_utils import create_file_list +from .ask import ask_an_ai SUMMARIES_FOLDER = ".goose/summaries" CLONED_REPOS_FOLDER = ".goose/cloned_repos" # TODO: move git stuff -def run_git_command(command: List[str]) -> subprocess.CompletedProcess[str]: +def _run_git_command(command: List[str]) -> subprocess.CompletedProcess[str]: result = subprocess.run(["git"] + command, capture_output=True, text=True, check=False) if result.returncode != 0: @@ -24,8 +24,8 @@ def run_git_command(command: List[str]) -> subprocess.CompletedProcess[str]: return result -def clone_repo(repo_url: str, target_directory: str) -> None: - run_git_command(["clone", repo_url, target_directory]) +def _clone_repo(repo_url: str, target_directory: str) -> None: + _run_git_command(["clone", repo_url, target_directory]) def load_summary_file_if_exists(project_name: str) -> Optional[Dict]: @@ -99,7 +99,7 @@ def summarize_repo( summary_instructions_prompt=summary_instructions_prompt, ) - clone_repo(repo_url, target_directory=repo_dir) + _clone_repo(repo_url, target_directory=repo_dir) return summarize_directory( directory=repo_dir, @@ -139,7 +139,7 @@ def summarize_directory( Path(SUMMARIES_FOLDER).mkdir(exist_ok=True, parents=True) # select a subset of files to summarize based on file extension - files_to_summarize = create_file_list(directory, extensions=extensions) + files_to_summarize = _create_file_list(directory, extensions=extensions) file_summaries = summarize_files_concurrent( exchange=exchange, @@ -197,3 +197,29 @@ def summarize_files_concurrent( json.dump(file_summaries, f, indent=2) return file_summaries + + +def _create_file_list(dir_path: str, extensions: List[str]) -> List[str]: + """Creates a list of files with certain extensions + + Args: + dir_path (str): Directory to list files of. Will include files recursively in sub-directories. + extensions (List[str]): List of file extensions to select for. If empty list, return all files + + Returns: + final_file_list (List[str]): List of file paths with specified extensions. + """ + # if extensions is empty list, return all files + if not extensions: + return glob.glob(f"{dir_path}/**/*", recursive=True) + + # prune out files that do not end with any of the extensions in extensions + final_file_list = [] + for ext in extensions: + if ext and not ext.startswith("."): + ext = f".{ext}" + + files = glob.glob(f"{dir_path}/**/*{ext}", recursive=True) + final_file_list += files + + return final_file_list diff --git a/src/goose/toolkit/__init__.py b/src/goose/toolkit/__init__.py deleted file mode 100644 index a3a97d41f..000000000 --- a/src/goose/toolkit/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from functools import cache - -from goose.toolkit.base import Toolkit -from goose.utils import load_plugins - - -@cache -def get_toolkit(name: str) -> type[Toolkit]: - return load_plugins(group="goose.toolkit")[name] diff --git a/src/goose/toolkit/summarization/__init__.py b/src/goose/toolkit/summarization/__init__.py deleted file mode 100644 index 3a4b25916..000000000 --- a/src/goose/toolkit/summarization/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .summarize_repo import SummarizeRepo # noqa -from .summarize_project import SummarizeProject # noqa -from .summarize_file import SummarizeFile # noqa diff --git a/src/goose/utils/__init__.py b/src/goose/utils/__init__.py deleted file mode 100644 index 69887e7f0..000000000 --- a/src/goose/utils/__init__.py +++ /dev/null @@ -1,70 +0,0 @@ -import random -import string -from importlib.metadata import entry_points -from typing import Any, Callable, Dict, List, Type, TypeVar - -T = TypeVar("T") - - -def load_plugins(group: str) -> dict: - """ - Load plugins based on a specified entry point group. - - This function iterates through all entry points registered under a specified group - - Args: - group (str): The entry point group to load plugins from. This should match the group specified - in the package setup where plugins are defined. - - Returns: - dict: A dictionary where each key is the entry point name, and the value is the loaded plugin object. - - Raises: - Exception: Propagates exceptions raised by entry point loading, which might occur if a plugin - is not found or if there are issues with the plugin's code. - """ - plugins = {} - # Access all entry points for the specified group and load each. - for entrypoint in entry_points(group=group): - plugin = entrypoint.load() # Load the plugin. - plugins[entrypoint.name] = plugin # Store the loaded plugin in the dictionary. - return plugins - - -def ensure(cls: Type[T]) -> Callable[[Any], T]: - """Convert dictionary to a class instance""" - - def converter(val: Any) -> T: # noqa: ANN401 - if isinstance(val, cls): - return val - elif isinstance(val, dict): - return cls(**val) - elif isinstance(val, list): - return cls(*val) - else: - return cls(val) - - return converter - - -def ensure_list(cls: Type[T]) -> Callable[[List[Dict[str, Any]]], Type[T]]: - """Convert a list of dictionaries to class instances""" - - def converter(val: List[Dict[str, Any]]) -> List[T]: - output = [] - for entry in val: - output.append(ensure(cls)(entry)) - return output - - return converter - - -def droid() -> str: - return "".join( - [ - random.choice(string.ascii_lowercase), - random.choice(string.digits), - random.choice(string.ascii_lowercase), - random.choice(string.digits), - ] - ) diff --git a/src/goose/utils/file_utils.py b/src/goose/utils/file_utils.py deleted file mode 100644 index eabc50f73..000000000 --- a/src/goose/utils/file_utils.py +++ /dev/null @@ -1,103 +0,0 @@ -import glob -import os -from collections import Counter -from pathlib import Path -from typing import Dict, List, Optional - - -def create_extensions_list(project_root: str, max_n: int) -> list: - """Get the top N file extensions in the current project - Args: - project_root (str): Root of the project to analyze - max_n (int): The number of file extensions to return - Returns: - extensions (List[str]): A list of the top N file extensions - """ - if max_n == 0: - raise (ValueError("Number of file extensions must be greater than 0")) - - files = create_file_list(project_root, []) - - counter = Counter() - - for file in files: - file_path = Path(file) - if file_path.suffix: # omit '' - counter[file_path.suffix] += 1 - - top_n = counter.most_common(max_n) - extensions = [ext for ext, _ in top_n] - - return extensions - - -def create_language_weighting(files_in_directory: List[str]) -> Dict[str, float]: - """Calculate language weighting by file size to match GitHub's methodology. - - Args: - files_in_directory (List[str]): Paths to files in the project directory - - Returns: - Dict[str, float]: A dictionary with languages as keys and their percentage of the total codebase as values - """ - - # Initialize counters for sizes - size_by_language = Counter() - - # Calculate size for files by language - for file_path in files_in_directory: - path = Path(file_path) - if path.suffix: - size_by_language[path.suffix] += os.path.getsize(file_path) - - # Calculate total size and language percentages - total_size = sum(size_by_language.values()) - language_percentages = { - lang: (size / total_size * 100) if total_size else 0 for lang, size in size_by_language.items() - } - - return dict(sorted(language_percentages.items(), key=lambda item: item[1], reverse=True)) - - -def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> List[str]: - """List all files in a directory with a given extension. Set extension to '' to return all files. - - Args: - dir_path (str): The path to the directory - extension (Optional[str]): extension to lookup. Defaults to '' which will return all files. - - Returns: - files (List[str]): List of file paths - """ - # add a leading '.' to extension if needed - if extension and not extension.startswith("."): - extension = f".{extension}" - - files = glob.glob(f"{dir_path}/**/*{extension}", recursive=True) - return files - - -def create_file_list(dir_path: str, extensions: List[str]) -> List[str]: - """Creates a list of files with certain extensions - - Args: - dir_path (str): Directory to list files of. Will include files recursively in sub-directories. - extensions (List[str]): List of file extensions to select for. If empty list, return all files - - Returns: - final_file_list (List[str]): List of file paths with specified extensions. - """ - # if extensions is empty list, return all files - if not extensions: - return glob.glob(f"{dir_path}/**/*", recursive=True) - - # prune out files that do not end with any of the extensions in extensions - final_file_list = [] - for ext in extensions: - if ext and not ext.startswith("."): - ext = f".{ext}" - - files = glob.glob(f"{dir_path}/**/*{ext}", recursive=True) - final_file_list += files - - return final_file_list diff --git a/tests/test_completer.py b/tests/_internal/cli/prompt/test_completer.py similarity index 92% rename from tests/test_completer.py rename to tests/_internal/cli/prompt/test_completer.py index 8749975f3..56d8676a3 100644 --- a/tests/test_completer.py +++ b/tests/_internal/cli/prompt/test_completer.py @@ -1,8 +1,8 @@ from unittest.mock import Mock import pytest -from goose.cli.prompt.completer import GoosePromptCompleter -from goose.command.base import Command +from goose._internal.cli.prompt.completer import GoosePromptCompleter +from goose.pluginbase.command import Command from prompt_toolkit.completion import Completion from prompt_toolkit.document import Document diff --git a/tests/cli/prompt/test_goose_prompt_session.py b/tests/_internal/cli/prompt/test_goose_prompt_session.py similarity index 90% rename from tests/cli/prompt/test_goose_prompt_session.py rename to tests/_internal/cli/prompt/test_goose_prompt_session.py index eca44cc67..bde568262 100644 --- a/tests/cli/prompt/test_goose_prompt_session.py +++ b/tests/_internal/cli/prompt/test_goose_prompt_session.py @@ -1,8 +1,8 @@ from unittest.mock import patch import pytest -from goose.cli.prompt.goose_prompt_session import GoosePromptSession -from goose.cli.prompt.user_input import PromptAction, UserInput +from goose._internal.cli.prompt.goose_prompt_session import GoosePromptSession +from goose._internal.cli.prompt.user_input import PromptAction, UserInput @pytest.fixture diff --git a/tests/cli/prompt/test_lexer.py b/tests/_internal/cli/prompt/test_lexer.py similarity index 99% rename from tests/cli/prompt/test_lexer.py rename to tests/_internal/cli/prompt/test_lexer.py index 585bead9b..32f501873 100644 --- a/tests/cli/prompt/test_lexer.py +++ b/tests/_internal/cli/prompt/test_lexer.py @@ -1,4 +1,4 @@ -from goose.cli.prompt.lexer import ( +from goose._internal.cli.prompt.lexer import ( PromptLexer, command_itself, completion_for_command, diff --git a/tests/cli/prompt/test_prompt_validator.py b/tests/_internal/cli/prompt/test_prompt_validator.py similarity index 93% rename from tests/cli/prompt/test_prompt_validator.py rename to tests/_internal/cli/prompt/test_prompt_validator.py index 380a7dd51..f019d87fc 100644 --- a/tests/cli/prompt/test_prompt_validator.py +++ b/tests/_internal/cli/prompt/test_prompt_validator.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch import pytest -from goose.cli.prompt.prompt_validator import PromptValidator +from goose._internal.cli.prompt.prompt_validator import PromptValidator from prompt_toolkit.validation import ValidationError diff --git a/tests/cli/prompt/test_user_input.py b/tests/_internal/cli/prompt/test_user_input.py similarity index 84% rename from tests/cli/prompt/test_user_input.py rename to tests/_internal/cli/prompt/test_user_input.py index 029fadfad..c3220a43d 100644 --- a/tests/cli/prompt/test_user_input.py +++ b/tests/_internal/cli/prompt/test_user_input.py @@ -1,4 +1,4 @@ -from goose.cli.prompt.user_input import PromptAction, UserInput +from goose._internal.cli.prompt.user_input import PromptAction, UserInput def test_user_input_with_action_continue(): diff --git a/tests/cli/test_session.py b/tests/_internal/cli/session/test_session.py similarity index 80% rename from tests/cli/test_session.py rename to tests/_internal/cli/session/test_session.py index 79a7c4a2b..5b28d779c 100644 --- a/tests/cli/test_session.py +++ b/tests/_internal/cli/session/test_session.py @@ -2,9 +2,9 @@ import pytest from exchange import Message, ToolUse, ToolResult -from goose.cli.prompt.goose_prompt_session import GoosePromptSession -from goose.cli.prompt.user_input import PromptAction, UserInput -from goose.cli.session import Session +from goose._internal.cli.prompt.goose_prompt_session import GoosePromptSession +from goose._internal.cli.prompt.user_input import PromptAction, UserInput +from goose._internal.cli.session.session import Session from prompt_toolkit import PromptSession SPECIFIED_SESSION_NAME = "mySession" @@ -18,11 +18,17 @@ def mock_specified_session_name(): @pytest.fixture -def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profile_factory): - with patch("goose.cli.session.build_exchange", return_value=exchange_factory()), patch( - "goose.cli.session.load_profile", return_value=profile_factory() - ), patch("goose.cli.session.SessionNotifier") as mock_session_notifier, patch( - "goose.cli.session.load_provider", return_value="provider" +def mock_sessions_path(tmp_path): + with patch("goose.config.SESSIONS_PATH", tmp_path) as mock_path: + yield mock_path + + +@pytest.fixture +def create_session_with_mock_configs(exchange_factory, profile_factory, tmp_path): + with patch("goose._internal.cli.session.session.build_exchange", return_value=exchange_factory()), patch( + "goose._internal.cli.session.session.load_profile", return_value=profile_factory() + ), patch("goose._internal.cli.session.session.SessionNotifier") as mock_session_notifier, patch( + "goose.pluginbase.utils.session_file.SESSIONS_PATH", tmp_path ): mock_session_notifier.return_value = MagicMock() @@ -83,11 +89,11 @@ def test_save_session_create_session(mock_sessions_path, create_session_with_moc session = create_session_with_mock_configs() session.exchange.messages.append(Message.assistant("Hello")) - session.save_session() + session._save_session() session_file = mock_sessions_path / f"{SPECIFIED_SESSION_NAME}.jsonl" assert session_file.exists() - saved_messages = session.load_session() + saved_messages = session._load_session() assert len(saved_messages) == 1 assert saved_messages[0].text == "Hello" @@ -95,7 +101,7 @@ def test_save_session_create_session(mock_sessions_path, create_session_with_moc def test_save_session_resume_session_new_file( mock_sessions_path, create_session_with_mock_configs, mock_specified_session_name, create_session_file ): - with patch("goose.cli.session.confirm", return_value=False): + with patch("goose._internal.cli.session.session.confirm", return_value=False): existing_messages = [Message.assistant("existing_message")] existing_session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl" create_session_file(existing_messages, existing_session_file) @@ -106,19 +112,19 @@ def test_save_session_resume_session_new_file( session = create_session_with_mock_configs({"name": SESSION_NAME}) session.exchange.messages.append(Message.assistant("new_message")) - session.save_session() + session._save_session() assert new_session_file.exists() assert existing_session_file.exists() - saved_messages = session.load_session() + saved_messages = session._load_session() assert [message.text for message in saved_messages] == ["existing_message", "new_message"] def test_save_session_resume_session_existing_session_file( mock_sessions_path, create_session_with_mock_configs, create_session_file ): - with patch("goose.cli.session.confirm", return_value=True): + with patch("goose._internal.cli.session.session.confirm", return_value=True): existing_messages = [Message.assistant("existing_message")] existing_session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl" create_session_file(existing_messages, existing_session_file) @@ -126,9 +132,9 @@ def test_save_session_resume_session_existing_session_file( session = create_session_with_mock_configs({"name": SESSION_NAME}) session.exchange.messages.append(Message.assistant("new_message")) - session.save_session() + session._save_session() - saved_messages = session.load_session() + saved_messages = session._load_session() assert [message.text for message in saved_messages] == ["existing_message", "new_message"] @@ -137,7 +143,7 @@ def test_process_first_message_return_message(create_session_with_mock_configs): with patch.object( GoosePromptSession, "get_user_input", return_value=UserInput(action=PromptAction.CONTINUE, text="Hello") ): - message = session.process_first_message() + message = session._process_first_message() assert message.text == "Hello" assert len(session.exchange.messages) == 0 @@ -146,7 +152,7 @@ def test_process_first_message_return_message(create_session_with_mock_configs): def test_process_first_message_to_exit(create_session_with_mock_configs): session = create_session_with_mock_configs() with patch.object(GoosePromptSession, "get_user_input", return_value=UserInput(action=PromptAction.EXIT)): - message = session.process_first_message() + message = session._process_first_message() assert message is None @@ -155,7 +161,7 @@ def test_process_first_message_return_last_exchange_message(create_session_with_ session = create_session_with_mock_configs() session.exchange.messages.append(Message.user("Hi")) - message = session.process_first_message() + message = session._process_first_message() assert message.text == "Hi" assert len(session.exchange.messages) == 0 @@ -164,6 +170,6 @@ def test_process_first_message_return_last_exchange_message(create_session_with_ def test_generate_session_name(create_session_with_mock_configs): session = create_session_with_mock_configs() with patch.object(GoosePromptSession, "get_save_session_name", return_value=SPECIFIED_SESSION_NAME): - session.generate_session_name() + session._generate_session_name() assert session.name == SPECIFIED_SESSION_NAME diff --git a/tests/_internal/cli/session/test_session_utils.py b/tests/_internal/cli/session/test_session_utils.py new file mode 100644 index 000000000..eebbfc74e --- /dev/null +++ b/tests/_internal/cli/session/test_session_utils.py @@ -0,0 +1,11 @@ +import string +from goose._internal.cli.session.session_utils import random_session_name + +def test_droid(): + result = random_session_name() + assert isinstance(result, str) + assert len(result) == 4 + for character in [result[i] for i in [0, 2]]: + assert character in string.ascii_lowercase, "should be in lower case" + for character in [result[i] for i in [1, 3]]: + assert character in string.digits, "should be a digit" diff --git a/tests/cli/test_main.py b/tests/_internal/cli/test_main.py similarity index 89% rename from tests/cli/test_main.py rename to tests/_internal/cli/test_main.py index 617b3d5c1..4ba3c052a 100644 --- a/tests/cli/test_main.py +++ b/tests/_internal/cli/test_main.py @@ -7,24 +7,24 @@ import pytest from click.testing import CliRunner from exchange import Message -from goose.cli.main import cli, goose_cli +from goose._internal.cli.main import cli, goose_cli @pytest.fixture def mock_print(): - with patch("goose.cli.main.print") as mock_print: + with patch("goose._internal.cli.main.print") as mock_print: yield mock_print @pytest.fixture def mock_session_files_path(tmp_path): - with patch("goose.cli.main.SESSIONS_PATH", tmp_path) as session_files_path: + with patch("goose._internal.cli.main.SESSIONS_PATH", tmp_path) as session_files_path: yield session_files_path @pytest.fixture def mock_session(): - with patch("goose.cli.main.Session") as mock_session_class: + with patch("goose._internal.cli.main.Session") as mock_session_class: mock_session_instance = MagicMock() mock_session_class.return_value = mock_session_instance yield mock_session_class, mock_session_instance @@ -83,7 +83,7 @@ def test_session_clear_command(mock_session_files_path, create_session_file): def test_combined_group_option(): - with patch("goose.utils.load_plugins") as mock_load_plugin: + with patch("goose._internal.utils.load_plugins") as mock_load_plugin: group_option_name = "--describe-commands" def option_callback(ctx, *_): @@ -107,10 +107,10 @@ def side_effect_func(param): mock_load_plugin.side_effect = side_effect_func # reload cli after mocking - importlib.reload(importlib.import_module("goose.cli.main")) - import goose.cli.main + importlib.reload(importlib.import_module("goose._internal.cli.main")) + import goose._internal.cli.main - cli = goose.cli.main.cli + cli = goose._internal.cli.main.cli runner = CliRunner() result = runner.invoke(cli, [group_option_name]) diff --git a/tests/_internal/profile/test_config.py b/tests/_internal/profile/test_config.py new file mode 100644 index 000000000..4635fa092 --- /dev/null +++ b/tests/_internal/profile/test_config.py @@ -0,0 +1,83 @@ +from unittest.mock import patch + +import pytest +from goose._internal.profile import default_profile +from goose._internal.profile.config import _ensure_config, _read_config, _write_config + + +@pytest.fixture +def mock_profile_config_path(tmp_path): + with patch("goose._internal.profile.config.PROFILES_CONFIG_PATH", tmp_path / "profiles.yaml") as mock_path: + yield mock_path + + +@pytest.fixture +def mock_default_model_configuration(): + with patch( + "goose._internal.profile.config._default_model_configuration", + return_value=("provider", "processor", "accelerator"), + ) as mock_default_model_configuration: + yield mock_default_model_configuration + + +def test_read_write_config(mock_profile_config_path, profile_factory): + profiles = { + "profile1": profile_factory({"provider": "providerA"}), + } + _write_config(profiles) + + assert _read_config() == profiles + + +def test_ensure_config_create_profiles_file_with_default_profile( + mock_profile_config_path, mock_default_model_configuration +): + assert not mock_profile_config_path.exists() + + _ensure_config(name="default") + assert mock_profile_config_path.exists() + + assert _read_config() == {"default": default_profile(*mock_default_model_configuration())} + + +@patch("goose._internal.profile.config.print") +def test_ensure_config_add_default_profile( + mock_print, mock_profile_config_path, profile_factory, mock_default_model_configuration +): + existing_profile = profile_factory({"provider": "providerA"}) + _write_config({"profile1": existing_profile}) + + _ensure_config(name="default") + + assert _read_config() == { + "profile1": existing_profile, + "default": default_profile(*mock_default_model_configuration()), + } + + +@patch("goose._internal.profile.config.Confirm.ask", return_value=True) +@patch("goose._internal.profile.config.print") +def test_ensure_config_overwrite_default_profile( + mock_confirm, mock_print, mock_profile_config_path, profile_factory, mock_default_model_configuration +): + existing_profile = profile_factory({"provider": "providerA"}) + profile_name = "default" + _write_config({profile_name: existing_profile}) + + expected_default_profile = default_profile(*mock_default_model_configuration()) + assert _ensure_config(name="default") == expected_default_profile + assert _read_config() == {"default": expected_default_profile} + + +@patch("goose._internal.profile.config.Confirm.ask", return_value=False) +@patch("goose._internal.profile.config.print") +def test_ensure_config_keep_original_default_profile( + mock_confirm, mock_print, mock_profile_config_path, profile_factory, mock_default_model_configuration +): + existing_profile = profile_factory({"provider": "providerA"}) + profile_name = "default" + _write_config({profile_name: existing_profile}) + + assert _ensure_config(name="default") == existing_profile + + assert _read_config() == {"default": existing_profile} diff --git a/tests/_internal/test_utils.py b/tests/_internal/test_utils.py new file mode 100644 index 000000000..85f5735dc --- /dev/null +++ b/tests/_internal/test_utils.py @@ -0,0 +1,7 @@ +from goose._internal.utils import load_plugins + + +def test_load_plugins(): + plugins = load_plugins("exchange.provider") + assert isinstance(plugins, dict) + assert len(plugins) > 0 diff --git a/tests/_internal/toolkit/__init__.py b/tests/_internal/toolkit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/utils/test_check_shell_command.py b/tests/_internal/toolkit/developer/test_check_shell_command.py similarity index 87% rename from tests/utils/test_check_shell_command.py rename to tests/_internal/toolkit/developer/test_check_shell_command.py index 7d94a46d6..d66694e3e 100644 --- a/tests/utils/test_check_shell_command.py +++ b/tests/_internal/toolkit/developer/test_check_shell_command.py @@ -1,5 +1,5 @@ import pytest -from goose.utils.check_shell_command import is_dangerous_command +from goose._internal.toolkit.developer.utils.check_shell_command import is_dangerous_command @pytest.mark.parametrize( diff --git a/tests/toolkit/test_developer.py b/tests/_internal/toolkit/developer/test_developer.py similarity index 94% rename from tests/toolkit/test_developer.py rename to tests/_internal/toolkit/developer/test_developer.py index e049ee9f2..6d14611fd 100644 --- a/tests/toolkit/test_developer.py +++ b/tests/_internal/toolkit/developer/test_developer.py @@ -5,8 +5,8 @@ from unittest.mock import MagicMock, Mock import pytest -from goose.toolkit.base import Requirements -from goose.toolkit.developer import Developer +from goose.pluginbase.toolkit import Requirements +from goose._internal.toolkit.developer.developer import Developer @pytest.fixture @@ -68,5 +68,3 @@ def test_write_file(temp_dir, developer_toolkit): content = "Hello World" developer_toolkit.write_file(test_file.as_posix(), content) assert test_file.read_text() == content - - diff --git a/tests/cli/test_config.py b/tests/cli/test_config.py deleted file mode 100644 index b857f8b99..000000000 --- a/tests/cli/test_config.py +++ /dev/null @@ -1,81 +0,0 @@ -from unittest.mock import patch - -import pytest -from goose.cli.config import ensure_config, read_config, session_path, write_config -from goose.profile import default_profile - - -@pytest.fixture -def mock_profile_config_path(tmp_path): - with patch("goose.cli.config.PROFILES_CONFIG_PATH", tmp_path / "profiles.yaml") as mock_path: - yield mock_path - - -@pytest.fixture -def mock_default_model_configuration(): - with patch( - "goose.cli.config.default_model_configuration", return_value=("provider", "processor", "accelerator") - ) as mock_default_model_configuration: - yield mock_default_model_configuration - - -def test_read_write_config(mock_profile_config_path, profile_factory): - profiles = { - "profile1": profile_factory({"provider": "providerA"}), - } - write_config(profiles) - - assert read_config() == profiles - - -def test_ensure_config_create_profiles_file_with_default_profile( - mock_profile_config_path, mock_default_model_configuration -): - assert not mock_profile_config_path.exists() - - ensure_config(name="default") - assert mock_profile_config_path.exists() - - assert read_config() == {"default": default_profile(*mock_default_model_configuration())} - - -def test_ensure_config_add_default_profile(mock_profile_config_path, profile_factory, mock_default_model_configuration): - existing_profile = profile_factory({"provider": "providerA"}) - write_config({"profile1": existing_profile}) - - ensure_config(name="default") - - assert read_config() == { - "profile1": existing_profile, - "default": default_profile(*mock_default_model_configuration()), - } - - -@patch("goose.cli.config.Confirm.ask", return_value=True) -def test_ensure_config_overwrite_default_profile( - mock_confirm, mock_profile_config_path, profile_factory, mock_default_model_configuration -): - existing_profile = profile_factory({"provider": "providerA"}) - profile_name = "default" - write_config({profile_name: existing_profile}) - - expected_default_profile = default_profile(*mock_default_model_configuration()) - assert ensure_config(name="default") == expected_default_profile - assert read_config() == {"default": expected_default_profile} - - -@patch("goose.cli.config.Confirm.ask", return_value=False) -def test_ensure_config_keep_original_default_profile( - mock_confirm, mock_profile_config_path, profile_factory, mock_default_model_configuration -): - existing_profile = profile_factory({"provider": "providerA"}) - profile_name = "default" - write_config({profile_name: existing_profile}) - - assert ensure_config(name="default") == existing_profile - - assert read_config() == {"default": existing_profile} - - -def test_session_path(mock_sessions_path): - assert session_path("session1") == mock_sessions_path / "session1.jsonl" diff --git a/tests/conftest.py b/tests/conftest.py index 975066579..486416745 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,11 @@ import json import os from time import time -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest from exchange import Exchange -from goose.profile import Profile +from goose.pluginbase.profile import Profile @pytest.fixture @@ -40,12 +40,6 @@ def _create_exchange(attributes={}): return _create_exchange -@pytest.fixture -def mock_sessions_path(tmp_path): - with patch("goose.cli.config.SESSIONS_PATH", tmp_path) as mock_path: - yield mock_path - - @pytest.fixture def create_session_file(): def _create_session_file(messages, session_file_path, mtime=time()): diff --git a/tests/utils/test_ask.py b/tests/pluginbase/utils/test_ask.py similarity index 93% rename from tests/utils/test_ask.py rename to tests/pluginbase/utils/test_ask.py index b7bd8269e..a1948bfe4 100644 --- a/tests/utils/test_ask.py +++ b/tests/pluginbase/utils/test_ask.py @@ -2,7 +2,7 @@ import pytest from exchange import Exchange, CheckpointData -from goose.utils.ask import ask_an_ai, clear_exchange, replace_prompt +from goose.pluginbase.utils.ask import ask_an_ai, clear_exchange, replace_prompt # tests for `ask_an_ai` @@ -16,7 +16,7 @@ def test_ask_an_ai_empty_input(): def test_ask_an_ai_no_history(): """Test the no_history functionality.""" exchange = MagicMock(spec=Exchange) - with patch("goose.utils.ask.clear_exchange") as mock_clear: + with patch("goose.pluginbase.utils.ask.clear_exchange") as mock_clear: ask_an_ai("Test input", exchange, no_history=True) mock_clear.assert_called_once_with(exchange) @@ -26,7 +26,7 @@ def test_ask_an_ai_prompt_replacement(): exchange = MagicMock(spec=Exchange) prompt = "New prompt" - with patch("goose.utils.ask.replace_prompt") as mock_replace_prompt: + with patch("goose.pluginbase.utils.ask.replace_prompt") as mock_replace_prompt: # Configure the mock to return a new mock object with the same spec modified_exchange = MagicMock(spec=Exchange) mock_replace_prompt.return_value = modified_exchange @@ -46,7 +46,7 @@ def test_ask_an_ai_exchange_usage(): input_text = "Test input" message_mock = MagicMock(return_value="Mocked Message") - with patch("goose.utils.ask.Message.user", new=message_mock): + with patch("goose.pluginbase.utils.ask.Message.user", new=message_mock): ask_an_ai(input_text, exchange, no_history=False) # Assert that Message.user was called with the correct input diff --git a/tests/utils/test_utils.py b/tests/pluginbase/utils/test_converter.py similarity index 64% rename from tests/utils/test_utils.py rename to tests/pluginbase/utils/test_converter.py index b7a9f992b..e4e594b45 100644 --- a/tests/utils/test_utils.py +++ b/tests/pluginbase/utils/test_converter.py @@ -1,7 +1,5 @@ -import string - import pytest -from goose.utils import droid, ensure, ensure_list, load_plugins +from goose.pluginbase.utils.converter import ensure, ensure_list class MockClass: @@ -11,13 +9,6 @@ def __init__(self, name): def __eq__(self, other): return self.name == other.name - -def test_load_plugins(): - plugins = load_plugins("exchange.provider") - assert isinstance(plugins, dict) - assert len(plugins) > 0 - - def test_ensure_with_class(): mock_class = MockClass("foo") assert ensure(MockClass)(mock_class) == mock_class @@ -52,12 +43,3 @@ def test_ensure_list(): obj_list = ensure_list(MockClass)(["foo", "bar"]) assert obj_list == [MockClass("foo"), MockClass("bar")] - -def test_droid(): - result = droid() - assert isinstance(result, str) - assert len(result) == 4 - for character in [result[i] for i in [0, 2]]: - assert character in string.ascii_lowercase, "should be in lower case" - for character in [result[i] for i in [1, 3]]: - assert character in string.digits, "should be a digit" diff --git a/tests/utils/test_session_file.py b/tests/pluginbase/utils/test_session_file.py similarity index 84% rename from tests/utils/test_session_file.py rename to tests/pluginbase/utils/test_session_file.py index d922bd81d..ed394d3f6 100644 --- a/tests/utils/test_session_file.py +++ b/tests/pluginbase/utils/test_session_file.py @@ -1,8 +1,10 @@ from pathlib import Path +from unittest.mock import patch import pytest from exchange import Message -from goose.utils.session_file import list_sorted_session_files, read_from_file, session_file_exists, write_to_file +from goose.pluginbase.utils.session_file import list_sorted_session_files, read_from_file +from goose.pluginbase.utils.session_file import session_file_exists, session_path, write_to_file @pytest.fixture @@ -75,3 +77,8 @@ def create_session_file(file_path, file_name) -> Path: file = file_path / f"{file_name}.jsonl" file.touch() return file + + +def test_session_path(tmp_path): + with patch("goose.pluginbase.utils.session_file.SESSIONS_PATH", tmp_path) as mock_session_path: + assert session_path("session1") == mock_session_path / "session1.jsonl" diff --git a/tests/utils/test_file_utils.py b/tests/utils/test_file_utils.py deleted file mode 100644 index e64739541..000000000 --- a/tests/utils/test_file_utils.py +++ /dev/null @@ -1,192 +0,0 @@ -from unittest.mock import patch - -import pytest -from goose.utils.file_utils import ( - create_extensions_list, - create_language_weighting, -) # Adjust the import path as necessary - - -# tests for `create_extensions_list` -def test_create_extensions_list_valid_input(): - """Test with valid input and multiple file extensions.""" - project_root = "/fake/project/root" - max_n = 3 - files = [ - "/fake/project/root/file1.py", - "/fake/project/root/file2.py", - "/fake/project/root/file3.md", - "/fake/project/root/file4.md", - "/fake/project/root/file5.txt", - "/fake/project/root/file6.py", - "/fake/project/root/file7.md", - ] - - with patch("goose.utils.file_utils.create_file_list", return_value=files): - extensions = create_extensions_list(project_root, max_n) - assert extensions == [".py", ".md", ".txt"], "Should return the top 3 extensions in the correct order" - - -def test_create_extensions_list_zero_max_n(): - """Test that a ValueError is raised when max_n is 0.""" - project_root = "/fake/project/root" - max_n = 0 - - with pytest.raises(ValueError, match="Number of file extensions must be greater than 0"): - create_extensions_list(project_root, max_n) - - -def test_create_extensions_list_no_files(): - """Test with a project root that contains no files.""" - project_root = "/fake/project/root" - max_n = 3 - - with patch("goose.utils.file_utils.create_file_list", return_value=[]): - extensions = create_extensions_list(project_root, max_n) - assert extensions == [], "Should return an empty list when no files are present" - - -def test_create_extensions_list_fewer_extensions_than_max_n(): - """Test when there are fewer unique extensions than max_n.""" - project_root = "/fake/project/root" - max_n = 5 - files = [ - "/fake/project/root/file1.py", - "/fake/project/root/file2.py", - "/fake/project/root/file3.md", - ] - - with patch("goose.utils.file_utils.create_file_list", return_value=files): - extensions = create_extensions_list(project_root, max_n) - assert extensions == [".py", ".md"], "Should return all available extensions when fewer than max_n" - - -def test_create_extensions_list_files_without_extensions(): - """Test that files without extensions are ignored.""" - project_root = "/fake/project/root" - max_n = 3 - files = [ - "/fake/project/root/file1", - "/fake/project/root/file2.py", - "/fake/project/root/file3", - "/fake/project/root/file4.md", - ] - - with patch("goose.utils.file_utils.create_file_list", return_value=files): - extensions = create_extensions_list(project_root, max_n) - assert extensions == [".py", ".md"], "Should ignore files without extensions" - - -# tests for `create_language_weighting` -def test_create_language_weighting_normal_case(): - """Test the function with multiple files and different sizes.""" - files = [ - "/fake/project/file1.py", - "/fake/project/file2.py", - "/fake/project/file3.md", - "/fake/project/file4.txt", - ] - - sizes = { - "/fake/project/file1.py": 100, - "/fake/project/file2.py": 200, - "/fake/project/file3.md": 50, - "/fake/project/file4.txt": 150, - } - - # Mocking os.path.getsize to return different sizes for different files - with patch("os.path.getsize") as mock_getsize: - mock_getsize.side_effect = lambda file: sizes[file] - - result = create_language_weighting(files) - - total = sum(sizes.values()) - - expected_result = { - ".py": 300 / total * 100, # 300 out of 600 total - ".txt": 150 / total * 100, # 150 out of 600 total - ".md": 50 / total * 100, # 50 out of 600 total - } - - # Check if the result matches the expected output - assert result[".py"] == pytest.approx(expected_result.get(".py"), 0.01) - assert result[".txt"] == pytest.approx(expected_result.get(".txt"), 0.01) - assert result[".md"] == pytest.approx(expected_result.get(".md"), 0.01) - - -def test_create_language_weighting_no_files(): - """Test the function when no files are provided.""" - files = [] - - result = create_language_weighting(files) - assert result == {}, "Should return an empty dictionary when no files are provided" - - -def test_create_language_weighting_files_without_extensions(): - """Test the function when files have no extensions.""" - files = [ - "/fake/project/file1", - "/fake/project/file2", - ] - - with patch("os.path.getsize", return_value=100): - result = create_language_weighting(files) - - assert result == {}, "Should return an empty dictionary when files have no extensions" - - -def test_create_language_weighting_zero_total_size(): - """Test the function when all files have a size of 0.""" - files = [ - "/fake/project/file1.py", - "/fake/project/file2.py", - ] - - with patch("os.path.getsize", return_value=0): - result = create_language_weighting(files) - - assert result == {".py": 0} - - -def test_create_language_weighting_single_file(): - """Test the function with a single file.""" - files = [ - "/fake/project/file1.py", - ] - - with patch("os.path.getsize", return_value=100): - result = create_language_weighting(files) - - assert result == {".py": 100.0}, "Should return 100% for the single file's extension" - - -def test_create_language_weighting_mixed_extensions(): - """Test the function with files of mixed extensions and sizes.""" - files = [ - "/fake/project/file1.py", - "/fake/project/file2.py", - "/fake/project/file3.md", - "/fake/project/file4.txt", - "/fake/project/file5.md", - ] - - with patch("os.path.getsize") as mock_getsize: - mock_getsize.side_effect = lambda file: { - "/fake/project/file1.py": 100, - "/fake/project/file2.py": 100, - "/fake/project/file3.md": 200, - "/fake/project/file4.txt": 300, - "/fake/project/file5.md": 100, - }[file] - - result = create_language_weighting(files) - - expected_result = { - ".txt": 37.5, # 300 out of 800 total - ".md": 37.5, # 300 out of 800 total - ".py": 25.0, # 200 out of 800 total - } - - assert result[".txt"] == pytest.approx(expected_result.get(".txt"), 0.01) - assert result[".md"] == pytest.approx(expected_result.get(".md"), 0.01) - assert result[".py"] == pytest.approx(expected_result.get(".py"), 0.01)