Skip to content

Commit

Permalink
Journal consistency checker (#680)
Browse files Browse the repository at this point in the history
* empty cli command

* first draft of class structure

* Display results in table + --verbose option

* restructure classes

* remove required_messages method

* another restructure

* Print named of failed checks + dividers

* add trivial check + display emojis

* waterfall check i.e. only compare to previous stage

* Fix difference check

* Add one annotator check:

* get terminal size from shutil instead

* Add sut,test,etc.. to table header

* remove record_path args

* Basic test

* more unit tests
  • Loading branch information
bkorycki authored Nov 13, 2024
1 parent 5050006 commit 2a36c92
Show file tree
Hide file tree
Showing 4 changed files with 456 additions and 2 deletions.
306 changes: 306 additions & 0 deletions src/modelbench/consistency_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
import json
import shutil
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from itertools import product
from tabulate import tabulate
from typing import Dict, List

from modelbench.run_journal import journal_reader

LINE_WIDTH = shutil.get_terminal_size(fallback=(120, 50)).columns


class JournalSearch:
def __init__(self, journal_path):
self.journal_path = journal_path
self.message_entries: Dict[str, List] = defaultdict(list) # or maybe sqllite dict?
# Load journal into message_entries dict.
self._read_journal()

def _read_journal(self):
# Might want to filter out irrelevant messages here. idk.
with journal_reader(self.journal_path) as f:
for line in f:
entry = json.loads(line)
self.message_entries[entry["message"]].append(entry)

def query(self, message: str, **kwargs):
messages = self.message_entries[message]
return [m for m in messages if all(m[k] == v for k, v in kwargs.items())]

def num_test_prompts(self, test) -> int:
# TODO: Implement cache.
test_entry = self.query("using test items", test=test)
assert len(test_entry) == 1, "Only 1 `using test items` entry expected per test but found multiple."
return test_entry[0]["using"]

def test_prompt_uids(self, test) -> List[str]:
"""Returns all prompt UIDs queue"""
# TODO: Implement cache.
return [item["prompt_id"] for item in self.query("queuing item", test=test)]

def sut_response_prompt_uids_for_test(self, sut, test) -> List[str]:
cached_responses = self.query("using cached sut response", sut=sut, test=test)
fetched_responses = self.query("fetched sut response", sut=sut, test=test)
all_prompts = [response["prompt_id"] for response in cached_responses + fetched_responses]
return all_prompts


class JournalCheck(ABC):
"""All checks must inherit from this class."""

@abstractmethod
def check(self) -> bool:
pass

@abstractmethod
def failure_message(self) -> str:
"""The message to display if the check fails."""
pass


# TODO: Check that all prompts in a test are unique.

# TODO:
# class NumPromptsQueuedMatchesExpected(JournalCheck):
# def __init__(self, search_engine: JournalSearch, sut, test):
# # Load all data needed for the check.
# self.num_test_prompts = search_engine.num_test_prompts(test)


class OneToOneCheck(JournalCheck):
"""Checks for a one-to-one mapping between two lists of prompt uids."""

def __init__(self, expected_prompts: List[str], found_prompts: List[str]):
found_counts = Counter(found_prompts)
# TODO: Could probably make these 3 checks more efficient.
self.duplicates = [uid for uid, count in found_counts.items() if count > 1]
# Check for differences in the two sets.
expected_prompts = set(expected_prompts)
found_prompts = set(found_prompts)
self.missing_prompts = list(expected_prompts - found_prompts)
self.unknown_prompts = list(found_prompts - expected_prompts)

def check(self) -> bool:
return not any([len(self.duplicates), len(self.missing_prompts), len(self.unknown_prompts)])

def failure_message(self) -> str:
assert not self.check()
messages = []
if len(self.duplicates) > 0:
messages.append(f"The following duplicate prompts were found: {self.duplicates}")
if len(self.missing_prompts) > 0:
messages.append(f"The prompts were expected but missing: {self.missing_prompts}")
if len(self.unknown_prompts) > 0:
messages.append(f"The following prompts were found but were not expected: {self.unknown_prompts}")
return "\n\t".join(messages)


class EachPromptRespondedToOnce(OneToOneCheck):
def __init__(self, search_engine: JournalSearch, sut, test):
super().__init__(
search_engine.test_prompt_uids(test), search_engine.sut_response_prompt_uids_for_test(sut, test)
)

def failure_message(self) -> str:
message = "Expected exactly 1 SUT response for each prompt in the test.\n\t"
# Call super() to get specific details about duplicates/missing/extra prompts.
return message + super().failure_message()


# TODO: Add class to check that fetched and cached responses are mutually exclusive.


class EachResponseTranslatedOnce(OneToOneCheck):
def __init__(self, search_engine: JournalSearch, sut, test):
all_sut_response_prompts = search_engine.sut_response_prompt_uids_for_test(sut, test)
translated_responses = search_engine.query("translated sut response", sut=sut, test=test)
super().__init__(all_sut_response_prompts, [response["prompt_id"] for response in translated_responses])

def failure_message(self) -> str:
message = "Expected each SUT response to be translated exactly once.\n\t"
# Call super() to get specific details about duplicates/missing/extra prompts.
return message + super().failure_message()


class EachItemMeasuredOnce(OneToOneCheck):
def __init__(self, search_engine: JournalSearch, sut, test):
translated_responses = search_engine.query("translated sut response", sut=sut, test=test)
measured_items = search_engine.query("measured item quality", sut=sut, test=test)
super().__init__(
[response["prompt_id"] for response in translated_responses],
[response["prompt_id"] for response in measured_items],
)

def failure_message(self) -> str:
message = "Expected every prompt-response to be measured exactly once.\n"
# Call super() to get specific details about duplicates/missing/extra prompts.
return message + super().failure_message()


class EachResponseAnnotatedOnce(OneToOneCheck):
def __init__(self, search_engine: JournalSearch, sut, test, annotator):
self.annotator = annotator
self.sut = sut
self.test = test
translated_responses = search_engine.query("translated sut response", sut=sut, test=test)
cached_annotations = search_engine.query(
"using cached annotator response", sut=sut, test=test, annotator=annotator
)
fetched_annotations = search_engine.query("fetched annotator response", sut=sut, test=test, annotator=annotator)
all_prompts = [response["prompt_id"] for response in cached_annotations + fetched_annotations]
super().__init__([response["prompt_id"] for response in translated_responses], all_prompts)

def failure_message(self) -> str:
message = f"Expected exactly 1 {self.annotator} annotation for each response from SUT {self.sut} in test {self.test}\n\t"
# Call super() to get specific details about duplicates/missing/extra prompts.
return message + super().failure_message()


class JournalEntityLevelCheck:
"""A group of a checks that occupy the same conceptual entity-level in the journal.
All checks in a group must accept the same entities in their init. params."""

def __init__(self, name, check_classes, **entity_sets):
"""Each entity_set kwarg is a list of a type of entity."""
self.name = name
self.check_classes: List = check_classes
# Outer-level dictionary keys are the entity tuples, inner dict. keys are the check names.
# Values are boolean check results.
self.results: Dict[str, Dict[str, bool | None]] | None = None
self.row_names = None
self.check_names = None
self._init_results_table(**entity_sets)
# List of warning messages for failed checks.
self.warnings: List[str] = []

def _init_results_table(self, **entity_sets):
# Create an empty table where each row is an entity (or entity tuple) and each column is a check.
self.results = defaultdict(dict)
self.entity_names = sorted(list(entity_sets.keys()))
self.row_names = []
self.check_names = []

for col in self.check_classes:
self.check_names.append(self._col_name(col))
for entity_tuple in product(*entity_sets.values()):
entity_dict = dict(zip(entity_sets.keys(), entity_tuple))
row_key = self._row_key(**entity_dict)
self.row_names.append(row_key)
# Each check is initialized to None to indicate it hasn't been run yet.
self.results[row_key] = {col: None for col in self.check_names}

@staticmethod
def _row_key(**entities) -> str:
"""Return string key for a given set of entities."""
sorted_keys = sorted(entities.keys())
return ", ".join([entities[k] for k in sorted_keys])

@staticmethod
def _col_name(check_cls) -> str:
return check_cls.__name__

def check_is_complete(self) -> bool:
"""Make sure table is fully populated."""
for row in self.row_names:
for check in self.check_names:
if self.results[row][check] is None:
return False
return True

def run_checks_for_row(self, search_engine, **entities):
"""Run all individual checks on a given entity tuple and store results and warnings."""
for check_cls in self.check_classes:
check = check_cls(search_engine, **entities)
result = check.check()
self.results[self._row_key(**entities)][self._col_name(check_cls)] = result
if not result:
# TODO: Add check name to warning message.
self.warnings.append(f"{self._col_name(check_cls)}: {check.failure_message()}")


class ConsistencyChecker:

def __init__(self, journal_path):
# Object holding journal entries
self.search_engine = JournalSearch(journal_path)

# Entities to run checks for.
self.suts = None
self.tests = None
self.annotators = None
self._collect_entities()

# Checks to run at each level.
self.test_sut_level_checker = JournalEntityLevelCheck(
"Test x SUT level checks",
[EachPromptRespondedToOnce, EachResponseTranslatedOnce, EachItemMeasuredOnce],
tests=self.tests,
suts=self.suts,
)
self.test_sut_annotator_level_checker = JournalEntityLevelCheck(
"Test x SUT x Annotator checks",
[EachResponseAnnotatedOnce],
tests=self.tests,
suts=self.suts,
annotators=self.annotators,
)

def _collect_entities(self):
# Get all SUTs and tests that were ran in the journal. We will run checks for each (SUT, test) pair.
starting_run_entry = self.search_engine.query("starting run")
assert len(starting_run_entry) == 1

self.suts = starting_run_entry[0]["suts"]
self.tests = starting_run_entry[0]["tests"]
# TODO: Find a more reliable way of getting all expected annotators for the tests. THIS WILL FAIL IF THEY ARE ALL CACHED.
annotator_entries = self.search_engine.query("fetched annotator response", test=self.tests[0], sut=self.suts[0])
self.annotators = list(set([entry["annotator"] for entry in annotator_entries]))

def run(self, verbose=False):
self.collect_results()
self.display_results()
if verbose:
self.display_warnings()
# TODO: Also run checks for the json record file.

def collect_results(self):
"""Populate the results/warning tables of each check level."""
for test in self.tests:
for sut in self.suts:
self.test_sut_level_checker.run_checks_for_row(self.search_engine, sut=sut, test=test)
for annotator in self.annotators:
self.test_sut_annotator_level_checker.run_checks_for_row(
self.search_engine, sut=sut, test=test, annotator=annotator
)

@staticmethod
def _format_result(result: bool):
return "✅" if result else "❌"

def display_results(self):
"""Print simple table where each row is a single entity (or entity tuple e.g. test x SUT) and each column is a check."""
check_groups = [self.test_sut_level_checker, self.test_sut_annotator_level_checker]
for checker in check_groups:
print("Results for", checker.name)
assert checker.check_is_complete()
results_table = []
for entity, checks in checker.results.items():
results_table.append([entity] + [self._format_result(checks[c]) for c in checker.check_names])
print(tabulate(results_table, headers=[", ".join(checker.entity_names)] + list(checker.check_names)))
print()

def display_warnings(self):
"""Print details about the failed checks."""
check_groups = [self.test_sut_level_checker, self.test_sut_annotator_level_checker]
for checker in check_groups:
print("-" * LINE_WIDTH)
assert checker.check_is_complete()
if len(checker.warnings) == 0:
print(f"All {checker.name} checks passed!")
return
print(f"Failed checks for {checker.name}:")
for warning in checker.warnings:
print(warning) # or something
11 changes: 11 additions & 0 deletions src/modelbench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import modelgauge
from modelbench.benchmark_runner import BenchmarkRunner, TqdmRunTracker, JsonRunTracker
from modelbench.benchmarks import BenchmarkDefinition, GeneralPurposeAiChatBenchmark, GeneralPurposeAiChatBenchmarkV1
from modelbench.consistency_checker import ConsistencyChecker
from modelbench.hazards import STANDARDS
from modelbench.record import dump_json
from modelbench.static_site_generator import StaticContent, StaticSiteGenerator
Expand Down Expand Up @@ -133,6 +134,16 @@ def benchmark(
json_path = output_dir / f"benchmark_record-{b.uid}.json"
scores = [score for score in benchmark_scores if score.benchmark_definition == b]
dump_json(json_path, start_time, b, scores)
# TODO: Consistency check


@cli.command(help="check the consistency of a benchmark run using it's record and journal files.")
@click.option("--journal-path", "-j", type=click.Path(exists=True, dir_okay=False, path_type=pathlib.Path))
# @click.option("--record-path", "-r", type=click.Path(exists=True, dir_okay=False, path_type=pathlib.Path))
@click.option("--verbose", "-v", default=False, is_flag=True, help="Print details about the failed checks.")
def consistency_check(journal_path, verbose):
checker = ConsistencyChecker(journal_path)
checker.run(verbose)


def find_suts_for_sut_argument(sut_args: List[str]):
Expand Down
15 changes: 13 additions & 2 deletions src/modelbench/run_journal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from contextlib import AbstractContextManager
from datetime import datetime, timezone
from enum import Enum
from io import IOBase
from io import IOBase, TextIOWrapper
from typing import Sequence, Mapping
from unittest.mock import MagicMock

from pydantic import BaseModel
from zstandard.backend_cffi import ZstdCompressor
from zstandard.backend_cffi import ZstdCompressor, ZstdDecompressor

from modelbench.benchmark_runner_items import TestRunItem, Timer
from modelgauge.sut import SUTResponse
Expand Down Expand Up @@ -54,6 +54,17 @@ def for_journal(o):
return o


def journal_reader(path):
"""Loads existing journal file, decompressing if necessary."""
if path.suffix == ".zst":
raw_fh = open(path, "rb")
dctx = ZstdDecompressor()
sr = dctx.stream_reader(raw_fh)
return TextIOWrapper(sr, encoding="utf-8")
else:
return open(path, "r")


class RunJournal(AbstractContextManager):

def __init__(self, output=None):
Expand Down
Loading

0 comments on commit 2a36c92

Please sign in to comment.