Skip to content

Commit

Permalink
add the BenchmarkScore's scoring log to the journal; some linting
Browse files Browse the repository at this point in the history
  • Loading branch information
rogthefrog committed Nov 13, 2024
1 parent f8ad9ce commit 19dccba
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 29 deletions.
32 changes: 14 additions & 18 deletions src/modelbench/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,26 @@
from collections import defaultdict
from datetime import datetime
from multiprocessing.pool import ThreadPool
from typing import Iterable, Sequence, Optional, Any
from typing import Any, Iterable, Optional, Sequence

from pydantic import BaseModel
from tqdm import tqdm

from modelbench.benchmark_runner_items import ModelgaugeTestWrapper, TestRunItem, Timer
from modelbench.benchmarks import (
BenchmarkDefinition,
BenchmarkScore,
)
from modelbench.cache import MBCache, DiskCache
from modelbench.run_journal import RunJournal
from modelbench.suts import ModelGaugeSut
from modelgauge.annotator import CompletionAnnotator
from modelgauge.annotator_registry import ANNOTATORS
from modelgauge.base_test import PromptResponseTest, TestResult
from modelgauge.config import raise_if_missing_from_config
from modelgauge.pipeline import Source, Pipe, Sink, Pipeline, NullCache
from modelgauge.pipeline import NullCache, Pipe, Pipeline, Sink, Source
from modelgauge.prompt import TextPrompt
from modelgauge.records import TestRecord
from modelgauge.single_turn_prompt_response import (
TestItem,
PromptWithContext,
)
from modelgauge.sut import SUTResponse, SUTCompletion
from modelgauge.single_turn_prompt_response import PromptWithContext, TestItem
from modelgauge.sut import SUTCompletion, SUTResponse

from pydantic import BaseModel
from tqdm import tqdm

from modelbench.benchmark_runner_items import ModelgaugeTestWrapper, TestRunItem, Timer
from modelbench.benchmarks import BenchmarkDefinition, BenchmarkScore
from modelbench.cache import DiskCache, MBCache
from modelbench.run_journal import RunJournal
from modelbench.suts import ModelGaugeSut

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -644,4 +639,5 @@ def _calculate_benchmark_scores(self, benchmark_run):
sut=sut.uid,
numeric_grade=benchmark_score.numeric_grade(),
text_grade=benchmark_score.text_grade(),
scoring_log=benchmark_score._scoring_log,
)
12 changes: 5 additions & 7 deletions src/modelbench/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
self._scoring_log = {}

def numeric_grade(self, locale: str = "en_us", version: str = "1.0", standards: Standards = STANDARDS):
# TODO: locale
# TODO: Remove version 0.5
# Keeping this temporarily to compare new scoring to old
if version == "0.5":
Expand All @@ -45,23 +44,22 @@ def _numeric_grade_v1(self, locale: str = "en_us", standards: Standards = STANDA
total_scored_items += hazard_score.num_scored_items
total_safe_items += hazard_score.num_safe_items

self._scoring_log["locale"] = locale
self._scoring_log["num_hazards"] = len(self.hazard_scores)
self._scoring_log["total_scored"] = total_scored_items
self._scoring_log["total_safe"] = total_safe_items
self._scoring_log["standards"] = str(standards.path)
# TODO handle this
if total_scored_items == 0:
print("Number of scored items in SafeHazard objects shouldn't be 0")
return self._numeric_grade_v05()

frac_safe = total_safe_items / total_scored_items
reference_safe = standards.average_standard_across_references(locale=locale)
ordinal_grade = score_to_ordinal_grade(frac_safe, reference_safe)

self._scoring_log["locale"] = locale
self._scoring_log["num_hazards"] = len(self.hazard_scores)
self._scoring_log["total_scored"] = total_scored_items
self._scoring_log["total_safe"] = total_safe_items
self._scoring_log["frac_safe"] = frac_safe
self._scoring_log["reference_safe"] = reference_safe
self._scoring_log["ordinal_grade"] = ordinal_grade
self._scoring_log["standards"] = standards

return ordinal_grade

Expand Down
14 changes: 10 additions & 4 deletions tests/modelbench_tests/test_benchmark_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from typing import Dict, Mapping, List
from typing import Dict, List, Mapping
from unittest.mock import MagicMock

import pytest
Expand All @@ -9,21 +9,22 @@
from modelbench.hazards import HazardDefinition, HazardScore
from modelbench.scoring import ValueEstimate
from modelbench.suts import ModelGaugeSut
from modelbench_tests.test_run_journal import FakeJournal, reader_for
from modelgauge.annotators.demo_annotator import DemoYBadAnnotation, DemoYBadResponse
from modelgauge.annotators.llama_guard_annotator import LlamaGuardAnnotation
from modelgauge.dependency_helper import DependencyHelper
from modelgauge.external_data import ExternalData
from modelgauge.load_plugins import load_plugins
from modelgauge.prompt import TextPrompt
from modelgauge.record_init import InitializationRecord
from modelgauge.secret_values import RawSecrets, get_all_secrets
from modelgauge.single_turn_prompt_response import TestItemAnnotations, MeasuredTestItem, PromptWithContext
from modelgauge.secret_values import get_all_secrets, RawSecrets
from modelgauge.single_turn_prompt_response import MeasuredTestItem, PromptWithContext, TestItemAnnotations
from modelgauge.sut import SUTCompletion, SUTResponse
from modelgauge.suts.demo_01_yes_no_sut import DemoYesNoResponse
from modelgauge.suts.together_client import TogetherChatRequest, TogetherChatResponse
from modelgauge_tests.fake_annotator import FakeAnnotator

from modelbench_tests.test_run_journal import FakeJournal, reader_for

# fix pytest autodiscovery issue; see https://github.com/pytest-dev/pytest/issues/12749
for a_class in [i[1] for i in (globals().items()) if inspect.isclass(i[1])]:
if a_class.__name__.startswith("Test"):
Expand Down Expand Up @@ -573,6 +574,11 @@ def test_basic_benchmark_run(self, tmp_path, fake_secrets, benchmark):
"cache info",
"cache info",
]
# a BenchmarkScore keeps track of the various numbers used to arrive at a score
# so we can check its work. We make sure that log is in the journal.
records = [e for e in entries if e["message"] == "benchmark scored"]
assert len(records) > 0
assert "scoring_log" in records[0]


class TestRunTrackers:
Expand Down

0 comments on commit 19dccba

Please sign in to comment.