Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rename file action #46

Merged
merged 8 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion mentat/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from termcolor import cprint

from .code_change import CodeChange
from .code_change import CodeChange, CodeChangeAction
from .code_change_display import print_change
from .code_context import CodeContext
from .code_file import parse_intervals
Expand Down Expand Up @@ -205,8 +205,16 @@ def user_filter_changes(
indices = []
for index, change in enumerate(code_changes, start=1):
print_change(change)
# Allowing the user to remove rename file changes introduces a lot of edge cases
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What edge cases have you discovered/thought about so far?

One that comes to mind is a file being renamed and imports not being updated accordingly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking more edge cases in how we handle it; the big one being that if we don't rename the file, then any changes the model suggests that try to change the renamed file will be trying to change a non-existent file. It gets worse if the model is renaming multiple files to similar names, or one file twice, and so on.

if change.action == CodeChangeAction.RenameFile:
new_changes.append(change)
indices.append(index)
cprint("Cannot remove rename file change", "light_yellow")
continue

cprint("Keep this change?", "light_blue")
if user_input_manager.ask_yes_no(default_yes=True):
new_changes.append(change)
indices.append(index)

return new_changes, indices
44 changes: 30 additions & 14 deletions mentat/code_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@ class CodeChangeAction(Enum):
Delete = "delete"
CreateFile = "create-file"
DeleteFile = "delete-file"
RenameFile = "rename-file"

def has_surrounding_lines(self):
return (
self != CodeChangeAction.CreateFile and self != CodeChangeAction.DeleteFile
self != CodeChangeAction.CreateFile
and self != CodeChangeAction.DeleteFile
and self != CodeChangeAction.RenameFile
)

def has_removals(self):
Expand All @@ -40,6 +43,7 @@ def __init__(
json_data: dict,
code_lines: list[str],
code_file_manager,
rename_map: dict = {},
):
self.json_data = json_data
# Sometimes GPT puts quotes around numbers, so we have to convert those
Expand Down Expand Up @@ -99,6 +103,10 @@ def __init__(
case CodeChangeAction.Delete:
self.first_changed_line = self.json_data["start-line"]
self.last_changed_line = self.json_data["end-line"]

case CodeChangeAction.RenameFile:
self.name = Path(self.json_data["name"])

except KeyError:
self.error = "Line numbers not given"

Expand All @@ -109,24 +117,32 @@ def __init__(
):
self.error = "Starting line of change is greater than ending line of change"

if self.action != CodeChangeAction.CreateFile:
rel_path = str(self.file)
try:
self.file_lines = code_file_manager.file_lines[rel_path]
self.line_number_buffer = len(str(len(self.file_lines) + 1)) + 1
except KeyError:
self.error = (
f"Model attempted to edit {rel_path}, which isn't in"
" current context or doesn't exist"
)
else:
if self.action == CodeChangeAction.CreateFile:
if self.file.exists():
self.error = (
f"Model attempted to create file that already exists: {self.file}"
)

self.file_lines = []
self.line_number_buffer = 2
else:
if self.action == CodeChangeAction.RenameFile and self.name.exists():
self.error = (
f"Model attempted to rename file {self.file} to a file that"
f" already exists: {self.name}"
)

# This rename_map is a bit hacky; it shouldn't be used outside of streaming/parsing
rel_path = str(
self.file if self.file not in rename_map else rename_map[self.file]
)
try:
self.file_lines = code_file_manager.file_lines[rel_path]
self.line_number_buffer = len(str(len(self.file_lines) + 1)) + 1
except KeyError:
self.error = (
f"Model attempted to edit {rel_path}, which isn't in current"
" context or doesn't exist"
)

def __lt__(self, other):
return self.last_changed_line < other.last_changed_line
Expand All @@ -148,7 +164,7 @@ def apply(self, cur_file_lines: list[str]) -> list[str]:
following_lines = cur_file_lines[self.last_changed_line :]
new_file_lines = previous_lines + following_lines

case CodeChangeAction.CreateFile | CodeChangeAction.DeleteFile:
case CodeChangeAction.CreateFile | CodeChangeAction.DeleteFile | CodeChangeAction.RenameFile:
raise Exception(
f"CodeChange with action={self.action} shouldn't have apply called"
)
Expand Down
25 changes: 13 additions & 12 deletions mentat/code_change_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ def _prefixed_lines(code_change, lines, prefix):
def print_change(code_change):
to_print = [
get_file_name(code_change),
change_delimiter,
change_delimiter if code_change.action != CodeChangeAction.RenameFile else "",
get_previous_lines(code_change),
get_removed_block(code_change),
get_added_block(code_change),
get_later_lines(code_change),
change_delimiter,
change_delimiter if code_change.action != CodeChangeAction.RenameFile else "",
]
for s in to_print:
if s:
Expand All @@ -59,16 +59,17 @@ def print_change(code_change):

def get_file_name(code_change):
file_name = code_change.file
action = code_change.action
color = (
"light_red"
if action == CodeChangeAction.DeleteFile
else ("light_green" if action == CodeChangeAction.CreateFile else "light_blue")
)
return colored(
f"\n{file_name}{'*' if action == CodeChangeAction.CreateFile else ''}",
color=color,
)
match code_change.action:
case CodeChangeAction.CreateFile:
return colored(f"\n{file_name}*", color="light_green")
case CodeChangeAction.DeleteFile:
return colored(f"\n{file_name}", color="light_red")
case CodeChangeAction.RenameFile:
return colored(
f"\nRename: {file_name} -> {code_change.name}", color="yellow"
)
case _:
return colored(f"\n{file_name}", color="light_blue")


def get_removed_block(code_change, prefix="-", color="red"):
Expand Down
3 changes: 1 addition & 2 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _get_files(
files = {}
for file in files_direct:
if file.path not in excluded_files | excluded_files_from_dir:
files[file.path] = file
files[Path(os.path.realpath(file.path))] = file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the thinking for using realpath here? Are there issues with using symbolic links?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or did you mean to use relpath instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the real reason was because we were actually just putting a relative Path in there (when it was supposed to be an absolute path!); but yeah, the reason I use realpath instead of abspath is because of symlinks


return files

Expand Down Expand Up @@ -135,7 +135,6 @@ def __init__(
exclude_paths: Iterable[str],
):
self.config = config

self.files = _get_files(self.config, paths, exclude_paths)

def display_context(self):
Expand Down
84 changes: 52 additions & 32 deletions mentat/code_file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,34 @@ def get_code_message(self):

return "\n".join(code_message)

def _add_file(self, abs_path):
logging.info(f"Adding new file {abs_path} to context")
self.code_context.files[abs_path] = CodeFile(abs_path)
# create any missing directories in the path
abs_path.parent.mkdir(parents=True, exist_ok=True)

def _delete_file(self, abs_path: Path):
logging.info(f"Deleting file {abs_path}")
if abs_path in self.code_context.files:
del self.code_context.files[abs_path]
abs_path.unlink()

def _handle_delete(self, delete_change):
file_path = self.config.git_root / delete_change.file
if not file_path.exists():
logging.error(f"Path {file_path} non-existent on delete")
abs_path = self.config.git_root / delete_change.file
if not abs_path.exists():
logging.error(f"Path {abs_path} non-existent on delete")
return
Comment on lines +88 to 91
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using self.config.git_root / delete_change.file we can use os.path.abspath(delete_change.file) to handle deletes for files outside of a repo. I know we already have this behavior in the main branch but I just noticed it lol. Pretty sure we're assuming all files are under a git repo across Mentat so it might be better to address this separately.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually necessary; abspath appends the CWD to the beginning of the path, which might not necessarily be the current git repo (it could be a subdirectory). This does remind me though, I don't think we have any tests where we run mentat from a subdiretory of the git repo it's actually in; we should probably think about adding some of those!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have any tests where we run mentat from a subdiretory of the git repo it's actually in; we should probably think about adding some of those!

@PCSwingle can you add an issue for this


cprint(f"Are you sure you want to delete {delete_change.file}?", "red")
if self.user_input_manager.ask_yes_no(default_yes=False):
logging.info(f"Deleting file {file_path}")
cprint(f"Deleting {delete_change.file}...")
if file_path in self.code_context.files:
del self.code_context.files[file_path]
file_path.unlink()
self._delete_file(abs_path)
else:
cprint(f"Not deleting {delete_change.file}")

def _get_new_code_lines(self, changes) -> Iterable[str] | None:
def _get_new_code_lines(self, rel_path, changes) -> Iterable[str]:
if not changes:
return []
if len(set(map(lambda change: change.file, changes))) > 1:
raise Exception("All changes passed in must be for the same file")

Expand All @@ -103,7 +114,6 @@ def _get_new_code_lines(self, changes) -> Iterable[str] | None:
if not changes:
return []

rel_path = str(changes[0].file)
new_code_lines = self.file_lines[rel_path].copy()
if new_code_lines != self._read_file(rel_path):
logging.info(f"File '{rel_path}' changed while generating changes")
Expand Down Expand Up @@ -131,31 +141,41 @@ def _get_new_code_lines(self, changes) -> Iterable[str] | None:
return new_code_lines

def write_changes_to_files(self, code_changes: list[CodeChange]) -> None:
files_to_write = dict()
file_changes = defaultdict(list)
for code_change in code_changes:
# here keys are str not path object
rel_path = str(code_change.file)
if code_change.action == CodeChangeAction.CreateFile:
cprint(f"Creating new file {rel_path}", color="light_green")
files_to_write[rel_path] = code_change.code_lines
elif code_change.action == CodeChangeAction.DeleteFile:
self._handle_delete(code_change)
else:
file_changes[rel_path].append(code_change)

for file_path, changes in file_changes.items():
new_code_lines = self._get_new_code_lines(changes)
abs_path = self.config.git_root / rel_path
match code_change.action:
case CodeChangeAction.CreateFile:
cprint(f"Creating new file {rel_path}", color="light_green")
self._add_file(abs_path)
with open(abs_path, "w") as f:
f.write("\n".join(code_change.code_lines))
case CodeChangeAction.DeleteFile:
self._handle_delete(code_change)
case CodeChangeAction.RenameFile:
abs_new_path = self.config.git_root / code_change.name
self._add_file(abs_new_path)
code_lines = self.file_lines[rel_path]
with open(abs_new_path, "w") as f:
f.write("\n".join(code_lines))
self._delete_file(abs_path)
file_changes[str(code_change.name)] += file_changes[rel_path]
file_changes[rel_path] = []
self.file_lines[str(code_change.name)] = self._read_file(
abs_new_path
)
case _:
file_changes[rel_path].append(code_change)

for rel_path, changes in file_changes.items():
abs_path = self.config.git_root / rel_path
new_code_lines = self._get_new_code_lines(rel_path, changes)
if new_code_lines:
files_to_write[file_path] = new_code_lines

for rel_path, code_lines in files_to_write.items():
file_path = self.config.git_root / rel_path
if file_path not in self.code_context.files:
# newly created files added to Mentat's context
logging.info(f"Adding new file {file_path} to context")
self.code_context.files[file_path] = CodeFile(file_path)
# create any missing directories in the path
file_path.parent.mkdir(parents=True, exist_ok=True)
with open(file_path, "w") as f:
f.write("\n".join(code_lines))
if abs_path not in self.code_context.files:
raise MentatError(
f"Attempted to edit file {abs_path} not in context"
)
with open(abs_path, "w") as f:
f.write("\n".join(new_code_lines))
23 changes: 19 additions & 4 deletions mentat/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import logging
from enum import Enum
from json import JSONDecodeError
from pathlib import Path
from timeit import default_timer
from typing import Generator

import attr
import openai
from termcolor import cprint

from .code_change import CodeChange
from .code_change import CodeChange, CodeChangeAction
from .code_change_display import (
change_delimiter,
get_file_name,
Expand Down Expand Up @@ -43,6 +44,7 @@ class ParsingState:
code_changes: list[CodeChange] = attr.ib(factory=list)
json_lines: list[str] = attr.ib(factory=list)
code_lines: list[str] = attr.ib(factory=list)
rename_map: dict[Path, Path] = attr.ib(factory=dict)

def parse_line_printing(self, content):
to_print = ""
Expand Down Expand Up @@ -75,9 +77,14 @@ def create_code_change(self, code_file_manager: CodeFileManager):
already_added_to_changelist=False,
)

new_change = CodeChange(json_data, self.code_lines, code_file_manager)
new_change = CodeChange(
json_data, self.code_lines, code_file_manager, self.rename_map
)
self.code_changes.append(new_change)
self.json_lines, self.code_lines = [], []
if new_change.action == CodeChangeAction.RenameFile:
# This rename_map is a bit hacky; it shouldn't be used outside of streaming/parsing
self.rename_map[new_change.name] = new_change.file

def new_line(self, code_file_manager: CodeFileManager):
to_print = ""
Expand Down Expand Up @@ -277,13 +284,21 @@ def _process_content_line(
len(state.code_changes) < 2
or state.code_changes[-2].file != cur_change.file
or state.explained_since_change
or state.code_changes[-1].action == CodeChangeAction.RenameFile
):
printer.add_string(get_file_name(cur_change))
printer.add_string(change_delimiter)
if (
cur_change.action.has_additions()
or cur_change.action.has_removals()
):
printer.add_string(change_delimiter)
state.explained_since_change = False
printer.add_string(get_previous_lines(cur_change))
printer.add_string(get_removed_block(cur_change))
if not cur_change.action.has_additions():
if (
not cur_change.action.has_additions()
and cur_change.action.has_removals()
):
printer.add_string(get_later_lines(cur_change))
printer.add_string(change_delimiter)

Expand Down
Loading