Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use custom exceptions instead of KeyboardInterrupts and exit() #45

Merged
merged 3 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions mentat/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from .code_file_manager import CodeFileManager
from .config_manager import ConfigManager, mentat_dir_path
from .conversation import Conversation
from .errors import MentatError, UserError
from .git_handler import get_shared_git_root_for_paths
from .llm_api import CostTracker, count_tokens, setup_api_key
from .logging_config import setup_logging
from .user_input_manager import UserInputManager
from .user_input_manager import UserInputManager, UserQuitInterrupt


def run_cli():
Expand Down Expand Up @@ -56,8 +57,15 @@ def run(paths: Iterable[str], exclude_paths: Optional[Iterable[str]] = None):
cost_tracker = CostTracker()
try:
loop(paths, exclude_paths, cost_tracker)
except (EOFError, KeyboardInterrupt) as e:
print(e)
except (
EOFError,
KeyboardInterrupt,
UserQuitInterrupt,
UserError,
MentatError,
) as e:
if str(e):
cprint(e, "light_yellow")
finally:
cost_tracker.display_total_cost()

Expand Down
2 changes: 1 addition & 1 deletion mentat/code_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pygments.lexers import TextLexer, get_lexer_for_filename
from pygments.util import ClassNotFound

from .model_error import ModelError
from .errors import ModelError


