diff --git a/mentat/app.py b/mentat/app.py index 453bae27c..34accbf06 100644 --- a/mentat/app.py +++ b/mentat/app.py @@ -6,8 +6,9 @@ from termcolor import cprint -from .code_change import CodeChange, CodeChangeAction -from .code_change_display import print_change +from mentat.parsers.block_parser import BlockParser +from mentat.parsers.file_edit import FileEdit + from .code_context import CodeContext from .code_file import parse_intervals from .code_file_manager import CodeFileManager @@ -140,6 +141,8 @@ def loop( pr_diff: Optional[str], ) -> None: git_root = get_shared_git_root_for_paths([Path(path) for path in paths]) + # The parser can be selected here + parser = BlockParser() config = ConfigManager(git_root) code_context = CodeContext( config, paths, exclude_paths or [], diff, pr_diff, no_code_map @@ -147,7 +150,7 @@ def loop( code_context.display_context() user_input_manager = UserInputManager(config, code_context) code_file_manager = CodeFileManager(user_input_manager, config, code_context) - conv = Conversation(config, cost_tracker, code_file_manager) + conv = Conversation(parser, config, cost_tracker, code_file_manager) cprint("Type 'q' or use Ctrl-C to quit at any time.\n", color="cyan") cprint("What can I do for you?", color="light_blue") @@ -157,22 +160,26 @@ def loop( user_response = user_input_manager.collect_user_input() conv.add_user_message(user_response) - _, code_changes = conv.get_model_response() - - if code_changes: - need_user_request = get_user_feedback_on_changes( - config, conv, user_input_manager, code_file_manager, code_changes + file_edits = conv.get_model_response(parser, config) + file_edits = [ + file_edit + for file_edit in file_edits + if file_edit.is_valid(code_file_manager, config) + ] + if file_edits: + need_user_request = get_user_feedback_on_edits( + config, conv, user_input_manager, code_file_manager, file_edits ) else: need_user_request = True -def get_user_feedback_on_changes( +def get_user_feedback_on_edits( config: ConfigManager, conv: Conversation, user_input_manager: UserInputManager, code_file_manager: CodeFileManager, - code_changes: list[CodeChange], + file_edits: list[FileEdit], ) -> bool: cprint( "Apply these changes? 'Y/n/i' or provide feedback.", @@ -183,36 +190,35 @@ def get_user_feedback_on_changes( need_user_request = True match user_response.lower(): case "y" | "": - code_changes_to_apply = code_changes + edits_to_apply = file_edits conv.add_user_message("User chose to apply all your changes.") case "n": - code_changes_to_apply = [] + edits_to_apply = [] conv.add_user_message("User chose not to apply any of your changes.") case "i": - code_changes_to_apply, indices = user_filter_changes( - user_input_manager, code_changes + edits_to_apply = user_filter_changes( + code_file_manager, user_input_manager, config, file_edits ) conv.add_user_message( "User chose to apply" - f" {len(code_changes_to_apply)}/{len(code_changes)} of your suggest" - " changes. The changes they applied were:" - f" {', '.join(map(str, indices))}" + f" {len(edits_to_apply)}/{len(file_edits)} of your suggested" + " changes." ) case _: need_user_request = False - code_changes_to_apply = [] + edits_to_apply = [] conv.add_user_message( "User chose not to apply any of your changes. User response:" f" {user_response}\n\nPlease adjust your previous plan and changes to" " reflect this. Respond with a full new set of changes." ) - if code_changes_to_apply: - code_file_manager.write_changes_to_files(code_changes_to_apply) - if len(code_changes_to_apply) == len(code_changes): - cprint("Changes applied.", color="light_blue") - else: - cprint("Selected changes applied.", color="light_blue") + for file_edit in edits_to_apply: + file_edit.resolve_conflicts(user_input_manager) + + if edits_to_apply: + code_file_manager.write_changes_to_files(edits_to_apply) + cprint("Changes applied.", color="light_blue") else: cprint("No changes applied.", color="light_blue") @@ -223,22 +229,14 @@ def get_user_feedback_on_changes( def user_filter_changes( - user_input_manager: UserInputManager, code_changes: list[CodeChange] -) -> tuple[list[CodeChange], list[int]]: - new_changes = list[CodeChange]() - indices = list[int]() - for index, change in enumerate(code_changes, start=1): - print_change(change) - # Allowing the user to remove rename file changes introduces a lot of edge cases - if change.action == CodeChangeAction.RenameFile: - new_changes.append(change) - indices.append(index) - cprint("Cannot remove rename file change", "light_yellow") - continue - - cprint("Keep this change?", "light_blue") - if user_input_manager.ask_yes_no(default_yes=True): - new_changes.append(change) - indices.append(index) - - return new_changes, indices + code_file_manager: CodeFileManager, + user_input_manager: UserInputManager, + config: ConfigManager, + file_edits: list[FileEdit], +) -> list[FileEdit]: + new_edits = list[FileEdit]() + for file_edit in file_edits: + if file_edit.filter_replacements(code_file_manager, user_input_manager, config): + new_edits.append(file_edit) + + return new_edits diff --git a/mentat/change_conflict_resolution.py b/mentat/change_conflict_resolution.py deleted file mode 100644 index 42c4c34df..000000000 --- a/mentat/change_conflict_resolution.py +++ /dev/null @@ -1,110 +0,0 @@ -from __future__ import annotations - -import logging -import string -from typing import TYPE_CHECKING - -from termcolor import cprint - -from .code_change import CodeChange, CodeChangeAction -from .code_change_display import get_added_block, get_removed_block -from .user_input_manager import UserInputManager - -if TYPE_CHECKING: - # This normally will cause a circular import - from .code_file_manager import CodeFileManager - - -def resolve_insertion_conflicts( - changes: list[CodeChange], - user_input_manager: UserInputManager, - code_file_manager: CodeFileManager, -) -> list[CodeChange]: - """merges insertion conflicts into one singular code change""" - insert_changes = list( - filter( - lambda change: change.action == CodeChangeAction.Insert, - sorted(changes, reverse=True), - ) - ) - new_insert_changes = list[CodeChange]() - cur = 0 - while cur < len(insert_changes): - end = cur + 1 - while ( - end < len(insert_changes) - and insert_changes[end].first_changed_line - == insert_changes[cur].first_changed_line - ): - end += 1 - if end > cur + 1: - logging.debug("insertion conflict") - cprint("Insertion conflict:", "red") - for i in range(end - cur): - cprint(f"({string.printable[i]})", "green") - cprint("\n".join(insert_changes[cur + i].code_lines), "light_cyan") - cprint( - "Type the order in which to insert changes (omit for no preference):" - ) - user_input = user_input_manager.collect_user_input() - new_code_lines = list[str]() - used = set[int]() - for c in user_input: - index = string.printable.index(c) if c in string.printable else -1 - if index < end - cur and index != -1: - new_code_lines += insert_changes[cur + index].code_lines - used.add(index) - for i in range(end - cur): - if i not in used: - new_code_lines += insert_changes[cur + i].code_lines - new_change = CodeChange( - insert_changes[cur].json_data, - new_code_lines, - code_file_manager, - ) - new_insert_changes.append(new_change) - else: - new_insert_changes.append(insert_changes[cur]) - cur = end - return sorted( - list(filter(lambda change: change.action != CodeChangeAction.Insert, changes)) - + new_insert_changes, - reverse=True, - ) - - -def resolve_non_insertion_conflicts( - changes: list[CodeChange], user_input_manager: UserInputManager -) -> list[CodeChange]: - """resolves delete-replace conflicts and asks user on delete-insert or replace-insert conflicts""" - min_changed_line = changes[0].last_changed_line + 1 - removed_changes = set[int]() - for i, change in enumerate(changes): - if change.last_changed_line >= min_changed_line: - if change.action == CodeChangeAction.Insert: - logging.debug("insertion inside removed block") - if changes[i - 1].action == CodeChangeAction.Delete: - keep = True - else: - cprint( - "\nInsertion conflict: Lines inserted inside replaced block\n", - "light_red", - ) - print(get_removed_block(changes[i - 1])) - print(get_added_block(change, prefix=">", color=None)) - print(get_added_block(changes[i - 1])) - cprint("Keep this insertion?") - keep = user_input_manager.ask_yes_no(default_yes=True) - if keep: - change.first_changed_line = changes[i - 1].first_changed_line - 0.5 - change.last_changed_line = change.first_changed_line - else: - removed_changes.add(i) - - else: - change.last_changed_line = min_changed_line - 1 - change.first_changed_line = min( - change.first_changed_line, changes[i - 1].first_changed_line - ) - min_changed_line = change.first_changed_line - return [change for i, change in enumerate(changes) if i not in removed_changes] diff --git a/mentat/code_change.py b/mentat/code_change.py deleted file mode 100644 index 840e774ea..000000000 --- a/mentat/code_change.py +++ /dev/null @@ -1,187 +0,0 @@ -from __future__ import annotations - -from enum import Enum -from pathlib import Path -from typing import TYPE_CHECKING, Any - -from pygments.lexer import Lexer -from pygments.lexers import TextLexer, get_lexer_for_filename -from pygments.util import ClassNotFound - -from .errors import ModelError - -if TYPE_CHECKING: - # This normally will cause a circular import - from .code_file_manager import CodeFileManager - - -class CodeChangeAction(Enum): - Insert = "insert" - Replace = "replace" - Delete = "delete" - CreateFile = "create-file" - DeleteFile = "delete-file" - RenameFile = "rename-file" - - def has_surrounding_lines(self): - return ( - self != CodeChangeAction.CreateFile - and self != CodeChangeAction.DeleteFile - and self != CodeChangeAction.RenameFile - ) - - def has_removals(self): - return ( - self == CodeChangeAction.Delete - or self == CodeChangeAction.Replace - or self == CodeChangeAction.DeleteFile - ) - - def has_additions(self): - return ( - self == CodeChangeAction.Insert - or self == CodeChangeAction.Replace - or self == CodeChangeAction.CreateFile - ) - - -class CodeChange: - def __init__( - self, - json_data: dict[Any, Any], - code_lines: list[str], - code_file_manager: CodeFileManager, - rename_map: dict[Path, Path] = {}, - ): - self.json_data = json_data - # Sometimes GPT puts quotes around numbers, so we have to convert those - for json_key in [ - "insert-before-line", - "insert-after-line", - "start-line", - "end-line", - ]: - if json_key in self.json_data: - self.json_data[json_key] = int(self.json_data[json_key]) - self.code_lines = code_lines - self.file = Path(self.json_data["file"]) - # This is not ideal; however, it would be a lot of work to fix, - # and this class is going to completely change soon anyways - self.first_changed_line: int | float = None # type: ignore - self.last_changed_line: int | float = None # type: ignore - self.error = "" - try: - self.lexer: Lexer = get_lexer_for_filename(self.file) - self.lexer.stripnl = False - self.lexer.stripall = False - self.lexer.ensurenl = False - except ClassNotFound: - self.lexer = TextLexer() - - try: - self.action = CodeChangeAction(self.json_data["action"]) - except ValueError: - raise ModelError( - f"Model created change with unknown action {self.json_data['action']}", - already_added_to_changelist=False, - ) - - try: - match self.action: - case CodeChangeAction.Insert: - if "insert-before-line" in self.json_data: - self.first_changed_line = self.json_data["insert-before-line"] - if "insert-after-line" in self.json_data: - if ( - self.first_changed_line - 1 - != self.json_data["insert-after-line"] - ): - self.error = "Insert line numbers invalid" - elif "insert-after-line" in self.json_data: - self.first_changed_line = ( - self.json_data["insert-after-line"] + 1 - ) - else: - self.first_changed_line = 0 - self.error = "Insert line number not specified" - self.first_changed_line -= 0.5 - self.last_changed_line = self.first_changed_line - - case CodeChangeAction.Replace: - self.first_changed_line = self.json_data["start-line"] - self.last_changed_line = self.json_data["end-line"] - - case CodeChangeAction.Delete: - self.first_changed_line = self.json_data["start-line"] - self.last_changed_line = self.json_data["end-line"] - - case CodeChangeAction.RenameFile: - self.name = Path(self.json_data["name"]) - - case _: - pass - - except KeyError: - self.error = "Line numbers not given" - - if ( - self.first_changed_line - and self.last_changed_line - and self.first_changed_line > self.last_changed_line - ): - self.error = "Starting line of change is greater than ending line of change" - - if self.action == CodeChangeAction.CreateFile: - if self.file.exists(): - self.error = ( - f"Model attempted to create file that already exists: {self.file}" - ) - self.file_lines = [] - self.line_number_buffer = 2 - else: - if self.action == CodeChangeAction.RenameFile and self.name.exists(): - self.error = ( - f"Model attempted to rename file {self.file} to a file that" - f" already exists: {self.name}" - ) - - # This rename_map is a bit hacky; it shouldn't be used outside of streaming/parsing - rel_path = ( - self.file if self.file not in rename_map else rename_map[self.file] - ) - - try: - self.file_lines = code_file_manager.file_lines[rel_path] - self.line_number_buffer = len(str(len(self.file_lines) + 1)) + 1 - except KeyError: - self.error = ( - f"Model attempted to edit {rel_path}, which isn't in current" - " context or doesn't exist" - ) - - def __lt__(self, other: CodeChange) -> bool: - return self.last_changed_line < other.last_changed_line - - def apply(self, cur_file_lines: list[str]) -> list[str]: - match self.action: - case CodeChangeAction.Insert: - previous_lines = cur_file_lines[: int(self.first_changed_line)] - following_lines = cur_file_lines[int(self.first_changed_line) :] - new_file_lines = previous_lines + self.code_lines + following_lines - - case CodeChangeAction.Replace: - previous_lines = cur_file_lines[: self.first_changed_line - 1] - following_lines = cur_file_lines[self.last_changed_line :] - new_file_lines = previous_lines + self.code_lines + following_lines - - case CodeChangeAction.Delete: - previous_lines = cur_file_lines[: self.first_changed_line - 1] - following_lines = cur_file_lines[self.last_changed_line :] - new_file_lines = previous_lines + following_lines - - case CodeChangeAction.CreateFile | CodeChangeAction.DeleteFile | CodeChangeAction.RenameFile: - raise Exception( - f"CodeChange with action={self.action} shouldn't have apply called" - ) - - return new_file_lines diff --git a/mentat/code_change_display.py b/mentat/code_change_display.py deleted file mode 100644 index 27e82993f..000000000 --- a/mentat/code_change_display.py +++ /dev/null @@ -1,161 +0,0 @@ -import math - -from pygments import highlight # pyright: ignore[reportUnknownVariableType] -from pygments.formatters import TerminalFormatter -from termcolor import colored - -from .code_change import CodeChange, CodeChangeAction - -change_delimiter = 60 * "=" - - -def _remove_extra_empty_lines(lines: list[str]) -> list[str]: - if not lines: - return [] - - # Find the first non-empty line - start = 0 - while start < len(lines) and not lines[start].strip(): - start += 1 - - # Find the last non-empty line - end = len(lines) - 1 - while end > start and not lines[end].strip(): - end -= 1 - - # If all lines are empty, keep only one empty line - if start == len(lines): - return [" "] - - # Return the list with only a maximum of one empty line on either side - return lines[max(start - 1, 0) : end + 2] - - -def _prefixed_lines(code_change: CodeChange, lines: list[str], prefix: str): - return "\n".join( - [ - prefix - + " " * (code_change.line_number_buffer - len(prefix)) - + line.strip("\n") - for line in lines - ] - ) - - -def print_change(code_change: CodeChange): - to_print = [ - get_file_name(code_change), - change_delimiter if code_change.action != CodeChangeAction.RenameFile else "", - get_previous_lines(code_change), - get_removed_block(code_change), - get_added_block(code_change), - get_later_lines(code_change), - change_delimiter if code_change.action != CodeChangeAction.RenameFile else "", - ] - for s in to_print: - if s: - print(s) - - -def get_file_name(code_change: CodeChange): - file_name = code_change.file - match code_change.action: - case CodeChangeAction.CreateFile: - return colored(f"\n{file_name}*", color="light_green") - case CodeChangeAction.DeleteFile: - return colored(f"\n{file_name}", color="light_red") - case CodeChangeAction.RenameFile: - return colored( - f"\nRename: {file_name} -> {code_change.name}", color="yellow" - ) - case _: - return colored(f"\n{file_name}", color="light_blue") - - -def get_removed_block( - code_change: CodeChange, prefix: str = "-", color: str | None = "red" -): - if code_change.action.has_removals(): - if code_change.action == CodeChangeAction.DeleteFile: - changed_lines = code_change.file_lines - else: - changed_lines = code_change.file_lines[ - code_change.first_changed_line - 1 : code_change.last_changed_line - ] - - removed = _prefixed_lines(code_change, changed_lines, prefix) - if removed: - return colored(removed, color=color) - return "" - - -def get_added_block( - code_change: CodeChange, prefix: str = "+", color: str | None = "green" -): - if code_change.action.has_additions(): - added = _prefixed_lines(code_change, code_change.code_lines, prefix) - if added: - return colored(added, color=color) - return "" - - -def get_previous_lines(code_change: CodeChange, num: int = 2): - if not code_change.action.has_surrounding_lines(): - return "" - lines = _remove_extra_empty_lines( - [ - code_change.file_lines[i] - for i in range( - max(0, math.ceil(code_change.first_changed_line) - (num + 1)), - min( - math.ceil(code_change.first_changed_line) - 1, - len(code_change.file_lines), - ), - ) - ] - ) - numbered = [ - (str(math.ceil(code_change.first_changed_line) - len(lines) + i) + ":").ljust( - code_change.line_number_buffer - ) - + line - for i, line in enumerate(lines) - ] - - prev = "\n".join(numbered) - if prev: - # pygments doesn't have type hints on TerminalFormatter - h_prev: str = highlight(prev, code_change.lexer, TerminalFormatter(bg="dark")) # type: ignore - return h_prev - return "" - - -def get_later_lines(code_change: CodeChange, num: int = 2): - if not code_change.action.has_surrounding_lines(): - return "" - lines = _remove_extra_empty_lines( - [ - code_change.file_lines[i] - for i in range( - max(0, int(code_change.last_changed_line)), - min( - int(code_change.last_changed_line) + num, - len(code_change.file_lines), - ), - ) - ] - ) - numbered = [ - (str(int(code_change.last_changed_line) + 1 + i) + ":").ljust( - code_change.line_number_buffer - ) - + line - for i, line in enumerate(lines) - ] - - later = "\n".join(numbered) - if later: - # pygments doesn't have type hints on TerminalFormatter - h_later: str = highlight(later, code_change.lexer, TerminalFormatter(bg="dark")) # type: ignore - return h_later - return "" diff --git a/mentat/code_file_manager.py b/mentat/code_file_manager.py index 6eddf87d9..ca4c74fcf 100644 --- a/mentat/code_file_manager.py +++ b/mentat/code_file_manager.py @@ -1,19 +1,13 @@ import logging -import math import os -from collections import defaultdict from pathlib import Path from typing import Union from termcolor import cprint from mentat.llm_api import count_tokens, model_context_size +from mentat.parsers.file_edit import FileEdit -from .change_conflict_resolution import ( - resolve_insertion_conflicts, - resolve_non_insertion_conflicts, -) -from .code_change import CodeChange, CodeChangeAction from .code_context import CodeContext from .code_file import CodeFile from .config_manager import ConfigManager @@ -32,7 +26,7 @@ def __init__( self.config = config self.code_context = code_context - def _read_file(self, file: Union[Path, CodeFile]) -> list[str]: + def read_file(self, file: Union[Path, CodeFile]) -> list[str]: if isinstance(file, CodeFile): rel_path = file.path else: @@ -46,8 +40,9 @@ def _read_file(self, file: Union[Path, CodeFile]) -> list[str]: def _read_all_file_lines(self) -> None: self.file_lines = dict[Path, list[str]]() for file in self.code_context.files.values(): + # self.file_lines is relative to git root rel_path = Path(os.path.relpath(file.path, self.config.git_root)) - self.file_lines[rel_path] = self._read_file(file) + self.file_lines[rel_path] = self.read_file(file) def get_code_message(self, model: str) -> str: code_message: list[str] = [] @@ -126,6 +121,8 @@ def _add_file(self, abs_path: Path): self.code_context.files[abs_path] = CodeFile(abs_path) # create any missing directories in the path abs_path.parent.mkdir(parents=True, exist_ok=True) + with open(abs_path, "w") as f: + f.write("") def _delete_file(self, abs_path: Path): logging.info(f"Deleting file {abs_path}") @@ -133,99 +130,62 @@ def _delete_file(self, abs_path: Path): del self.code_context.files[abs_path] abs_path.unlink() - def _handle_delete(self, delete_change: CodeChange): - abs_path = self.config.git_root / delete_change.file - if not abs_path.exists(): - logging.error(f"Path {abs_path} non-existent on delete") - return + # Mainly does checks on if file is in context, file exists, file is unchanged, etc. + def write_changes_to_files(self, file_edits: list[FileEdit]): + for file_edit in file_edits: + rel_path = Path(os.path.relpath(file_edit.file_path, self.config.git_root)) + if file_edit.is_creation: + if file_edit.file_path.exists(): + raise MentatError( + f"Model attempted to create file {file_edit.file_path} which" + " already exists" + ) + self._add_file(file_edit.file_path) + else: + if not file_edit.file_path.exists(): + raise MentatError( + f"Attempted to edit non-existent file {file_edit.file_path}" + ) + elif file_edit.file_path not in self.code_context.files: + raise MentatError( + f"Attempted to edit file {file_edit.file_path} not in context" + ) - cprint(f"Are you sure you want to delete {delete_change.file}?", "red") - if self.user_input_manager.ask_yes_no(default_yes=False): - cprint(f"Deleting {delete_change.file}...") - self._delete_file(abs_path) - else: - cprint(f"Not deleting {delete_change.file}") - - def _get_new_code_lines( - self, rel_path: Path, changes: list[CodeChange] - ) -> list[str] | None: - if not changes: - return [] - if len(set(map(lambda change: change.file, changes))) > 1: - raise Exception("All changes passed in must be for the same file") - - changes = sorted(changes, reverse=True) - - # We resolve insertion conflicts twice because non-insertion conflicts - # might move insert blocks outside of replace/delete blocks and cause - # them to conflict again - changes = resolve_insertion_conflicts(changes, self.user_input_manager, self) - changes = resolve_non_insertion_conflicts(changes, self.user_input_manager) - changes = resolve_insertion_conflicts(changes, self.user_input_manager, self) - if not changes: - return [] - - new_code_lines = self.file_lines[rel_path].copy() - if new_code_lines != self._read_file(rel_path): - logging.info(f"File '{rel_path}' changed while generating changes") - cprint( - f"File '{rel_path}' changed while generating; current file changes" - " will be erased. Continue?", - color="light_yellow", - ) - if not self.user_input_manager.ask_yes_no(default_yes=False): - cprint(f"Not applying changes to file {rel_path}.") - return None - - # Necessary in case the model needs to insert past the end of the file - last_line = len(new_code_lines) + 1 - largest_changed_line = math.ceil(changes[0].last_changed_line) - if largest_changed_line > last_line: - new_code_lines += [""] * (largest_changed_line - last_line) - - min_changed_line = largest_changed_line + 1 - for change in changes: - if change.last_changed_line >= min_changed_line: - raise MentatError(f"Change line number overlap in file {change.file}") - min_changed_line = change.first_changed_line - new_code_lines = change.apply(new_code_lines) - return new_code_lines - - def write_changes_to_files(self, code_changes: list[CodeChange]) -> None: - file_changes = defaultdict[Path, list[CodeChange]](list) - for code_change in code_changes: - rel_path = code_change.file - abs_path = self.config.git_root / rel_path - match code_change.action: - case CodeChangeAction.CreateFile: - cprint(f"Creating new file {rel_path}", color="light_green") - self._add_file(abs_path) - with open(abs_path, "w") as f: - f.write("\n".join(code_change.code_lines)) - case CodeChangeAction.DeleteFile: - self._handle_delete(code_change) - case CodeChangeAction.RenameFile: - abs_new_path = self.config.git_root / code_change.name - self._add_file(abs_new_path) - code_lines = self.file_lines[rel_path] - with open(abs_new_path, "w") as f: - f.write("\n".join(code_lines)) - self._delete_file(abs_path) - file_changes[code_change.name] += file_changes[rel_path] - file_changes[rel_path] = [] - self.file_lines[code_change.name] = self._read_file( - code_change.name + if file_edit.is_deletion: + cprint(f"Are you sure you want to delete {rel_path}?", "red") + if self.user_input_manager.ask_yes_no(default_yes=False): + cprint(f"Deleting {rel_path}...", "red") + self._delete_file(file_edit.file_path) + continue + else: + cprint(f"Not deleting {rel_path}", "green") + + if not file_edit.is_creation: + stored_lines = self.file_lines[rel_path] + if stored_lines != self.read_file(rel_path): + logging.info( + f"File '{file_edit.file_path}' changed while generating changes" + ) + cprint( + f"File '{rel_path}' changed while generating; current" + " file changes will be erased. Continue?", + color="light_yellow", ) - case _: - file_changes[rel_path].append(code_change) - - for rel_path, changes in file_changes.items(): - abs_path = self.config.git_root / rel_path - new_code_lines = self._get_new_code_lines(rel_path, changes) - if new_code_lines: - if abs_path not in self.code_context.files: + if not self.user_input_manager.ask_yes_no(default_yes=False): + cprint(f"Not applying changes to file {rel_path}") + else: + stored_lines = [] + + if file_edit.rename_file_path is not None: + if file_edit.rename_file_path.exists(): raise MentatError( - f"Attempted to edit file {abs_path} not in context" + f"Attempted to rename file {file_edit.file_path} to existing" + f" file {file_edit.rename_file_path}" ) - with open(abs_path, "w") as f: - f.write("\n".join(new_code_lines)) + self._add_file(file_edit.rename_file_path) + self._delete_file(file_edit.file_path) + file_edit.file_path = file_edit.rename_file_path + + new_lines = file_edit.get_updated_file_lines(stored_lines) + with open(file_edit.file_path, "w") as f: + f.write("\n".join(new_lines)) diff --git a/mentat/conversation.py b/mentat/conversation.py index 1fd673a66..439bf087a 100644 --- a/mentat/conversation.py +++ b/mentat/conversation.py @@ -1,28 +1,36 @@ +import asyncio +from timeit import default_timer + +from openai.error import InvalidRequestError, RateLimitError from termcolor import cprint -from .code_change import CodeChange -from .code_file_manager import CodeFileManager -from .config_manager import ConfigManager, user_config_path -from .llm_api import ( +from mentat.config_manager import ConfigManager, user_config_path +from mentat.errors import MentatError, UserError +from mentat.llm_api import ( CostTracker, + call_llm_api, count_tokens, get_prompt_token_count, is_model_available, model_context_size, ) -from .parsing import run_async_stream_and_parse_llm_response -from .prompts import system_prompt +from mentat.parsers.file_edit import FileEdit +from mentat.parsers.parser import Parser + +from .code_file_manager import CodeFileManager class Conversation: def __init__( self, + parser: Parser, config: ConfigManager, cost_tracker: CostTracker, code_file_manager: CodeFileManager, ): self.messages = list[dict[str, str]]() - self.add_system_message(system_prompt) + prompt = parser.get_system_prompt() + self.add_system_message(prompt) self.cost_tracker = cost_tracker self.code_file_manager = code_file_manager self.model = config.model() @@ -47,8 +55,7 @@ def __init__( tokens = count_tokens( code_file_manager.get_code_message(self.model), self.model - ) + count_tokens(system_prompt, self.model) - + ) + count_tokens(prompt, self.model) context_size = model_context_size(self.model) maximum_context = config.maximum_context() if maximum_context: @@ -90,24 +97,57 @@ def add_user_message(self, message: str): def add_assistant_message(self, message: str): self.messages.append({"role": "assistant", "content": message}) - def get_model_response(self) -> tuple[str, list[CodeChange]]: + async def _run_async_stream( + self, parser: Parser, config: ConfigManager, messages: list[dict[str, str]] + ) -> tuple[str, list[FileEdit]]: + response = await call_llm_api(messages, self.model) + with parser.interrupt_catcher(): + print("\nStreaming... use control-c to interrupt the model at any point\n") + message, file_edits = await parser.stream_and_parse_llm_response( + response, self.code_file_manager, config + ) + return message, file_edits + + def _handle_async_stream( + self, + parser: Parser, + config: ConfigManager, + messages: list[dict[str, str]], + ) -> tuple[str, list[FileEdit], float]: + start_time = default_timer() + try: + message, file_edits = asyncio.run( + self._run_async_stream(parser, config, messages) + ) + except InvalidRequestError as e: + raise MentatError( + "Something went wrong - invalid request to OpenAI API. OpenAI" + " returned:\n" + + str(e) + ) + except RateLimitError as e: + raise UserError("OpenAI gave a rate limit error:\n" + str(e)) + + time_elapsed = default_timer() - start_time + return (message, file_edits, time_elapsed) + + def get_model_response( + self, parser: Parser, config: ConfigManager + ) -> list[FileEdit]: messages = self.messages.copy() code_message = self.code_file_manager.get_code_message(self.model) - messages.append({"role": "system", "content": code_message}) num_prompt_tokens = get_prompt_token_count(messages, self.model) - - state = run_async_stream_and_parse_llm_response( - messages, self.model, self.code_file_manager + message, file_edits, time_elapsed = self._handle_async_stream( + parser, config, messages ) - self.cost_tracker.display_api_call_stats( num_prompt_tokens, - count_tokens(state.message, self.model), + count_tokens(message, self.model), self.model, - state.time_elapsed, + time_elapsed, ) - self.add_assistant_message(state.message) - return state.explanation, state.code_changes + self.add_assistant_message(message) + return file_edits diff --git a/mentat/parsers/__init__.py b/mentat/parsers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mentat/parsers/block_parser.py b/mentat/parsers/block_parser.py new file mode 100644 index 000000000..a4a91471a --- /dev/null +++ b/mentat/parsers/block_parser.py @@ -0,0 +1,30 @@ +from typing import Any, AsyncGenerator + +from typing_extensions import override + +from mentat.code_file_manager import CodeFileManager +from mentat.config_manager import ConfigManager +from mentat.parsers.file_edit import FileEdit +from mentat.parsers.original_format.original_format_parsing import ( + stream_and_parse_llm_response, +) +from mentat.parsers.parser import Parser +from mentat.prompts import block_parser_prompt + + +class BlockParser(Parser): + @override + def get_system_prompt(self) -> str: + return block_parser_prompt + + @override + async def stream_and_parse_llm_response( + self, + response: AsyncGenerator[Any, None], + code_file_manager: CodeFileManager, + config: ConfigManager, + ) -> tuple[str, list[FileEdit]]: + # Uses the legacy parsing code + return await stream_and_parse_llm_response( + response, code_file_manager, config, self.shutdown + ) diff --git a/mentat/parsers/change_display_helper.py b/mentat/parsers/change_display_helper.py new file mode 100644 index 000000000..bb381bade --- /dev/null +++ b/mentat/parsers/change_display_helper.py @@ -0,0 +1,224 @@ +from enum import Enum +from pathlib import Path + +import attr +from pygments import highlight # pyright: ignore[reportUnknownVariableType] +from pygments.formatters import TerminalFormatter +from pygments.lexer import Lexer +from pygments.lexers import TextLexer, get_lexer_for_filename +from pygments.util import ClassNotFound +from termcolor import colored + +change_delimiter = 60 * "=" + + +def _get_lexer(file_path: Path): + try: + lexer: Lexer = get_lexer_for_filename(file_path) + except ClassNotFound: + lexer = TextLexer() + lexer.stripnl = False + lexer.stripall = False + lexer.ensurenl = False + return lexer + + +def get_line_number_buffer(file_lines: list[str]): + return len(str(len(file_lines) + 1)) + 1 + + +class FileActionType(Enum): + RenameFile = "rename" + CreateFile = "create" + DeleteFile = "delete" + UpdateFile = "update" + + +@attr.define(slots=False) +class DisplayInformation: + file_name: Path = attr.field() + file_lines: list[str] = attr.field() + added_block: list[str] = attr.field() + removed_block: list[str] = attr.field() + file_action_type: FileActionType = attr.field() + first_changed_line: int | None = attr.field(default=None) + last_changed_line: int | None = attr.field(default=None) + new_name: Path | None = attr.field(default=None) + + def __attrs_post_init__(self): + self.line_number_buffer = get_line_number_buffer(self.file_lines) + self.lexer = _get_lexer(self.file_name) + + +def _remove_extra_empty_lines(lines: list[str]) -> list[str]: + if not lines: + return [] + + # Find the first non-empty line + start = 0 + while start < len(lines) and not lines[start].strip(): + start += 1 + + # Find the last non-empty line + end = len(lines) - 1 + while end > start and not lines[end].strip(): + end -= 1 + + # If all lines are empty, keep only one empty line + if start == len(lines): + return [" "] + + # Return the list with only a maximum of one empty line on either side + return lines[max(start - 1, 0) : end + 2] + + +def _prefixed_lines(line_number_buffer: int, lines: list[str], prefix: str): + return "\n".join( + [ + prefix + " " * (line_number_buffer - len(prefix)) + line.strip("\n") + for line in lines + ] + ) + + +def _get_code_block( + code_lines: list[str], + line_number_buffer: int, + prefix: str, + color: str | None, +): + lines = _prefixed_lines(line_number_buffer, code_lines, prefix) + if lines: + return colored(lines, color=color) + else: + return "" + + +def get_full_change(display_information: DisplayInformation): + to_print = [ + get_file_name(display_information), + ( + change_delimiter + if display_information.added_block or display_information.removed_block + else "" + ), + get_previous_lines(display_information), + get_removed_lines(display_information), + get_added_lines(display_information), + get_later_lines(display_information), + ( + change_delimiter + if display_information.added_block or display_information.removed_block + else "" + ), + ] + full_change = "\n".join([line for line in to_print if line]) + return full_change + + +def get_file_name( + display_information: DisplayInformation, +): + match display_information.file_action_type: + case FileActionType.CreateFile: + return colored(f"\n{display_information.file_name}*", color="light_green") + case FileActionType.DeleteFile: + return colored(f"\n{display_information.file_name}", color="light_red") + case FileActionType.RenameFile: + return colored( + f"\nRename: {display_information.file_name} ->" + f" {display_information.new_name}", + color="yellow", + ) + case FileActionType.UpdateFile: + return colored(f"\n{display_information.file_name}", color="light_blue") + + +def get_added_lines( + display_information: DisplayInformation, + prefix: str = "+", + color: str | None = "green", +): + return _get_code_block( + display_information.added_block, + display_information.line_number_buffer, + prefix, + color, + ) + + +def get_removed_lines( + display_information: DisplayInformation, + prefix: str = "-", + color: str | None = "red", +): + return _get_code_block( + display_information.removed_block, + display_information.line_number_buffer, + prefix, + color, + ) + + +def get_previous_lines( + display_information: DisplayInformation, + num: int = 2, +): + if display_information.first_changed_line is None: + return "" + lines = _remove_extra_empty_lines( + [ + display_information.file_lines[i] + for i in range( + max(0, display_information.first_changed_line - (num + 1)), + min( + display_information.first_changed_line, + len(display_information.file_lines), + ), + ) + ] + ) + numbered = [ + (str(display_information.first_changed_line - len(lines) + i + 1) + ":").ljust( + display_information.line_number_buffer + ) + + line + for i, line in enumerate(lines) + ] + + prev = "\n".join(numbered) + # pygments doesn't have type hints on TerminalFormatter + h_prev: str = highlight(prev, display_information.lexer, TerminalFormatter(bg="dark")) # type: ignore + return h_prev + + +def get_later_lines( + display_information: DisplayInformation, + num: int = 2, +): + if display_information.last_changed_line is None: + return "" + lines = _remove_extra_empty_lines( + [ + display_information.file_lines[i] + for i in range( + max(0, display_information.last_changed_line), + min( + display_information.last_changed_line + num, + len(display_information.file_lines), + ), + ) + ] + ) + numbered = [ + (str(display_information.last_changed_line + 1 + i) + ":").ljust( + display_information.line_number_buffer + ) + + line + for i, line in enumerate(lines) + ] + + later = "\n".join(numbered) + # pygments doesn't have type hints on TerminalFormatter + h_later: str = highlight(later, display_information.lexer, TerminalFormatter(bg="dark")) # type: ignore + return h_later diff --git a/mentat/parsers/file_edit.py b/mentat/parsers/file_edit.py new file mode 100644 index 000000000..fe3e83733 --- /dev/null +++ b/mentat/parsers/file_edit.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import os +from pathlib import Path +from typing import TYPE_CHECKING + +import attr +from termcolor import cprint + +from mentat.config_manager import ConfigManager +from mentat.errors import MentatError +from mentat.parsers.change_display_helper import ( + DisplayInformation, + FileActionType, + change_delimiter, + get_full_change, +) +from mentat.user_input_manager import UserInputManager + +if TYPE_CHECKING: + # This normally will cause a circular import + from mentat.code_file_manager import CodeFileManager + + +# TODO: Add 'owner' to Replacement so that interactive mode can accept/reject multiple replacements at once +@attr.define(order=False) +class Replacement: + """ + Represents that the lines from starting_line (inclusive) to ending_line (exclusive) + should be replaced with new_lines + """ + + # Inclusive + starting_line: int = attr.field() + # Exclusive + ending_line: int = attr.field() + + new_lines: list[str] = attr.field() + + def __lt__(self, other: Replacement): + return self.ending_line < other.ending_line or ( + self.ending_line == other.ending_line + and self.starting_line < other.ending_line + ) + + +def _ask_user_change( + user_input_manager: UserInputManager, + display_information: DisplayInformation, + text: str, +) -> bool: + print(get_full_change(display_information)) + cprint(text, "light_blue") + return user_input_manager.ask_yes_no(default_yes=True) + + +@attr.define +class FileEdit: + """ + Represents that this file_path content should have specified Replacements applied to it. + Can also represent that this file should be created, deleted, or is being renamed. + """ + + # Should be abs path + file_path: Path = attr.field() + replacements: list[Replacement] = attr.field(factory=list) + is_creation: bool = attr.field(default=False) + is_deletion: bool = attr.field(default=False) + # Should be abs path + rename_file_path: Path | None = attr.field(default=None) + + def is_valid( + self, code_file_manager: CodeFileManager, config: ConfigManager + ) -> bool: + rel_path = Path(os.path.relpath(self.file_path, config.git_root)) + if self.is_creation: + if self.file_path.exists(): + cprint(f"File {rel_path} already exists, canceling creation.") + return False + else: + if not self.file_path.exists(): + cprint(f"File {rel_path} does not exist, canceling all edits to file.") + return False + elif rel_path not in code_file_manager.file_lines: + cprint(f"File {rel_path} not in context, canceling all edits to file.") + return False + + if self.rename_file_path is not None and self.rename_file_path.exists(): + rel_rename_path = Path( + os.path.relpath(self.rename_file_path, config.git_root) + ) + cprint( + f"File {rel_path} being renamed to existing file {rel_rename_path}," + " canceling rename." + ) + self.rename_file_path = None + return True + + def filter_replacements( + self, + code_file_manager: CodeFileManager, + user_input_manager: UserInputManager, + config: ConfigManager, + ) -> bool: + if self.is_creation: + display_information = DisplayInformation( + self.file_path, [], [], [], FileActionType.CreateFile, None, None, None + ) + if not _ask_user_change( + user_input_manager, display_information, "Create this file?" + ): + return False + file_lines = [] + else: + rel_path = Path(os.path.relpath(self.file_path, config.git_root)) + file_lines = code_file_manager.file_lines[rel_path] + + if self.is_deletion: + display_information = DisplayInformation( + self.file_path, [], [], file_lines, FileActionType.DeleteFile + ) + if not _ask_user_change( + user_input_manager, display_information, "Delete this file?" + ): + return False + + if self.rename_file_path is not None: + display_information = DisplayInformation( + self.file_path, + [], + [], + [], + FileActionType.RenameFile, + new_name=self.rename_file_path, + ) + if not _ask_user_change( + user_input_manager, display_information, "Rename this file?" + ): + self.rename_file_path = None + + new_replacements = list[Replacement]() + for replacement in self.replacements: + removed_block = file_lines[ + replacement.starting_line : replacement.ending_line + ] + display_information = DisplayInformation( + self.file_path, + file_lines, + replacement.new_lines, + removed_block, + FileActionType.UpdateFile, + replacement.starting_line, + replacement.ending_line, + self.rename_file_path, + ) + if _ask_user_change( + user_input_manager, display_information, "Keep this change?" + ): + new_replacements.append(replacement) + self.replacements = new_replacements + + return ( + self.is_creation + or self.is_deletion + or (self.rename_file_path is not None) + or len(self.replacements) > 0 + ) + + def _print_resolution(self, first: Replacement, second: Replacement): + print("Change overlap detected, auto-merged back to back changes:\n") + print(self.file_path) + print(change_delimiter) + for line in first.new_lines + second.new_lines: + cprint("+ " + line, color="green") + print() + + def resolve_conflicts(self, user_input_manager: UserInputManager): + self.replacements.sort(reverse=True) + for index, replacement in enumerate(self.replacements): + for other in self.replacements[index + 1 :]: + if ( + other.ending_line > replacement.starting_line + and other.starting_line < replacement.ending_line + ): + # Overlap conflict + other.ending_line = replacement.starting_line + other.starting_line = min(other.starting_line, other.ending_line) + self._print_resolution(other, replacement) + elif ( + other.ending_line == other.starting_line + and replacement.ending_line == replacement.starting_line + and replacement.starting_line == other.starting_line + ): + # Insertion conflict + # This will be a bit wonky if there are more than 2 insertion conflicts on the same line + self._print_resolution(replacement, other) + + def get_updated_file_lines(self, file_lines: list[str]): + self.replacements.sort(reverse=True) + earliest_line = None + for replacement in self.replacements: + if earliest_line is not None and replacement.ending_line > earliest_line: + # This should never happen if resolve conflicts is called + raise MentatError("Error: Line overlap in Replacements") + if replacement.ending_line > len(file_lines): + file_lines += [""] * (replacement.ending_line - len(file_lines)) + earliest_line = replacement.starting_line + file_lines = ( + file_lines[: replacement.starting_line] + + replacement.new_lines + + file_lines[replacement.ending_line :] + ) + return file_lines diff --git a/mentat/parsers/original_format/__init__.py b/mentat/parsers/original_format/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mentat/parsers/original_format/original_format_change.py b/mentat/parsers/original_format/original_format_change.py new file mode 100644 index 000000000..812a6aa54 --- /dev/null +++ b/mentat/parsers/original_format/original_format_change.py @@ -0,0 +1,211 @@ +# This file is mainly kept as legacy so that we don't have to rewrite this code + +from __future__ import annotations + +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from mentat.config_manager import ConfigManager +from mentat.errors import ModelError +from mentat.parsers.change_display_helper import DisplayInformation, FileActionType +from mentat.parsers.file_edit import FileEdit, Replacement + +if TYPE_CHECKING: + # This normally will cause a circular import + from mentat.code_file_manager import CodeFileManager + + +class OriginalFormatChangeAction(Enum): + Insert = "insert" + Replace = "replace" + Delete = "delete" + CreateFile = "create-file" + DeleteFile = "delete-file" + RenameFile = "rename-file" + + +class OriginalFormatChange: + @classmethod + def to_file_edits( + cls, + changes: list[OriginalFormatChange], + config: ConfigManager, + ) -> list[FileEdit]: + file_edits = dict[Path, FileEdit]() + for code_change in changes: + rel_path = code_change.file + abs_path = config.git_root / rel_path + if abs_path not in file_edits: + file_edits[abs_path] = FileEdit(abs_path) + match code_change.action: + case OriginalFormatChangeAction.CreateFile: + file_edits[abs_path].replacements.append( + Replacement(0, 0, code_change.code_lines) + ) + file_edits[abs_path].is_creation = True + case OriginalFormatChangeAction.DeleteFile: + file_edits[abs_path].is_deletion = True + case OriginalFormatChangeAction.RenameFile: + abs_new_path = config.git_root / code_change.name + file_edits[abs_path].rename_file_path = abs_new_path + case _: + file_edits[abs_path].replacements.append( + Replacement( + code_change.first_changed_line, + code_change.last_changed_line, + code_change.code_lines, + ) + ) + return [file_edit for file_edit in file_edits.values()] + + def __init__( + self, + json_data: dict[Any, Any], + code_lines: list[str], + code_file_manager: CodeFileManager, + rename_map: dict[Path, Path] = {}, + ): + self.json_data = json_data + # Sometimes GPT puts quotes around numbers, so we have to convert those + for json_key in [ + "insert-before-line", + "insert-after-line", + "start-line", + "end-line", + ]: + if json_key in self.json_data: + self.json_data[json_key] = int(self.json_data[json_key]) + self.code_lines = code_lines + self.file = Path(self.json_data["file"]) + # This rename_map is a bit hacky; it shouldn't be used outside of streaming/parsing + if self.file in rename_map: + self.file = rename_map[self.file] + self.first_changed_line: int = 0 + self.last_changed_line: int = 0 + self.error = "" + + try: + self.action = OriginalFormatChangeAction(self.json_data["action"]) + except ValueError: + raise ModelError( + f"Model created change with unknown action {self.json_data['action']}", + already_added_to_changelist=False, + ) + + try: + match self.action: + case OriginalFormatChangeAction.Insert: + if "insert-before-line" in self.json_data: + self.first_changed_line = ( + self.json_data["insert-before-line"] - 1 + ) + if ( + "insert-after-line" in self.json_data + and self.first_changed_line + != self.json_data["insert-after-line"] + ): + self.error = "Insert line numbers invalid" + elif "insert-after-line" in self.json_data: + self.first_changed_line = self.json_data["insert-after-line"] + else: + self.first_changed_line = 0 + self.error = "Insert line number not specified" + self.last_changed_line = self.first_changed_line + + case OriginalFormatChangeAction.Replace: + self.first_changed_line = self.json_data["start-line"] - 1 + self.last_changed_line = self.json_data["end-line"] + + case OriginalFormatChangeAction.Delete: + self.first_changed_line = self.json_data["start-line"] - 1 + self.last_changed_line = self.json_data["end-line"] + + case OriginalFormatChangeAction.RenameFile: + self.name = Path(self.json_data["name"]) + + case _: + pass + + except KeyError: + self.error = "Line numbers not given" + + if ( + self.first_changed_line + and self.last_changed_line + and self.first_changed_line > self.last_changed_line + ): + self.error = "Starting line of change is greater than ending line of change" + + if self.action == OriginalFormatChangeAction.CreateFile: + if self.file.exists(): + self.error = ( + f"Model attempted to create file that already exists: {self.file}" + ) + self.file_lines = [] + self.line_number_buffer = 2 + else: + if ( + self.action == OriginalFormatChangeAction.RenameFile + and self.name.exists() + ): + self.error = ( + f"Model attempted to rename file {self.file} to a file that" + f" already exists: {self.name}" + ) + + rel_path = self.file + try: + self.file_lines = code_file_manager.file_lines[rel_path] + except KeyError: + self.error = ( + f"Model attempted to edit {rel_path}, which isn't in current" + " context or doesn't exist" + ) + + def get_change_display_information(self) -> DisplayInformation: + removed_block = ( + self.file_lines + if self.action == OriginalFormatChangeAction.DeleteFile + else ( + self.file_lines[self.first_changed_line : self.last_changed_line] + if self.has_removals() + else [] + ) + ) + display_information = DisplayInformation( + self.file, + self.file_lines, + self.code_lines, + removed_block, + self.get_file_action_type(), + self.first_changed_line, + self.last_changed_line, + self.name if self.action == OriginalFormatChangeAction.RenameFile else None, + ) + return display_information + + def has_removals(self): + return ( + self.action == OriginalFormatChangeAction.Delete + or self.action == OriginalFormatChangeAction.Replace + or self.action == OriginalFormatChangeAction.DeleteFile + ) + + def has_additions(self): + return ( + self.action == OriginalFormatChangeAction.Insert + or self.action == OriginalFormatChangeAction.Replace + or self.action == OriginalFormatChangeAction.CreateFile + ) + + def get_file_action_type(self): + match self.action: + case OriginalFormatChangeAction.CreateFile: + return FileActionType.CreateFile + case OriginalFormatChangeAction.DeleteFile: + return FileActionType.DeleteFile + case OriginalFormatChangeAction.RenameFile: + return FileActionType.RenameFile + case _: + return FileActionType.UpdateFile diff --git a/mentat/parsing.py b/mentat/parsers/original_format/original_format_parsing.py similarity index 71% rename from mentat/parsing.py rename to mentat/parsers/original_format/original_format_parsing.py index 66a64fea8..2c1aa1292 100644 --- a/mentat/parsing.py +++ b/mentat/parsers/original_format/original_format_parsing.py @@ -1,28 +1,34 @@ +# This file is mainly kept as legacy so that we don't have to rewrite this code + +from __future__ import annotations + import asyncio import json import logging +from asyncio import Event from enum import Enum from json import JSONDecodeError from pathlib import Path -from timeit import default_timer from typing import Any, AsyncGenerator import attr -from openai.error import InvalidRequestError, RateLimitError from termcolor import cprint -from .code_change import CodeChange, CodeChangeAction -from .code_change_display import ( +from mentat.code_file_manager import CodeFileManager +from mentat.config_manager import ConfigManager +from mentat.errors import ModelError +from mentat.parsers.change_display_helper import ( change_delimiter, get_file_name, get_later_lines, + get_line_number_buffer, get_previous_lines, - get_removed_block, + get_removed_lines, ) -from .code_file_manager import CodeFileManager -from .errors import MentatError, ModelError, UserError -from .llm_api import call_llm_api -from .streaming_printer import StreamingPrinter +from mentat.parsers.file_edit import FileEdit +from mentat.streaming_printer import StreamingPrinter + +from .original_format_change import OriginalFormatChange, OriginalFormatChangeAction class _BlockIndicator(Enum): @@ -31,20 +37,20 @@ class _BlockIndicator(Enum): End = "@@end" -@attr.s +@attr.define class ParsingState: - message: str = attr.ib(default="") - cur_line: str = attr.ib(default="") - cur_printed: bool = attr.ib(default=False) - time_elapsed: float = attr.ib(default=0) - in_special_lines: bool = attr.ib(default=False) - in_code_lines: bool = attr.ib(default=False) - explanation: str = attr.ib(default="") - explained_since_change: bool = attr.ib(default=True) - code_changes: list[CodeChange] = attr.ib(factory=list) - json_lines: list[str] = attr.ib(factory=list) - code_lines: list[str] = attr.ib(factory=list) - rename_map: dict[Path, Path] = attr.ib(factory=dict) + message: str = attr.field(default="") + cur_line: str = attr.field(default="") + cur_printed: bool = attr.field(default=False) + time_elapsed: float = attr.field(default=0) + in_special_lines: bool = attr.field(default=False) + in_code_lines: bool = attr.field(default=False) + explanation: str = attr.field(default="") + explained_since_change: bool = attr.field(default=True) + code_changes: list[OriginalFormatChange] = attr.field(factory=list) + json_lines: list[str] = attr.field(factory=list) + code_lines: list[str] = attr.field(factory=list) + rename_map: dict[Path, Path] = attr.field(factory=dict) def parse_line_printing(self, content: str): to_print = "" @@ -77,12 +83,12 @@ def create_code_change(self, code_file_manager: CodeFileManager): already_added_to_changelist=False, ) - new_change = CodeChange( + new_change = OriginalFormatChange( json_data, self.code_lines, code_file_manager, self.rename_map ) self.code_changes.append(new_change) self.json_lines, self.code_lines = [], [] - if new_change.action == CodeChangeAction.RenameFile: + if new_change.action == OriginalFormatChangeAction.RenameFile: # This rename_map is a bit hacky; it shouldn't be used outside of streaming/parsing self.rename_map[new_change.name] = new_change.file @@ -112,7 +118,7 @@ def new_line(self, code_file_manager: CodeFileManager): ) self.in_code_lines = True self.create_code_change(code_file_manager) - if not self.code_changes[-1].action.has_additions(): + if not self.code_changes[-1].has_additions(): raise ModelError( "Model gave code indicator for action without code", already_added_to_changelist=True, @@ -154,55 +160,23 @@ def new_line(self, code_file_manager: CodeFileManager): return to_print, entered_code_lines, exited_code_lines, created_code_change -def run_async_stream_and_parse_llm_response( - messages: list[dict[str, str]], - model: str, - code_file_manager: CodeFileManager, -) -> ParsingState: - state: ParsingState = ParsingState() - start_time = default_timer() - try: - asyncio.run( - stream_and_parse_llm_response(messages, model, state, code_file_manager) - ) - except InvalidRequestError as e: - raise MentatError( - "Something went wrong - invalid request to OpenAI API. OpenAI returned:\n" - + str(e) - ) - except RateLimitError as e: - raise UserError("OpenAI gave a rate limit error:\n" + str(e)) - except KeyboardInterrupt: - print("\n\nInterrupted by user. Using the response up to this point.") - # if the last change is incomplete, remove it - if state.in_code_lines: - state.code_changes = state.code_changes[:-1] - logging.info("User interrupted response.") - - state.code_changes = list( - filter(lambda change: not change.error, state.code_changes) - ) - - state.time_elapsed = default_timer() - start_time - return state - - async def stream_and_parse_llm_response( - messages: list[dict[str, str]], - model: str, - state: ParsingState, + response: AsyncGenerator[Any, None], code_file_manager: CodeFileManager, -) -> None: - response = await call_llm_api(messages, model) - - print("\nstreaming... use control-c to interrupt the model at any point\n") - + config: ConfigManager, + shutdown: Event, +) -> tuple[str, list[FileEdit]]: + state = ParsingState() printer = StreamingPrinter() printer_task = asyncio.create_task(printer.print_lines()) try: - await _process_response(state, response, printer, code_file_manager) - printer.wrap_it_up() - await printer_task + if await _process_response( + state, response, printer, code_file_manager, shutdown + ): + printer.wrap_it_up() + await printer_task + else: + printer_task.cancel() except ModelError as e: logging.info(f"Model created error {e}") printer.wrap_it_up() @@ -216,13 +190,17 @@ async def stream_and_parse_llm_response( finally: logging.debug(f"LLM response:\n{state.message}") + code_changes = list(filter(lambda change: not change.error, state.code_changes)) + return (state.message, OriginalFormatChange.to_file_edits(code_changes, config)) + async def _process_response( state: ParsingState, response: AsyncGenerator[Any, None], printer: StreamingPrinter, code_file_manager: CodeFileManager, -): + shutdown: Event, +) -> bool: def chunk_to_lines(chunk: Any) -> list[str]: return chunk["choices"][0]["delta"].get("content", "").splitlines(keepends=True) @@ -231,6 +209,10 @@ def chunk_to_lines(chunk: Any) -> list[str]: if content_line: state.message += content_line _process_content_line(state, content_line, printer, code_file_manager) + if shutdown.is_set(): + if state.in_code_lines: + state.code_changes = state.code_changes[:-1] + return False # This newline solves at least 5 edge cases singlehandedly _process_content_line(state, "\n", printer, code_file_manager) @@ -239,6 +221,7 @@ def chunk_to_lines(chunk: Any) -> list[str]: if state.in_special_lines: logging.info("Model forgot an @@end!") _process_content_line(state, "@@end\n", printer, code_file_manager) + return True def _process_content_line( @@ -255,7 +238,7 @@ def _process_content_line( ) or not state.in_special_lines: to_print = state.parse_line_printing(content) prefix = ( - "+" + " " * (state.code_changes[-1].line_number_buffer - 1) + "+" + " " * (get_line_number_buffer(state.code_changes[-1].file_lines) - 1) if state.in_code_lines and beginning else "" ) @@ -283,36 +266,35 @@ def _process_content_line( "Continuing model response...\n", color="light_green" ) else: + display_information = cur_change.get_change_display_information() if ( len(state.code_changes) < 2 or state.code_changes[-2].file != cur_change.file or state.explained_since_change - or state.code_changes[-1].action == CodeChangeAction.RenameFile + or state.code_changes[-1].action + == OriginalFormatChangeAction.RenameFile ): - printer.add_string(get_file_name(cur_change)) - if ( - cur_change.action.has_additions() - or cur_change.action.has_removals() - ): + printer.add_string(get_file_name(display_information)) + if cur_change.has_additions() or cur_change.has_removals(): printer.add_string(change_delimiter) state.explained_since_change = False - printer.add_string(get_previous_lines(cur_change)) - printer.add_string(get_removed_block(cur_change)) - if ( - not cur_change.action.has_additions() - and cur_change.action.has_removals() - ): - printer.add_string(get_later_lines(cur_change)) + printer.add_string(get_previous_lines(display_information)) + printer.add_string(get_removed_lines(display_information)) + if not cur_change.has_additions() and cur_change.has_removals(): + printer.add_string(get_later_lines(display_information)) printer.add_string(change_delimiter) if to_print and not (state.in_code_lines and state.code_changes[-1].error): prefix = ( - "+" + " " * (state.code_changes[-1].line_number_buffer - 1) + "+" + + " " * (get_line_number_buffer(state.code_changes[-1].file_lines) - 1) if state.in_code_lines and beginning else "" ) color = "green" if state.in_code_lines else None printer.add_string(prefix + to_print, end="", color=color) if exited_code_lines and not state.code_changes[-1].error: - printer.add_string(get_later_lines(state.code_changes[-1])) + printer.add_string( + get_later_lines(state.code_changes[-1].get_change_display_information()) + ) printer.add_string(change_delimiter) diff --git a/mentat/parsers/parser.py b/mentat/parsers/parser.py new file mode 100644 index 000000000..83b67e4e5 --- /dev/null +++ b/mentat/parsers/parser.py @@ -0,0 +1,43 @@ +import logging +import signal +from abc import ABC, abstractmethod +from asyncio import Event +from contextlib import contextmanager +from types import FrameType +from typing import Any, AsyncGenerator + +from mentat.code_file_manager import CodeFileManager +from mentat.config_manager import ConfigManager +from mentat.parsers.file_edit import FileEdit + + +class Parser(ABC): + def __init__(self): + self.shutdown = Event() + + def shutdown_handler(self, sig: int, frame: FrameType | None): + print("\n\nInterrupted by user. Using the response up to this point.") + logging.info("User interrupted response.") + self.shutdown.set() + + # Interface redesign will likely completely change interrupt handling + @contextmanager + def interrupt_catcher(self): + signal.signal(signal.SIGINT, self.shutdown_handler) + yield + # Reset to default interrupt handler + signal.signal(signal.SIGINT, signal.SIG_DFL) + self.shutdown.clear() + + @abstractmethod + def get_system_prompt(self) -> str: + pass + + @abstractmethod + async def stream_and_parse_llm_response( + self, + response: AsyncGenerator[Any, None], + code_file_manager: CodeFileManager, + config: ConfigManager, + ) -> tuple[str, list[FileEdit]]: + pass diff --git a/mentat/prompts.py b/mentat/prompts.py index a69adb772..7ccfe6be2 100644 --- a/mentat/prompts.py +++ b/mentat/prompts.py @@ -1,6 +1,6 @@ from textwrap import dedent -system_prompt = """ +block_parser_prompt = """ You are part of an automated coding system. As such, responses must adhere strictly to the required format, so they can be parsed programmaticaly. Your input will consist of a user request, the contents of code files, and sometimes the git diff of @@ -206,4 +206,4 @@ def main(): say_goodbye() @@end """ -system_prompt = dedent(system_prompt).strip() +block_parser_prompt = dedent(block_parser_prompt).strip() diff --git a/pyrightconfig.json b/pyrightconfig.json index 24d021f07..758ece608 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -1 +1 @@ -{"include": ["mentat"], "typeCheckingMode": "strict"} +{"include": ["mentat"], "ignore": ["testbed", "tests"], "typeCheckingMode": "strict"} diff --git a/tests/conftest.py b/tests/conftest.py index 470b8555e..39d597361 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ import pytest +from mentat.code_context import CodeContext from mentat.config_manager import ConfigManager from mentat.streaming_printer import StreamingPrinter from mentat.user_input_manager import UserInputManager @@ -69,7 +70,7 @@ def pytest_collection_modifyitems(config, items): @pytest.fixture def mock_call_llm_api(mocker): - mock = mocker.patch("mentat.parsing.call_llm_api") + mock = mocker.patch("mentat.conversation.call_llm_api") def set_generator_values(values): async def async_generator(): @@ -104,6 +105,16 @@ def mock_config(temp_testbed): return config +@pytest.fixture +def mock_context(mock_config): + return CodeContext(mock_config, [], []) + + +@pytest.fixture +def mock_user_input_manager(mock_config, mock_context): + return UserInputManager(mock_config, mock_context) + + def add_permissions(func, path, exc_info): """ Error handler for ``shutil.rmtree``. diff --git a/tests/model_error_test.py b/tests/parser_tests/block_format_error_test.py similarity index 100% rename from tests/model_error_test.py rename to tests/parser_tests/block_format_error_test.py diff --git a/tests/code_change_test.py b/tests/parser_tests/block_format_test.py similarity index 100% rename from tests/code_change_test.py rename to tests/parser_tests/block_format_test.py diff --git a/tests/parser_tests/file_edit_test.py b/tests/parser_tests/file_edit_test.py new file mode 100644 index 000000000..50be8a316 --- /dev/null +++ b/tests/parser_tests/file_edit_test.py @@ -0,0 +1,41 @@ +from pathlib import Path + +from mentat.parsers.file_edit import FileEdit, Replacement + +# Since file creation, deletion, and renaming is almost entirely handled in +# the CodeFileManager, no need to test that here + + +def test_replacement(mock_user_input_manager): + replacements = [ + Replacement(0, 2, ["# Line 0", "# Line 1", "# Line 2"]), + Replacement(3, 3, ["# Inserted"]), + ] + file_edit = FileEdit(file_path=Path("test.py"), replacements=replacements) + file_edit.resolve_conflicts(mock_user_input_manager) + original_lines = ["# Remove me", "# Remove me", "# Line 3", "# Line 4"] + new_lines = file_edit.get_updated_file_lines(original_lines) + assert new_lines == [ + "# Line 0", + "# Line 1", + "# Line 2", + "# Line 3", + "# Inserted", + "# Line 4", + ] + + +# When we add user conflict resolution, this test will need to be changed +def test_replacement_conflict(mock_user_input_manager): + replacements = [ + Replacement(0, 2, ["L0"]), + Replacement(1, 3, ["L1"]), + Replacement(4, 7, ["L3"]), + Replacement(5, 6, ["L2"]), + ] + file_edit = FileEdit(file_path=Path("test.py"), replacements=replacements) + file_edit.resolve_conflicts(mock_user_input_manager) + original_lines = ["O0", "O1", "O2", "O3", "O4", "O5", "O6"] + new_lines = file_edit.get_updated_file_lines(original_lines) + print(new_lines) + assert new_lines == ["L0", "L1", "O3", "L2", "L3"]