Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unified diff format #118

Merged
merged 7 commits into from
Oct 7, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,4 @@ Available formats:
* block
* replacement
* split-diff
* unified-diff
3 changes: 2 additions & 1 deletion mentat/parsers/block_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def get_system_prompt(self) -> str:
@override
def _could_be_special(self, cur_line: str) -> bool:
return any(
to_match.value.startswith(cur_line) for to_match in _BlockParserIndicator
to_match.value.startswith(cur_line.strip())
for to_match in _BlockParserIndicator
)

@override
Expand Down
25 changes: 19 additions & 6 deletions mentat/parsers/change_display_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ class FileActionType(Enum):
UpdateFile = "update"


def get_file_action_type(is_creation: bool, is_deletion: bool, new_name: Path | None):
if is_creation:
file_action_type = FileActionType.CreateFile
elif is_deletion:
file_action_type = FileActionType.DeleteFile
elif new_name is not None:
file_action_type = FileActionType.RenameFile
else:
file_action_type = FileActionType.UpdateFile
return file_action_type


@attr.define(slots=False)
class DisplayInformation:
file_name: Path = attr.field()
Expand Down Expand Up @@ -162,6 +174,11 @@ def get_removed_lines(
)


def highlight_text(display_information: DisplayInformation, text: str) -> str:
# pygments doesn't have type hints on TerminalFormatter
return highlight(text, display_information.lexer, TerminalFormatter(bg="dark")) # type: ignore


