Skip to content

Commit

Permalink
Merge branch 'main' into remove-even-more-0.5-code
Browse files Browse the repository at this point in the history
  • Loading branch information
rogthefrog committed Dec 18, 2024
2 parents 65877ce + c79d7d4 commit c826a38
Show file tree
Hide file tree
Showing 16 changed files with 281 additions and 318 deletions.
50 changes: 24 additions & 26 deletions src/modelbench/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from modelbench.benchmarks import BenchmarkDefinition, BenchmarkScore
from modelbench.cache import DiskCache, MBCache
from modelbench.run_journal import RunJournal
from modelbench.suts import ModelGaugeSut
from modelgauge.annotator import CompletionAnnotator
from modelgauge.annotator_registry import ANNOTATORS
from modelgauge.base_test import PromptResponseTest, TestResult
Expand All @@ -27,7 +26,7 @@
from modelgauge.prompt import TextPrompt
from modelgauge.records import TestRecord
from modelgauge.single_turn_prompt_response import PromptWithContext, TestItem
from modelgauge.sut import SUTCompletion, SUTResponse
from modelgauge.sut import PromptResponseSUT, SUTCompletion, SUTResponse

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -144,12 +143,12 @@ def _add_test_annotators(self, test: PromptResponseTest):
annotators.append(ANNOTATORS.make_instance(annotator_uid, secrets=self.secrets))
self.test_annotators[test.uid] = annotators

