Skip to content

Commit

Permalink
Extract the cache stuff and make it available at the end of a run.
Browse files Browse the repository at this point in the history
  • Loading branch information
wpietri committed Oct 11, 2024
1 parent 6ac65c2 commit de24180
Show file tree
Hide file tree
Showing 7 changed files with 305 additions and 152 deletions.
190 changes: 93 additions & 97 deletions poetry.lock

Large diffs are not rendered by default.

62 changes: 39 additions & 23 deletions src/modelbench/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@
from datetime import datetime
from typing import Mapping, Iterable, Sequence, List, Optional, Any

import diskcache
from tqdm import tqdm

from modelbench.benchmarks import (
BenchmarkDefinition,
BenchmarkScore,
)
from modelbench.cache import MBCache, DiskCache
from modelbench.suts import ModelGaugeSut
from modelgauge.annotation import Annotation
from modelgauge.annotator import CompletionAnnotator
from modelgauge.annotator_registry import ANNOTATORS
Expand All @@ -30,13 +37,6 @@
SUTCompletionAnnotations,
)
from modelgauge.sut import SUTResponse, SUTCompletion
from tqdm import tqdm

from modelbench.benchmarks import (
BenchmarkDefinition,
BenchmarkScore,
)
from modelbench.suts import ModelGaugeSut


class RunTracker:
Expand Down Expand Up @@ -176,7 +176,8 @@ def __init__(self, runner: "TestRunnerBase"):
super().__init__()
# copy the starting state
self.pipeline_segments = []
self.test_data_path = runner.data_dir / "tests"
self.data_dir = runner.data_dir
self.test_data_path = self.data_dir / "tests"
self.secrets = runner.secrets
self.suts = runner.suts
self.max_items = runner.max_items
Expand All @@ -186,6 +187,9 @@ def __init__(self, runner: "TestRunnerBase"):
self.run_tracker = runner.run_tracker
self.completed_item_count = 0

self.caches = {}
self.cache_starting_size = {}

# set up for result collection
self.finished_items = defaultdict(lambda: defaultdict(lambda: list()))
self.failed_items = defaultdict(lambda: defaultdict(lambda: list()))
Expand Down Expand Up @@ -228,6 +232,24 @@ def failed_items_for(self, sut, test) -> Sequence[TestItem]:
def annotators_for_test(self, test: PromptResponseTest) -> Sequence[CompletionAnnotator]:
return self.test_annotators[test.uid]

def cache_for(self, cache_name: str):
if self.data_dir:
result = DiskCache(self.data_dir / cache_name)
else:
result = NullCache()

self.caches[cache_name] = result
self.cache_starting_size[cache_name] = len(result)
return result

def cache_info(self):
result = []
for key in self.caches.keys():
result.append(f" {key}: {self.caches[key]}")
result.append(f" {key}: started with {self.cache_starting_size[key]}")
result.append(f" {key}: finished with {len(self.caches[key])}")
return "\n".join(result)


