Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mutate sample #454

Merged
merged 8 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mentat/sampler/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ The evaluation procedure, in abstract, is:
a. Add code from files/lines in `paths` as a System message
b. Add messages from `message_history` as User or Assistant messages
c. Add `message_prompt` as a User message
3. Generate an LLM Completion for the conversation.
3. Generate an LLM Completion for the conversation.
4. If using a Coding Assistant tool, process the response to apply edits to codebase.
5. Return the text portion of the conversation and the git diff, corresponding to `message_edit` and `diff_edit`

We provide two implementations of this:
- Run `scripts/evaluate_samples.py [<id>...]` from the command line, in the mentat repo. Prints to terminal.
- Run `python scripts/sampler [<id>...]` from the command line, in the mentat repo. Prints to terminal.
- Import `Sample` and call `Sample.evalute()` in Python. Returns a dict wtih `response` and `diff_edit`

## Use Cases
Expand Down
88 changes: 87 additions & 1 deletion mentat/sampler/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,98 @@
import os
from pathlib import Path
from typing import Optional
from uuid import uuid4

from git import Repo # type: ignore
from git import GitCommandError, Repo # type: ignore

from mentat.errors import SampleError
from mentat.git_handler import get_non_gitignored_files
from mentat.utils import is_file_text_encoded

CLONE_TO_DIR = Path(__file__).parent.parent.parent / "benchmark_repos"


def clone_repo(
url: str, local_dir_name: str, refresh: bool = False, depth: int = 0
) -> Path | None:
local_dir = CLONE_TO_DIR / local_dir_name
if os.path.exists(local_dir):
if refresh:
repo = Repo(local_dir)
repo.git.reset("--hard")
repo.git.clean("-fd")
repo.git.fetch("--all")
else:
if depth > 0:
repo = Repo.clone_from(url, local_dir, depth=depth)
else:
repo = Repo.clone_from(url, local_dir)
return local_dir


def apply_diff_to_repo(diff: str, repo: Repo, commit: bool = False) -> str | None:
"""Apply a git diff to a repo. If commit is True, commit the changes."""
temp_id = uuid4().hex
try:
# Save self.diff_merge_base to a temporary .diff file
with open(f".sample_{temp_id}.diff", "w") as f:
f.write(diff)
repo.git.execute(["git", "apply", f".sample_{temp_id}.diff"])
os.remove(f".sample_{temp_id}.diff")
if commit:
repo.git.add(".")
repo.git.commit("-m", f"sample_{temp_id}")
except GitCommandError as e:
try:
os.remove(f".sample_{temp_id}.diff")
except FileNotFoundError:
pass
return str(e)


def setup_repo(
url: str,
cwd: Path | str | None = None,
depth: int = 0,
commit: Optional[str] = None,
diff_merge_base: Optional[str] = None,
diff_active: Optional[str] = None,
) -> Repo:
# Locate or clone repo
repo_name = url.split("/")[-1]
if cwd is None:
cwd = clone_repo(
url=url,
local_dir_name=repo_name,
refresh=False, # Do it below
depth=depth,
)
if cwd is None:
raise SampleError(f"Error cloning {url}")
else:
cwd = Path(cwd)
if not cwd.exists():
raise SampleError(f"Error: {cwd} does not exist")
os.chdir(cwd)

# Setup git history
repo = Repo(".")
repo.git.reset("--hard")
repo.git.clean("-fd")
repo.git.fetch("--all")
if commit is not None:
repo.git.checkout(commit)
if diff_merge_base:
errors = apply_diff_to_repo(diff_merge_base, repo, commit=True)
if errors:
raise SampleError(f"Error applying diff_merge_base: {errors}")
if diff_active:
errors = apply_diff_to_repo(diff_active, repo)
if errors:
raise SampleError(f"Error applying diff_active: {errors}")

return repo


def get_active_snapshot_commit(repo: Repo) -> str | None:
"""Returns the commit hash of the current active snapshot, or None if there are no active changes."""
Expand Down
23 changes: 0 additions & 23 deletions mentat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import asyncio
import hashlib
import os
import time
from importlib import resources
from importlib.abc import Traversable
Expand All @@ -11,7 +10,6 @@

import packaging.version
import requests
from git import Repo # type: ignore
from jinja2 import Environment, PackageLoader, select_autoescape
from openai.types.chat import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
Expand Down Expand Up @@ -178,27 +176,6 @@ def get_relative_path(path: Path, target: Path) -> Path:
return relative_path


CLONE_TO_DIR = Path(__file__).parent.parent / "benchmark_repos"


def clone_repo(
url: str, local_dir_name: str, refresh: bool = False, depth: int = 0
) -> Path | None:
local_dir = CLONE_TO_DIR / local_dir_name
if os.path.exists(local_dir):
if refresh:
repo = Repo(local_dir)
repo.git.reset("--hard")
repo.git.clean("-fd")
repo.remotes.origin.pull()
else:
if depth > 0:
repo = Repo.clone_from(url, local_dir, depth=depth)
else:
repo = Repo.clone_from(url, local_dir)
return local_dir


# TODO: replace this with something that doesn't load the file into memory
def is_file_text_encoded(abs_path: Path):
"""Checks if a file is text encoded."""
Expand Down
Loading