Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Add Textual #502

Merged
merged 24 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ List the files you would like Mentat to read and edit as arguments. Mentat will

For more information on commands, configuration or using other models see [the documentation](https://docs.mentat.ai/en/latest/user/guides.html).

## MacOS Visual Artifacts

Mentat uses [Textual](https://textual.textualize.io/). On MacOS, Textual may not render the TUI correctly; if you run into this problem, use the fix [here](https://textual.textualize.io/FAQ/#why-doesnt-textual-look-good-on-macos).

# 👩‍💻 Roadmap and Contributing

We welcome contributions! To coordinate, make sure to join the Discord server: [![Discord Follow](https://dcbadge.vercel.app/api/server/XbPdxAMJte?style=flat)](https://discord.gg/zbvd9qx9Pb)
Expand Down
5 changes: 0 additions & 5 deletions docs/source/user/commands.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@ Commit all unstaged and staged changes to git.

Show or set a config option's value.

/context
--------

Show all files currently in context.

/exclude <path|glob pattern> ...
--------------------------------

Expand Down
24 changes: 0 additions & 24 deletions docs/source/user/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,30 +53,6 @@ List of `glob patterns <https://docs.python.org/3/library/glob.html>`_ to exclud
"file-exclude-glob-list": ["**/.*, **/.*/**"]
}

input_style
^^^^^^^^^^^

A list of key-value pairs defining a custom `Pygment Style <https://pygments.org/docs/styledevelopment/>`_ to style the Mentat prompt.

.. code-block:: json

{
"input-style": [
[
"",
"#9835bd"
],
[
"prompt",
"#ffffff bold"
],
[
"continuation",
"#ffffff bold"
]
]
}

parser
^^^^^^

Expand Down
2 changes: 1 addition & 1 deletion docs/source/user/context.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Context refers to portions of your code/files which are sent to the LLM along wi

Files can be manually added to context as a command line argument when starting mentat or with the :code:`/include` command during a session. Files can be removed from context with the :code:`/exclude` command. Mentat always puts all included files into the system message sent to the LLM so you probably don't want to start mentat with :code:`mentat .`. If you do want mentat to intelligently select the context from your prompt you should run :code:`mentat -a` and mentat will build its own context. For more see :ref:`auto context`.

You can specify line ranges to add only a subset of a file to context by adding the starting line (inclusive) and ending line (exclusive) to the path. For example :code:`/include README.md:1-5,10-20` would add lines 1, 2, 3 and 4 and 10th to 19th lines to the LLMs context. You can see a summary of what is in context with the :code:`/context` command.
You can specify line ranges to add only a subset of a file to context by adding the starting line (inclusive) and ending line (exclusive) to the path. For example :code:`/include README.md:1-5,10-20` would add lines 1, 2, 3 and 4 and 10th to 19th lines to the LLMs context.

You can see the conversation exactly as the LLM sees it by running :code:`/viewer`. This command opens the transcript in a web browser. If you click a message from the LLM you will see the conversation as the LLM sees it. You can see past conversations by using the arrow keys.

Expand Down
117 changes: 71 additions & 46 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import json
import os
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set, Union
from typing import Dict, Iterable, List, Optional, Set, TypedDict, Union

from openai.types.chat import ChatCompletionSystemMessageParam

from mentat.code_feature import (
CodeFeature,
Expand All @@ -17,25 +20,35 @@
from mentat.git_handler import get_paths_with_git_diffs
from mentat.include_files import (
PathType,
build_path_tree,
get_code_features_for_path,
get_path_type,
get_paths_for_directory,
is_file_text_encoded,
match_path_with_patterns,
print_path_tree,
validate_and_format_path,
)
from mentat.interval import parse_intervals, split_intervals_from_path
from mentat.llm_api_handler import (
count_tokens,
get_max_tokens,
prompt_tokens,
raise_if_context_exceeds_max,
)
from mentat.session_context import SESSION_CONTEXT
from mentat.session_stream import SessionStream


class ContextStreamMessage(TypedDict):
cwd: str
diff_context_display: Optional[str]
auto_context_tokens: int
features: List[str]
auto_features: List[str]
git_diff_paths: List[str]
total_tokens: int
total_cost: float


class CodeContext:
def __init__(
self,
Expand All @@ -60,60 +73,67 @@ def __init__(
self.ignore_files: Set[Path] = set()
self.auto_features: List[CodeFeature] = []

def display_context(self):
"""Display the baseline context: included files and auto-context settings"""
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
config = session_context.config
def refresh_context_display(self):
"""
Sends a message to the client with the new updated code context
Must be called whenever the context changes!
"""
ctx = SESSION_CONTEXT.get()

stream.send("Code Context:", style="code")
prefix = " "
stream.send(f"{prefix}Directory: {session_context.cwd}")
diff_context_display = None
if self.diff_context and self.diff_context.name:
stream.send(f"{prefix}Diff:", end=" ")
stream.send(self.diff_context.get_display_context(), style="success")
diff_context_display = self.diff_context.get_display_context()

if config.auto_context_tokens > 0:
stream.send(f"{prefix}Auto-Context: Enabled")
stream.send(f"{prefix}Auto-Context Tokens: {config.auto_context_tokens}")
else:
stream.send(f"{prefix}Auto-Context: Disabled")
features = get_consolidated_feature_refs(
[
feature
for file_features in self.include_files.values()
for feature in file_features
]
)
auto_features = get_consolidated_feature_refs(self.auto_features)
git_diff_paths = (
list(get_paths_with_git_diffs(self.git_root)) if self.git_root else []
)

if self.include_files:
stream.send(f"{prefix}Included files:")
stream.send(f"{prefix + prefix}{session_context.cwd.name}")
features = [
messages = ctx.conversation.get_messages()
code_message = get_code_message_from_features(
[
feature
for file_features in self.include_files.values()
for feature in file_features
]
refs = get_consolidated_feature_refs(features)
print_path_tree(
build_path_tree([Path(r) for r in refs], session_context.cwd),
get_paths_with_git_diffs(self.git_root) if self.git_root else set(),
session_context.cwd,
prefix + prefix,
)
else:
stream.send(f"{prefix}Included files: ", end="")
stream.send("None", style="warning")

if self.auto_features:
stream.send(f"{prefix}Auto-Included Features:")
refs = get_consolidated_feature_refs(self.auto_features)
print_path_tree(
build_path_tree([Path(r) for r in refs], session_context.cwd),
get_paths_with_git_diffs(self.git_root) if self.git_root else set(),
session_context.cwd,
prefix + prefix,
)
+ self.auto_features
)
total_tokens = prompt_tokens(
messages
+ [
ChatCompletionSystemMessageParam(
role="system", content="\n".join(code_message)
)
],
ctx.config.model,
)

total_cost = ctx.cost_tracker.total_cost

data = ContextStreamMessage(
cwd=str(ctx.cwd),
diff_context_display=diff_context_display,
auto_context_tokens=ctx.config.auto_context_tokens,
features=features,
auto_features=auto_features,
git_diff_paths=[str(p) for p in git_diff_paths],
total_tokens=total_tokens,
total_cost=total_cost,
)
ctx.stream.send(json.dumps(data), channel="context_update")

async def get_code_message(
self,
prompt_tokens: int,
prompt: Optional[str] = None,
expected_edits: Optional[list[str]] = None, # for training/benchmarking
loading_multiplier: float = 0.0,
suppress_context_check: bool = False,
) -> str:
"""
Expand All @@ -130,7 +150,9 @@ async def get_code_message(
# Setup code message metadata
code_message = list[str]()
if self.diff_context:
self.diff_context.clear_cache()
# Since there is no way of knowing when the git diff changes,
# we just refresh the cache every time get_code_message is called
self.diff_context.refresh_diff_files()
if self.diff_context.diff_files():
code_message += [
"Diff References:",
Expand Down Expand Up @@ -168,11 +190,11 @@ async def get_code_message(
auto_tokens,
prompt,
expected_edits,
loading_multiplier=loading_multiplier,
)
self.auto_features = list(
set(self.auto_features) | set(await feature_filter.filter(features))
)
self.refresh_context_display()

# Merge include file features and auto features and add to code message
code_message += get_code_message_from_features(
Expand Down Expand Up @@ -221,7 +243,8 @@ def clear_auto_context(self):
"""
Clears all auto-features added to the conversation so far.
"""
self.auto_features = []
self._auto_features = []
self.refresh_context_display()

def include_features(self, code_features: Iterable[CodeFeature]):
"""
Expand Down Expand Up @@ -249,6 +272,7 @@ def include_features(self, code_features: Iterable[CodeFeature]):
self.include_files[code_feature.path] = []
self.include_files[code_feature.path].append(code_feature)
included_paths.add(Path(str(code_feature)))
self.refresh_context_display()
return included_paths

def include(
Expand Down Expand Up @@ -404,6 +428,7 @@ def exclude(self, path: Path | str) -> Set[Path]:
except PathValidationError as e:
session_context.stream.send(str(e), style="error")

self.refresh_context_display()
return excluded_paths

async def search(
Expand Down
10 changes: 7 additions & 3 deletions mentat/code_file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List

from mentat.edit_history import EditHistory
from mentat.errors import MentatError
Expand Down Expand Up @@ -67,6 +67,11 @@ def rename_file(self, abs_path: Path, new_abs_path: Path):
if new_abs_path not in code_context.include_files:
code_context.include(new_abs_path)

def write_to_file(self, abs_path: Path, new_lines: List[str]):
with open(abs_path, "w") as f:
f.write("\n".join(new_lines))
self.file_lines[abs_path] = new_lines

# Mainly does checks on if file is in context, file exists, file is unchanged, etc.
async def write_changes_to_files(
self,
Expand Down Expand Up @@ -135,8 +140,7 @@ async def write_changes_to_files(
file_path = file_edit.rename_file_path or file_edit.file_path
# We use the current lines rather than the stored lines for undo
file_edit.previous_file_lines = self.read_file(file_path)
with open(file_path, "w") as f:
f.write("\n".join(new_lines))
self.write_to_file(file_path, new_lines)
applied_edits.append(file_edit)

for applied_edit in applied_edits:
Expand Down
1 change: 0 additions & 1 deletion mentat/command/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .clear import ClearCommand
from .commit import CommitCommand
from .config import ConfigCommand
from .context import ContextCommand
from .exclude import ExcludeCommand
from .help import HelpCommand
from .include import IncludeCommand
Expand Down
32 changes: 0 additions & 32 deletions mentat/command/commands/context.py

This file was deleted.

8 changes: 3 additions & 5 deletions mentat/command/commands/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def apply(self, *args: str) -> None:
file_name = feature.rel_path(session_context.cwd)
stream.send(file_name, color="blue", end="")
file_interval = feature.interval_string()
stream.send(file_interval, color="light_cyan", end="")
stream.send(file_interval, color="bright_cyan", end="")

tokens = feature.count_tokens(config.model)
cumulative_tokens += tokens
Expand All @@ -81,7 +81,7 @@ async def apply(self, *args: str) -> None:
"(Y/n) for more results or to exit search mode.\nResults to"
' include in context: (eg: "1 3 4" or "1-4")'
)
user_input: str = (await collect_user_input(plain=True)).data.strip()
user_input: str = (await collect_user_input()).data.strip()
while user_input.lower() not in "yn":
to_include = _parse_include_input(user_input, i)
if to_include is not None:
Expand All @@ -94,9 +94,7 @@ async def apply(self, *args: str) -> None:
stream.send(f"{rel_path} added to context", style="success")
else:
stream.send("(Y/n)", style="input")
user_input: str = (
await collect_user_input(plain=True)
).data.strip()
user_input: str = (await collect_user_input()).data.strip()
if user_input.lower() == "n":
stream.send("Exiting search mode...", style="input")
break
Expand Down
Loading
Loading