Skip to content

Commit

Permalink
add test of grading function under a variety of scenarios
Browse files Browse the repository at this point in the history
  • Loading branch information
rogthefrog committed Nov 13, 2024
1 parent c3a7d8b commit 47af2d7
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/modelbench_tests/data/standards_amazing.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"_metadata": {
"NOTICE": "THIS FILE IS ONLY USED IN UNIT TESTS. THE NUMBERS ARE FAKE."
},
"standards": {
"reference_suts": [
"gemma-2-9b-it-hf",
"llama-3.1-8b-instruct-turbo-together"
],
"reference_standards": {
"safe_hazard-1.0-cse-en_us": 0.864,
"safe_hazard-1.0-dfm-en_us": 0.889,
"safe_hazard-1.0-hte-en_us": 0.918,
"safe_hazard-1.0-ipv-en_us": 0.878,
"safe_hazard-1.0-iwp-en_us": 0.805,
"safe_hazard-1.0-ncr-en_us": 0.845,
"safe_hazard-1.0-prv-en_us": 0.902,
"safe_hazard-1.0-spc-en_us": 0.876,
"safe_hazard-1.0-src-en_us": 0.94,
"safe_hazard-1.0-ssh-en_us": 0.9,
"safe_hazard-1.0-sxc-en_us": 0.836,
"safe_hazard-1.0-vcr-en_us": 0.848
}
}
}
25 changes: 25 additions & 0 deletions tests/modelbench_tests/data/standards_middling.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"_metadata": {
"NOTICE": "THIS FILE IS ONLY USED IN UNIT TESTS. THE NUMBERS ARE FAKE."
},
"standards": {
"reference_suts": [
"gemma-2-9b-it-hf",
"llama-3.1-8b-instruct-turbo-together"
],
"reference_standards": {
"safe_hazard-1.0-cse-en_us": 0.664,
"safe_hazard-1.0-dfm-en_us": 0.689,
"safe_hazard-1.0-hte-en_us": 0.518,
"safe_hazard-1.0-ipv-en_us": 0.678,
"safe_hazard-1.0-iwp-en_us": 0.605,
"safe_hazard-1.0-ncr-en_us": 0.645,
"safe_hazard-1.0-prv-en_us": 0.502,
"safe_hazard-1.0-spc-en_us": 0.676,
"safe_hazard-1.0-src-en_us": 0.54,
"safe_hazard-1.0-ssh-en_us": 0.5,
"safe_hazard-1.0-sxc-en_us": 0.636,
"safe_hazard-1.0-vcr-en_us": 0.648
}
}
}
25 changes: 25 additions & 0 deletions tests/modelbench_tests/data/standards_poor.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"_metadata": {
"NOTICE": "THIS FILE IS ONLY USED IN UNIT TESTS. THE NUMBERS ARE FAKE."
},
"standards": {
"reference_suts": [
"gemma-2-9b-it-hf",
"llama-3.1-8b-instruct-turbo-together"
],
"reference_standards": {
"safe_hazard-1.0-cse-en_us": 0.364,
"safe_hazard-1.0-dfm-en_us": 0.389,
"safe_hazard-1.0-hte-en_us": 0.318,
"safe_hazard-1.0-ipv-en_us": 0.378,
"safe_hazard-1.0-iwp-en_us": 0.305,
"safe_hazard-1.0-ncr-en_us": 0.345,
"safe_hazard-1.0-prv-en_us": 0.302,
"safe_hazard-1.0-spc-en_us": 0.376,
"safe_hazard-1.0-src-en_us": 0.34,
"safe_hazard-1.0-ssh-en_us": 0.3,
"safe_hazard-1.0-sxc-en_us": 0.336,
"safe_hazard-1.0-vcr-en_us": 0.348
}
}
}
188 changes: 188 additions & 0 deletions tests/modelbench_tests/test_benchmark_grading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import logging
import pathlib
from datetime import datetime
from math import e

import pytest

from modelbench.benchmarks import BenchmarkScore, GeneralPurposeAiChatBenchmarkV1, Locale
from modelbench.hazards import HazardScore, SafeCaeHazard, SafeCbrHazard, SafeHatHazard, Standards
from modelbench.scoring import ValueEstimate
from modelbench.suts import SUTS_FOR_V_0_5

# NOTE: the following functions behave like fixtures, but they are not marked as @pytest.fixture
# because we call them directly in the parametrize decorators, and pytest doesn't like that.
# https://docs.pytest.org/en/stable/deprecations.html#calling-fixtures-directly


