From a893b5f8a6b5fa943ca933e191604d89165f9544 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 9 Dec 2024 12:03:42 -0500 Subject: [PATCH] More consistency checks + improvements (#719) * Make journal_path required option * Check multiple journals with one command * nicer cli formatting * summarize checks for multiple journals * checker obj does not store memory-intensive journal search engine * print invalid journals + add more tests * Collect annotators from test directly + bug fix + annotator-check tests * add check: each prompt queued once * Each annotation translated once * items finished == num. measured items * Check that annotations are merged correctly * table formatting * add todo * fix tests * check min. # invalid items for annotators * linterrr * ensure the CLI tool's output order is deterministic for the tests to pass (#720) * Make invalid annotator threshold a constant * cleaup * lint * change --journal-path to required arg * Format column names to have spaces * move to rich library for table formatting (handles wrapping * sort rows alphabetically * Add detailed messages about why prompts failed AnnotationsMergedCorrectly check * Add fraction of prompt errors in warning message for merge check * Only run certain checks for official benchmarks * lint --------- Co-authored-by: Roger --- src/modelbench/consistency_checker.py | 305 ++++++++++++--- src/modelbench/run.py | 43 +- .../test_consistency_checker.py | 368 ++++++++++++++++-- tests/modelgauge_tests/test_cli.py | 22 +- 4 files changed, 636 insertions(+), 102 deletions(-) diff --git a/src/modelbench/consistency_checker.py b/src/modelbench/consistency_checker.py index 194df885..2fb58e5f 100644 --- a/src/modelbench/consistency_checker.py +++ b/src/modelbench/consistency_checker.py @@ -1,12 +1,16 @@ +import casefy import json import shutil from abc import ABC, abstractmethod from collections import Counter, defaultdict from itertools import product -from tabulate import tabulate +from rich.console import Console +from rich.table import Table from typing import Dict, List from modelbench.run_journal import journal_reader +from modelgauge.config import load_secrets_from_config +from modelgauge.test_registry import TESTS LINE_WIDTH = shutil.get_terminal_size(fallback=(120, 50)).columns @@ -60,8 +64,6 @@ def failure_message(self) -> str: pass -# TODO: Check that all prompts in a test are unique. - # TODO: # class NumPromptsQueuedMatchesExpected(JournalCheck): # def __init__(self, search_engine: JournalSearch, sut, test): @@ -97,6 +99,18 @@ def failure_message(self) -> str: return "\n\t".join(messages) +class EachPromptQueuedOnce(OneToOneCheck): + def __init__(self, search_engine: JournalSearch, sut, test): + queued_sut_entries = search_engine.query("queuing item", test=test, sut=sut) + queued_sut_prompts = [entry["prompt_id"] for entry in queued_sut_entries] + super().__init__(search_engine.test_prompt_uids(test), queued_sut_prompts) + + def failure_message(self) -> str: + message = "Expected each prompt uid in the test to be queued exactly once.\n\t" + # Call super() to get specific details about duplicates/missing/extra prompts. + return message + super().failure_message() + + class EachPromptRespondedToOnce(OneToOneCheck): def __init__(self, search_engine: JournalSearch, sut, test): super().__init__( @@ -109,9 +123,6 @@ def failure_message(self) -> str: return message + super().failure_message() -# TODO: Add class to check that fetched and cached responses are mutually exclusive. - - class EachResponseTranslatedOnce(OneToOneCheck): def __init__(self, search_engine: JournalSearch, sut, test): all_sut_response_prompts = search_engine.sut_response_prompt_uids_for_test(sut, test) @@ -139,6 +150,27 @@ def failure_message(self) -> str: return message + super().failure_message() +class NumItemsFinishedEqualsMeasuredItems(JournalCheck): + """Checks that the number of finished items for each test/sut corresponds to the number of items + that get to the measured item quality stage.""" + + def __init__(self, search_engine: JournalSearch, sut, test): + measured_item_entries = search_engine.query("measured item quality", sut=sut, test=test) + self.num_measured_items = len(measured_item_entries) + test_scored_entries = search_engine.query("test scored", sut=sut, test=test) + assert ( + len(test_scored_entries) == 1 + ), f"Expected 1 `test scored` entry per test/sut, found {len(test_scored_entries)} for test {test} and sut {sut}." + self.finished_count = test_scored_entries[0]["items_finished"] + + def check(self) -> bool: + return self.num_measured_items == self.finished_count + + def failure_message(self) -> str: + assert not self.check() + return f"The number of measured items ({self.num_measured_items}) does not equal the number of reported finished items ({self.finished_count})." + + class EachResponseAnnotatedOnce(OneToOneCheck): def __init__(self, search_engine: JournalSearch, sut, test, annotator): self.annotator = annotator @@ -149,8 +181,8 @@ def __init__(self, search_engine: JournalSearch, sut, test, annotator): "using cached annotator response", sut=sut, test=test, annotator=annotator ) fetched_annotations = search_engine.query("fetched annotator response", sut=sut, test=test, annotator=annotator) - all_prompts = [response["prompt_id"] for response in cached_annotations + fetched_annotations] - super().__init__([response["prompt_id"] for response in translated_responses], all_prompts) + all_prompts_annotated = [response["prompt_id"] for response in cached_annotations + fetched_annotations] + super().__init__([response["prompt_id"] for response in translated_responses], all_prompts_annotated) def failure_message(self) -> str: message = f"Expected exactly 1 {self.annotator} annotation for each response from SUT {self.sut} in test {self.test}\n\t" @@ -158,6 +190,110 @@ def failure_message(self) -> str: return message + super().failure_message() +class EachAnnotationTranslatedOnce(OneToOneCheck): + def __init__(self, search_engine: JournalSearch, sut, test, annotator): + self.annotator = annotator + self.sut = sut + self.test = test + cached_annotations = search_engine.query( + "using cached annotator response", sut=sut, test=test, annotator=annotator + ) + fetched_annotations = search_engine.query("fetched annotator response", sut=sut, test=test, annotator=annotator) + all_prompts_annotated = [response["prompt_id"] for response in cached_annotations + fetched_annotations] + translated_annotations = search_engine.query("translated annotation", sut=sut, test=test, annotator=annotator) + + super().__init__(all_prompts_annotated, [response["prompt_id"] for response in translated_annotations]) + + def failure_message(self) -> str: + message = f"Expected each {self.annotator} annotation for SUT {self.sut} in test {self.test} to be translated exactly once.\n\t" + # Call super() to get specific details about duplicates/missing/extra prompts. + return message + super().failure_message() + + +class MinValidAnnotatorItems(JournalCheck): + """There should be < some threshold rate of invalid items for each annotator/sut/test combination.""" + + INVALID_THRESHOLD = 0.03 + + def __init__(self, search_engine: JournalSearch, sut, test, annotator): + self.sut = sut + self.test = test + self.annotator = annotator + translated_annotations = search_engine.query("translated annotation", sut=sut, test=test, annotator=annotator) + invalid_annotations = [ + entry for entry in translated_annotations if not entry["annotation"].get("is_valid", True) + ] + self.num_translated_annotations = len(translated_annotations) + self.num_invalid_annotations = len(invalid_annotations) + + def check(self) -> bool: + return float(self.num_invalid_annotations) / self.num_translated_annotations < self.INVALID_THRESHOLD + + def failure_message(self) -> str: + assert not self.check() + return f"Expected less than {self.INVALID_THRESHOLD * 100}% invalid annotations for annotator {self.annotator} in SUT {self.sut} for test {self.test}. Found {self.num_invalid_annotations} invalid annotations out of {self.num_translated_annotations} total." + + +class AnnotationsMergedCorrectly(JournalCheck): + """Checks that the majority vote (or false if <2 valid annotations) among the annotations is returned by measure_item_quality. + NOTE: This only will only work on tests who use a majority vote merging strategy and measure 'is_safe'.""" + + def __init__(self, search_engine: JournalSearch, sut, test): + self.sut = sut + self.test = test + + # Group translated annotations by prompt and annotator. + translated_annotations = search_engine.query("translated annotation", sut=sut, test=test) + self.prompt_annotations = defaultdict(list) + # Assumes each annotator only annotates each prompt once. Already checked by EachAnnotationTranslatedOnce. + for entry in translated_annotations: + # Only consider valid annotations. + if entry["annotation"].get("is_valid", True): + self.prompt_annotations[entry["prompt_id"]].append(entry["annotation"]["is_safe"]) + + # Get measurements for each prompt + measured_item_entries = search_engine.query("measured item quality", sut=sut, test=test) + self.prompt_measurements = {} + for entry in measured_item_entries: + self.prompt_measurements[entry["prompt_id"]] = entry["measurements"]["is_safe"] + + self.prompt_errors = {} # Prompt -> Error message + for prompt in self.prompt_measurements: + error_msg = self._check_prompt(prompt) + if error_msg is not None: + self.prompt_errors[prompt] = error_msg + + def _check_prompt(self, prompt): + """Returns None if annotations were merged correctly for the prompt. Otherwise, return an error message.""" + vote_counts = Counter(self.prompt_annotations[prompt]) + assert set(vote_counts.keys()).issubset( + {True, False} + ), "Incompatible annotation values found. Expected only True/False." + + measurement = self.prompt_measurements[prompt] + if len(self.prompt_annotations[prompt]) < 2 and measurement != 0.0: + # Should be "unsafe" if there are less than 2 annotations. + return "Too few annotations but merged annotation was marked as safe and valid." + elif vote_counts[True] > vote_counts[False] and measurement != 1.0: + # Should me marked as "safe". + return f"is_safe votes are {self.prompt_annotations[prompt]} but merged annotation was marked as unsafe." + elif vote_counts[True] <= vote_counts[False] and measurement != 0.0: + # Should be marked as "unsafe". + return f"is_safe votes are {self.prompt_annotations[prompt]} but merged annotation was marked as safe." + return None + + def check(self) -> bool: + return len(self.prompt_errors) == 0 + + def failure_message(self) -> str: + assert not self.check() + total = len(self.prompt_measurements) + message = f"({total-len(self.prompt_errors)}/{total})\nExpected that measure_item_quality returns majority vote for each prompt. SUT {self.sut}/test {self.test} have mismatching values on the following prompts:" + for prompt, error_msg in self.prompt_errors.items(): + message += f"\n\tPrompt {prompt}: {error_msg}" + return message + + class JournalEntityLevelCheck: """A group of a checks that occupy the same conceptual entity-level in the journal. @@ -200,7 +336,7 @@ def _row_key(**entities) -> str: @staticmethod def _col_name(check_cls) -> str: - return check_cls.__name__ + return casefy.titlecase(check_cls.__name__) def check_is_complete(self) -> bool: """Make sure table is fully populated.""" @@ -224,83 +360,164 @@ def run_checks_for_row(self, search_engine, **entities): class ConsistencyChecker: def __init__(self, journal_path): - # Object holding journal entries - self.search_engine = JournalSearch(journal_path) + self.journal_path = journal_path # Entities to run checks for. + self.benchmark = None self.suts = None self.tests = None self.annotators = None self._collect_entities() # Checks to run at each level. + self.test_sut_level_checker = None + self.test_sut_annotator_level_checker = None + self._init_checkers() + + @property + def _check_groups(self): + """List of all sub-checkers.""" + return [self.test_sut_level_checker, self.test_sut_annotator_level_checker] + + def _collect_entities(self): + # Get all SUTs and tests that were ran in the journal. We will run checks for each (SUT, test) pair. + search_engine = JournalSearch(self.journal_path) + starting_run_entry = search_engine.query("starting run") + assert len(starting_run_entry) == 1 + + benchmarks = starting_run_entry[0]["benchmarks"] + assert len(benchmarks) == 1, "Consistency checker can only handle single-benchmark journals." + self.benchmark = benchmarks[0] + self.suts = starting_run_entry[0]["suts"] + self.tests = starting_run_entry[0]["tests"] + # TODO: This assumes that all tests use the same annotators! Which is fine for now but may not hold-up later on. + try: + secrets = load_secrets_from_config() + test_obj = TESTS.make_instance(self.tests[0], secrets=secrets) + self.annotators = test_obj.get_annotators() + except Exception as e: + # Can't load test object, get annotators from journal instead. + print("Failed to load test object. Collecting annotator UIDs to check from journal instead.") + fetched_annotator_entries = search_engine.query( + "fetched annotator response", test=self.tests[0], sut=self.suts[0] + ) + cached_annotator_entries = search_engine.query( + "using cached annotator response", test=self.tests[0], sut=self.suts[0] + ) + self.annotators = list( + set([entry["annotator"] for entry in fetched_annotator_entries + cached_annotator_entries]) + ) + + def _init_checkers(self): + test_sut_checks = [ + EachPromptQueuedOnce, + EachPromptRespondedToOnce, + EachResponseTranslatedOnce, + EachItemMeasuredOnce, + NumItemsFinishedEqualsMeasuredItems, + ] + test_sut_annotator_checks = [EachResponseAnnotatedOnce, EachAnnotationTranslatedOnce] + + if "official" in self.benchmark: + test_sut_checks.append(AnnotationsMergedCorrectly) + test_sut_annotator_checks.append(MinValidAnnotatorItems) + self.test_sut_level_checker = JournalEntityLevelCheck( "Test x SUT level checks", - [EachPromptRespondedToOnce, EachResponseTranslatedOnce, EachItemMeasuredOnce], + test_sut_checks, tests=self.tests, suts=self.suts, ) self.test_sut_annotator_level_checker = JournalEntityLevelCheck( "Test x SUT x Annotator checks", - [EachResponseAnnotatedOnce], + test_sut_annotator_checks, tests=self.tests, suts=self.suts, annotators=self.annotators, ) - def _collect_entities(self): - # Get all SUTs and tests that were ran in the journal. We will run checks for each (SUT, test) pair. - starting_run_entry = self.search_engine.query("starting run") - assert len(starting_run_entry) == 1 - - self.suts = starting_run_entry[0]["suts"] - self.tests = starting_run_entry[0]["tests"] - # TODO: Find a more reliable way of getting all expected annotators for the tests. THIS WILL FAIL IF THEY ARE ALL CACHED. - annotator_entries = self.search_engine.query("fetched annotator response", test=self.tests[0], sut=self.suts[0]) - self.annotators = list(set([entry["annotator"] for entry in annotator_entries])) - def run(self, verbose=False): - self.collect_results() + self._collect_results() self.display_results() if verbose: self.display_warnings() # TODO: Also run checks for the json record file. - def collect_results(self): + def _collect_results(self): """Populate the results/warning tables of each check level.""" + search_engine = JournalSearch(self.journal_path) for test in self.tests: for sut in self.suts: - self.test_sut_level_checker.run_checks_for_row(self.search_engine, sut=sut, test=test) + self.test_sut_level_checker.run_checks_for_row(search_engine, sut=sut, test=test) for annotator in self.annotators: self.test_sut_annotator_level_checker.run_checks_for_row( - self.search_engine, sut=sut, test=test, annotator=annotator + search_engine, sut=sut, test=test, annotator=annotator ) @staticmethod - def _format_result(result: bool): - return "✅" if result else "❌" + def format_result(result: bool) -> str: + # Emojis + return ":white_check_mark:" if result else ":x:" + + def checks_are_complete(self) -> bool: + for checker in self._check_groups: + if not checker.check_is_complete(): + return False + return True + + def checks_all_passed(self) -> bool: + assert self.checks_are_complete(), "Cannot determine pass/fail for this journal until all checks have been run." + for checker in self._check_groups: + if any(not result for results in checker.results.values() for result in results.values()): + return False + return True def display_results(self): """Print simple table where each row is a single entity (or entity tuple e.g. test x SUT) and each column is a check.""" - check_groups = [self.test_sut_level_checker, self.test_sut_annotator_level_checker] - for checker in check_groups: + assert self.checks_are_complete(), "Cannot display results until all checks have been run." + for checker in self._check_groups: print("Results for", checker.name) - assert checker.check_is_complete() - results_table = [] - for entity, checks in checker.results.items(): - results_table.append([entity] + [self._format_result(checks[c]) for c in checker.check_names]) - print(tabulate(results_table, headers=[", ".join(checker.entity_names)] + list(checker.check_names))) + + table = Table() + # Format header + table.add_column(", ".join(checker.entity_names)) + for check in checker.check_names: + table.add_column(check, max_width=20, justify="center") + # Format rows + sorted_row_names = sorted(checker.row_names) + for entity in sorted_row_names: + entity_results = checker.results[entity] + entity_results_list = [self.format_result(entity_results[c]) for c in checker.check_names] + table.add_row(entity, *entity_results_list) + + console = Console() + console.print(table) print() def display_warnings(self): """Print details about the failed checks.""" - check_groups = [self.test_sut_level_checker, self.test_sut_annotator_level_checker] - for checker in check_groups: + assert self.checks_are_complete(), "Cannot display results until all checks have been run." + for checker in self._check_groups: print("-" * LINE_WIDTH) - assert checker.check_is_complete() if len(checker.warnings) == 0: print(f"All {checker.name} checks passed!") - return - print(f"Failed checks for {checker.name}:") - for warning in checker.warnings: - print(warning) # or something + else: + print(f"Failed checks for {checker.name}:") + for warning in checker.warnings: + print(warning) # or something + + +def summarize_consistency_check_results(checkers: List[ConsistencyChecker]): + """Print a table summarizing the overall pass/fail results for multiple consistency checks.""" + table = Table(min_width=200) + table.add_column("Journal", overflow="fold", no_wrap=False) + table.add_column("All checks passed", justify="center") + for checker in checkers: + if checker.checks_are_complete(): + result = ConsistencyChecker.format_result(checker.checks_all_passed()) + else: + result = "INCOMPLETE" + table.add_row(str(checker.journal_path), result) + + console = Console() + console.print(table) diff --git a/src/modelbench/run.py b/src/modelbench/run.py index 5c283f04..5f207d0a 100644 --- a/src/modelbench/run.py +++ b/src/modelbench/run.py @@ -21,7 +21,7 @@ import modelgauge from modelbench.benchmark_runner import BenchmarkRunner, TqdmRunTracker, JsonRunTracker from modelbench.benchmarks import BenchmarkDefinition, GeneralPurposeAiChatBenchmark, GeneralPurposeAiChatBenchmarkV1 -from modelbench.consistency_checker import ConsistencyChecker +from modelbench.consistency_checker import ConsistencyChecker, summarize_consistency_check_results from modelbench.hazards import STANDARDS from modelbench.record import dump_json from modelbench.static_site_generator import StaticContent, StaticSiteGenerator @@ -153,13 +153,46 @@ def benchmark( # TODO: Consistency check -@cli.command(help="check the consistency of a benchmark run using it's record and journal files.") -@click.option("--journal-path", "-j", type=click.Path(exists=True, dir_okay=False, path_type=pathlib.Path)) +@cli.command( + help="Check the consistency of a benchmark run using it's journal file. You can pass the name of the file OR a directory containing multiple journal files (will be searched recursively)" +) +@click.argument("journal-path", type=click.Path(exists=True, dir_okay=True, path_type=pathlib.Path)) # @click.option("--record-path", "-r", type=click.Path(exists=True, dir_okay=False, path_type=pathlib.Path)) @click.option("--verbose", "-v", default=False, is_flag=True, help="Print details about the failed checks.") def consistency_check(journal_path, verbose): - checker = ConsistencyChecker(journal_path) - checker.run(verbose) + journal_paths = [] + if journal_path.is_dir(): + # Search for all journal files in the directory. + for p in journal_path.rglob("*"): + if p.name.startswith("journal-run") and (p.suffix == ".jsonl" or p.suffix == ".zst"): + journal_paths.append(p) + if len(journal_paths) == 0: + raise click.BadParameter( + f"No journal files starting with 'journal-run' and ending with '.jsonl' or '.zst' found in the directory '{journal_path}'." + ) + else: + journal_paths = [journal_path] + + checkers = [] + checking_error_journals = [] + for p in journal_paths: + echo(termcolor.colored(f"\nChecking consistency of journal {p} ..........", "green")) + try: + checker = ConsistencyChecker(p) + checker.run(verbose) + checkers.append(checker) + except Exception as e: + print("Error running consistency check", e) + checking_error_journals.append(p) + + # Summarize results and unsuccessful checks. + if len(checkers) > 1: + echo(termcolor.colored("\nSummary of consistency checks for all journals:", "green")) + summarize_consistency_check_results(checkers) + if len(checking_error_journals) > 0: + echo(termcolor.colored(f"\nCould not run checks on the following journals:", "red")) + for j in checking_error_journals: + print("\t", j) def find_suts_for_sut_argument(sut_args: List[str]): diff --git a/tests/modelbench_tests/test_consistency_checker.py b/tests/modelbench_tests/test_consistency_checker.py index e920d45f..57f9aef5 100644 --- a/tests/modelbench_tests/test_consistency_checker.py +++ b/tests/modelbench_tests/test_consistency_checker.py @@ -1,53 +1,103 @@ import json import pytest +import re from typing import Dict, List -from modelbench.consistency_checker import ConsistencyChecker +from modelbench.consistency_checker import ( + AnnotationsMergedCorrectly, + ConsistencyChecker, + EachAnnotationTranslatedOnce, + EachItemMeasuredOnce, + EachPromptQueuedOnce, + EachPromptRespondedToOnce, + EachResponseAnnotatedOnce, + EachResponseTranslatedOnce, + MinValidAnnotatorItems, + NumItemsFinishedEqualsMeasuredItems, + summarize_consistency_check_results, +) + +DEFAULT_SUT = "sut1" +DEFAULT_TEST = "test1" +DEFAULT_ANNOTATOR = "annotator1" + + +def make_sut_entry( + message, + test=DEFAULT_TEST, + sut=DEFAULT_SUT, + prompt_id="prompt1", + annotator=DEFAULT_ANNOTATOR, + measurements_is_safe=1.0, + translated_is_safe=True, +): + """Helper function to create a basic SUT journal entry. Not all args will be used depending on the messge.""" + # SUT messages + base_sut_entry = {"test": test, "sut": sut, "prompt_id": prompt_id} + if message in ["queuing item", "fetched sut response", "using cached sut response", "translated sut response"]: + # Basic SUT journal entries that share the same shape. + return {"message": message, **base_sut_entry} + if message == "measured item quality": + return {"message": message, "measurements": {"is_safe": measurements_is_safe}, **base_sut_entry} + + # SUT + Annotator messages + base_annotator_entry = {"annotator": annotator, **base_sut_entry} + if message in ["fetched annotator response", "using cached annotator response"]: + # Basic annotator journal entries that share the same shape. + return {"message": message, **base_annotator_entry} + if message == "translated annotation": + return {"message": message, "annotation": {"is_safe": translated_is_safe}, **base_annotator_entry} + + raise ValueError(f"Can't make SUT entry for message: {message}") def make_basic_run(suts: List[str], test_prompts: Dict[str, List[str]], annotators: List[str]): - """Successful "fresh" benchmark run with all SUT/annotator responses fetched (not cached).""" + """Successful "fresh" benchmark run with all SUT/annotator responses fetched (not cached). + Measurements/annotations are all safe.""" + journal = [] - journal.append({"message": "starting run", "suts": suts, "tests": list(test_prompts.keys())}) + journal.append( + {"message": "starting run", "suts": suts, "tests": list(test_prompts.keys()), "benchmarks": ["official"]} + ) for sut in suts: for test, prompts in test_prompts.items(): journal.append({"message": "using test items", "test": test, "using": len(prompts)}) for prompt in prompts: - # Normal pipeline. - sut_messages = [ + # Normal SUT pipeline. + base_sut_entry = {"test": test, "sut": sut, "prompt_id": prompt} + for message in [ "queuing item", "fetched sut response", "translated sut response", "measured item quality", - ] - for message in sut_messages: - journal.append({"message": message, "test": test, "prompt_id": prompt, "sut": sut}) + ]: + journal.append(make_sut_entry(message, **base_sut_entry)) + # Annotator pipeline. for annotator in annotators: - journal.append( - { - "message": "fetched annotator response", - "test": test, - "sut": sut, - "prompt_id": prompt, - "annotator": annotator, - } - ) + for message in ["fetched annotator response", "translated annotation"]: + journal.append(make_sut_entry(message, annotator=annotator, **base_sut_entry)) + journal.append({"message": "test scored", "test": test, "sut": sut, "items_finished": len(prompts)}) return journal @pytest.fixture def basic_benchmark_run(): return make_basic_run( - suts=["sut1", "sut2"], test_prompts={"test1": ["prompt1", "prompt2"]}, annotators=["annotator1", "annotator2"] + suts=["sut1", "sut2"], + test_prompts={"test1": ["prompt1", "prompt2"]}, + annotators=["annotator1", "annotator2", "annotator3"], ) -def init_checker_for_journal(tmp_path, journal): - journal_path = tmp_path / "journal.jsonl" - with open(journal_path, "w") as f: +def write_journal_to_file(journal, path): + with open(path, "w") as f: for item in journal: f.write(json.dumps(item) + "\n") + +def init_checker_for_journal(tmp_path, journal): + journal_path = tmp_path / "journal.jsonl" + write_journal_to_file(journal, journal_path) checker = ConsistencyChecker(journal_path=journal_path) return checker @@ -63,64 +113,298 @@ def test_normal_run(tmp_path, basic_benchmark_run): assert subchecker.warnings == [] +def test_entities_collected(tmp_path, basic_benchmark_run): + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + + assert sorted(checker.suts) == ["sut1", "sut2"] + assert checker.tests == ["test1"] + assert sorted(checker.annotators) == ["annotator1", "annotator2", "annotator3"] + + +def test_cached_and_fetched_only_annotators_also_collected(tmp_path, basic_benchmark_run): + basic_benchmark_run.append(make_sut_entry("fetched annotator response", annotator="annotator4")) + basic_benchmark_run.append(make_sut_entry("using cached annotator response", annotator="annotator5")) + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + + assert "annotator4" in checker.annotators + assert "annotator5" in checker.annotators + + @pytest.mark.parametrize( "duplicate_message,failed_check", [ - ("fetched sut response", "EachPromptRespondedToOnce"), - ("using cached sut response", "EachPromptRespondedToOnce"), - ("translated sut response", "EachResponseTranslatedOnce"), - ("measured item quality", "EachItemMeasuredOnce"), + ("queuing item", EachPromptQueuedOnce), + ("fetched sut response", EachPromptRespondedToOnce), + ("using cached sut response", EachPromptRespondedToOnce), + ("translated sut response", EachResponseTranslatedOnce), + ("measured item quality", EachItemMeasuredOnce), ], ) def test_run_with_duplicate_sut_stuff(tmp_path, basic_benchmark_run, duplicate_message, failed_check): - basic_benchmark_run.append({"message": duplicate_message, "test": "test1", "sut": "sut1", "prompt_id": "prompt1"}) + basic_benchmark_run.append(make_sut_entry(duplicate_message)) checker = init_checker_for_journal(tmp_path, basic_benchmark_run) checker.run() subchecker = checker.test_sut_level_checker - failed_row = subchecker._row_key(sut="sut1", test="test1") + failed_row = subchecker._row_key(sut=DEFAULT_SUT, test=DEFAULT_TEST) assert subchecker.check_is_complete() - assert subchecker.results[failed_row][failed_check] is False + assert subchecker.results[failed_row][subchecker._col_name(failed_check)] is False # TODO: Check warnings @pytest.mark.parametrize( "extra_earlier_message,failed_check", [ - ("queuing item", "EachPromptRespondedToOnce"), - ("fetched sut response", "EachResponseTranslatedOnce"), - ("translated sut response", "EachItemMeasuredOnce"), + ("queuing item", EachPromptRespondedToOnce), + ("fetched sut response", EachResponseTranslatedOnce), + ("translated sut response", EachItemMeasuredOnce), ], ) def test_run_with_missing_sut_stuff(tmp_path, basic_benchmark_run, extra_earlier_message, failed_check): - basic_benchmark_run.append( - {"message": extra_earlier_message, "test": "test1", "sut": "sut1", "prompt_id": "NEW PROMPT"} - ) + basic_benchmark_run.append(make_sut_entry(extra_earlier_message, prompt_id="NEW PROMPT")) checker = init_checker_for_journal(tmp_path, basic_benchmark_run) checker.run() subchecker = checker.test_sut_level_checker - failed_row = subchecker._row_key(sut="sut1", test="test1") + failed_row = subchecker._row_key(sut=DEFAULT_SUT, test=DEFAULT_TEST) assert subchecker.check_is_complete() - assert subchecker.results[failed_row][failed_check] is False + assert subchecker.results[failed_row][subchecker._col_name(failed_check)] is False + # TODO: Check warnings + + +def test_run_with_missing_queued_item_for_sut(tmp_path, basic_benchmark_run): + # Add extra test item by adding an entry for another sut. + basic_benchmark_run.append(make_sut_entry("queuing item", sut="another_sut", prompt_id="NEW PROMPT")) + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + checker.run() + + subchecker = checker.test_sut_level_checker + failed_row = subchecker._row_key(sut=DEFAULT_SUT, test=DEFAULT_TEST) + assert subchecker.check_is_complete() + assert subchecker.results[failed_row][subchecker._col_name(EachPromptQueuedOnce)] is False # TODO: Check warnings @pytest.mark.parametrize( "extra_message,failed_check", [ - ("fetched sut response", "EachPromptRespondedToOnce"), - ("translated sut response", "EachResponseTranslatedOnce"), - ("measured item quality", "EachItemMeasuredOnce"), + ("fetched sut response", EachPromptRespondedToOnce), + ("translated sut response", EachResponseTranslatedOnce), + ("measured item quality", EachItemMeasuredOnce), + ("measured item quality", NumItemsFinishedEqualsMeasuredItems), ], ) def test_run_with_extra_sut_stuff(tmp_path, basic_benchmark_run, extra_message, failed_check): - basic_benchmark_run.append({"message": extra_message, "test": "test1", "sut": "sut1", "prompt_id": "NEW PROMPT"}) + basic_benchmark_run.append(make_sut_entry(extra_message, prompt_id="NEW PROMPT")) checker = init_checker_for_journal(tmp_path, basic_benchmark_run) checker.run() subchecker = checker.test_sut_level_checker - failed_row = subchecker._row_key(sut="sut1", test="test1") + failed_row = subchecker._row_key(sut=DEFAULT_SUT, test=DEFAULT_TEST) + assert subchecker.check_is_complete() + assert subchecker.results[failed_row][subchecker._col_name(failed_check)] is False + # TODO: Check warnings + + +@pytest.mark.parametrize( + "duplicate_message,failed_check", + [ + ("fetched annotator response", EachResponseAnnotatedOnce), + ("using cached annotator response", EachResponseAnnotatedOnce), + ("translated annotation", EachAnnotationTranslatedOnce), + ], +) +def test_run_with_duplicate_annotator_stuff(tmp_path, basic_benchmark_run, duplicate_message, failed_check): + basic_benchmark_run.append(make_sut_entry(duplicate_message)) + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + checker.run() + + subchecker = checker.test_sut_annotator_level_checker + failed_row = subchecker._row_key(sut=DEFAULT_SUT, test=DEFAULT_TEST, annotator=DEFAULT_ANNOTATOR) + assert subchecker.check_is_complete() + assert subchecker.results[failed_row][subchecker._col_name(failed_check)] is False + # TODO: Check warnings + + +def test_run_with_missing_annotations(tmp_path, basic_benchmark_run): + basic_benchmark_run.append(make_sut_entry("translated sut response", prompt_id="NEW PROMPT")) + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + checker.run() + + subchecker = checker.test_sut_annotator_level_checker + failed_row = subchecker._row_key(sut=DEFAULT_SUT, test=DEFAULT_TEST, annotator=DEFAULT_ANNOTATOR) assert subchecker.check_is_complete() - assert subchecker.results[failed_row][failed_check] is False + assert subchecker.results[failed_row][subchecker._col_name(EachResponseAnnotatedOnce)] is False # TODO: Check warnings + + +@pytest.mark.parametrize( + "extra_earlier_message,failed_check", + [ + ("fetched annotator response", EachAnnotationTranslatedOnce), + ("using cached annotator response", EachAnnotationTranslatedOnce), + ], +) +def test_run_with_missing_annotator_translations(tmp_path, basic_benchmark_run, extra_earlier_message, failed_check): + basic_benchmark_run.append(make_sut_entry(extra_earlier_message, prompt_id="NEW PROMPT")) + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + checker.run() + + subchecker = checker.test_sut_annotator_level_checker + failed_row = subchecker._row_key(sut=DEFAULT_SUT, test=DEFAULT_TEST, annotator=DEFAULT_ANNOTATOR) + assert subchecker.check_is_complete() + assert subchecker.results[failed_row][subchecker._col_name(failed_check)] is False + # TODO: Check warnings + + +@pytest.mark.parametrize( + "extra_message,failed_check", + [ + ("fetched annotator response", EachResponseAnnotatedOnce), + ("using cached annotator response", EachResponseAnnotatedOnce), + ("translated annotation", EachAnnotationTranslatedOnce), + ], +) +def test_run_with_extra_annotator_stuff(tmp_path, basic_benchmark_run, extra_message, failed_check): + basic_benchmark_run.append(make_sut_entry(extra_message, prompt_id="NEW PROMPT")) + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + checker.run() + + subchecker = checker.test_sut_annotator_level_checker + failed_row = subchecker._row_key(sut=DEFAULT_SUT, test=DEFAULT_TEST, annotator=DEFAULT_ANNOTATOR) + assert subchecker.check_is_complete() + assert subchecker.results[failed_row][subchecker._col_name(failed_check)] is False + # TODO: Check warnings + + +@pytest.mark.parametrize("is_safe", [True, False]) +def test_min_valid_items_checker(tmp_path, basic_benchmark_run, is_safe): + # Add some invalid translated annotations for one annotator. + entry = make_sut_entry("translated annotation", prompt_id="NEW PROMPT", translated_is_safe=is_safe) + entry["annotation"]["is_valid"] = False + basic_benchmark_run.append(entry) + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + checker.run() + + subchecker = checker.test_sut_annotator_level_checker + failed_row = subchecker._row_key(sut=DEFAULT_SUT, test=DEFAULT_TEST, annotator=DEFAULT_ANNOTATOR) + assert subchecker.check_is_complete() + assert subchecker.results[failed_row][subchecker._col_name(MinValidAnnotatorItems)] is False + # TODO: Check warnings + + +def test_annotations_merged_correctly(tmp_path, basic_benchmark_run): + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + checker.run() + + subchecker = checker.test_sut_level_checker + failed_row = subchecker._row_key(sut=DEFAULT_SUT, test=DEFAULT_TEST) + assert subchecker.check_is_complete() + assert subchecker.results[failed_row][subchecker._col_name(AnnotationsMergedCorrectly)] is True + + +def test_annotations_merged_correctly_fails_with_one_annotator(tmp_path): + # Simulate run with only 1 annotator. + run = make_basic_run( + suts=["sut1"], + test_prompts={"test1": ["prompt1"]}, + annotators=["annotator1"], + ) + checker = init_checker_for_journal(tmp_path, run) + checker.run() + + subchecker = checker.test_sut_level_checker + failed_row = subchecker._row_key(sut="sut1", test="test1") + assert subchecker.check_is_complete() + assert subchecker.results[failed_row][subchecker._col_name(AnnotationsMergedCorrectly)] is False + + +def test_annotations_merged_correctly_false_safe(tmp_path, basic_benchmark_run): + # Add a bunch of fake unsafe annotations for existing prompt that was measured safe. + entry = make_sut_entry("translated annotation", translated_is_safe=False) + for _ in range(4): + basic_benchmark_run.append(entry) + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + checker.run() + + subchecker = checker.test_sut_level_checker + failed_row = subchecker._row_key(sut=DEFAULT_SUT, test=DEFAULT_TEST) + assert subchecker.check_is_complete() + assert subchecker.results[failed_row][subchecker._col_name(AnnotationsMergedCorrectly)] is False + + +def test_annotations_merged_correctly_false_unsafe(tmp_path, basic_benchmark_run): + # Create safe annotations for new prompt. + entry = make_sut_entry("translated annotation", prompt_id="NEW PROMPT", translated_is_safe=True) + for _ in range(4): + basic_benchmark_run.append(entry) + # Measure that prompt as unsafe (wrongly). + basic_benchmark_run.append( + make_sut_entry("measured item quality", prompt_id="NEW PROMPT", measurements_is_safe=0.0) + ) + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + checker.run() + + subchecker = checker.test_sut_level_checker + failed_row = subchecker._row_key(sut=DEFAULT_SUT, test=DEFAULT_TEST) + assert subchecker.check_is_complete() + assert subchecker.results[failed_row][subchecker._col_name(AnnotationsMergedCorrectly)] is False + + +def _manually_set_results_to_pass(sub_checker): + for row_key in sub_checker.results: + for col_key in sub_checker.check_names: + sub_checker.results[row_key][col_key] = True + + +def test_empty_run_is_not_complete(tmp_path, basic_benchmark_run): + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + assert checker.checks_are_complete() is False + + +def test_partial_run_is_not_complete(tmp_path, basic_benchmark_run): + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + # Manually set results for only one sub-checker. + _manually_set_results_to_pass(checker.test_sut_level_checker) + + assert checker.checks_are_complete() is False + + +def test_finished_run_is_complete(tmp_path, basic_benchmark_run): + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + # Manually set results for all sub-checkers. + for sub_checker in checker._check_groups: + _manually_set_results_to_pass(sub_checker) + + assert checker.checks_are_complete() + + +def journal_result_is_expected_in_summary(journal_path, expected_result, output): + f_result = ConsistencyChecker.format_result(expected_result) + return re.search(rf"{re.escape(str(journal_path))}\s*.*\s*{f_result}", output) + + +def test_summarize_results_pass(tmp_path, capsys, basic_benchmark_run): + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + # Manually set results for all sub-checkers. + for sub_checker in checker._check_groups: + _manually_set_results_to_pass(sub_checker) + summarize_consistency_check_results([checker]) + + captured = capsys.readouterr() + assert "✅" in captured.out + assert "❌" not in captured.out + + +def test_summarize_results_fail(tmp_path, capsys, basic_benchmark_run): + checker = init_checker_for_journal(tmp_path, basic_benchmark_run) + for sub_checker in checker._check_groups: + _manually_set_results_to_pass(sub_checker) + # Make sure there is at least on failed check. + checker.test_sut_level_checker.results[("sut1", "test1")][EachPromptQueuedOnce] = False + + summarize_consistency_check_results([checker]) + + captured = capsys.readouterr() + assert "✅" not in captured.out + assert "❌" in captured.out diff --git a/tests/modelgauge_tests/test_cli.py b/tests/modelgauge_tests/test_cli.py index 7818cada..3b1131f4 100644 --- a/tests/modelgauge_tests/test_cli.py +++ b/tests/modelgauge_tests/test_cli.py @@ -1,10 +1,11 @@ import csv -import jsonlines import re +from unittest.mock import patch + +import jsonlines import pytest from click.testing import CliRunner, Result -from unittest.mock import patch from modelgauge import main from modelgauge.prompt import SUTOptions @@ -126,15 +127,14 @@ def test_run_prompts_normal(tmp_path): with open(tmp_path / out_path, "r") as f: reader = csv.DictReader(f) - row1 = next(reader) - assert row1["UID"] == "p1" - assert row1["Text"] == "Say yes" - assert row1["demo_yes_no"] == "Yes" - - row2 = next(reader) - assert row2["UID"] == "p2" - assert row2["Text"] == "Refuse" - assert row2["demo_yes_no"] == "No" + rows = (next(reader), next(reader)) + rows = sorted(rows, key=lambda row: row["UID"]) + expected = ( + {"UID": "p1", "Text": "Say yes", "demo_yes_no": "Yes"}, + {"UID": "p2", "Text": "Refuse", "demo_yes_no": "No"}, + ) + assert rows[0] == expected[0] + assert rows[1] == expected[1] def test_run_prompts_with_annotators(tmp_path):