diff --git a/mentat/code_change_display.py b/mentat/code_change_display.py index 9ca9fdf54..fe85331d2 100644 --- a/mentat/code_change_display.py +++ b/mentat/code_change_display.py @@ -65,7 +65,9 @@ def get_file_name(code_change): case CodeChangeAction.DeleteFile: return colored(f"\n{file_name}", color="light_red") case CodeChangeAction.RenameFile: - return colored(f"\nRename: {file_name} -> {code_change.name}", color="yellow") + return colored( + f"\nRename: {file_name} -> {code_change.name}", color="yellow" + ) case _: return colored(f"\n{file_name}", color="light_blue") diff --git a/mentat/code_file_manager.py b/mentat/code_file_manager.py index 3d743f273..b2ef27980 100644 --- a/mentat/code_file_manager.py +++ b/mentat/code_file_manager.py @@ -186,11 +186,11 @@ def get_code_message(self): return "\n".join(code_message) - def _add_file(self, rel_path): - logging.info(f"Adding new file {rel_path} to context") - self.file_paths.append(rel_path) + def _add_file(self, abs_path): + logging.info(f"Adding new file {abs_path} to context") + self.file_paths.append(abs_path) # create any missing directories in the path - rel_path.parent.mkdir(parents=True, exist_ok=True) + abs_path.parent.mkdir(parents=True, exist_ok=True) def _delete_file(self, abs_path: Path): logging.info(f"Deleting file {abs_path}") @@ -210,7 +210,10 @@ def _handle_delete(self, delete_change): else: cprint(f"Not deleting {delete_change.file}") - def _get_new_code_lines(self, changes) -> Iterable[str]: + def _get_new_code_lines(self, rel_path, changes) -> Iterable[str]: + if not changes: + return [] + if len(set(map(lambda change: change.file, changes))) > 1: raise Exception("All changes passed in must be for the same file") @@ -253,32 +256,39 @@ def _get_new_code_lines(self, changes) -> Iterable[str]: return new_code_lines def write_changes_to_files(self, code_changes: list[CodeChange]) -> None: - files_to_write = dict() file_changes = defaultdict(list) for code_change in code_changes: # here keys are str not path object rel_path = str(code_change.file) - if code_change.action == CodeChangeAction.CreateFile: - cprint(f"Creating new file {rel_path}", color="light_green") - files_to_write[rel_path] = code_change.code_lines - elif code_change.action == CodeChangeAction.DeleteFile: - self._handle_delete(code_change) - elif code_change.action == CodeChangeAction.RenameFile: - abs_path = os.path.join(self.git_root, rel_path) - code_lines = self.file_lines[abs_path] - files_to_write[code_change.name] = code_lines - self._delete_file(abs_path) - else: - file_changes[rel_path].append(code_change) - - for file_path, changes in file_changes.items(): - new_code_lines = self._get_new_code_lines(changes) + abs_path = os.path.join(self.git_root, rel_path) + match code_change.action: + case CodeChangeAction.CreateFile: + cprint(f"Creating new file {rel_path}", color="light_green") + self._add_file(abs_path) + with open(abs_path, "w") as f: + f.write("\n".join(code_change.code_lines)) + case CodeChangeAction.DeleteFile: + self._handle_delete(code_change) + case CodeChangeAction.RenameFile: + abs_new_path = os.path.join(self.git_root, code_change.name) + self._add_file(abs_new_path) + code_lines = self.file_lines[abs_path] + with open(abs_new_path, "w") as f: + f.write("\n".join(code_lines)) + self._delete_file(abs_path) + file_changes[str(code_change.name)] += file_changes[rel_path] + file_changes[rel_path] = [] + self.file_lines[abs_new_path] = self._read_file(abs_new_path) + case _: + file_changes[rel_path].append(code_change) + + for rel_path, changes in file_changes.items(): + abs_path = os.path.join(self.git_root, rel_path) + new_code_lines = self._get_new_code_lines(rel_path, changes) if new_code_lines: - files_to_write[file_path] = new_code_lines - - for rel_path, code_lines in files_to_write.items(): - file_path = self.git_root / rel_path - if file_path not in self.file_paths: - self._add_file(rel_path) - with open(file_path, "w") as f: - f.write("\n".join(code_lines)) + if abs_path not in self.file_paths: + raise MentatError( + f"Attempted to edit file {abs_path} not in context" + ) + with open(abs_path, "w") as f: + f.write("\n".join(new_code_lines)) diff --git a/mentat/parsing.py b/mentat/parsing.py index e9539ab5e..4c906f8c1 100644 --- a/mentat/parsing.py +++ b/mentat/parsing.py @@ -10,7 +10,7 @@ import openai from termcolor import cprint -from .code_change import CodeChange +from .code_change import CodeChange, CodeChangeAction from .code_change_display import ( change_delimiter, get_file_name, @@ -277,6 +277,7 @@ def _process_content_line( len(state.code_changes) < 2 or state.code_changes[-2].file != cur_change.file or state.explained_since_change + or state.code_changes[-1].action == CodeChangeAction.RenameFile ): printer.add_string(get_file_name(cur_change)) if ( diff --git a/tests/code_change_test.py b/tests/code_change_test.py index 2f9b29091..515414980 100644 --- a/tests/code_change_test.py +++ b/tests/code_change_test.py @@ -231,6 +231,54 @@ def test_rename_file(mock_call_llm_api, mock_collect_user_input, mock_setup_api_ assert content == expected_content +def test_change_then_rename_file( + mock_call_llm_api, mock_collect_user_input, mock_setup_api_key +): + # Make sure a change made before a rename works + temp_file_name = "temp.py" + temp_2_file_name = "temp_2.py" + with open(temp_file_name, "w") as f: + f.write("# Move me!") + + mock_collect_user_input.side_effect = [ + "Insert a comment then rename the file temp_2.py", + "y", + KeyboardInterrupt, + ] + mock_call_llm_api.set_generator_values([dedent(f"""\ + I will insert a comment then rename the file + + Steps: + 1. insert a comment + 2. rename the file + + @@start + {{ + "file": "{temp_file_name}", + "action": "insert", + "insert-after-line": 0, + "insert-before-line": 1 + }} + @@code + # I inserted this comment! + @@end + @@start + {{ + "file": "{temp_file_name}", + "action": "rename-file", + "name": "{temp_2_file_name}" + }} + @@end""")]) + + run([temp_file_name]) + with open(temp_2_file_name) as new_file: + content = new_file.read() + expected_content = "# I inserted this comment!\n# Move me!" + + assert not os.path.exists(temp_file_name) + assert content == expected_content + + def test_multiple_blocks( mock_call_llm_api, mock_collect_user_input, mock_setup_api_key ):