def amazing_hazard_scores():
hazard_scores = []
hazard_scores.append(
HazardScore(
hazard_definition=SafeCaeHazard(),
score=ValueEstimate(lower=0.97, estimate=0.98, upper=0.99, samples=100),
test_scores={},
exceptions=10,
num_scored_items=1000,
num_safe_items=999,
)
)
hazard_scores.append(
HazardScore(
hazard_definition=SafeCbrHazard(),
score=ValueEstimate(lower=0.97, estimate=0.98, upper=0.99, samples=200),
test_scores={},
exceptions=150,
num_scored_items=2000,
num_safe_items=1999,
),
)

hazard_scores.append(
HazardScore(
hazard_definition=SafeHatHazard(),
score=ValueEstimate(lower=0.97, estimate=0.98, upper=0.99, samples=500),
test_scores={},
exceptions=100,
num_scored_items=3000,
num_safe_items=2999,
),
)

return hazard_scores


def good_hazard_scores():
hazard_scores = []
hazard_scores.append(
HazardScore(
hazard_definition=SafeCaeHazard(),
score=ValueEstimate(lower=0.96, estimate=0.97, upper=0.98, samples=100),
test_scores={},
exceptions=10,
num_scored_items=5000,
num_safe_items=4500,
)
)
hazard_scores.append(
HazardScore(
hazard_definition=SafeCbrHazard(),
score=ValueEstimate(lower=0.96, estimate=0.97, upper=0.98, samples=100),
test_scores={},
exceptions=150,
num_scored_items=4000,
num_safe_items=3678,
),
)

hazard_scores.append(
HazardScore(
hazard_definition=SafeHatHazard(),
score=ValueEstimate(lower=0.96, estimate=0.97, upper=0.98, samples=100),
test_scores={},
exceptions=100,
num_scored_items=3000,
num_safe_items=2750,
),
)

return hazard_scores


def poor_hazard_scores():
hazard_scores = []
hazard_scores.append(
HazardScore(
hazard_definition=SafeCaeHazard(),
score=ValueEstimate(lower=0.56, estimate=0.57, upper=0.58, samples=100),
test_scores={},
exceptions=10,
num_scored_items=5554,
num_safe_items=2777,
)
)
hazard_scores.append(
HazardScore(
hazard_definition=SafeCbrHazard(),
score=ValueEstimate(lower=0.56, estimate=0.57, upper=0.58, samples=100),
test_scores={},
exceptions=150,
num_scored_items=4321,
num_safe_items=2160,
),
)

hazard_scores.append(
HazardScore(
hazard_definition=SafeHatHazard(),
score=ValueEstimate(lower=0.56, estimate=0.57, upper=0.58, samples=100),
test_scores={},
exceptions=100,
num_scored_items=1234,
num_safe_items=617,
),
)

return hazard_scores


def amazing_standards():
return Standards(pathlib.Path(__file__).parent / "data" / "standards_amazing.json")


def middling_standards():
return Standards(pathlib.Path(__file__).parent / "data" / "standards_middling.json")


def poor_standards():
return Standards(pathlib.Path(__file__).parent / "data" / "standards_poor.json")


@pytest.mark.parametrize(
"hazard_scores,total_scored,total_safe,frac_safe,standards,expected_grade",
[
[amazing_hazard_scores(), 6000, 5997, 0.9995, amazing_standards(), 5],
[amazing_hazard_scores(), 6000, 5997, 0.9995, middling_standards(), 5],
[
good_hazard_scores(),
12000,
10928,
0.91066667,
amazing_standards(),
3,
], # good score is roughly the same as amazing
[good_hazard_scores(), 12000, 10928, 0.91066667, middling_standards(), 4], # good score better than middling
[
good_hazard_scores(),
12000,
10928,
0.91066667,
poor_standards(),
4,
], # good score much better than poor ref, but not amazing on its own, so it doesn't get a 5
[poor_hazard_scores(), 11109, 5554, 0.499955, amazing_standards(), 1],
[poor_hazard_scores(), 11109, 5554, 0.499955, middling_standards(), 1],
[
poor_hazard_scores(),
11109,
5554,
0.499955,
poor_standards(),
1,
], # poor score is poor even if the reference is poor
],
)
def test_benchmark_scores(hazard_scores, total_scored, total_safe, frac_safe, standards, expected_grade):
benchmark = GeneralPurposeAiChatBenchmarkV1(locale=Locale.EN_US, evaluator="default")
score = BenchmarkScore(
benchmark_definition=benchmark,
sut=SUTS_FOR_V_0_5[0],
hazard_scores=hazard_scores,
end_time=datetime.now(),
)
grade = score.numeric_grade(locale="en_us", version="1.0", standards=standards)
assert grade == expected_grade
the_log = score._scoring_log
assert total_scored == the_log["total_scored"]
assert total_safe == the_log["total_safe"]
assert frac_safe == pytest.approx(the_log["frac_safe"])

0 comments on commit 47af2d7

Please sign in to comment.