diff --git a/mentat/code_feature.py b/mentat/code_feature.py index c70d44497..22b862cb1 100644 --- a/mentat/code_feature.py +++ b/mentat/code_feature.py @@ -152,8 +152,8 @@ def __repr__(self): ) def ref(self, cwd: Optional[Path] = None) -> str: - if cwd is not None and self.path.is_relative_to(cwd): - path_string = self.path.relative_to(cwd) + if cwd is not None: + path_string = str(get_relative_path(self.path, cwd)) else: path_string = str(self.path) diff --git a/mentat/code_file_manager.py b/mentat/code_file_manager.py index 479c96e8f..562bbb4bb 100644 --- a/mentat/code_file_manager.py +++ b/mentat/code_file_manager.py @@ -5,22 +5,15 @@ from pathlib import Path from typing import TYPE_CHECKING -from mentat.edit_history import ( - CreationAction, - DeletionAction, - EditAction, - EditHistory, - RenameAction, -) +from mentat.edit_history import EditHistory from mentat.errors import MentatError from mentat.interval import Interval from mentat.session_context import SESSION_CONTEXT from mentat.session_input import ask_yes_no -from mentat.utils import sha256 +from mentat.utils import get_relative_path, sha256 if TYPE_CHECKING: # This normally will cause a circular import - from mentat.code_context import CodeContext from mentat.parsers.file_edit import FileEdit @@ -38,26 +31,39 @@ def read_file(self, path: Path) -> list[str]: self.file_lines[abs_path] = lines return lines - def _create_file(self, code_context: CodeContext, abs_path: Path): + def create_file(self, abs_path: Path, content: str = ""): + ctx = SESSION_CONTEXT.get() + code_context = ctx.code_context + logging.info(f"Creating new file {abs_path}") # Create any missing directories in the path abs_path.parent.mkdir(parents=True, exist_ok=True) with open(abs_path, "w") as f: - f.write("") - code_context.include(abs_path) + f.write(content) + + if abs_path not in code_context.include_files: + code_context.include(abs_path) + + def delete_file(self, abs_path: Path): + ctx = SESSION_CONTEXT.get() + code_context = ctx.code_context - def _delete_file(self, code_context: CodeContext, abs_path: Path): logging.info(f"Deleting file {abs_path}") - code_context.exclude(abs_path) + + if abs_path in code_context.include_files: + code_context.exclude(abs_path) abs_path.unlink() - def _rename_file( - self, code_context: CodeContext, abs_path: Path, new_abs_path: Path - ): + def rename_file(self, abs_path: Path, new_abs_path: Path): + ctx = SESSION_CONTEXT.get() + code_context = ctx.code_context + logging.info(f"Renaming file {abs_path} to {new_abs_path}") - code_context.exclude(abs_path) + if abs_path in code_context.include_files: + code_context.exclude(abs_path) os.rename(abs_path, new_abs_path) - code_context.include(new_abs_path) + if new_abs_path not in code_context.include_files: + code_context.include(new_abs_path) # Mainly does checks on if file is in context, file exists, file is unchanged, etc. async def write_changes_to_files( @@ -66,7 +72,6 @@ async def write_changes_to_files( ) -> list[FileEdit]: session_context = SESSION_CONTEXT.get() stream = session_context.stream - code_context = session_context.code_context agent_handler = session_context.agent_handler if not file_edits: @@ -74,10 +79,7 @@ async def write_changes_to_files( applied_edits: list[FileEdit] = [] for file_edit in file_edits: - if file_edit.file_path.is_relative_to(session_context.cwd): - display_path = file_edit.file_path.relative_to(session_context.cwd) - else: - display_path = file_edit.file_path + display_path = get_relative_path(file_edit.file_path, session_context.cwd) if file_edit.is_creation: if file_edit.file_path.exists(): @@ -85,8 +87,7 @@ async def write_changes_to_files( f"Model attempted to create file {file_edit.file_path} which" " already exists" ) - self.history.add_action(CreationAction(file_edit.file_path)) - self._create_file(code_context, file_edit.file_path) + self.create_file(file_edit.file_path) elif not file_edit.file_path.exists(): raise MentatError( f"Attempted to edit non-existent file {file_edit.file_path}" @@ -100,16 +101,12 @@ async def write_changes_to_files( if await ask_yes_no(default_yes=False): stream.send(f"Deleting {display_path}...", color="red") # We use the current lines rather than the stored lines for undo - self.history.add_action( - DeletionAction( - file_edit.file_path, self.read_file(file_edit.file_path) - ) - ) - self._delete_file(code_context, file_edit.file_path) + file_edit.previous_file_lines = self.read_file(file_edit.file_path) + self.delete_file(file_edit.file_path) applied_edits.append(file_edit) - continue else: stream.send(f"Not deleting {display_path}", color="green") + continue if not file_edit.is_creation: stored_lines = self.file_lines[file_edit.file_path] @@ -134,23 +131,19 @@ async def write_changes_to_files( f"Attempted to rename file {file_edit.file_path} to existing" f" file {file_edit.rename_file_path}" ) - self.history.add_action( - RenameAction(file_edit.file_path, file_edit.rename_file_path) - ) - self._rename_file( - code_context, file_edit.file_path, file_edit.rename_file_path - ) - file_edit.file_path = file_edit.rename_file_path + self.rename_file(file_edit.file_path, file_edit.rename_file_path) new_lines = file_edit.get_updated_file_lines(stored_lines) if new_lines != stored_lines: + file_path = file_edit.rename_file_path or file_edit.file_path # We use the current lines rather than the stored lines for undo - self.history.add_action( - EditAction(file_edit.file_path, self.read_file(file_edit.file_path)) - ) - with open(file_edit.file_path, "w") as f: + file_edit.previous_file_lines = self.read_file(file_path) + with open(file_path, "w") as f: f.write("\n".join(new_lines)) applied_edits.append(file_edit) + + for applied_edit in applied_edits: + self.history.add_edit(applied_edit) if not agent_handler.agent_enabled: self.history.push_edits() return applied_edits diff --git a/mentat/edit_history.py b/mentat/edit_history.py index 2eee67627..d1a1540d7 100644 --- a/mentat/edit_history.py +++ b/mentat/edit_history.py @@ -1,107 +1,26 @@ -import os -from pathlib import Path from typing import Optional -import attr from termcolor import colored from mentat.errors import HistoryError -from mentat.parsers.file_edit import FileEdit, Replacement +from mentat.parsers.file_edit import FileEdit from mentat.session_context import SESSION_CONTEXT -# All paths should be abs paths -@attr.define() -class RenameAction: - old_file_name: Path = attr.field() - cur_file_name: Path = attr.field() - - def undo(self) -> FileEdit: - if self.old_file_name.exists(): - raise HistoryError( - f"File {self.old_file_name} already exists; unable to undo rename from" - f" {self.cur_file_name}" - ) - else: - os.rename(self.cur_file_name, self.old_file_name) - return FileEdit( - file_path=self.old_file_name, rename_file_path=self.cur_file_name - ) - - -@attr.define() -class CreationAction: - cur_file_name: Path = attr.field() - - def undo(self) -> FileEdit: - if not self.cur_file_name.exists(): - raise HistoryError( - f"File {self.cur_file_name} does not exist; unable to delete" - ) - else: - self.cur_file_name.unlink() - return FileEdit(file_path=self.cur_file_name, is_creation=True) - - -@attr.define() -class DeletionAction: - old_file_name: Path = attr.field() - old_file_lines: list[str] = attr.field() - - def undo(self) -> FileEdit: - if self.old_file_name.exists(): - raise HistoryError( - f"File {self.old_file_name} already exists; unable to re-create" - ) - else: - with open(self.old_file_name, "w") as f: - f.write("\n".join(self.old_file_lines)) - return FileEdit(file_path=self.old_file_name, is_deletion=True) - - -@attr.define() -class EditAction: - cur_file_name: Path = attr.field() - old_file_lines: list[str] = attr.field() - - def undo(self) -> FileEdit: - if not self.cur_file_name.exists(): - raise HistoryError( - f"File {self.cur_file_name} does not exist; unable to undo edit" - ) - else: - new_file_lines = self.cur_file_name.read_text().split("\n") - with open(self.cur_file_name, "w") as f: - f.write("\n".join(self.old_file_lines)) - return FileEdit( - file_path=self.cur_file_name, - replacements=[ - Replacement( - starting_line=0, - ending_line=len(self.old_file_lines), - new_lines=new_file_lines, - ) - ], - ) - - -HistoryAction = RenameAction | CreationAction | DeletionAction | EditAction - - # TODO: Keep track of when we create directories so we can undo those as well class EditHistory: def __init__(self): - self.edits = list[list[HistoryAction]]() - self.cur_edit = list[HistoryAction]() + self.edits = list[list[FileEdit]]() + self.cur_edits = list[FileEdit]() self.undone_edits = list[list[FileEdit]]() - def add_action(self, history_action: HistoryAction): - self.cur_edit.append(history_action) + def add_edit(self, file_edit: FileEdit): + self.cur_edits.append(file_edit) def push_edits(self): - if self.cur_edit: - self.edits.append(self.cur_edit) - self.cur_edit = list[HistoryAction]() + if self.cur_edits: + self.edits.append(self.cur_edits) + self.cur_edits = list[FileEdit]() def undo(self) -> str: if not self.edits: @@ -112,10 +31,10 @@ def undo(self) -> str: errors = list[str]() undone_edit = list[FileEdit]() while cur_edit: - cur_action = cur_edit.pop() + cur_file_edit = cur_edit.pop() try: - redo_edit = cur_action.undo() - undone_edit.append(redo_edit) + cur_file_edit.undo() + undone_edit.append(cur_file_edit) except HistoryError as e: errors.append(colored(str(e), color="light_red")) if undone_edit: @@ -131,6 +50,8 @@ async def redo(self) -> Optional[str]: edits_to_redo = self.undone_edits.pop() edits_to_redo.reverse() + for edit in edits_to_redo: + edit.display_full_edit(code_file_manager.file_lines[edit.file_path]) await code_file_manager.write_changes_to_files(edits_to_redo) def undo_all(self) -> str: diff --git a/mentat/parsers/change_display_helper.py b/mentat/parsers/change_display_helper.py index 2397da17f..10db7f44d 100644 --- a/mentat/parsers/change_display_helper.py +++ b/mentat/parsers/change_display_helper.py @@ -9,6 +9,9 @@ from pygments.util import ClassNotFound from termcolor import colored +from mentat.session_context import SESSION_CONTEXT +from mentat.utils import get_relative_path + change_delimiter = 60 * "=" @@ -58,9 +61,16 @@ class DisplayInformation: new_name: Path | None = attr.field(default=None) def __attrs_post_init__(self): + ctx = SESSION_CONTEXT.get() + self.line_number_buffer = get_line_number_buffer(self.file_lines) self.lexer = _get_lexer(self.file_name) + if self.file_name.is_absolute(): + self.file_name = get_relative_path(self.file_name, ctx.cwd) + if self.new_name is not None and self.new_name.is_absolute(): + self.new_name = get_relative_path(self.new_name, ctx.cwd) + def _remove_extra_empty_lines(lines: list[str]) -> list[str]: if not lines: @@ -101,12 +111,12 @@ def _get_code_block( ): lines = _prefixed_lines(line_number_buffer, code_lines, prefix) if lines: - return colored(lines, color=color) + return "\n".join(colored(line, color=color) for line in lines.split("\n")) else: return "" -def get_full_change(display_information: DisplayInformation): +def get_full_change(display_information: DisplayInformation, prefix: str = ""): to_print = [ get_file_name(display_information), ( @@ -125,7 +135,10 @@ def get_full_change(display_information: DisplayInformation): ), ] full_change = "\n".join([line for line in to_print if line]) - return full_change + prefixed_change = "\n".join( + (prefix + line) if line.strip() else line for line in full_change.split("\n") + ) + return prefixed_change def get_file_name( @@ -133,19 +146,23 @@ def get_file_name( ): match display_information.file_action_type: case FileActionType.CreateFile: - return colored(f"\n{display_information.file_name}*", color="light_green") + return "\n" + colored( + f"{display_information.file_name}*", color="light_green" + ) case FileActionType.DeleteFile: - return colored( - f"\nDeletion: {display_information.file_name}", color="light_red" + return "\n" + colored( + f"Deletion: {display_information.file_name}", color="light_red" ) case FileActionType.RenameFile: - return colored( - f"\nRename: {display_information.file_name} ->" + return "\n" + colored( + f"Rename: {display_information.file_name} ->" f" {display_information.new_name}", color="yellow", ) case FileActionType.UpdateFile: - return colored(f"\n{display_information.file_name}", color="light_blue") + return "\n" + colored( + f"{display_information.file_name}", color="light_blue" + ) def get_added_lines( diff --git a/mentat/parsers/file_edit.py b/mentat/parsers/file_edit.py index 3b190a078..9e7e2d333 100644 --- a/mentat/parsers/file_edit.py +++ b/mentat/parsers/file_edit.py @@ -5,7 +5,7 @@ import attr -from mentat.errors import MentatError +from mentat.errors import HistoryError, MentatError from mentat.parsers.change_display_helper import ( DisplayInformation, FileActionType, @@ -14,6 +14,7 @@ ) from mentat.session_context import SESSION_CONTEXT from mentat.session_input import ask_yes_no +from mentat.utils import get_relative_path # TODO: Add 'owner' to Replacement so that interactive mode can accept/reject multiple replacements at once @@ -39,12 +40,11 @@ def __lt__(self, other: Replacement): async def _ask_user_change( - display_information: DisplayInformation, text: str, ) -> bool: session_context = SESSION_CONTEXT.get() stream = session_context.stream - stream.send(get_full_change(display_information)) + stream.send(text, color="light_blue") return await ask_yes_no(default_yes=True) @@ -64,22 +64,91 @@ class FileEdit: # Should be abs path rename_file_path: Path | None = attr.field(default=None) + # Used for undo + previous_file_lines: list[str] | None = attr.field(default=None) + @file_path.validator # pyright: ignore def is_abs_path(self, attribute: attr.Attribute[Path], value: Any): if not isinstance(value, Path): - raise ValueError(f"file_path must be a Path, got {type(value)}") + raise ValueError(f"File_path must be a Path, got {type(value)}") if not value.is_absolute(): - raise ValueError(f"file_path must be an absolute path, got {value}") + raise ValueError(f"File_path must be an absolute path, got {value}") + + def _display_creation(self, prefix: str = ""): + ctx = SESSION_CONTEXT.get() + + added_lines = list[str]() + for replacement in self.replacements: + added_lines.extend(replacement.new_lines) + display_information = DisplayInformation( + self.file_path, [], added_lines, [], FileActionType.CreateFile + ) + ctx.stream.send(get_full_change(display_information, prefix=prefix)) + + def _display_deletion(self, file_lines: list[str], prefix: str = ""): + ctx = SESSION_CONTEXT.get() + + display_information = DisplayInformation( + self.file_path, + [], + [], + file_lines, + FileActionType.DeleteFile, + ) + ctx.stream.send(get_full_change(display_information, prefix=prefix)) + + def _display_rename(self, prefix: str = ""): + ctx = SESSION_CONTEXT.get() + + display_information = DisplayInformation( + self.file_path, + [], + [], + [], + FileActionType.RenameFile, + new_name=self.rename_file_path, + ) + ctx.stream.send(get_full_change(display_information, prefix=prefix)) + + def _display_replacement( + self, replacement: Replacement, file_lines: list[str], prefix: str = "" + ): + ctx = SESSION_CONTEXT.get() + + removed_block = file_lines[replacement.starting_line : replacement.ending_line] + display_information = DisplayInformation( + self.file_path, + file_lines, + replacement.new_lines, + removed_block, + FileActionType.UpdateFile, + replacement.starting_line, + replacement.ending_line, + self.rename_file_path, + ) + ctx.stream.send(get_full_change(display_information, prefix=prefix)) + + def _display_replacements(self, file_lines: list[str], prefix: str = ""): + for replacement in self.replacements: + self._display_replacement(replacement, file_lines, prefix=prefix) + + def display_full_edit(self, file_lines: list[str], prefix: str = ""): + """Displays the full edit as if it were altering a file with the lines given""" + if self.is_deletion: + self._display_deletion(file_lines, prefix=prefix) + if self.rename_file_path: + self._display_rename(prefix=prefix) + if self.is_creation: + self._display_creation(prefix=prefix) + else: + self._display_replacements(file_lines, prefix=prefix) def is_valid(self) -> bool: session_context = SESSION_CONTEXT.get() stream = session_context.stream code_context = session_context.code_context - if self.file_path.is_relative_to(session_context.cwd): - display_path = self.file_path.relative_to(session_context.cwd) - else: - display_path = self.file_path + display_path = get_relative_path(self.file_path, session_context.cwd) if self.is_creation: if self.file_path.exists(): @@ -128,52 +197,30 @@ async def filter_replacements( code_file_manager = session_context.code_file_manager if self.is_creation: - display_information = DisplayInformation( - self.file_path, [], [], [], FileActionType.CreateFile - ) - if not await _ask_user_change(display_information, "Create this file?"): + self._display_creation() + if not await _ask_user_change("Create this file?"): return False file_lines = [] else: file_lines = code_file_manager.file_lines[self.file_path] if self.is_deletion: - display_information = DisplayInformation( - self.file_path, [], [], file_lines, FileActionType.DeleteFile - ) - if not await _ask_user_change(display_information, "Delete this file?"): + self._display_deletion(file_lines) + if not await _ask_user_change("Delete this file?"): return False if self.rename_file_path is not None: - display_information = DisplayInformation( - self.file_path, - [], - [], - [], - FileActionType.RenameFile, - new_name=self.rename_file_path, - ) - if not await _ask_user_change(display_information, "Rename this file?"): + self._display_rename() + if not await _ask_user_change("Rename this file?"): self.rename_file_path = None - new_replacements = list[Replacement]() - for replacement in self.replacements: - removed_block = file_lines[ - replacement.starting_line : replacement.ending_line - ] - display_information = DisplayInformation( - self.file_path, - file_lines, - replacement.new_lines, - removed_block, - FileActionType.UpdateFile, - replacement.starting_line, - replacement.ending_line, - self.rename_file_path, - ) - if await _ask_user_change(display_information, "Keep this change?"): - new_replacements.append(replacement) - self.replacements = new_replacements + if not self.is_creation: + new_replacements = list[Replacement]() + for replacement in self.replacements: + self._display_replacement(replacement, file_lines) + if await _ask_user_change("Keep this change?"): + new_replacements.append(replacement) + self.replacements = new_replacements return ( self.is_creation @@ -191,7 +238,7 @@ def _print_resolution(self, first: Replacement, second: Replacement): stream.send(change_delimiter) for line in first.new_lines + second.new_lines: stream.send("+ " + line, color="green") - stream.send("") + stream.send(change_delimiter) def resolve_conflicts(self): self.replacements.sort(reverse=True) @@ -230,3 +277,75 @@ def get_updated_file_lines(self, file_lines: list[str]): + file_lines[replacement.ending_line :] ) return file_lines + + def undo(self): + ctx = SESSION_CONTEXT.get() + + prefix = "UNDO: " + + if self.is_creation: + if not self.file_path.exists(): + raise HistoryError( + f"File {self.file_path} does not exist; unable to delete" + ) + ctx.code_file_manager.delete_file(self.file_path) + + self._display_creation(prefix=prefix) + ctx.stream.send( + f"Creation of file {self.file_path} undone", color="light_blue" + ) + return + + if self.rename_file_path is not None: + if self.file_path.exists(): + raise HistoryError( + f"File {self.file_path} already exists; unable to undo rename to" + f" {self.rename_file_path}" + ) + if not self.rename_file_path.exists(): + raise HistoryError( + f"File {self.rename_file_path} does not exist; unable to undo" + f" rename from {self.file_path}" + ) + ctx.code_file_manager.rename_file(self.rename_file_path, self.file_path) + + self._display_rename(prefix=prefix) + ctx.stream.send( + f"Rename of file {self.file_path} to {self.rename_file_path} undone", + color="light_blue", + ) + + if self.is_deletion: + if self.file_path.exists(): + raise HistoryError( + f"File {self.file_path} already exists; unable to re-create" + ) + if not self.previous_file_lines: + # Should never happen + raise ValueError( + "Previous file lines not set when undoing file deletion" + ) + ctx.code_file_manager.create_file( + self.file_path, content="\n".join(self.previous_file_lines) + ) + + self._display_deletion(self.previous_file_lines, prefix=prefix) + ctx.stream.send( + f"Deletion of file {self.file_path} undone", color="light_red" + ) + elif self.replacements: + if not self.file_path.exists(): + raise HistoryError( + f"File {self.file_path} does not exist; unable to undo edit" + ) + if not self.previous_file_lines: + # Should never happen + raise ValueError("Previous file lines not set when undoing file edit") + + with open(self.file_path, "w") as f: + f.write("\n".join(self.previous_file_lines)) + + self._display_replacements(self.previous_file_lines, prefix=prefix) + ctx.stream.send( + f"Edits to file {self.file_path} undone", color="light_blue" + )