class TestRun(TestRunBase):
tests: list[ModelgaugeTestWrapper]
Expand Down Expand Up @@ -260,13 +282,9 @@ class IntermediateCachingPipe(Pipe):
this just makes a cache available for internal use to cache intermediate results.
"""

def __init__(self, thread_count=1, cache_path=None):
def __init__(self, cache: MBCache, thread_count=1):
super().__init__(thread_count)

if cache_path:
self.cache = diskcache.Cache(cache_path).__enter__()
else:
self.cache = NullCache()
self.cache = cache

def handle_item(self, item) -> Optional[Any]:
pass
Expand Down Expand Up @@ -312,8 +330,8 @@ def handle_item(self, item: TestRunItem):

class TestRunSutWorker(IntermediateCachingPipe):

def __init__(self, test_run: TestRunBase, thread_count=1, cache_path=None):
super().__init__(thread_count, cache_path=cache_path)
def __init__(self, test_run: TestRunBase, cache: MBCache, thread_count=1):
super().__init__(cache, thread_count)
self.test_run = test_run

def handle_item(self, item):
Expand All @@ -340,8 +358,8 @@ def handle_item(self, item):

class TestRunAnnotationWorker(IntermediateCachingPipe):

def __init__(self, test_run: TestRunBase, thread_count=1, cache_path=None):
super().__init__(thread_count, cache_path=cache_path)
def __init__(self, test_run: TestRunBase, cache: MBCache, thread_count=1, cache_path=None):
super().__init__(cache, thread_count)
self.test_run = test_run

def handle_item(self, item: TestRunItem) -> TestRunItem:
Expand Down Expand Up @@ -425,11 +443,9 @@ def _make_test_record(self, run, sut, test, test_result):
def _build_pipeline(self, run):
run.pipeline_segments.append(TestRunItemSource(run))
run.pipeline_segments.append(TestRunSutAssigner(run))
run.pipeline_segments.append(TestRunSutWorker(run, run.cache_for("sut_cache"), thread_count=self.thread_count))
run.pipeline_segments.append(
TestRunSutWorker(run, thread_count=self.thread_count, cache_path=self.data_dir / "sut_cache")
)
run.pipeline_segments.append(
TestRunAnnotationWorker(run, thread_count=self.thread_count, cache_path=self.data_dir / "annotator_cache")
TestRunAnnotationWorker(run, run.cache_for("annotator_cache"), thread_count=self.thread_count)
)
run.pipeline_segments.append(TestRunResultsCollector(run))
pipeline = Pipeline(
Expand Down
93 changes: 93 additions & 0 deletions src/modelbench/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import collections.abc
from abc import ABC, abstractmethod

import diskcache


class MBCache(ABC, collections.abc.Mapping):
@abstractmethod
def __setitem__(self, __key, __value):
pass

def __enter__(self):
return self

def __exit__(self, __type, __value, __traceback):
pass


class NullCache(MBCache):
"""Doesn't save anything"""

def __setitem__(self, __key, __value):
pass

def __getitem__(self, key, /):
raise KeyError()

def __len__(self):
return 0

def __iter__(self):
pass


class InMemoryCache(MBCache):
"""Holds stuff in memory only"""

def __init__(self):
super().__init__()
self.contents = dict()

def __setitem__(self, __key, __value):
self.contents.__setitem__(__key, __value)

def __getitem__(self, key, /):
return self.contents.__getitem__(key)

def __len__(self):
return self.contents.__len__()

def __iter__(self):
return self.contents.__iter__()


class DiskCache(MBCache):
"""
Holds stuff in memory only. The docs recommend using
it as a context manager in a threaded context:
"Each thread that accesses a cache should also call close
on the cache. Cache objects can be used in a with statement
to safeguard calling close."
"""

def __init__(self, cache_path):
super().__init__()
self.cache_path = cache_path
self.raw_cache = diskcache.Cache(cache_path)
self.contents = self.raw_cache

def __enter__(self):
self.contents = self.raw_cache.__enter__()
return self.contents

def __exit__(self, __type, __value, __traceback):
self.raw_cache.__exit__(__type, __value, __traceback)
self.contents = self.raw_cache

def __setitem__(self, __key, __value):
self.contents.__setitem__(__key, __value)

def __getitem__(self, key, /):
return self.contents.__getitem__(key)

def __len__(self):
return self.contents.__len__()

def __iter__(self):
return self.contents.__iter__()

def __str__(self):
return self.__class__.__name__ + f"({self.cache_path})"
10 changes: 6 additions & 4 deletions src/modelbench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
import click
import termcolor
from click import echo
from modelgauge.config import load_secrets_from_config, write_default_config
from modelgauge.load_plugins import load_plugins
from modelgauge.sut_registry import SUTS
from modelgauge.tests.safe_v1 import Locale

