diff --git a/mentat/app.py b/mentat/app.py index 2401505e2..1553efeae 100644 --- a/mentat/app.py +++ b/mentat/app.py @@ -11,10 +11,11 @@ from .code_file_manager import CodeFileManager from .config_manager import ConfigManager, mentat_dir_path from .conversation import Conversation +from .errors import MentatError, UserError from .git_handler import get_shared_git_root_for_paths from .llm_api import CostTracker, count_tokens, setup_api_key from .logging_config import setup_logging -from .user_input_manager import UserInputManager +from .user_input_manager import UserInputManager, UserQuitInterrupt def run_cli(): @@ -42,22 +43,41 @@ def run_cli(): def expand_paths(paths: Iterable[str]) -> Iterable[str]: globbed_paths = set() + invalid_paths = [] for path in paths: - globbed_paths.update(glob.glob(pathname=path, recursive=True)) + new_paths = glob.glob(pathname=path, recursive=True) + if new_paths: + globbed_paths.update(new_paths) + else: + invalid_paths.append(path) + if invalid_paths: + cprint( + "The following paths do not exist:", + "light_yellow", + ) + print("\n".join(invalid_paths)) + exit() return globbed_paths def run(paths: Iterable[str], exclude_paths: Optional[Iterable[str]] = None): os.makedirs(mentat_dir_path, exist_ok=True) setup_logging() - setup_api_key() logging.debug(f"Paths: {paths}") cost_tracker = CostTracker() try: + setup_api_key() loop(paths, exclude_paths, cost_tracker) - except (EOFError, KeyboardInterrupt) as e: - print(e) + except ( + EOFError, + KeyboardInterrupt, + UserQuitInterrupt, + UserError, + MentatError, + ) as e: + if str(e): + cprint(e, "light_yellow") finally: cost_tracker.display_total_cost() diff --git a/mentat/code_change.py b/mentat/code_change.py index 50916ec80..61b614570 100644 --- a/mentat/code_change.py +++ b/mentat/code_change.py @@ -5,7 +5,7 @@ from pygments.lexers import TextLexer, get_lexer_for_filename from pygments.util import ClassNotFound -from .model_error import ModelError +from .errors import ModelError class CodeChangeAction(Enum): diff --git a/mentat/code_file_manager.py b/mentat/code_file_manager.py index fd606cb58..1c41400b5 100644 --- a/mentat/code_file_manager.py +++ b/mentat/code_file_manager.py @@ -14,6 +14,7 @@ ) from .code_change import CodeChange, CodeChangeAction from .config_manager import ConfigManager +from .errors import MentatError, UserError from .git_handler import ( get_git_diff_for_path, get_non_gitignored_files, @@ -76,11 +77,7 @@ def _abs_file_paths_from_list(paths: Iterable[str], check_for_text: bool = True) if path.is_file(): if check_for_text and not _is_file_text_encoded(path): logging.info(f"File path {path} is not text encoded.") - cprint( - f"Filepath {path} is not text encoded.", - "light_yellow", - ) - raise KeyboardInterrupt + raise UserError(f"File path {path} is not text encoded.") file_paths_direct.add(os.path.realpath(path)) elif path.is_dir(): nonignored_files = set( @@ -128,17 +125,6 @@ def __init__( def _set_file_paths( self, paths: Iterable[str], exclude_paths: Iterable[str] ) -> None: - invalid_paths = [] - for path in paths: - if not os.path.exists(path): - invalid_paths.append(path) - if invalid_paths: - cprint("Error:", "red", end=" ") - cprint("The following paths do not exist:") - print("\n".join(invalid_paths)) - print("Exiting...") - exit() - excluded_files, excluded_files_from_dir = _abs_file_paths_from_list( exclude_paths, check_for_text=False ) @@ -249,7 +235,7 @@ def _get_new_code_lines(self, changes) -> Iterable[str]: min_changed_line = largest_changed_line + 1 for i, change in enumerate(changes): if change.last_changed_line >= min_changed_line: - raise ValueError(f"Change line number overlap in file {change.file}") + raise MentatError(f"Change line number overlap in file {change.file}") min_changed_line = change.first_changed_line new_code_lines = change.apply(new_code_lines) return new_code_lines diff --git a/mentat/model_error.py b/mentat/errors.py similarity index 59% rename from mentat/model_error.py rename to mentat/errors.py index 73c8d049f..247d92177 100644 --- a/mentat/model_error.py +++ b/mentat/errors.py @@ -3,3 +3,13 @@ class ModelError(Exception): def __init__(self, message, already_added_to_changelist): super().__init__(message) self.already_added_to_changelist = already_added_to_changelist + + +# Used to indicate an issue with Mentat's code +class MentatError(Exception): + pass + + +# Used to indicate an issue with the user's usage of Mentat +class UserError(Exception): + pass diff --git a/mentat/git_handler.py b/mentat/git_handler.py index 54bdb6503..02a91f404 100644 --- a/mentat/git_handler.py +++ b/mentat/git_handler.py @@ -3,6 +3,8 @@ import subprocess from pathlib import Path +from mentat.errors import UserError + def get_git_diff_for_path(git_root, path: str) -> str: return subprocess.check_output(["git", "diff", path], cwd=git_root).decode("utf-8") @@ -63,7 +65,7 @@ def _get_git_root_for_path(path) -> str: return os.path.realpath(git_root) except subprocess.CalledProcessError: logging.error(f"File {path} isn't part of a git project.") - exit() + raise UserError() def get_shared_git_root_for_paths(paths) -> str: @@ -80,9 +82,9 @@ def get_shared_git_root_for_paths(paths) -> str: "All paths must be part of the same git project! Projects provided:" f" {git_roots}" ) - exit() + raise UserError() elif len(git_roots) == 0: logging.error("No git projects provided.") - exit() + raise UserError() return git_roots.pop() diff --git a/mentat/llm_api.py b/mentat/llm_api.py index af8aee7b1..618054e41 100644 --- a/mentat/llm_api.py +++ b/mentat/llm_api.py @@ -10,12 +10,13 @@ from termcolor import cprint from .config_manager import mentat_dir_path, user_config_path +from .errors import MentatError, UserError package_name = __name__.split(".")[0] # Check for .env file or already exported API key -# If no api key found, exit and warn user +# If no api key found, raise an error def setup_api_key(): if not load_dotenv(os.path.join(mentat_dir_path, ".env")): load_dotenv() @@ -24,12 +25,10 @@ def setup_api_key(): openai.api_key = key openai.Model.list() # Test the API key except openai.error.AuthenticationError: - cprint( + raise UserError( "No valid OpenAI api key detected.\nEither place your key into a .env" - " file or export it as an environment variable.", - "red", + " file or export it as an environment variable." ) - sys.exit(0) async def call_llm_api(messages: list[dict[str, str]], model) -> Generator: @@ -38,8 +37,8 @@ async def call_llm_api(messages: list[dict[str, str]], model) -> Generator: and "--benchmark" not in sys.argv and os.getenv("MENTAT_BENCHMARKS_RUNNING") == "false" ): - logging.critical("OpenAI call made in non benchmark test environment!") - sys.exit(1) + logging.critical("OpenAI call attempted in non benchmark test environment!") + raise MentatError("OpenAI call attempted in non benchmark test environment!") response = await openai.ChatCompletion.acreate( model=model, @@ -73,12 +72,10 @@ def check_model_availability(allow_32k: bool) -> bool: if not allow_32k: # check if user has access to gpt-4 if "gpt-4-0314" not in available_models: - cprint( + raise UserError( "Sorry, but your OpenAI API key doesn't have access to gpt-4-0314," - " which is currently required to run Mentat.", - "red", + " which is currently required to run Mentat." ) - raise KeyboardInterrupt return allow_32k diff --git a/mentat/parsing.py b/mentat/parsing.py index 206175821..35985612c 100644 --- a/mentat/parsing.py +++ b/mentat/parsing.py @@ -19,8 +19,8 @@ get_removed_block, ) from .code_file_manager import CodeFileManager +from .errors import MentatError, ModelError from .llm_api import call_llm_api -from .model_error import ModelError from .streaming_printer import StreamingPrinter @@ -163,7 +163,7 @@ def run_async_stream_and_parse_llm_response( ) except (openai.error.InvalidRequestError, openai.error.RateLimitError) as e: cprint(e, "red") - exit() + MentatError("Something went wrong - invalid request to OpenAI API.") except KeyboardInterrupt: print("\n\nInterrupted by user. Using the response up to this point.") # if the last change is incomplete, remove it diff --git a/mentat/user_input_manager.py b/mentat/user_input_manager.py index e232d2922..24debb341 100644 --- a/mentat/user_input_manager.py +++ b/mentat/user_input_manager.py @@ -23,6 +23,10 @@ def append_string(self, string): super().append_string(string) +class UserQuitInterrupt(Exception): + pass + + class UserInputManager: def __init__(self, config: ConfigManager): self.config = config @@ -77,7 +81,7 @@ def collect_user_input(self) -> str: user_input = self.session.prompt().strip() logging.debug(f"User input:\n{user_input}") if user_input.lower() == "q": - raise KeyboardInterrupt("User used 'q' to quit") + raise UserQuitInterrupt() return user_input def ask_yes_no(self, default_yes: bool) -> bool: diff --git a/tests/code_file_manager_test.py b/tests/code_file_manager_test.py index 89d82937e..4669bcc30 100644 --- a/tests/code_file_manager_test.py +++ b/tests/code_file_manager_test.py @@ -6,6 +6,7 @@ from mentat.app import expand_paths from mentat.code_file_manager import CodeFileManager from mentat.config_manager import ConfigManager +from mentat.errors import UserError def test_path_gitignoring(temp_testbed, mock_config): @@ -165,7 +166,7 @@ def test_text_encoding_checking(temp_testbed, mock_config): ) assert os.path.join(temp_testbed, nontext_path) not in code_file_manager.file_paths - with pytest.raises(KeyboardInterrupt) as e_info: + with pytest.raises(UserError) as e_info: nontext_path_requested = "iamalsonottext.py" with open(nontext_path_requested, "wb") as f: # 0x81 is invalid in UTF-8 (single byte > 127), and undefined in cp1252 and iso-8859-1 @@ -179,7 +180,7 @@ def test_text_encoding_checking(temp_testbed, mock_config): config=mock_config, git_root=temp_testbed, ) - assert e_info.type == KeyboardInterrupt + assert e_info.type == UserError # Make sure we always give posix paths to GPT diff --git a/tests/git_handler_test.py b/tests/git_handler_test.py index 5f467f8fb..182395898 100644 --- a/tests/git_handler_test.py +++ b/tests/git_handler_test.py @@ -3,6 +3,7 @@ import pytest +from mentat.errors import UserError from mentat.git_handler import get_shared_git_root_for_paths @@ -20,9 +21,9 @@ def test_paths_given(temp_testbed): def test_two_git_roots_given(): # Exits when given 2 paths with separate git roots - with pytest.raises(SystemExit) as e_info: + with pytest.raises(UserError) as e_info: os.makedirs("git_testing_dir") subprocess.run(["git", "init"], cwd="git_testing_dir") _ = get_shared_git_root_for_paths(["./", "git_testing_dir"]) - assert e_info.type == SystemExit + assert e_info.type == UserError