Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic reports #377

Merged
merged 6 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions src/modelbench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections import defaultdict
from datetime import datetime, timezone
from multiprocessing import Pool
from typing import Dict, List, Mapping
from typing import Dict, List, Mapping, Optional

import click
import termcolor
Expand Down Expand Up @@ -76,26 +76,33 @@ def cli() -> None:
multiple=True,
)
@click.option("--view-embed", default=False, is_flag=True, help="Render the HTML to be embedded in another view")
@click.option("--anonymize", type=int, help="Random number seed for consistent anonymization of SUTs")
@click.option("--parallel", default=False, help="experimentally run SUTs in parallel")
@click.option(
"--custom-branding",
type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path),
bkorycki marked this conversation as resolved.
Show resolved Hide resolved
help="Path to custom branding. Implicitly sets --generic-report to remove MLCommons copy.",
)
@click.option(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make more sense here to just have one CLI option that is either 1) --custom-branding which includes MLC branded stuff or 2) no --custom-branding which is always generic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't users need to know the path of MLC branding in that case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for now the only users running reports with MLC branding will be the three of us. And we'll eventually put that branding directory in modellab, where we'll be running the official reports from.

"--generic-report", default=False, is_flag=True, help="Generate a generic webpage without MLCommons branding."
)
@click.option("--anonymize", type=int, help="Random number seed for consistent anonymization of SUTs")
@click.option("--parallel", default=False, help="experimentally run SUTs in parallel")
@local_plugin_dir_option
def benchmark(
output_dir: pathlib.Path,
max_instances: int,
debug: bool,
sut: List[str],
view_embed: bool,
custom_branding: Optional[pathlib.Path] = None,
generic_report: bool = False,
anonymize=None,
parallel=False,
generic_report=False,
) -> None:
generic_report = generic_report or (custom_branding is not None)
suts = find_suts_for_sut_argument(sut)
benchmarks = [GeneralPurposeAiChatBenchmark()]
benchmark_scores = score_benchmarks(benchmarks, suts, max_instances, debug, parallel)
generate_content(benchmark_scores, output_dir, anonymize, view_embed, generic_report)
generate_content(benchmark_scores, output_dir, anonymize, view_embed, generic_report, custom_branding)


def find_suts_for_sut_argument(sut_args: List[str]):
Expand Down Expand Up @@ -165,8 +172,8 @@ def score_a_sut(benchmarks, max_instances, secrets, debug, sut):
return sut_scores


def generate_content(benchmark_scores, output_dir, anonymize, view_embed, generic):
static_site_generator = StaticSiteGenerator(view_embed=view_embed, generic=generic)
def generate_content(benchmark_scores, output_dir, anonymize, view_embed, generic, custom_branding=None):
static_site_generator = StaticSiteGenerator(view_embed=view_embed, generic=generic, custom_branding=custom_branding)
if anonymize:

class FakeSut(SutDescription):
Expand Down
35 changes: 18 additions & 17 deletions src/modelbench/static_site_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,49 +68,50 @@ def _error_bar(self, hazard_score: HazardScore) -> dict:

class StaticContent(dict):

def __init__(self, templates_path=pathlib.Path(__file__).parent / "templates", generic=False):
def __init__(self, path=pathlib.Path(__file__).parent / "templates" / "content"):
super().__init__()
self.update(self._load_content_dir(templates_path / "content"))
if not generic:
# Override generic content with MLCommons-branded content, where provided.
mlc_content = self._load_content_dir(templates_path / "content_mlc")
for table in mlc_content:
self[table].update(mlc_content[table])

@staticmethod
def _load_content_dir(path):
content = {}
for file in (path).rglob("*.toml"):
with open(file, "rb") as f:
try:
data = tomli.load(f)
except tomli.TOMLDecodeError as e:
raise ValueError(f"failure reading {file}") from e
duplicate_keys = set(content.keys()) & set(data.keys())
duplicate_keys = set(self.keys()) & set(data.keys())
if duplicate_keys:
raise Exception(f"Duplicate tables found in content files: {duplicate_keys}")
content.update(data)
return content
self.update(data)

def update_custom_content(self, custom_content_path: pathlib.Path):
custom_content = StaticContent(custom_content_path)
for table in custom_content:
if table not in self:
raise ValueError(f"Unknown table {table} in custom content")
self[table].update(custom_content[table])


