Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove git root from CodeFileManager, FileEdit #274

Merged
merged 27 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
554a2a5
Fix Git subprocess calls to suppress error output
waydegg Nov 11, 2023
f261d7b
run black
waydegg Nov 11, 2023
8291594
Add cwd parameter to Session and TerminalClient
waydegg Nov 11, 2023
c4b1846
Add cwd parameter to PythonClient and
waydegg Nov 11, 2023
0857640
update tests
waydegg Nov 11, 2023
c9c3ea1
format
waydegg Nov 11, 2023
ea5fe4e
fix black check and cwd cli param
waydegg Nov 11, 2023
01e8585
sort imports for tests
waydegg Nov 11, 2023
e7f950e
Merge branch 'remove-git-root-0' into remove-git-root-1
waydegg Nov 11, 2023
6dd96bf
Remove `git_root` from Commands
waydegg Nov 11, 2023
f150291
Remove git root from CodeFileManager, FileEdit
waydegg Nov 12, 2023
38bf714
Run isort
waydegg Nov 12, 2023
293237d
Refactor file edit display paths in
waydegg Nov 14, 2023
a6522ec
Update path join style
waydegg Nov 14, 2023
1be559a
Merge main
waydegg Nov 21, 2023
dfc6f11
Run black
waydegg Nov 21, 2023
c529695
Fix merge errors
waydegg Nov 21, 2023
cd36016
Fix merge errors
waydegg Nov 21, 2023
64b5c29
Merge branch 'main' into remove-git-root-2
waydegg Nov 21, 2023
e334e71
Add cwd helper for `CodeFeature.ref`
waydegg Nov 21, 2023
38e0fd7
Merge branch 'main' into remove-git-root-2
waydegg Dec 1, 2023
518beff
merge remove-git-root-2
waydegg Dec 1, 2023
2af3d3f
Run black
waydegg Dec 1, 2023
8fa1e33
Fix file edit tests
waydegg Dec 1, 2023
f2eca86
Merge main
waydegg Dec 5, 2023
51421a4
Remove `git_root` from command modules
waydegg Dec 5, 2023
bd1c156
Merge branch 'remove-git-root-2' into remove-git-root-3
waydegg Dec 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions mentat/code_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,18 @@ def __repr__(self):
f" level={self.level.key}, diff={self.diff})"
)

def ref(self):
def ref(self, cwd: Optional[Path] = None) -> str:
if cwd is not None and self.path.is_relative_to(cwd):
path_string = self.path.relative_to(cwd)
else:
path_string = str(self.path)

if self.level == CodeMessageLevel.INTERVAL:
interval_string = f"{self.interval.start}-{self.interval.end}"
return f"{self.path}:{interval_string}"
return str(self.path)
interval_string = f":{self.interval.start}-{self.interval.end}"
else:
interval_string = ""

return f"{path_string}{interval_string}"

def contains_line(self, line_number: int):
return self.interval.contains(line_number)
Expand Down
30 changes: 17 additions & 13 deletions mentat/code_file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,11 @@ def __init__(self):

def read_file(self, path: Path) -> list[str]:
session_context = SESSION_CONTEXT.get()
git_root = session_context.git_root

abs_path = path if path.is_absolute() else Path(git_root / path)
rel_path = Path(os.path.relpath(abs_path, git_root))
abs_path = path if path.is_absolute() else session_context.cwd / path
with open(abs_path, "r") as f:
lines = f.read().split("\n")
self.file_lines[rel_path] = lines
self.file_lines[abs_path] = lines
waydegg marked this conversation as resolved.
Show resolved Hide resolved
return lines

def _create_file(self, code_context: CodeContext, abs_path: Path):
Expand Down Expand Up @@ -69,11 +67,14 @@ async def write_changes_to_files(
) -> list[FileEdit]:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
git_root = session_context.git_root

applied_edits: list[FileEdit] = []
for file_edit in file_edits:
rel_path = Path(os.path.relpath(file_edit.file_path, git_root))
if file_edit.file_path.is_relative_to(session_context.cwd):
display_path = file_edit.file_path.relative_to(session_context.cwd)
else:
display_path = file_edit.file_path

