Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
refactor write_code_changes function, fix rename_file bug on windows,…
Browse files Browse the repository at this point in the history
… add new test, fix bug with change and rename on same file
  • Loading branch information
PCSwingle committed Aug 26, 2023
1 parent 5b435d2 commit ad0701f
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 31 deletions.
4 changes: 3 additions & 1 deletion mentat/code_change_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def get_file_name(code_change):
case CodeChangeAction.DeleteFile:
return colored(f"\n{file_name}", color="light_red")
case CodeChangeAction.RenameFile:
return colored(f"\nRename: {file_name} -> {code_change.name}", color="yellow")
return colored(
f"\nRename: {file_name} -> {code_change.name}", color="yellow"
)
case _:
return colored(f"\n{file_name}", color="light_blue")

Expand Down
68 changes: 39 additions & 29 deletions mentat/code_file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ def get_code_message(self):

return "\n".join(code_message)

def _add_file(self, rel_path):
logging.info(f"Adding new file {rel_path} to context")
self.file_paths.append(rel_path)
def _add_file(self, abs_path):
logging.info(f"Adding new file {abs_path} to context")
self.file_paths.append(abs_path)
# create any missing directories in the path
rel_path.parent.mkdir(parents=True, exist_ok=True)
abs_path.parent.mkdir(parents=True, exist_ok=True)

def _delete_file(self, abs_path: Path):
logging.info(f"Deleting file {abs_path}")
Expand All @@ -210,7 +210,10 @@ def _handle_delete(self, delete_change):
else:
cprint(f"Not deleting {delete_change.file}")

def _get_new_code_lines(self, changes) -> Iterable[str]:
def _get_new_code_lines(self, rel_path, changes) -> Iterable[str]:
if not changes:
return []

if len(set(map(lambda change: change.file, changes))) > 1:
raise Exception("All changes passed in must be for the same file")

Expand Down Expand Up @@ -253,32 +256,39 @@ def _get_new_code_lines(self, changes) -> Iterable[str]:
return new_code_lines

def write_changes_to_files(self, code_changes: list[CodeChange]) -> None:
files_to_write = dict()
file_changes = defaultdict(list)
for code_change in code_changes:
# here keys are str not path object
rel_path = str(code_change.file)
if code_change.action == CodeChangeAction.CreateFile:
cprint(f"Creating new file {rel_path}", color="light_green")
files_to_write[rel_path] = code_change.code_lines
elif code_change.action == CodeChangeAction.DeleteFile:
self._handle_delete(code_change)
elif code_change.action == CodeChangeAction.RenameFile:
abs_path = os.path.join(self.git_root, rel_path)
code_lines = self.file_lines[abs_path]
files_to_write[code_change.name] = code_lines
self._delete_file(abs_path)
else:
file_changes[rel_path].append(code_change)

for file_path, changes in file_changes.items():
new_code_lines = self._get_new_code_lines(changes)
abs_path = os.path.join(self.git_root, rel_path)
match code_change.action:
case CodeChangeAction.CreateFile:
cprint(f"Creating new file {rel_path}", color="light_green")
self._add_file(abs_path)
with open(abs_path, "w") as f:
f.write("\n".join(code_change.code_lines))
case CodeChangeAction.DeleteFile:
self._handle_delete(code_change)
case CodeChangeAction.RenameFile:
abs_new_path = os.path.join(self.git_root, code_change.name)
self._add_file(abs_new_path)
code_lines = self.file_lines[abs_path]
with open(abs_new_path, "w") as f:
f.write("\n".join(code_lines))
self._delete_file(abs_path)
file_changes[str(code_change.name)] += file_changes[rel_path]
file_changes[rel_path] = []
self.file_lines[abs_new_path] = self._read_file(abs_new_path)
case _:
file_changes[rel_path].append(code_change)

for rel_path, changes in file_changes.items():
abs_path = os.path.join(self.git_root, rel_path)
new_code_lines = self._get_new_code_lines(rel_path, changes)
if new_code_lines:
files_to_write[file_path] = new_code_lines

for rel_path, code_lines in files_to_write.items():
file_path = self.git_root / rel_path
if file_path not in self.file_paths:
self._add_file(rel_path)
with open(file_path, "w") as f:
f.write("\n".join(code_lines))
if abs_path not in self.file_paths:
raise MentatError(
f"Attempted to edit file {abs_path} not in context"
)
with open(abs_path, "w") as f:
f.write("\n".join(new_code_lines))
3 changes: 2 additions & 1 deletion mentat/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import openai
from termcolor import cprint

from .code_change import CodeChange
from .code_change import CodeChange, CodeChangeAction
from .code_change_display import (
change_delimiter,
get_file_name,
Expand Down Expand Up @@ -277,6 +277,7 @@ def _process_content_line(
len(state.code_changes) < 2
or state.code_changes[-2].file != cur_change.file
or state.explained_since_change
or state.code_changes[-1].action == CodeChangeAction.RenameFile
):
printer.add_string(get_file_name(cur_change))
if (
Expand Down
48 changes: 48 additions & 0 deletions tests/code_change_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,54 @@ def test_rename_file(mock_call_llm_api, mock_collect_user_input, mock_setup_api_
assert content == expected_content


def test_change_then_rename_file(
mock_call_llm_api, mock_collect_user_input, mock_setup_api_key
):
# Make sure a change made before a rename works
temp_file_name = "temp.py"
temp_2_file_name = "temp_2.py"
with open(temp_file_name, "w") as f:
f.write("# Move me!")

mock_collect_user_input.side_effect = [
"Insert a comment then rename the file temp_2.py",
"y",
KeyboardInterrupt,
]
mock_call_llm_api.set_generator_values([dedent(f"""\
I will insert a comment then rename the file
Steps:
1. insert a comment
2. rename the file
@@start
{{
"file": "{temp_file_name}",
"action": "insert",
"insert-after-line": 0,
"insert-before-line": 1
}}
@@code
# I inserted this comment!
@@end
@@start
{{
"file": "{temp_file_name}",
"action": "rename-file",
"name": "{temp_2_file_name}"
}}
@@end""")])

run([temp_file_name])
with open(temp_2_file_name) as new_file:
content = new_file.read()
expected_content = "# I inserted this comment!\n# Move me!"

assert not os.path.exists(temp_file_name)
assert content == expected_content


def test_multiple_blocks(
mock_call_llm_api, mock_collect_user_input, mock_setup_api_key
):
Expand Down

0 comments on commit ad0701f

Please sign in to comment.