def add_finished_item(self, item: "TestRunItem"):
def add_finished_item(self, item: TestRunItem):
if item.completion() and item.annotations and not item.exceptions:
self.finished_items[item.sut.key][item.test.uid].append(item)
self.finished_items[item.sut.uid][item.test.uid].append(item)
self.journal.item_entry("item finished", item)
else:
self.failed_items[item.sut.key][item.test.uid].append(item)
self.failed_items[item.sut.uid][item.test.uid].append(item)
self.journal.item_entry(
"item failed",
item,
Expand All @@ -164,10 +163,10 @@ def add_test_record(self, test_record: TestRecord):
self.test_records[test_record.test_uid][test_record.sut_uid] = test_record

def finished_items_for(self, sut, test) -> Sequence[TestItem]:
return self.finished_items[sut.key][test.uid]
return self.finished_items[sut.uid][test.uid]

def failed_items_for(self, sut, test) -> Sequence[TestItem]:
return self.failed_items[sut.key][test.uid]
return self.failed_items[sut.uid][test.uid]

def annotators_for_test(self, test: PromptResponseTest) -> Sequence[CompletionAnnotator]:
return self.test_annotators[test.uid]
Expand Down Expand Up @@ -202,7 +201,7 @@ def __init__(self, runner: "TestRunner"):


class BenchmarkRun(TestRunBase):
benchmark_scores: dict[BenchmarkDefinition, dict[ModelGaugeSut, BenchmarkScore]]
benchmark_scores: dict[BenchmarkDefinition, dict[PromptResponseTest, BenchmarkScore]]
benchmarks: Sequence[BenchmarkDefinition]

def __init__(self, runner: "BenchmarkRunner"):
Expand Down Expand Up @@ -283,8 +282,8 @@ def __init__(self, test_run: TestRunBase, cache: MBCache, thread_count=1):
self.test_run = test_run

def handle_item(self, item: TestRunItem):
mg_sut = item.sut.instance(self.test_run.secrets)
raw_request = mg_sut.translate_text_prompt(item.prompt_with_context().prompt)
sut = item.sut
raw_request = sut.translate_text_prompt(item.prompt_with_context().prompt)
cache_key = raw_request.model_dump_json(exclude_none=True)
self._debug(f"looking for {cache_key} in cache")
try:
Expand All @@ -297,16 +296,16 @@ def handle_item(self, item: TestRunItem):
self._debug(f"cache entry not found; processing and saving")
with Timer() as timer:
try:
raw_response = mg_sut.evaluate(raw_request)
raw_response = sut.evaluate(raw_request)
except Exception as e:
logger.error(f"failure fetching sut {mg_sut.uid} on first try: {raw_request}", exc_info=True)
raw_response = mg_sut.evaluate(raw_request)
logger.error(f"failure fetching sut {sut.uid} on first try: {raw_request}", exc_info=True)
raw_response = sut.evaluate(raw_request)
self.cache[cache_key] = raw_response
self.test_run.journal.item_entry(
"fetched sut response", item, run_time=timer, request=raw_request, response=raw_response
)

response = mg_sut.translate_response(raw_request, raw_response)
response = sut.translate_response(raw_request, raw_response)
item.sut_response = response
self.test_run.journal.item_entry("translated sut response", item, response=response)

Expand Down Expand Up @@ -418,7 +417,7 @@ def __init__(self, data_dir: pathlib.Path):
self.thread_count = 1
self.run_tracker = NullRunTracker()

def add_sut(self, sut: ModelGaugeSut):
def add_sut(self, sut: PromptResponseSUT):
self.suts.append(sut)

def _check_ready_to_run(self):
Expand All @@ -433,16 +432,15 @@ def _check_external_services(self, run: TestRunBase):
self._check_annotators_working(run)

def _check_suts_working(self, run: TestRunBase):
def check_sut(sut: ModelGaugeSut):
def check_sut(sut: PromptResponseSUT):
try:
mg_sut = sut.instance(self.secrets)
raw_request = mg_sut.translate_text_prompt(TextPrompt(text="Why did the chicken cross the road?"))
raw_response = mg_sut.evaluate(raw_request)
response: SUTResponse = mg_sut.translate_response(raw_request, raw_response)
raw_request = sut.translate_text_prompt(TextPrompt(text="Why did the chicken cross the road?"))
raw_response = sut.evaluate(raw_request)
response: SUTResponse = sut.translate_response(raw_request, raw_response)
return bool(response.completions)
except Exception as e:
logger.error(f"initial check failure for {sut}", exc_info=e)
print(f"initial check failure for {sut}")
logger.error(f"initial check failure for {sut.uid}", exc_info=e)
print(f"initial check failure for {sut.uid}")
traceback.print_exc()

return False
Expand Down Expand Up @@ -497,8 +495,8 @@ def _make_test_record(self, run, sut, test, test_result):
test_uid=test.uid,
test_initialization=test.initialization_record,
dependency_versions=test.dependency_helper.versions_used(),
sut_uid=sut._instance.uid,
sut_initialization=sut._instance.initialization_record,
sut_uid=sut.uid,
sut_initialization=sut.initialization_record,
test_item_records=[],
test_item_exceptions=[],
result=TestResult.from_instance(test_result),
Expand Down Expand Up @@ -628,10 +626,10 @@ def _calculate_benchmark_scores(self, benchmark_run):
test_records = {}
for test in hazard.tests(benchmark_run.secrets):
records = benchmark_run.test_records[test.uid][sut.uid]
assert records, f"No records found for {benchmark_definition} {sut} {hazard} {test.uid}"
assert records, f"No records found for {benchmark_definition} {sut.uid} {hazard} {test.uid}"
test_records[test.uid] = records

assert test_records, f"No records found for {benchmark_definition} {sut} {hazard}"
assert test_records, f"No records found for {benchmark_definition} {sut.uid} {hazard}"

hazard_score = hazard.score(test_records)
hazard_scores.append(hazard_score) # TODO: score needs way less
Expand Down
15 changes: 7 additions & 8 deletions src/modelbench/benchmark_runner_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,23 @@
import time
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Mapping, List
from typing import List, Mapping

from modelbench.suts import ModelGaugeSut
from modelgauge.annotation import Annotation
from modelgauge.annotator import CompletionAnnotator
from modelgauge.base_test import PromptResponseTest
from modelgauge.dependency_helper import FromSourceDependencyHelper
from modelgauge.external_data import WebData
from modelgauge.single_turn_prompt_response import (
TestItem,
PromptWithContext,
MeasuredTestItem,
TestItemAnnotations,
PromptInteractionAnnotations,
SUTResponseAnnotations,
PromptWithContext,
SUTCompletionAnnotations,
SUTResponseAnnotations,
TestItem,
TestItemAnnotations,
)
from modelgauge.sut import SUTResponse, SUTCompletion
from modelgauge.sut import PromptResponseSUT, SUTResponse, SUTCompletion


# in their own file to solve circular import problems
Expand Down Expand Up @@ -100,7 +99,7 @@ class TestRunItem:

test: ModelgaugeTestWrapper
test_item: TestItem
sut: ModelGaugeSut = None
sut: PromptResponseSUT = None
sut_response: SUTResponse = None
annotations: dict[str, Annotation] = dataclasses.field(default_factory=dict)
measurements: dict[str, float] = dataclasses.field(default_factory=dict)
Expand Down
8 changes: 4 additions & 4 deletions src/modelbench/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@
from typing import List, Sequence

import casefy
from modelgauge.sut import PromptResponseSUT
from modelgauge.tests.safe_v1 import Locale

from modelbench.hazards import HazardDefinition, HazardScore, SafeHazardV1, Standards, STANDARDS
from modelbench.scoring import LetterGradeMixin, score_to_ordinal_grade
from modelbench.suts import ModelGaugeSut
from modelbench.scoring import LetterGradeMixin, score_to_ordinal_grade, score_to_ordinal_grade_v1
from modelbench.uid import HasUid


class BenchmarkScore(ABC, LetterGradeMixin):
def __init__(
self,
benchmark_definition: "BenchmarkDefinition",
sut: ModelGaugeSut,
sut: PromptResponseSUT,
hazard_scores: List["HazardScore"],
end_time: datetime,
):
Expand Down Expand Up @@ -62,7 +62,7 @@ def __repr__(self):
+ "("
+ str(self.benchmark_definition)
+ ", "
+ str(self.sut)
+ str(self.sut.uid)
+ ", "
+ str(self.hazard_scores)
+ ")"
Expand Down
11 changes: 7 additions & 4 deletions src/modelbench/consistency_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,13 @@ def failure_message(self) -> str:
assert not self.check()
messages = []
if len(self.duplicates) > 0:
messages.append(f"The following duplicate prompts were found: {self.duplicates}")
messages.append(f"{len(self.duplicates)} duplicate prompts were found: {self.duplicates}")
if len(self.missing_prompts) > 0:
messages.append(f"The prompts were expected but missing: {self.missing_prompts}")
messages.append(f"{len(self.missing_prompts)} prompts were expected but missing: {self.missing_prompts}")
if len(self.unknown_prompts) > 0:
messages.append(f"The following prompts were found but were not expected: {self.unknown_prompts}")
messages.append(
f"{len(self.unknown_prompts)} prompts were found but were not expected: {self.unknown_prompts}"
)
return "\n\t".join(messages)


Expand All @@ -113,12 +115,13 @@ def failure_message(self) -> str:

class EachPromptRespondedToOnce(OneToOneCheck):
def __init__(self, search_engine: JournalSearch, sut, test):
self.test = test
super().__init__(
search_engine.test_prompt_uids(test), search_engine.sut_response_prompt_uids_for_test(sut, test)
)

def failure_message(self) -> str:
message = "Expected exactly 1 SUT response for each prompt in the test.\n\t"
message = f"Expected exactly 1 SUT response for each prompt in the test {self.test}.\n\t"
# Call super() to get specific details about duplicates/missing/extra prompts.
return message + super().failure_message()

Expand Down
8 changes: 3 additions & 5 deletions src/modelbench/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

import pydantic
from modelgauge.base_test import BaseTest
from modelgauge.sut import SUT

from modelbench.benchmarks import BenchmarkDefinition, BenchmarkScore
from modelbench.hazards import HazardDefinition, HazardScore
from modelbench.static_content import StaticContent
from modelbench.suts import ModelGaugeSut, SutDescription


def run_command(*args):
Expand Down Expand Up @@ -111,10 +111,8 @@ def default(self, o):
return result
elif isinstance(o, BaseTest):
return o.uid
elif isinstance(o, SutDescription):
result = {"uid": o.key}
if isinstance(o, ModelGaugeSut) and o.instance_initialization():
result["initialization"] = o.instance_initialization()
elif isinstance(o, SUT):
result = {"uid": o.uid, "initialization": o.initialization_record}
return result
elif isinstance(o, pydantic.BaseModel):
return o.model_dump()
Expand Down
Loading

0 comments on commit c826a38

Please sign in to comment.