class StaticSiteGenerator:
def __init__(self, view_embed: bool = False, generic: bool = False) -> None:
def __init__(self, view_embed: bool = False, generic: bool = False, custom_branding: pathlib.Path = None) -> None:
"""Initialize the StaticSiteGenerator class for local file or website partial

Args:
view_embed (bool): Whether to generate local file or embedded view. Defaults to False.
generic (bool): Whether to generate MLCommons-branded or generic report. Defaults to False.
custom_branding (Path): Path to custom branding directory. Optional.
"""
self.view_embed = view_embed
self.generic = generic
self.generic = generic or (custom_branding is not None)
self.env = Environment(loader=PackageLoader("modelbench"), autoescape=select_autoescape())
self.env.globals["hsp"] = HazardScorePositions(min_bar_width=0.04, lowest_bar_percent=0.2)
self.env.globals["root_path"] = self.root_path
self.env.globals["benchmarks_path"] = self.benchmarks_path
self.env.globals["benchmark_path"] = self.benchmark_path
self.env.globals["test_report_path"] = self.test_report_path
self.env.globals["content"] = self.content
self._content = StaticContent(generic=generic)
self._content = StaticContent()
if custom_branding is not None:
self._content.update_custom_content(custom_branding)
elif not self.generic:
self._content.update_custom_content(self._template_dir() / "content_mlc")

@singledispatchmethod
def content(self, item, key: str):
Expand Down
26 changes: 12 additions & 14 deletions src/modelbench/templates/benchmark.html
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,14 @@ <h1 class="mlc--header">{{ content(benchmark_definition, "name") }}</h1>
{{ use_hazards_limitations(benchmark_definition) }}

{{ interpret_safety_ratings() }}

{% if not generic %}
<div class="mlc--section__header">
<h2>AI Systems Evaluated</h2>
<p>
{{ content("general", "ai_systems_evaluated") }}
</p>
</div>
<div class="mlc--section__header">
<h2>AI Systems Evaluated</h2>
<p>
{{ content("general", "ai_systems_evaluated") }}
</p>
</div>
{% endif %}

<figure class="mlc--section overflow-auto mlc--table__box-shadow">
<table class="mlc--table__ai-systems">
<thead>
Expand Down Expand Up @@ -66,11 +64,11 @@ <h2>AI Systems Evaluated</h2>

<hr>
{% if not generic %}
<article class="mlc--card__muted-background">
<h4>Don't see the AI system you are looking for?</h4>
<p>
{{ content("general", "new_benchmarks") | safe }}
</p>
</article>
<article class="mlc--card__muted-background">
<h4>Don't see the AI system you are looking for?</h4>
<p>
{{ content("general", "new_benchmarks") | safe }}
</p>
</article>
{% endif %}
{% endblock %}
12 changes: 6 additions & 6 deletions src/modelbench/templates/benchmarks.html
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ <h1 class="mlc--header">AI Safety Benchmarks</h1>

<hr>
{% if not generic %}
<article class="mlc--card__muted-background">
<h4>Don't see the benchmark you are looking for?</h4>
<p>
{{ content("general", "new_benchmarks") | safe }}
</p>
</article>
<article class="mlc--card__muted-background">
<h4>Don't see the benchmark you are looking for?</h4>
<p>
{{ content("general", "new_benchmarks") | safe }}
</p>
</article>
{% endif %}
{% endblock %}
13 changes: 6 additions & 7 deletions src/modelbench/templates/test_report.html
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,11 @@ <h6 class="mlc--test-detail-header">Model UID</h6>

<hr>
{% if not generic %}
<article class="mlc--card__muted-background">
<h4>Don't see the tests you are looking for?</h4>
<p>
{{ content("general", "new_tests") | safe }}
</p>
</article>
<article class="mlc--card__muted-background">
<h4>Don't see the tests you are looking for?</h4>
<p>
{{ content("general", "new_tests") | safe }}
</p>
</article>
{% endif %}

{% endblock %}
9 changes: 0 additions & 9 deletions tests/data/content/file1.toml

This file was deleted.

2 changes: 0 additions & 2 deletions tests/data/content/file2.toml

This file was deleted.

5 changes: 0 additions & 5 deletions tests/data/content_mlc/file1.toml

This file was deleted.

2 changes: 2 additions & 0 deletions tests/data/custom_content/file1.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[general]
description = "new description"
112 changes: 68 additions & 44 deletions tests/test_static_site_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,53 +210,77 @@ def test_test_defaults(self, ssg):
assert ssg.content(test, "not_a_real_key") == ""


