diff --git a/src/modelbench/consistency_checker.py b/src/modelbench/consistency_checker.py new file mode 100644 index 00000000..194df885 --- /dev/null +++ b/src/modelbench/consistency_checker.py @@ -0,0 +1,306 @@ +import json +import shutil +from abc import ABC, abstractmethod +from collections import Counter, defaultdict +from itertools import product +from tabulate import tabulate +from typing import Dict, List + +from modelbench.run_journal import journal_reader + +LINE_WIDTH = shutil.get_terminal_size(fallback=(120, 50)).columns + + +class JournalSearch: + def __init__(self, journal_path): + self.journal_path = journal_path + self.message_entries: Dict[str, List] = defaultdict(list) # or maybe sqllite dict? + # Load journal into message_entries dict. + self._read_journal() + + def _read_journal(self): + # Might want to filter out irrelevant messages here. idk. + with journal_reader(self.journal_path) as f: + for line in f: + entry = json.loads(line) + self.message_entries[entry["message"]].append(entry) + + def query(self, message: str, **kwargs): + messages = self.message_entries[message] + return [m for m in messages if all(m[k] == v for k, v in kwargs.items())] + + def num_test_prompts(self, test) -> int: + # TODO: Implement cache. + test_entry = self.query("using test items", test=test) + assert len(test_entry) == 1, "Only 1 `using test items` entry expected per test but found multiple." + return test_entry[0]["using"] + + def test_prompt_uids(self, test) -> List[str]: + """Returns all prompt UIDs queue""" + # TODO: Implement cache. + return [item["prompt_id"] for item in self.query("queuing item", test=test)] + + def sut_response_prompt_uids_for_test(self, sut, test) -> List[str]: + cached_responses = self.query("using cached sut response", sut=sut, test=test) + fetched_responses = self.query("fetched sut response", sut=sut, test=test) + all_prompts = [response["prompt_id"] for response in cached_responses + fetched_responses] + return all_prompts + + +class JournalCheck(ABC): + """All checks must inherit from this class.""" + + @abstractmethod + def check(self) -> bool: + pass + + @abstractmethod + def failure_message(self) -> str: + """The message to display if the check fails.""" + pass + + +# TODO: Check that all prompts in a test are unique. + +# TODO: +# class NumPromptsQueuedMatchesExpected(JournalCheck): +# def __init__(self, search_engine: JournalSearch, sut, test): +# # Load all data needed for the check. +# self.num_test_prompts = search_engine.num_test_prompts(test) + + +class OneToOneCheck(JournalCheck): + """Checks for a one-to-one mapping between two lists of prompt uids.""" + + def __init__(self, expected_prompts: List[str], found_prompts: List[str]): + found_counts = Counter(found_prompts) + # TODO: Could probably make these 3 checks more efficient. + self.duplicates = [uid for uid, count in found_counts.items() if count > 1] + # Check for differences in the two sets. + expected_prompts = set(expected_prompts) + found_prompts = set(found_prompts) + self.missing_prompts = list(expected_prompts - found_prompts) + self.unknown_prompts = list(found_prompts - expected_prompts) + + def check(self) -> bool: + return not any([len(self.duplicates), len(self.missing_prompts), len(self.unknown_prompts)]) + + def failure_message(self) -> str: + assert not self.check() + messages = [] + if len(self.duplicates) > 0: + messages.append(f"The following duplicate prompts were found: {self.duplicates}") + if len(self.missing_prompts) > 0: + messages.append(f"The prompts were expected but missing: {self.missing_prompts}") + if len(self.unknown_prompts) > 0: + messages.append(f"The following prompts were found but were not expected: {self.unknown_prompts}") + return "\n\t".join(messages) + + +class EachPromptRespondedToOnce(OneToOneCheck): + def __init__(self, search_engine: JournalSearch, sut, test): + super().__init__( + search_engine.test_prompt_uids(test), search_engine.sut_response_prompt_uids_for_test(sut, test) + ) + + def failure_message(self) -> str: + message = "Expected exactly 1 SUT response for each prompt in the test.\n\t" + # Call super() to get specific details about duplicates/missing/extra prompts. + return message + super().failure_message() + + +# TODO: Add class to check that fetched and cached responses are mutually exclusive. + + +class EachResponseTranslatedOnce(OneToOneCheck): + def __init__(self, search_engine: JournalSearch, sut, test): + all_sut_response_prompts = search_engine.sut_response_prompt_uids_for_test(sut, test) + translated_responses = search_engine.query("translated sut response", sut=sut, test=test) + super().__init__(all_sut_response_prompts, [response["prompt_id"] for response in translated_responses]) + + def failure_message(self) -> str: + message = "Expected each SUT response to be translated exactly once.\n\t" + # Call super() to get specific details about duplicates/missing/extra prompts. + return message + super().failure_message() + + +class EachItemMeasuredOnce(OneToOneCheck): + def __init__(self, search_engine: JournalSearch, sut, test): + translated_responses = search_engine.query("translated sut response", sut=sut, test=test) + measured_items = search_engine.query("measured item quality", sut=sut, test=test) + super().__init__( + [response["prompt_id"] for response in translated_responses], + [response["prompt_id"] for response in measured_items], + ) + + def failure_message(self) -> str: + message = "Expected every prompt-response to be measured exactly once.\n" + # Call super() to get specific details about duplicates/missing/extra prompts. + return message + super().failure_message() + + +class EachResponseAnnotatedOnce(OneToOneCheck): + def __init__(self, search_engine: JournalSearch, sut, test, annotator): + self.annotator = annotator + self.sut = sut + self.test = test + translated_responses = search_engine.query("translated sut response", sut=sut, test=test) + cached_annotations = search_engine.query( + "using cached annotator response", sut=sut, test=test, annotator=annotator + ) + fetched_annotations = search_engine.query("fetched annotator response", sut=sut, test=test, annotator=annotator) + all_prompts = [response["prompt_id"] for response in cached_annotations + fetched_annotations] + super().__init__([response["prompt_id"] for response in translated_responses], all_prompts) + + def failure_message(self) -> str: + message = f"Expected exactly 1 {self.annotator} annotation for each response from SUT {self.sut} in test {self.test}\n\t" + # Call super() to get specific details about duplicates/missing/extra prompts. + return message + super().failure_message() + + +class JournalEntityLevelCheck: + """A group of a checks that occupy the same conceptual entity-level in the journal. + + All checks in a group must accept the same entities in their init. params.""" + + def __init__(self, name, check_classes, **entity_sets): + """Each entity_set kwarg is a list of a type of entity.""" + self.name = name + self.check_classes: List = check_classes + # Outer-level dictionary keys are the entity tuples, inner dict. keys are the check names. + # Values are boolean check results. + self.results: Dict[str, Dict[str, bool | None]] | None = None + self.row_names = None + self.check_names = None + self._init_results_table(**entity_sets) + # List of warning messages for failed checks. + self.warnings: List[str] = [] + + def _init_results_table(self, **entity_sets): + # Create an empty table where each row is an entity (or entity tuple) and each column is a check. + self.results = defaultdict(dict) + self.entity_names = sorted(list(entity_sets.keys())) + self.row_names = [] + self.check_names = [] + + for col in self.check_classes: + self.check_names.append(self._col_name(col)) + for entity_tuple in product(*entity_sets.values()): + entity_dict = dict(zip(entity_sets.keys(), entity_tuple)) + row_key = self._row_key(**entity_dict) + self.row_names.append(row_key) + # Each check is initialized to None to indicate it hasn't been run yet. + self.results[row_key] = {col: None for col in self.check_names} + + @staticmethod + def _row_key(**entities) -> str: + """Return string key for a given set of entities.""" + sorted_keys = sorted(entities.keys()) + return ", ".join([entities[k] for k in sorted_keys]) + + @staticmethod + def _col_name(check_cls) -> str: + return check_cls.__name__ + + def check_is_complete(self) -> bool: + """Make sure table is fully populated.""" + for row in self.row_names: + for check in self.check_names: + if self.results[row][check] is None: + return False + return True + + def run_checks_for_row(self, search_engine, **entities): + """Run all individual checks on a given entity tuple and store results and warnings.""" + for check_cls in self.check_classes: + check = check_cls(search_engine, **entities) + result = check.check() + self.results[self._row_key(**entities)][self._col_name(check_cls)] = result + if not result: + # TODO: Add check name to warning message. + self.warnings.append(f"{self._col_name(check_cls)}: {check.failure_message()}") + + +class ConsistencyChecker: + + def __init__(self, journal_path): + # Object holding journal entries + self.search_engine = JournalSearch(journal_path) + + # Entities to run checks for. + self.suts = None + self.tests = None + self.annotators = None + self._collect_entities() + + # Checks to run at each level. + self.test_sut_level_checker = JournalEntityLevelCheck( + "Test x SUT level checks", + [EachPromptRespondedToOnce, EachResponseTranslatedOnce, EachItemMeasuredOnce], + tests=self.tests, + suts=self.suts, + ) + self.test_sut_annotator_level_checker = JournalEntityLevelCheck( + "Test x SUT x Annotator checks", + [EachResponseAnnotatedOnce], + tests=self.tests, + suts=self.suts, + annotators=self.annotators, + ) + + def _collect_entities(self): + # Get all SUTs and tests that were ran in the journal. We will run checks for each (SUT, test) pair. + starting_run_entry = self.search_engine.query("starting run") + assert len(starting_run_entry) == 1 + + self.suts = starting_run_entry[0]["suts"] + self.tests = starting_run_entry[0]["tests"] + # TODO: Find a more reliable way of getting all expected annotators for the tests. THIS WILL FAIL IF THEY ARE ALL CACHED. + annotator_entries = self.search_engine.query("fetched annotator response", test=self.tests[0], sut=self.suts[0]) + self.annotators = list(set([entry["annotator"] for entry in annotator_entries])) + + def run(self, verbose=False): + self.collect_results() + self.display_results() + if verbose: + self.display_warnings() + # TODO: Also run checks for the json record file. + + def collect_results(self): + """Populate the results/warning tables of each check level.""" + for test in self.tests: + for sut in self.suts: + self.test_sut_level_checker.run_checks_for_row(self.search_engine, sut=sut, test=test) + for annotator in self.annotators: + self.test_sut_annotator_level_checker.run_checks_for_row( + self.search_engine, sut=sut, test=test, annotator=annotator + ) + + @staticmethod + def _format_result(result: bool): + return "✅" if result else "❌" + + def display_results(self): + """Print simple table where each row is a single entity (or entity tuple e.g. test x SUT) and each column is a check.""" + check_groups = [self.test_sut_level_checker, self.test_sut_annotator_level_checker] + for checker in check_groups: + print("Results for", checker.name) + assert checker.check_is_complete() + results_table = [] + for entity, checks in checker.results.items(): + results_table.append([entity] + [self._format_result(checks[c]) for c in checker.check_names]) + print(tabulate(results_table, headers=[", ".join(checker.entity_names)] + list(checker.check_names))) + print() + + def display_warnings(self): + """Print details about the failed checks.""" + check_groups = [self.test_sut_level_checker, self.test_sut_annotator_level_checker] + for checker in check_groups: + print("-" * LINE_WIDTH) + assert checker.check_is_complete() + if len(checker.warnings) == 0: + print(f"All {checker.name} checks passed!") + return + print(f"Failed checks for {checker.name}:") + for warning in checker.warnings: + print(warning) # or something diff --git a/src/modelbench/run.py b/src/modelbench/run.py index 3aa3aef5..ae945f10 100644 --- a/src/modelbench/run.py +++ b/src/modelbench/run.py @@ -18,6 +18,7 @@ import modelgauge from modelbench.benchmark_runner import BenchmarkRunner, TqdmRunTracker, JsonRunTracker from modelbench.benchmarks import BenchmarkDefinition, GeneralPurposeAiChatBenchmark, GeneralPurposeAiChatBenchmarkV1 +from modelbench.consistency_checker import ConsistencyChecker from modelbench.hazards import STANDARDS from modelbench.record import dump_json from modelbench.static_site_generator import StaticContent, StaticSiteGenerator @@ -133,6 +134,16 @@ def benchmark( json_path = output_dir / f"benchmark_record-{b.uid}.json" scores = [score for score in benchmark_scores if score.benchmark_definition == b] dump_json(json_path, start_time, b, scores) + # TODO: Consistency check + + +@cli.command(help="check the consistency of a benchmark run using it's record and journal files.") +@click.option("--journal-path", "-j", type=click.Path(exists=True, dir_okay=False, path_type=pathlib.Path)) +# @click.option("--record-path", "-r", type=click.Path(exists=True, dir_okay=False, path_type=pathlib.Path)) +@click.option("--verbose", "-v", default=False, is_flag=True, help="Print details about the failed checks.") +def consistency_check(journal_path, verbose): + checker = ConsistencyChecker(journal_path) + checker.run(verbose) def find_suts_for_sut_argument(sut_args: List[str]): diff --git a/src/modelbench/run_journal.py b/src/modelbench/run_journal.py index e7da8451..bd43ebb5 100644 --- a/src/modelbench/run_journal.py +++ b/src/modelbench/run_journal.py @@ -4,12 +4,12 @@ from contextlib import AbstractContextManager from datetime import datetime, timezone from enum import Enum -from io import IOBase +from io import IOBase, TextIOWrapper from typing import Sequence, Mapping from unittest.mock import MagicMock from pydantic import BaseModel -from zstandard.backend_cffi import ZstdCompressor +from zstandard.backend_cffi import ZstdCompressor, ZstdDecompressor from modelbench.benchmark_runner_items import TestRunItem, Timer from modelgauge.sut import SUTResponse @@ -54,6 +54,17 @@ def for_journal(o): return o +def journal_reader(path): + """Loads existing journal file, decompressing if necessary.""" + if path.suffix == ".zst": + raw_fh = open(path, "rb") + dctx = ZstdDecompressor() + sr = dctx.stream_reader(raw_fh) + return TextIOWrapper(sr, encoding="utf-8") + else: + return open(path, "r") + + class RunJournal(AbstractContextManager): def __init__(self, output=None): diff --git a/tests/modelbench_tests/test_consistency_checker.py b/tests/modelbench_tests/test_consistency_checker.py new file mode 100644 index 00000000..e920d45f --- /dev/null +++ b/tests/modelbench_tests/test_consistency_checker.py @@ -0,0 +1,126 @@ +import json +import pytest +from typing import Dict, List + +from modelbench.consistency_checker import ConsistencyChecker + + +def make_basic_run(suts: List[str], test_prompts: Dict[str, List[str]], annotators: List[str]): + """Successful "fresh" benchmark run with all SUT/annotator responses fetched (not cached).""" + journal = [] + journal.append({"message": "starting run", "suts": suts, "tests": list(test_prompts.keys())}) + for sut in suts: + for test, prompts in test_prompts.items(): + journal.append({"message": "using test items", "test": test, "using": len(prompts)}) + for prompt in prompts: + # Normal pipeline. + sut_messages = [ + "queuing item", + "fetched sut response", + "translated sut response", + "measured item quality", + ] + for message in sut_messages: + journal.append({"message": message, "test": test, "prompt_id": prompt, "sut": sut}) + for annotator in annotators: + journal.append( + { + "message": "fetched annotator response", + "test": test, + "sut": sut, + "prompt_id": prompt, + "annotator": annotator, + } + ) + return journal + + +@pytest.fixture +def basic_benchmark_run(): + return make_basic_run( + suts=["sut1", "sut2"], test_prompts={"test1": ["prompt1", "prompt2"]}, annotators=["annotator1", "annotator2"] + ) + + +def init_checker_for_journal(tmp_path, journal): + journal_path = tmp_path / "journal.jsonl" + with open(journal_path, "w") as f: + for item in journal: + f.write(json.dumps(item) + "\n") + + checker = ConsistencyChecker(journal_path=journal_path) + return checker + + +def test_normal_run(tmp_path, basic_benchmark_run): + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + checker.run() + + for subchecker in [checker.test_sut_level_checker, checker.test_sut_annotator_level_checker]: + assert subchecker.check_is_complete() + for row in subchecker.results.values(): + assert all(row) + assert subchecker.warnings == [] + + +@pytest.mark.parametrize( + "duplicate_message,failed_check", + [ + ("fetched sut response", "EachPromptRespondedToOnce"), + ("using cached sut response", "EachPromptRespondedToOnce"), + ("translated sut response", "EachResponseTranslatedOnce"), + ("measured item quality", "EachItemMeasuredOnce"), + ], +) +def test_run_with_duplicate_sut_stuff(tmp_path, basic_benchmark_run, duplicate_message, failed_check): + basic_benchmark_run.append({"message": duplicate_message, "test": "test1", "sut": "sut1", "prompt_id": "prompt1"}) + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + checker.run() + + subchecker = checker.test_sut_level_checker + failed_row = subchecker._row_key(sut="sut1", test="test1") + assert subchecker.check_is_complete() + assert subchecker.results[failed_row][failed_check] is False + # TODO: Check warnings + + +@pytest.mark.parametrize( + "extra_earlier_message,failed_check", + [ + ("queuing item", "EachPromptRespondedToOnce"), + ("fetched sut response", "EachResponseTranslatedOnce"), + ("translated sut response", "EachItemMeasuredOnce"), + ], +) +def test_run_with_missing_sut_stuff(tmp_path, basic_benchmark_run, extra_earlier_message, failed_check): + basic_benchmark_run.append( + {"message": extra_earlier_message, "test": "test1", "sut": "sut1", "prompt_id": "NEW PROMPT"} + ) + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + checker.run() + + subchecker = checker.test_sut_level_checker + failed_row = subchecker._row_key(sut="sut1", test="test1") + assert subchecker.check_is_complete() + assert subchecker.results[failed_row][failed_check] is False + # TODO: Check warnings + + +@pytest.mark.parametrize( + "extra_message,failed_check", + [ + ("fetched sut response", "EachPromptRespondedToOnce"), + ("translated sut response", "EachResponseTranslatedOnce"), + ("measured item quality", "EachItemMeasuredOnce"), + ], +) +def test_run_with_extra_sut_stuff(tmp_path, basic_benchmark_run, extra_message, failed_check): + basic_benchmark_run.append({"message": extra_message, "test": "test1", "sut": "sut1", "prompt_id": "NEW PROMPT"}) + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + checker.run() + + subchecker = checker.test_sut_level_checker + failed_row = subchecker._row_key(sut="sut1", test="test1") + assert subchecker.check_is_complete() + assert subchecker.results[failed_row][failed_check] is False + # TODO: Check warnings