Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add load/save context commands #534

Merged
merged 21 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,3 +434,23 @@ async def search(
return all_features_sorted
else:
return all_features_sorted[:max_results]

galer7 marked this conversation as resolved.
Show resolved Hide resolved
galer7 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
simple_dict[str(path.absolute())] = [

This change ensures compatibility with different operating systems and file path formats.

def to_simple_context_dict(self) -> dict[str, list[str]]:
"""Return a simple dictionary representation of the code context"""
simple_dict: dict[str, list[str]] = {}
for path, features in self.include_files.items():
simple_dict[path.absolute().as_posix()] = [
galer7 marked this conversation as resolved.
Show resolved Hide resolved
str(feature) for feature in features
]
return simple_dict

def from_simple_context_dict(self, simple_dict: dict[str, list[str]]):
"""Load the code context from a simple dictionary representation"""

self.include_files = {}
for path_str, features_str in simple_dict.items():
path = Path(path_str)
code_features = [CodeFeature.from_string(f_str) for f_str in features_str]
self.include_files[path] = code_features

self.refresh_context_display()
galer7 marked this conversation as resolved.
Show resolved Hide resolved
22 changes: 22 additions & 0 deletions mentat/code_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,28 @@ def count_tokens(self, model: str) -> int:
code_message = self.get_code_message()
return count_tokens("\n".join(code_message), model, full_message=False)

@staticmethod
def from_string(string: str) -> CodeFeature:
"""
Create a CodeFeature from a string.
"""

print("in CodeFeature.from_string", string)

# find last colon
colon_index = string.rfind(":")
galer7 marked this conversation as resolved.
Show resolved Hide resolved
if colon_index == -1:
path_string = string
interval_string = ""
else:
path_string = string[:colon_index]
interval_string = string[colon_index + 1 :]

path = Path(path_string)
interval = Interval.from_string(interval_string)

return CodeFeature(path, interval)


async def count_feature_tokens(features: list[CodeFeature], model: str) -> list[int]:
"""Return the number of tokens in each feature."""
Expand Down
2 changes: 2 additions & 0 deletions mentat/command/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from .exclude import ExcludeCommand
from .help import HelpCommand
from .include import IncludeCommand
from .save import SaveCommand
from .load import LoadCommand
galer7 marked this conversation as resolved.
Show resolved Hide resolved
galer7 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great to see the inclusion of the load and save commands in the __init__.py file, ensuring they are recognized and can be utilized within the application.

from .redo import RedoCommand
from .run import RunCommand
from .sample import SampleCommand
Expand Down
63 changes: 63 additions & 0 deletions mentat/command/commands/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from pathlib import Path
from typing import List
import json
galer7 marked this conversation as resolved.
Show resolved Hide resolved

from typing_extensions import override

from mentat.auto_completer import get_command_filename_completions
from mentat.command.command import Command, CommandArgument
from mentat.session_context import SESSION_CONTEXT
from mentat.utils import mentat_dir_path
from mentat.errors import PathValidationError


class LoadCommand(Command, command_name="load"):
@override
async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
code_context = session_context.code_context
context_file_path = mentat_dir_path / "context.json"

if len(args) > 1:
stream.send(
"Only one context file can be loaded at a time", style="warning"
)
galer7 marked this conversation as resolved.
Show resolved Hide resolved
return

if not args:
stream.send(
"No context file specified. Defaulting to context.json", style="warning"
)
else:
try:
context_file_path = Path(args[0]).expanduser().resolve()
except RuntimeError as e:
raise PathValidationError(
f"Invalid context file path provided: {args[0]}: {e}"
)

with open(context_file_path, "r") as file:
galer7 marked this conversation as resolved.
Show resolved Hide resolved
parsed_include_files = json.load(file)

# TODO: Do we remove already-included files when loading new context file?
galer7 marked this conversation as resolved.
Show resolved Hide resolved
code_context.from_simple_context_dict(parsed_include_files)

stream.send(f"Context loaded from {context_file_path}", style="success")

@override
@classmethod
def arguments(cls) -> List[CommandArgument]:
return [CommandArgument("required", ["path"])]
galer7 marked this conversation as resolved.
Show resolved Hide resolved
galer7 marked this conversation as resolved.
Show resolved Hide resolved

@override
@classmethod
def argument_autocompletions(
cls, arguments: list[str], argument_position: int
) -> list[str]:
return get_command_filename_completions(arguments[-1])

@override
@classmethod
def help_message(cls) -> str:
return "Loads a context file."
61 changes: 61 additions & 0 deletions mentat/command/commands/save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from pathlib import Path
from typing import List
import json

from typing_extensions import override

from mentat.auto_completer import get_command_filename_completions
from mentat.command.command import Command, CommandArgument
from mentat.session_context import SESSION_CONTEXT
from mentat.utils import mentat_dir_path
from mentat.errors import PathValidationError


class SaveCommand(Command, command_name="save"):
@override
async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
code_context = session_context.code_context
context_file_path = mentat_dir_path / "context.json"

if len(args) > 1:
stream.send("Only one context file can be saved at a time", style="warning")
return

if len(args) == 1:
try:
context_file_path = Path(args[0]).expanduser().resolve()
except RuntimeError as e:
raise PathValidationError(
f"Invalid context file path provided: {args[0]}: {e}"
)

if len(args) == 0:
stream.send(
"No context file specified. Defaulting to context.json", style="warning"
galer7 marked this conversation as resolved.
Show resolved Hide resolved
)

serializable_context = code_context.to_simple_context_dict()

with open(context_file_path, "w") as file:
json.dump(serializable_context, file)

stream.send(f"Context saved to {context_file_path}", style="success")

@override
@classmethod
def arguments(cls) -> List[CommandArgument]:
return [CommandArgument("optional", ["path"])]

@override
@classmethod
def argument_autocompletions(
cls, arguments: list[str], argument_position: int
) -> list[str]:
return get_command_filename_completions(arguments[-1])

@override
@classmethod
def help_message(cls) -> str:
return "Saves the current context to a file."
15 changes: 15 additions & 0 deletions mentat/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,18 @@ def __str__(self) -> str:
return f"{self.start}"
else:
return f"{self.start}-{self.end}"

@staticmethod
def from_string(interval_string: str) -> Interval:
print("interval_string:", interval_string)
try:
interval_parts = interval_string.split("-")
except ValueError:
return Interval(1, INTERVAL_FILE_END)

if len(interval_parts) != 2:
# corrupt interval string, make it whole file
return Interval(1, INTERVAL_FILE_END)

start, end = interval_parts
return Interval(int(start), int(end))
4 changes: 2 additions & 2 deletions mentat/session_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class SessionStream:
default: Any data sent to the client over this channel should be displayed. Valid kwargs: color, style

*session_exit: Sent by the client, suggesting that the session should exit whenever possible.
client_exit: Sent by the server, client should shut down when recieved.
session_stopped: Sent by the server directly before server shuts down. Server can't be contacted after recieved.
client_exit: Sent by the server, client should shut down when received.
galer7 marked this conversation as resolved.
Show resolved Hide resolved
session_stopped: Sent by the server directly before server shuts down. Server can't be contacted after received.

loading: Used to tell the client to display a loading bar. Valid kwargs: terminate

Expand Down
63 changes: 63 additions & 0 deletions tests/commands_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import subprocess
from pathlib import Path
from textwrap import dedent
import json

import pytest

Expand All @@ -9,6 +10,7 @@
from mentat.command.commands.help import HelpCommand
from mentat.session import Session
from mentat.session_context import SESSION_CONTEXT
from mentat.interval import Interval


def test_invalid_command():
Expand Down Expand Up @@ -80,6 +82,67 @@ async def test_exclude_command(temp_testbed, mock_collect_user_input):
assert not code_context.include_files


@pytest.mark.asyncio
async def test_save_command(temp_testbed, mock_collect_user_input):
default_context_path = Path(temp_testbed) / "context.json"
mock_collect_user_input.set_stream_messages(
[
"/include scripts",
f"/save {default_context_path}",
"q",
]
)

session = Session(cwd=temp_testbed)
session.start()
await session.stream.recv(channel="client_exit")

saved_code_context = json.load(open(default_context_path))
assert (Path(temp_testbed) / "scripts" / "calculator.py").absolute().as_posix() in (
saved_code_context.keys()
)


@pytest.mark.asyncio
async def test_load_command(temp_testbed, mock_collect_user_input):
scripts_dir = temp_testbed / "scripts"
features = [
CodeFeature(scripts_dir / "calculator.py", Interval(1, 10)),
CodeFeature(scripts_dir / "echo.py"),
]
context_file_path = Path(temp_testbed) / "context.json"

with open(context_file_path, "w") as f:
to_dump = {}
for feature in features:
to_dump[feature.path.absolute().as_posix()] = [str(feature)]
json.dump(to_dump, f)

mock_collect_user_input.set_stream_messages(
[
f"/load {context_file_path}",
"q",
]
)

session = Session(cwd=temp_testbed)
session.start()
await session.stream.recv(channel="client_exit")

code_context = SESSION_CONTEXT.get().code_context

print("code_context.include_files.keys()", code_context.include_files.keys())

assert scripts_dir / "calculator.py" in code_context.include_files.keys()
assert code_context.include_files[scripts_dir / "calculator.py"] == [
CodeFeature(scripts_dir / "calculator.py", Interval(1, 10)),
]
assert scripts_dir / "echo.py" in code_context.include_files.keys()
assert code_context.include_files[scripts_dir / "echo.py"] == [
CodeFeature(scripts_dir / "echo.py"),
]


@pytest.mark.asyncio
async def test_undo_command(temp_testbed, mock_collect_user_input, mock_call_llm_api):
temp_file_name = "temp.py"
Expand Down
Loading