@pytest.mark.parametrize(
"use_mlc_content,contested_value_a,contested_value_b", [(False, "generic a", "generic b"), (True, "mlc a", "mlc b")]
)
def test_static_content(use_mlc_content, contested_value_a, contested_value_b):
content = StaticContent(templates_path=pathlib.Path(__file__).parent / "data", generic=not use_mlc_content)
assert content == {
"a": {"key1": contested_value_a, "key2": "common a"},
"b": {"key1": contested_value_b},
"c": {"key1": "common c"},
"d": {"key_d": "common d"},
}


def test_content_correct_values():
ssg_mlc = StaticSiteGenerator(generic=False)
ssg_generic = StaticSiteGenerator(generic=True)
# Spot check that content is the same if not provided by MLC branding.
assert ssg_mlc.content("general", "tests_run") == ssg_generic.content("general", "tests_run")
assert ssg_mlc._content["grades"] == ssg_generic._content["grades"]
# Spot check content overridden by MLC branding.
assert ssg_mlc.content("general", "description") != ssg_generic.content("general", "description")
# Content unique to MLC branding.
assert len(ssg_mlc.content("general", "new_benchmarks"))
assert "new_benchmarks" not in ssg_generic._content.keys()


def test_generic_content_no_mention_mlc():
ssg = StaticSiteGenerator(generic=True)

def recurse(content):
for key, values in content.items():
if isinstance(values, dict):
recurse(values)
else:
if not isinstance(values, list):
values = [values]
for text in values:
text = text.lower()
assert "mlcommons" not in text
assert "ml commons" not in text
class TestBrandingArgs:
"""
Tests to check that StaticSiteGenerator is correctly handling the generic and custom_branding arguments.
"""

@pytest.fixture
def ssg_mlc(self):
_ssg = StaticSiteGenerator()
return _ssg

@pytest.fixture
def ssg_generic(self):
_ssg = StaticSiteGenerator(generic=True)
return _ssg

@pytest.fixture
def ssg_custom(self):
_ssg = StaticSiteGenerator(custom_branding=pathlib.Path(__file__).parent / "data" / "custom_content")
return _ssg

recurse(ssg._content)
def test_generic_content(self, ssg_mlc, ssg_generic):
# Spot check that content is the same if not provided by MLC branding.
assert ssg_mlc.content("general", "tests_run") == ssg_generic.content("general", "tests_run")
assert ssg_mlc._content["grades"] == ssg_generic._content["grades"]
# Spot check content overridden by MLC branding.
assert ssg_mlc.content("general", "description") != ssg_generic.content("general", "description")
# Content unique to MLC branding.
assert len(ssg_mlc.content("general", "new_benchmarks"))
assert "new_benchmarks" not in ssg_generic._content.keys()

def test_generic_content_no_mention_mlc(self, ssg_generic):
def recurse(content):
for key, values in content.items():
if isinstance(values, dict):
recurse(values)
else:
if not isinstance(values, list):
values = [values]
for text in values:
text = text.lower()
assert "mlcommons" not in text
assert "ml commons" not in text

recurse(ssg_generic._content)

def test_custom_content(self, ssg_mlc, ssg_custom, ssg_generic):
# Check that content uses the custom value.
assert ssg_custom.content("general", "description") == "new description"
# Check that unspecified custom-content assumes uses generic values and not MLC branding.
assert ssg_custom.content("general", "provisional_disclaimer") == ssg_generic.content(
"general", "provisional_disclaimer"
)
assert ssg_custom.content("general", "provisional_disclaimer") != ssg_mlc.content(
"general", "provisional_disclaimer"
)

def test_static_site_generator_args(self):
# Check that generic is always set to True if custom_branding is provided.
ssg_a = StaticSiteGenerator(
generic=False, custom_branding=pathlib.Path(__file__).parent / "data" / "custom_content"
)
ssg_b = StaticSiteGenerator(
generic=True, custom_branding=pathlib.Path(__file__).parent / "data" / "custom_content"
)
assert ssg_a._content == ssg_b._content
assert ssg_a.generic == ssg_b.generic
assert ssg_a.generic is True


@pytest.mark.parametrize("generic", [True, False])
def test_sut_content_defaults(generic):
ssg = StaticSiteGenerator(generic=generic)
def test_sut_content_defaults():
ssg = StaticSiteGenerator()
a_dynamic_sut = SutDescription("fake", "Fake SUT")
assert ssg.content(a_dynamic_sut, "name") == "Fake SUT"

Expand Down
Loading