def get_previous_lines(
display_information: DisplayInformation,
num: int = 2,
Expand Down Expand Up @@ -189,9 +206,7 @@ def get_previous_lines(
]

prev = "\n".join(numbered)
# pygments doesn't have type hints on TerminalFormatter
h_prev: str = highlight(prev, display_information.lexer, TerminalFormatter(bg="dark")) # type: ignore
return h_prev
return highlight_text(display_information, prev)


def get_later_lines(
Expand Down Expand Up @@ -221,6 +236,4 @@ def get_later_lines(
]

later = "\n".join(numbered)
# pygments doesn't have type hints on TerminalFormatter
h_later: str = highlight(later, display_information.lexer, TerminalFormatter(bg="dark")) # type: ignore
return h_later
return highlight_text(display_information, later)
28 changes: 28 additions & 0 deletions mentat/parsers/diff_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
def matching_index(orig_lines: list[str], new_lines: list[str]) -> int:
orig_lines = orig_lines.copy()
new_lines = new_lines.copy()
index = _exact_match(orig_lines, new_lines)
if index == -1:
orig_lines = [s.lower() for s in orig_lines]
new_lines = [s.lower() for s in new_lines]
index = _exact_match(orig_lines, new_lines)
if index == -1:
orig_lines = [s.strip() for s in orig_lines]
new_lines = [s.strip() for s in new_lines]
index = _exact_match(orig_lines, new_lines)
if index == -1:
new_orig_lines = [s for s in orig_lines if s]
new_new_lines = [s for s in new_lines if s]
index = _exact_match(new_orig_lines, new_new_lines)
if index != -1:
index = orig_lines.index(new_orig_lines[index])
return index


def _exact_match(orig_lines: list[str], new_lines: list[str]) -> int:
if "".join(new_lines).strip() == "" and "".join(orig_lines).strip() == "":
return 0
for i in range(len(orig_lines) - (len(new_lines) - 1)):
if orig_lines[i : i + len(new_lines)] == new_lines:
return i
return -1
36 changes: 27 additions & 9 deletions mentat/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,22 +105,26 @@ async def stream_and_parse_llm_response(
# Print if not in special lines and line is confirmed not special
if not in_special_lines:
if not line_printed:
if not self._could_be_special(cur_line.strip()):
if not self._could_be_special(cur_line):
line_printed = True
to_print = (
cur_line
if not in_code_lines or display_information is None
else self._code_line_beginning(
display_information, cur_block
)
+ self._code_line_content(cur_line, cur_block)
+ self._code_line_content(
display_information, cur_line, cur_line, cur_block
)
)
printer.add_string(to_print, end="")
else:
to_print = (
content
if not in_code_lines
else self._code_line_content(content, cur_block)
if not in_code_lines or display_information is None
else self._code_line_content(
display_information, content, cur_line, cur_block
)
)
printer.add_string(to_print, end="")

Expand All @@ -131,15 +135,22 @@ async def stream_and_parse_llm_response(

# New line handling
if "\n" in cur_line:
# Always print whitespace lines (even though they 'match' could_be_special)
if not cur_line.strip() and not line_printed:
# Now that full line is in, give _could_be_special full line (including newline)
# and see if it should be printed or not
if (
not in_special_lines
and not line_printed
and not self._could_be_special(cur_line)
):
to_print = (
cur_line
if not in_code_lines or display_information is None
else self._code_line_beginning(
display_information, cur_block
)
+ self._code_line_content(cur_line, cur_block)
+ self._code_line_content(
display_information, cur_line, cur_line, cur_block
)
)
printer.add_string(to_print, end="")
line_printed = True
Expand Down Expand Up @@ -310,7 +321,13 @@ def _code_line_beginning(
"+" + " " * (display_information.line_number_buffer - 1), color="green"
)

def _code_line_content(self, content: str, cur_block: str) -> str:
def _code_line_content(
self,
display_information: DisplayInformation,
content: str,
cur_line: str,
cur_block: str,
) -> str:
"""
Part of a code line; normally this means printing in green
"""
Expand All @@ -319,7 +336,8 @@ def _code_line_content(self, content: str, cur_block: str) -> str:
@abstractmethod
def _could_be_special(self, cur_line: str) -> bool:
"""
Returns if this current line could be a special line and therefore shouldn't be printed yet
Returns if this current line could be a special line and therefore shouldn't be printed yet.
Once line is completed, will include a newline character.
"""
pass

Expand Down
53 changes: 20 additions & 33 deletions mentat/parsers/split_diff_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

from mentat.code_file_manager import CodeFileManager
from mentat.config_manager import ConfigManager
from mentat.parsers.change_display_helper import DisplayInformation, FileActionType
from mentat.parsers.change_display_helper import (
DisplayInformation,
FileActionType,
get_file_action_type,
)
from mentat.parsers.diff_utils import matching_index
from mentat.parsers.file_edit import FileEdit, Replacement
from mentat.parsers.parser import Parser
from mentat.prompts.prompts import read_prompt
Expand Down Expand Up @@ -44,7 +49,13 @@ def _code_line_beginning(
return colored("-", color="red")

@override
def _code_line_content(self, content: str, cur_block: str) -> str:
def _code_line_content(
self,
display_information: DisplayInformation,
content: str,
cur_line: str,
cur_block: str,
) -> str:
lines = cur_block.split("\n")
if SplitDiffDelimiters.Middle.value in lines:
return colored(content, color="green")
Expand Down Expand Up @@ -87,15 +98,7 @@ def _special_block(
file_name, new_name = Path(info), None

file_lines = self._get_file_lines(code_file_manager, rename_map, file_name)
if is_creation:
file_action_type = FileActionType.CreateFile
elif is_deletion:
file_action_type = FileActionType.DeleteFile
elif new_name is not None:
file_action_type = FileActionType.RenameFile
else:
file_action_type = FileActionType.UpdateFile

file_action_type = get_file_action_type(is_creation, is_deletion, new_name)
display_information = DisplayInformation(
file_name=file_name,
file_lines=file_lines,
Expand Down Expand Up @@ -136,35 +139,19 @@ def _add_code_block(
# excluding case and stripped. If we don't find one, we throw away this change.
file_lines = self._get_file_lines(
code_file_manager, rename_map, display_information.file_name
).copy()
)
# Remove the delimiters, ending fence, and new line after ending fence
lines = code_block.split("\n")[1:-3]
middle_index = lines.index(SplitDiffDelimiters.Middle.value)
removed_lines = lines[:middle_index]
added_lines = lines[middle_index + 1 :]
index = self._matching_index(file_lines, removed_lines)
index = matching_index(file_lines, removed_lines)
if index == -1:
file_lines = [s.lower() for s in file_lines]
removed_lines = [s.lower() for s in removed_lines]
index = self._matching_index(file_lines, removed_lines)
if index == -1:
file_lines = [s.strip() for s in file_lines]
removed_lines = [s.strip() for s in removed_lines]
index = self._matching_index(file_lines, removed_lines)
if index == -1:
return colored(
"Error: Original lines not found. Discarding this change.",
color="red",
)
return colored(
"Error: Original lines not found. Discarding this change.",
color="red",
)
file_edit.replacements.append(
Replacement(index, index + len(removed_lines), added_lines)
)
return ""

def _matching_index(self, orig_lines: list[str], new_lines: list[str]) -> int:
if "".join(new_lines).strip() == "" and "".join(orig_lines).strip() == "":
return 0
for i in range(len(orig_lines) - (len(new_lines) - 1)):
if orig_lines[i : i + len(new_lines)] == new_lines:
return i
return -1
Loading