class CodeChangeAction(Enum):
Expand Down
11 changes: 4 additions & 7 deletions mentat/code_file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from .code_change import CodeChange, CodeChangeAction
from .config_manager import ConfigManager
from .errors import UserError
from .git_handler import (
get_git_diff_for_path,
get_non_gitignored_files,
Expand Down Expand Up @@ -76,11 +77,7 @@ def _abs_file_paths_from_list(paths: Iterable[str], check_for_text: bool = True)
if path.is_file():
if check_for_text and not _is_file_text_encoded(path):
logging.info(f"File path {path} is not text encoded.")
cprint(
f"Filepath {path} is not text encoded.",
"light_yellow",
)
raise KeyboardInterrupt
raise UserError(f"File path {path} is not text encoded.")
file_paths_direct.add(os.path.realpath(path))
elif path.is_dir():
nonignored_files = set(
Expand Down Expand Up @@ -137,7 +134,7 @@ def _set_file_paths(
cprint("The following paths do not exist:")
print("\n".join(invalid_paths))
print("Exiting...")
exit()
raise UserError()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could change the error to be the following paths don't exist instead of cprinting, and maybe get rid of the 'exiting...'? Since that should be standard (either always have it, or always don't have it).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out the warning for bad file paths got broken when we switched to globbing so I moved this chunk of code to expand_paths() in app.py.


excluded_files, excluded_files_from_dir = _abs_file_paths_from_list(
exclude_paths, check_for_text=False
Expand Down Expand Up @@ -249,7 +246,7 @@ def _get_new_code_lines(self, changes) -> Iterable[str]:
min_changed_line = largest_changed_line + 1
for i, change in enumerate(changes):
if change.last_changed_line >= min_changed_line:
raise ValueError(f"Change line number overlap in file {change.file}")
raise UserError(f"Change line number overlap in file {change.file}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a MentatError; the collision handling is supposed to handle this

min_changed_line = change.first_changed_line
new_code_lines = change.apply(new_code_lines)
return new_code_lines
Expand Down
10 changes: 10 additions & 0 deletions mentat/model_error.py → mentat/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,13 @@ class ModelError(Exception):
def __init__(self, message, already_added_to_changelist):
super().__init__(message)
self.already_added_to_changelist = already_added_to_changelist


# Used to indicate an issue with Mentat's code
class MentatError(Exception):
pass


# Used to indicate an issue with the user's usage of Mentat
class UserError(Exception):
pass
8 changes: 5 additions & 3 deletions mentat/git_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import subprocess
from pathlib import Path

from mentat.errors import UserError


def get_git_diff_for_path(git_root, path: str) -> str:
return subprocess.check_output(["git", "diff", path], cwd=git_root).decode("utf-8")
Expand Down Expand Up @@ -63,7 +65,7 @@ def _get_git_root_for_path(path) -> str:
return os.path.realpath(git_root)
except subprocess.CalledProcessError:
logging.error(f"File {path} isn't part of a git project.")
exit()
raise UserError()


def get_shared_git_root_for_paths(paths) -> str:
Expand All @@ -80,9 +82,9 @@ def get_shared_git_root_for_paths(paths) -> str:
"All paths must be part of the same git project! Projects provided:"
f" {git_roots}"
)
exit()
raise UserError()
elif len(git_roots) == 0:
logging.error("No git projects provided.")
exit()
raise UserError()

return git_roots.pop()
19 changes: 8 additions & 11 deletions mentat/llm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
from termcolor import cprint

from .config_manager import mentat_dir_path, user_config_path
from .errors import MentatError, UserError

package_name = __name__.split(".")[0]


# Check for .env file or already exported API key
# If no api key found, exit and warn user
# If no api key found, raise an error
def setup_api_key():
if not load_dotenv(os.path.join(mentat_dir_path, ".env")):
load_dotenv()
Expand All @@ -24,12 +25,10 @@ def setup_api_key():
openai.api_key = key
openai.Model.list() # Test the API key
except openai.error.AuthenticationError:
cprint(
raise UserError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We call this function before the try except in app.py; we should move that call into loop() so this error gets properly caught

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

"No valid OpenAI api key detected.\nEither place your key into a .env"
" file or export it as an environment variable.",
"red",
" file or export it as an environment variable."
)
sys.exit(0)


async def call_llm_api(messages: list[dict[str, str]], model) -> Generator:
Expand All @@ -38,8 +37,8 @@ async def call_llm_api(messages: list[dict[str, str]], model) -> Generator:
and "--benchmark" not in sys.argv
and os.getenv("MENTAT_BENCHMARKS_RUNNING") == "false"
):
logging.critical("OpenAI call made in non benchmark test environment!")
sys.exit(1)
logging.critical("OpenAI call attempted in non benchmark test environment!")
raise MentatError("OpenAI call attempted in non benchmark test environment!")

response = await openai.ChatCompletion.acreate(
model=model,
Expand Down Expand Up @@ -73,12 +72,10 @@ def check_model_availability(allow_32k: bool) -> bool:
if not allow_32k:
# check if user has access to gpt-4
if "gpt-4-0314" not in available_models:
cprint(
raise UserError(
"Sorry, but your OpenAI API key doesn't have access to gpt-4-0314,"
" which is currently required to run Mentat.",
"red",
" which is currently required to run Mentat."
)
raise KeyboardInterrupt

return allow_32k

Expand Down
4 changes: 2 additions & 2 deletions mentat/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
get_removed_block,
)
from .code_file_manager import CodeFileManager
from .errors import MentatError, ModelError
from .llm_api import call_llm_api
from .model_error import ModelError
from .streaming_printer import StreamingPrinter


Expand Down Expand Up @@ -163,7 +163,7 @@ def run_async_stream_and_parse_llm_response(
)
except (openai.error.InvalidRequestError, openai.error.RateLimitError) as e:
cprint(e, "red")
exit()
MentatError("Something went wrong - invalid request to OpenAI API.")
except KeyboardInterrupt:
print("\n\nInterrupted by user. Using the response up to this point.")
# if the last change is incomplete, remove it
Expand Down
6 changes: 5 additions & 1 deletion mentat/user_input_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def append_string(self, string):
super().append_string(string)


class UserQuitInterrupt(Exception):
pass


class UserInputManager:
def __init__(self, config: ConfigManager):
self.config = config
Expand Down Expand Up @@ -77,7 +81,7 @@ def collect_user_input(self) -> str:
user_input = self.session.prompt().strip()
logging.debug(f"User input:\n{user_input}")
if user_input.lower() == "q":
raise KeyboardInterrupt("User used 'q' to quit")
raise UserQuitInterrupt()
return user_input

def ask_yes_no(self, default_yes: bool) -> bool:
Expand Down
5 changes: 3 additions & 2 deletions tests/code_file_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mentat.app import expand_paths
from mentat.code_file_manager import CodeFileManager
from mentat.config_manager import ConfigManager
from mentat.errors import UserError


def test_path_gitignoring(temp_testbed, mock_config):
Expand Down Expand Up @@ -165,7 +166,7 @@ def test_text_encoding_checking(temp_testbed, mock_config):
)
assert os.path.join(temp_testbed, nontext_path) not in code_file_manager.file_paths

with pytest.raises(KeyboardInterrupt) as e_info:
with pytest.raises(UserError) as e_info:
nontext_path_requested = "iamalsonottext.py"
with open(nontext_path_requested, "wb") as f:
# 0x81 is invalid in UTF-8 (single byte > 127), and undefined in cp1252 and iso-8859-1
Expand All @@ -179,7 +180,7 @@ def test_text_encoding_checking(temp_testbed, mock_config):
config=mock_config,
git_root=temp_testbed,
)
assert e_info.type == KeyboardInterrupt
assert e_info.type == UserError


# Make sure we always give posix paths to GPT
Expand Down
5 changes: 3 additions & 2 deletions tests/git_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from mentat.errors import UserError
from mentat.git_handler import get_shared_git_root_for_paths


Expand All @@ -20,9 +21,9 @@ def test_paths_given(temp_testbed):

def test_two_git_roots_given():
# Exits when given 2 paths with separate git roots
with pytest.raises(SystemExit) as e_info:
with pytest.raises(UserError) as e_info:
os.makedirs("git_testing_dir")
subprocess.run(["git", "init"], cwd="git_testing_dir")

_ = get_shared_git_root_for_paths(["./", "git_testing_dir"])
assert e_info.type == SystemExit
assert e_info.type == UserError