if file_edit.is_creation:
if file_edit.file_path.exists():
raise MentatError(
Expand All @@ -88,9 +89,12 @@ async def write_changes_to_files(
)

if file_edit.is_deletion:
stream.send(f"Are you sure you want to delete {rel_path}?", color="red")
stream.send(
f"Are you sure you want to delete {display_path}?",
color="red",
)
if await ask_yes_no(default_yes=False):
stream.send(f"Deleting {rel_path}...", color="red")
stream.send(f"Deleting {display_path}...", color="red")
# We use the current lines rather than the stored lines for undo
self.history.add_action(
DeletionAction(
Expand All @@ -101,21 +105,21 @@ async def write_changes_to_files(
applied_edits.append(file_edit)
continue
else:
stream.send(f"Not deleting {rel_path}", color="green")
stream.send(f"Not deleting {display_path}", color="green")

if not file_edit.is_creation:
stored_lines = self.file_lines[rel_path]
stored_lines = self.file_lines[file_edit.file_path]
if stored_lines != self.read_file(file_edit.file_path):
logging.info(
f"File '{file_edit.file_path}' changed while generating changes"
)
stream.send(
f"File '{rel_path}' changed while generating; current"
" file changes will be erased. Continue?",
f"File '{display_path}' changed while"
" generating; current file changes will be erased. Continue?",
color="light_yellow",
)
if not await ask_yes_no(default_yes=False):
stream.send(f"Not applying changes to file {rel_path}")
stream.send(f"Not applying changes to file {display_path}")
continue
else:
stored_lines = []
Expand Down
8 changes: 5 additions & 3 deletions mentat/command/commands/exclude.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
code_context = session_context.code_context
git_root = session_context.git_root

if len(args) == 0:
stream.send("No files specified", color="yellow")
Expand All @@ -22,8 +21,11 @@ async def apply(self, *args: str) -> None:
for invalid_path in invalid_paths:
print_invalid_path(invalid_path)
for excluded_path in excluded_paths:
rel_path = excluded_path.relative_to(git_root)
stream.send(f"{rel_path} removed from context", color="red")
if excluded_path.is_relative_to(session_context.cwd):
display_path = excluded_path.relative_to(session_context.cwd)
else:
display_path = excluded_path
stream.send(f"{display_path} removed from context", color="red")

@classmethod
def argument_names(cls) -> list[str]:
Expand Down
8 changes: 5 additions & 3 deletions mentat/command/commands/include.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
code_context = session_context.code_context
git_root = session_context.git_root

if len(args) == 0:
stream.send("No files specified", color="yellow")
Expand All @@ -22,8 +21,11 @@ async def apply(self, *args: str) -> None:
for invalid_path in invalid_paths:
print_invalid_path(invalid_path)
for included_path in included_paths:
rel_path = included_path.relative_to(git_root)
stream.send(f"{rel_path} added to context", color="green")
if included_path.is_relative_to(session_context.cwd):
display_path = included_path.relative_to(session_context.cwd)
else:
display_path = included_path
stream.send(f"{display_path} added to context", color="green")

@classmethod
def argument_names(cls) -> list[str]:
Expand Down
5 changes: 2 additions & 3 deletions mentat/command/commands/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
code_context = session_context.code_context
git_root = session_context.git_root

if len(args) == 0:
stream.send("No search query specified", color="yellow")
Expand All @@ -24,8 +23,8 @@ async def apply(self, *args: str) -> None:

for i, (feature, score) in enumerate(results, start=1):
label = feature.ref()
if label.startswith(str(git_root)):
label = label[len(str(git_root)) + 1 :]
if label.startswith(str(session_context.cwd)):
label = label[len(str(session_context.cwd)) + 1 :]
if feature.name:
label += f' "{feature.name}"'
stream.send(f"{i:3} | {score:.3f} | {label}")
Expand Down
35 changes: 22 additions & 13 deletions mentat/parsers/file_edit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import os
from pathlib import Path
from typing import Any

import attr

Expand Down Expand Up @@ -64,24 +64,34 @@ class FileEdit:
# Should be abs path
rename_file_path: Path | None = attr.field(default=None)

@file_path.validator # pyright: ignore
def is_abs_path(self, attribute: attr.Attribute[Path], value: Any):
if not isinstance(value, Path):
raise ValueError(f"file_path must be a Path, got {type(value)}")
if not value.is_absolute():
raise ValueError(f"file_path must be an absolute path, got {value}")

def is_valid(self) -> bool:
session_context = SESSION_CONTEXT.get()
git_root = session_context.git_root
stream = session_context.stream
code_context = session_context.code_context

rel_path = Path(os.path.relpath(self.file_path, git_root))
if self.file_path.is_relative_to(session_context.cwd):
display_path = self.file_path.relative_to(session_context.cwd)
else:
display_path = self.file_path

if self.is_creation:
if self.file_path.exists():
stream.send(
f"File {rel_path} already exists, canceling creation.",
f"File {display_path} already exists, canceling creation.",
color="light_yellow",
)
return False
else:
if not self.file_path.exists():
stream.send(
f"File {rel_path} does not exist, canceling all edits to file.",
f"File {display_path} does not exist, canceling all edits to file.",
color="light_yellow",
)
return False
Expand All @@ -94,17 +104,18 @@ def is_valid(self) -> bool:
for i in range(r.starting_line, r.ending_line)
):
stream.send(
f"Edits to {rel_path} include lines not in context, "
"canceling all edits to file.",
f"File {display_path} not in context, canceling all edits to file.",
color="light_yellow",
)
return False

if self.rename_file_path is not None and self.rename_file_path.exists():
rel_rename_path = Path(os.path.relpath(self.rename_file_path, git_root))
rel_rename_path = None
if self.rename_file_path.is_relative_to(session_context.cwd):
rel_rename_path = self.rename_file_path.relative_to(session_context.cwd)
stream.send(
f"File {rel_path} being renamed to existing file {rel_rename_path},"
" canceling rename.",
f"File {display_path} being renamed to existing file"
f" {rel_rename_path or self.rename_file_path}, canceling rename.",
color="light_yellow",
)
self.rename_file_path = None
Expand All @@ -114,7 +125,6 @@ async def filter_replacements(
self,
) -> bool:
session_context = SESSION_CONTEXT.get()
git_root = session_context.git_root
code_file_manager = session_context.code_file_manager

if self.is_creation:
Expand All @@ -125,8 +135,7 @@ async def filter_replacements(
return False
file_lines = []
else:
rel_path = Path(os.path.relpath(self.file_path, git_root))
file_lines = code_file_manager.file_lines[rel_path]
file_lines = code_file_manager.file_lines[self.file_path]

if self.is_deletion:
display_information = DisplayInformation(
Expand Down
9 changes: 8 additions & 1 deletion mentat/terminal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
class TerminalClient:
def __init__(
self,
cwd: Path = Path.cwd(),
paths: List[str] = [],
exclude_paths: List[str] = [],
ignore_paths: List[str] = [],
diff: str | None = None,
pr_diff: str | None = None,
config: Config = Config(),
):
self.cwd = cwd
self.paths = [Path(path) for path in paths]
self.exclude_paths = [Path(path) for path in exclude_paths]
self.ignore_paths = [Path(path) for path in ignore_paths]
Expand Down Expand Up @@ -123,7 +125,7 @@ def _init_signal_handlers(self):
async def _run(self):
self._init_signal_handlers()
self.session = Session(
Path.cwd(),
self.cwd,
self.paths,
self.exclude_paths,
self.ignore_paths,
Expand Down Expand Up @@ -223,18 +225,23 @@ def run_cli():
default=None,
help="A git tree-ish to diff against the latest common ancestor of",
)
parser.add_argument(
"--cwd", default=Path.cwd(), help="The current working directory"
)

Config.add_fields_to_argparse(parser)
args = parser.parse_args()

config = Config.create(args)
cwd = args.cwd
paths = args.paths
exclude_paths = args.exclude
ignore_paths = args.ignore
diff = args.diff
pr_diff = args.pr_diff

terminal_client = TerminalClient(
cwd,
paths,
exclude_paths,
ignore_paths,
Expand Down
3 changes: 2 additions & 1 deletion tests/clients/terminal_client_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import subprocess
from pathlib import Path
from textwrap import dedent
from unittest.mock import AsyncMock

Expand Down Expand Up @@ -95,7 +96,7 @@ def test_request_and_command(
# I created this file
@@end""")])

terminal_client = TerminalClient(["."])
terminal_client = TerminalClient(cwd=Path.cwd(), paths=["."])
terminal_client.run()

with open(file_name, "r") as f:
Expand Down
1 change: 1 addition & 0 deletions tests/code_file_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ async def test_change_after_creation(


@pytest.mark.asyncio
@pytest.mark.no_git_testbed
waydegg marked this conversation as resolved.
Show resolved Hide resolved
async def test_changed_file(
mocker,
temp_testbed,
Expand Down
3 changes: 3 additions & 0 deletions tests/commands_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ async def test_commit_command(temp_testbed, mock_collect_user_input):
assert subprocess.check_output(["git", "status", "-s"], text=True) == ""


# TODO: test without git
@pytest.mark.asyncio
async def test_include_command(temp_testbed, mock_collect_user_input):
mock_collect_user_input.set_stream_messages(
Expand All @@ -63,6 +64,7 @@ async def test_include_command(temp_testbed, mock_collect_user_input):
)


# TODO: test without git
@pytest.mark.asyncio
async def test_exclude_command(temp_testbed, mock_collect_user_input):
mock_collect_user_input.set_stream_messages(
Expand Down Expand Up @@ -188,6 +190,7 @@ async def test_clear_command(temp_testbed, mock_collect_user_input, mock_call_ll
assert len(conversation.get_messages()) == 1


# TODO: test without git
@pytest.mark.asyncio
async def test_search_command(
mocker, temp_testbed, mock_call_llm_api, mock_collect_user_input
Expand Down
16 changes: 10 additions & 6 deletions tests/parser_tests/file_edit_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

import pytest

from mentat.parsers.file_edit import FileEdit, Replacement
Expand All @@ -9,12 +7,15 @@


@pytest.mark.asyncio
async def test_replacement(mock_call_llm_api):
async def test_replacement(mock_session_context):
replacements = [
Replacement(0, 2, ["# Line 0", "# Line 1", "# Line 2"]),
Replacement(3, 3, ["# Inserted"]),
]
file_edit = FileEdit(file_path=Path("test.py"), replacements=replacements)
file_edit = FileEdit(
file_path=mock_session_context.cwd.joinpath("test.py"),
replacements=replacements,
)
file_edit.resolve_conflicts()
original_lines = ["# Remove me", "# Remove me", "# Line 3", "# Line 4"]
new_lines = file_edit.get_updated_file_lines(original_lines)
Expand All @@ -30,14 +31,17 @@ async def test_replacement(mock_call_llm_api):

# When we add user conflict resolution, this test will need to be changed
@pytest.mark.asyncio
async def test_replacement_conflict(mock_call_llm_api):
async def test_replacement_conflict(mock_session_context):
replacements = [
Replacement(0, 2, ["L0"]),
Replacement(1, 3, ["L1"]),
Replacement(4, 7, ["L3"]),
Replacement(5, 6, ["L2"]),
]
file_edit = FileEdit(file_path=Path("test.py"), replacements=replacements)
file_edit = FileEdit(
file_path=mock_session_context.cwd.joinpath("test.py"),
replacements=replacements,
)
file_edit.resolve_conflicts()
original_lines = ["O0", "O1", "O2", "O3", "O4", "O5", "O6"]
new_lines = file_edit.get_updated_file_lines(original_lines)
Expand Down
Loading