from modelbench.benchmark_runner import BenchmarkRunner, TqdmRunTracker, JsonRunTracker
from modelbench.benchmarks import BenchmarkDefinition, GeneralPurposeAiChatBenchmark, GeneralPurposeAiChatBenchmarkV1
from modelbench.hazards import STANDARDS
from modelbench.record import dump_json
from modelbench.static_site_generator import StaticContent, StaticSiteGenerator
from modelbench.suts import ModelGaugeSut, SutDescription, SUTS_FOR_V_0_5
from modelgauge.config import load_secrets_from_config, write_default_config
from modelgauge.load_plugins import load_plugins
from modelgauge.sut_registry import SUTS
from modelgauge.tests.safe_v1 import Locale

_DEFAULT_SUTS = SUTS_FOR_V_0_5

Expand Down Expand Up @@ -175,6 +175,8 @@ def run_benchmarks_for_suts(benchmarks, suts, max_instances, debug=False, json_l
runner.thread_count = thread_count
runner.run_tracker = JsonRunTracker() if json_logs else TqdmRunTracker(0.5)
run = runner.run()
print("Cache info:")
print(run.cache_info())
return run


Expand Down
17 changes: 3 additions & 14 deletions src/modelgauge/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def __init__(self, ):

import diskcache # type: ignore

from modelbench.cache import DiskCache, NullCache


class PipelineSegment(ABC):
"""A segment of a Pipeline used for parallel processing."""
Expand Down Expand Up @@ -228,27 +230,14 @@ def join(self):
self._debug(f"join done")


class NullCache(dict):
"""Compatible with diskcache.Cache, but does nothing."""

def __setitem__(self, __key, __value):
pass

def __enter__(self):
return self

def __exit__(self, __type, __value, __traceback):
pass


class CachingPipe(Pipe):
"""A Pipe that optionally caches results the given directory. Implement key and handle_uncached_item."""

def __init__(self, thread_count=1, cache_path=None):
super().__init__(thread_count)

if cache_path:
self.cache = diskcache.Cache(cache_path).__enter__()
self.cache = DiskCache(cache_path)
else:
self.cache = NullCache()

Expand Down
27 changes: 13 additions & 14 deletions tests/modelbench_tests/test_benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
from unittest.mock import MagicMock

import pytest
from modelgauge_tests.fake_annotator import FakeAnnotator

from modelgauge.annotators.demo_annotator import DemoYBadAnnotation, DemoYBadAnnotator
from modelbench.benchmark_runner import *
from modelbench.hazards import HazardDefinition, HazardScore
from modelbench.scoring import ValueEstimate
from modelbench.suts import ModelGaugeSut
from modelgauge.annotators.demo_annotator import DemoYBadAnnotation
from modelgauge.annotators.llama_guard_annotator import LlamaGuardAnnotation
from modelgauge.dependency_helper import DependencyHelper
from modelgauge.external_data import ExternalData
Expand All @@ -13,12 +18,6 @@
from modelgauge.record_init import InitializationRecord
from modelgauge.secret_values import RawSecrets, get_all_secrets
from modelgauge.suts.together_client import TogetherChatRequest, TogetherChatResponse
from modelgauge_tests.fake_annotator import FakeAnnotator

from modelbench.benchmark_runner import *
from modelbench.hazards import HazardDefinition, HazardScore
from modelbench.scoring import ValueEstimate
from modelbench.suts import ModelGaugeSut

# fix pytest autodiscovery issue; see https://github.com/pytest-dev/pytest/issues/12749
for a_class in [i[1] for i in (globals().items()) if inspect.isclass(i[1])]:
Expand Down Expand Up @@ -242,7 +241,7 @@ def test_benchmark_sut_assigner(self, a_wrapped_test, tmp_path):
assert item_two.sut == sut_two

def test_benchmark_sut_worker(self, item_from_test, a_wrapped_test, tmp_path, a_sut):
bsw = TestRunSutWorker(self.a_run(tmp_path, suts=[a_sut]))
bsw = TestRunSutWorker(self.a_run(tmp_path, suts=[a_sut]), NullCache())

result = bsw.handle_item(TestRunItem(a_wrapped_test, item_from_test, a_sut))

Expand All @@ -252,7 +251,7 @@ def test_benchmark_sut_worker(self, item_from_test, a_wrapped_test, tmp_path, a_
assert result.sut_response.completions[0].text == "No"

def test_benchmark_sut_worker_throws_exception(self, item_from_test, a_wrapped_test, tmp_path, exploding_sut):
bsw = TestRunSutWorker(self.a_run(tmp_path, suts=[exploding_sut]))
bsw = TestRunSutWorker(self.a_run(tmp_path, suts=[exploding_sut]), NullCache())

result = bsw.handle_item(TestRunItem(a_wrapped_test, item_from_test, exploding_sut))

Expand All @@ -264,7 +263,7 @@ def test_benchmark_sut_worker_throws_exception(self, item_from_test, a_wrapped_t
def test_benchmark_annotation_worker(
self, a_wrapped_test, tmp_path, item_from_test, sut_response, a_sut, benchmark
):
baw = TestRunAnnotationWorker(self.a_run(tmp_path, suts=[a_sut], benchmarks=[benchmark]))
baw = TestRunAnnotationWorker(self.a_run(tmp_path, suts=[a_sut], benchmarks=[benchmark]), NullCache())
pipeline_item = TestRunItem(a_wrapped_test, item_from_test, a_sut, sut_response)

result = baw.handle_item(pipeline_item)
Expand All @@ -273,7 +272,7 @@ def test_benchmark_annotation_worker(
assert result.annotations["demo_annotator"].badness == 1.0

def test_test_annotation_worker(self, a_wrapped_test, tmp_path, item_from_test, sut_response, a_sut, a_test):
taw = TestRunAnnotationWorker(self.a_test_run(tmp_path, suts=[a_sut], tests=[a_test]))
taw = TestRunAnnotationWorker(self.a_test_run(tmp_path, suts=[a_sut], tests=[a_test]), NullCache())
pipeline_item = TestRunItem(a_wrapped_test, item_from_test, a_sut, sut_response)

result = taw.handle_item(pipeline_item)
Expand All @@ -282,7 +281,7 @@ def test_test_annotation_worker(self, a_wrapped_test, tmp_path, item_from_test,
assert result.annotations["demo_annotator"].badness == 1.0

def test_benchmark_annotation_worker_ignores_failed(self, a_wrapped_test, tmp_path, item_from_test, a_sut):
baw = TestRunAnnotationWorker(self.a_run(tmp_path, suts=[a_sut]))
baw = TestRunAnnotationWorker(self.a_run(tmp_path, suts=[a_sut]), NullCache())
pipeline_item = TestRunItem(a_wrapped_test, item_from_test, a_sut)
pipeline_item.exception = ValueError()

Expand All @@ -293,7 +292,7 @@ def test_benchmark_annotation_worker_ignores_failed(self, a_wrapped_test, tmp_pa
def test_benchmark_annotation_worker_throws_exception(
self, exploding_wrapped_test, tmp_path, item_from_test, sut_response, a_sut
):
baw = TestRunAnnotationWorker(self.a_run(tmp_path, suts=[a_sut]))
baw = TestRunAnnotationWorker(self.a_run(tmp_path, suts=[a_sut]), NullCache())
pipeline_item = TestRunItem(exploding_wrapped_test, item_from_test, a_sut, sut_response)

result = baw.handle_item(pipeline_item)
Expand Down Expand Up @@ -389,7 +388,7 @@ def test_sut_caching(self, item_from_test, a_wrapped_test, tmp_path):
object="foo",
)
run = self.a_run(tmp_path, suts=[sut])
bsw = TestRunSutWorker(run, cache_path=tmp_path)
bsw = TestRunSutWorker(run, DiskCache(tmp_path))

bsw.handle_item(TestRunItem(a_wrapped_test, item_from_test, sut))
assert sut.instance().evaluate.call_count == 1
Expand Down
Loading

0 comments on commit de24180

Please sign in to comment.