Skip to content

Commit

Permalink
Only run certain checks for official benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
bkorycki committed Dec 6, 2024
1 parent 6e431bc commit b92b951
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 21 deletions.
55 changes: 35 additions & 20 deletions src/modelbench/consistency_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,32 +363,16 @@ def __init__(self, journal_path):
self.journal_path = journal_path

# Entities to run checks for.
self.benchmark = None
self.suts = None
self.tests = None
self.annotators = None
self._collect_entities()

# Checks to run at each level.
self.test_sut_level_checker = JournalEntityLevelCheck(
"Test x SUT level checks",
[
EachPromptQueuedOnce,
EachPromptRespondedToOnce,
EachResponseTranslatedOnce,
EachItemMeasuredOnce,
NumItemsFinishedEqualsMeasuredItems,
AnnotationsMergedCorrectly,
],
tests=self.tests,
suts=self.suts,
)
self.test_sut_annotator_level_checker = JournalEntityLevelCheck(
"Test x SUT x Annotator checks",
[EachResponseAnnotatedOnce, EachAnnotationTranslatedOnce, MinValidAnnotatorItems],
tests=self.tests,
suts=self.suts,
annotators=self.annotators,
)
self.test_sut_level_checker = None
self.test_sut_annotator_level_checker = None
self._init_checkers()

@property
def _check_groups(self):
Expand All @@ -401,6 +385,9 @@ def _collect_entities(self):
starting_run_entry = search_engine.query("starting run")
assert len(starting_run_entry) == 1

benchmarks = starting_run_entry[0]["benchmarks"]
assert len(benchmarks) == 1, "Consistency checker can only handle single-benchmark journals."
self.benchmark = benchmarks[0]
self.suts = starting_run_entry[0]["suts"]
self.tests = starting_run_entry[0]["tests"]
# TODO: This assumes that all tests use the same annotators! Which is fine for now but may not hold-up later on.
Expand All @@ -421,6 +408,34 @@ def _collect_entities(self):
set([entry["annotator"] for entry in fetched_annotator_entries + cached_annotator_entries])
)

def _init_checkers(self):
test_sut_checks = [
EachPromptQueuedOnce,
EachPromptRespondedToOnce,
EachResponseTranslatedOnce,
EachItemMeasuredOnce,
NumItemsFinishedEqualsMeasuredItems,
]
test_sut_annotator_checks = [EachResponseAnnotatedOnce, EachAnnotationTranslatedOnce]

if "official" in self.benchmark:
test_sut_checks.append(AnnotationsMergedCorrectly)
test_sut_annotator_checks.append(MinValidAnnotatorItems)

self.test_sut_level_checker = JournalEntityLevelCheck(
"Test x SUT level checks",
test_sut_checks,
tests=self.tests,
suts=self.suts,
)
self.test_sut_annotator_level_checker = JournalEntityLevelCheck(
"Test x SUT x Annotator checks",
test_sut_annotator_checks,
tests=self.tests,
suts=self.suts,
annotators=self.annotators,
)

def run(self, verbose=False):
self._collect_results()
self.display_results()
Expand Down
2 changes: 1 addition & 1 deletion tests/modelbench_tests/test_consistency_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def make_basic_run(suts: List[str], test_prompts: Dict[str, List[str]], annotato
Measurements/annotations are all safe."""

journal = []
journal.append({"message": "starting run", "suts": suts, "tests": list(test_prompts.keys())})
journal.append({"message": "starting run", "suts": suts, "tests": list(test_prompts.keys()), "benchmarks": ["official"]})
for sut in suts:
for test, prompts in test_prompts.items():
journal.append({"message": "using test items", "test": test, "using": len(prompts)})
Expand Down

0 comments on commit b92b951

Please sign in to comment.